zoo3d / MaskClustering /semantics /c_open-voc_query_optimized.py
bulatko's picture
adding real MK
55e58d1
raw
history blame
15.5 kB
'''
Оптимизированная версия скрипта для генерации семантических меток объектов в сцене.
Оптимизации:
- Многопоточная обработка I/O операций
- Кэширование загруженных изображений
- Пакетная обработка
- Оптимизированные операции с NumPy
'''
from utils.config import get_dataset, get_args
from dataset.scannet import ScanNetDataset
import os
import numpy as np
import open_clip
import cv2
from PIL import Image
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from utils.config import update_args
import argparse
from datetime import datetime
import shutil
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Dict, List, Tuple
LEVELS = 3
# Глобальный кэш для изображений
CACHE_SIZE = 1000
def load_clip(device):
print(f'[INFO] loading CLIP model...')
model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
model.to(device)
model.eval()
print(f'[INFO] finish loading CLIP model...')
return model, preprocess
def box_multi_level(bbox, shape, level, expansion_ratio):
left, top, right, bottom = bbox
if level == 0:
return left, top, right, bottom
x_exp = int(abs(right - left) * expansion_ratio) * level
y_exp = int(abs(bottom - top) * expansion_ratio) * level
return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp)
@lru_cache(maxsize=CACHE_SIZE)
def load_image_cached(img_path: str) -> np.ndarray:
"""Кэшированная загрузка изображений"""
if not os.path.exists(img_path):
return None
return cv2.imread(img_path)
@lru_cache(maxsize=CACHE_SIZE)
def load_mask_cached(mask_path: str) -> np.ndarray:
"""Кэшированная загрузка масок"""
if not os.path.exists(mask_path):
return None
return np.load(mask_path, allow_pickle=True)
def process_single_frame(args_tuple):
"""Обработка одного кадра - функция для многопоточности I/O"""
frame, key, scene_id, img_base_path, mask_base_path, expansion_ratio = args_tuple
img_path = f'{img_base_path}/{str(frame["frame_id"]).zfill(5)}.jpg'
mask_path = os.path.join(mask_base_path, frame['mask_path'])
try:
image = load_image_cached(img_path)
mask = load_mask_cached(mask_path)
if image is None or mask is None:
return None
# Применяем маску более эффективно
mask_indices = mask == key
if not np.any(mask_indices):
return None
image = image.copy() # Создаем копию только когда нужно
image[mask_indices] = (image[mask_indices] * 0.8).astype(np.uint8)
x1, y1, x2, y2 = frame['bbox']
cropped_images = []
for level in range(LEVELS):
x1, y1, x2, y2 = box_multi_level((x1, y1, x2, y2), image.shape, level, expansion_ratio)
# Обрезаем изображение
cropped = image[y1:y2, x1:x2]
if cropped.size == 0:
continue
# Конвертируем BGR -> RGB один раз
rgb_image = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)
pil_image = pad_into_square(Image.fromarray(rgb_image))
cropped_images.append(np.array(pil_image))
return cropped_images
except Exception as e:
print(f"Ошибка обработки кадра {frame['frame_id']}: {e}")
return None
def get_cropped_images_parallel(key, object_data, scene_id, preprocess, num_images=5, expansion_ratio=0.1, num_workers=4):
"""Параллельная обработка изображений с использованием ThreadPoolExecutor"""
img_base_path = f'/home/jovyan/users/lemeshko/mmdetection3d/data/scannet/posed_images/{scene_id}'
mask_base_path = f'/home/jovyan/users/lemeshko/scripts/gsam_result/scannet200/yolo/{scene_id}'
frames = object_data['frames'][:num_images]
# Подготавливаем аргументы для параллельной обработки
args_list = [
(frame, key, scene_id, img_base_path, mask_base_path, expansion_ratio)
for frame in frames
]
all_cropped_images = []
all_display_images = []
# Используем ThreadPoolExecutor для I/O операций (безопасно для CUDA)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
results = list(executor.map(process_single_frame, args_list))
# Обрабатываем результаты
for result in results:
if result is not None:
for img_array in result:
# Применяем preprocess к PIL изображению
pil_img = Image.fromarray(img_array)
processed_img = preprocess(pil_img)
all_cropped_images.append(processed_img)
# Для отображения
display_img = np.array(pil_img.resize((64, 64)))
all_display_images.append(display_img)
if not all_cropped_images:
return torch.empty(0), np.array([])
return torch.stack(all_cropped_images), np.concatenate(all_display_images, axis=1)[..., ::-1]
def pad_into_square(image):
"""Оптимизированная функция для создания квадратного изображения"""
width, height = image.size
if width == height:
return image
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(image, (left, top))
return new_image
def batch_encode_images(model, images, device, batch_size=32):
"""Пакетная обработка изображений через CLIP"""
if len(images) == 0:
return np.array([])
features = []
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size].to(device)
with torch.no_grad():
batch_features = model.encode_image(batch).float()
batch_features /= batch_features.norm(dim=-1, keepdim=True)
features.append(batch_features.cpu().numpy())
return np.vstack(features) if features else np.array([])
def process_objects_batch(object_items, args, scene_name, label_text_features, descriptions, label2id, total_point_num, logs_path, dataset):
"""Обработка пакета объектов для улучшения эффективности"""
batch_results = []
for idx, (key, object_data) in object_items:
try:
cropped_images, saved_images = get_cropped_images_parallel(
key, object_data, scene_name, args.preprocess, args.num_images,
num_workers=args.image_workers
)
if len(cropped_images) == 0:
batch_results.append(None)
continue
# Пакетная обработка через CLIP
features = batch_encode_images(args.model, cropped_images, args.device, batch_size=32)
if features.size == 0:
batch_results.append(None)
continue
# Вычисляем среднее по признакам
object_feature = np.mean(features, axis=0, keepdims=True)
# Вычисляем схожесть
raw_similarity = np.dot(object_feature, label_text_features.T)
exp_sim = np.exp(raw_similarity * 100)
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True)
probs = np.max(prob, axis=0)
max_label_id = np.argmax(probs)
prob_score = probs[max_label_id]
label_id = label2id[descriptions[max_label_id]]
# Создаем бинарную маску
point_ids = object_data['mask']
binary_mask = np.zeros(total_point_num, dtype=bool)
binary_mask[list(point_ids)] = True
result = {
'idx': idx,
'key': key,
'binary_mask': binary_mask,
'prob_score': prob_score,
'label_id': label_id,
'saved_images': saved_images if args.debug else None
}
batch_results.append(result)
if args.debug and saved_images is not None:
os.makedirs(logs_path, exist_ok=True)
cv2.imwrite(f'{logs_path}/{str(key).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob_score:.2f}).jpg', saved_images)
except Exception as e:
print(f"Ошибка обработки объекта {key}: {e}")
batch_results.append(None)
return batch_results
def main(args):
scenes = np.loadtxt('splits/scannet200_subset.txt', dtype=str)
scenes = scenes[args.job_id::args.num_workers]
pred_dir = os.path.join('data/prediction', args.config, args.exp_name)
os.makedirs(pred_dir, exist_ok=True)
shutil.copy("semantics/c_open-voc_query_optimized.py", os.path.join(pred_dir, "c_open-voc_query_optimized.py"))
for scene in tqdm(scenes):
args.seq_name = scene
if os.path.exists(f'{pred_dir}/{args.seq_name}.npz'):
continue
dataset = ScanNetDataset(scene)
total_point_num = dataset.get_scene_points().shape[0]
label_features_dict = dataset.get_label_features()
label_text_features = np.stack(list(label_features_dict.values()))
descriptions = list(label_features_dict.keys())
scene_name = dataset.seq_name
logs_path = os.path.join('logs', args.exp_name, scene)
object_dict = np.load(f'/home/jovyan/users/lemeshko/scripts/gsam_result/scannet200/yolo/{scene_name}/infos.npy', allow_pickle=True).item()
label2id = dataset.get_label_id()[0]
num_instance = len(object_dict)
pred_dict = {
"pred_masks": np.zeros((total_point_num, num_instance), dtype=bool),
"pred_score": np.ones(num_instance),
"pred_classes": np.zeros(num_instance, dtype=np.int32)
}
print(f"Обработка сцены {scene_name} с {num_instance} объектами")
# Пакетная обработка объектов для лучшей эффективности
if args.batch_processing and num_instance > args.batch_size:
object_items = list(enumerate(object_dict.items()))
# Разбиваем на пакеты
for batch_start in range(0, len(object_items), args.batch_size):
batch_end = min(batch_start + args.batch_size, len(object_items))
batch = object_items[batch_start:batch_end]
print(f"Обработка пакета {batch_start//args.batch_size + 1}/{(len(object_items) + args.batch_size - 1)//args.batch_size}")
batch_results = process_objects_batch(
batch, args, scene_name, label_text_features, descriptions,
label2id, total_point_num, logs_path, dataset
)
# Сохраняем результаты пакета
for result in batch_results:
if result is not None:
idx = result['idx']
pred_dict['pred_masks'][:, idx] = result['binary_mask']
pred_dict['pred_score'][idx] = result['prob_score']
pred_dict['pred_classes'][idx] = result['label_id']
else:
# Последовательная обработка с оптимизированными функциями
for idx, (key, object_data) in enumerate(object_dict.items()):
cropped_images, saved_images = get_cropped_images_parallel(
key, object_data, scene_name, args.preprocess, args.num_images,
num_workers=args.image_workers
)
if len(cropped_images) == 0:
continue
features = batch_encode_images(args.model, cropped_images, args.device, batch_size=32)
object_feature = np.mean(features, axis=0, keepdims=True)
raw_similarity = np.dot(object_feature, label_text_features.T)
exp_sim = np.exp(raw_similarity * 100)
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True)
probs = np.max(prob, axis=0)
max_label_id = np.argmax(probs)
prob_score = probs[max_label_id]
pred_dict['pred_score'][idx] = prob_score
label_id = label2id[descriptions[max_label_id]]
pred_dict['pred_classes'][idx] = label_id
point_ids = object_data['mask']
binary_mask = np.zeros(total_point_num, dtype=bool)
binary_mask[list(point_ids)] = True
pred_dict['pred_masks'][:, idx] = binary_mask
if args.debug:
os.makedirs(logs_path, exist_ok=True)
cv2.imwrite(f'{logs_path}/{str(key).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob_score:.2f}).jpg', saved_images)
np.savez(f'{pred_dir}/{args.seq_name}.npz', **pred_dict)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='scannet')
parser.add_argument('--debug', action="store_true")
parser.add_argument('--exp_name', type=str, default='baseline')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--num_images', type=int, default=5)
parser.add_argument('--num_workers', '-n', type=int, default=1)
parser.add_argument('--job_id', '-i', type=int, default=0)
# Параметры оптимизации (безопасные для CUDA)
parser.add_argument('--batch_processing', action="store_true", help="Пакетная обработка объектов")
parser.add_argument('--batch_size', type=int, default=10, help="Размер пакета объектов")
parser.add_argument('--image_workers', type=int, default=4, help="Количество потоков для загрузки изображений")
args = parser.parse_args()
args = update_args(args)
model, preprocess = load_clip(args.device)
args.model = model
args.preprocess = preprocess
date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
args.exp_name = f'{args.exp_name}'
main(args)