''' Оптимизированная версия скрипта для генерации семантических меток объектов в сцене. Оптимизации: - Многопоточная обработка I/O операций - Кэширование загруженных изображений - Пакетная обработка - Оптимизированные операции с NumPy ''' from utils.config import get_dataset, get_args from dataset.scannet import ScanNetDataset import os import numpy as np import open_clip import cv2 from PIL import Image import torch from torch.utils.data import TensorDataset, DataLoader from tqdm import tqdm from utils.config import update_args import argparse from datetime import datetime import shutil from concurrent.futures import ThreadPoolExecutor from functools import lru_cache from typing import Dict, List, Tuple LEVELS = 3 # Глобальный кэш для изображений CACHE_SIZE = 1000 def load_clip(device): print(f'[INFO] loading CLIP model...') model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k") model.to(device) model.eval() print(f'[INFO] finish loading CLIP model...') return model, preprocess def box_multi_level(bbox, shape, level, expansion_ratio): left, top, right, bottom = bbox if level == 0: return left, top, right, bottom x_exp = int(abs(right - left) * expansion_ratio) * level y_exp = int(abs(bottom - top) * expansion_ratio) * level return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp) @lru_cache(maxsize=CACHE_SIZE) def load_image_cached(img_path: str) -> np.ndarray: """Кэшированная загрузка изображений""" if not os.path.exists(img_path): return None return cv2.imread(img_path) @lru_cache(maxsize=CACHE_SIZE) def load_mask_cached(mask_path: str) -> np.ndarray: """Кэшированная загрузка масок""" if not os.path.exists(mask_path): return None return np.load(mask_path, allow_pickle=True) def process_single_frame(args_tuple): """Обработка одного кадра - функция для многопоточности I/O""" frame, key, scene_id, img_base_path, mask_base_path, expansion_ratio = args_tuple img_path = f'{img_base_path}/{str(frame["frame_id"]).zfill(5)}.jpg' mask_path = os.path.join(mask_base_path, frame['mask_path']) try: image = load_image_cached(img_path) mask = load_mask_cached(mask_path) if image is None or mask is None: return None # Применяем маску более эффективно mask_indices = mask == key if not np.any(mask_indices): return None image = image.copy() # Создаем копию только когда нужно image[mask_indices] = (image[mask_indices] * 0.8).astype(np.uint8) x1, y1, x2, y2 = frame['bbox'] cropped_images = [] for level in range(LEVELS): x1, y1, x2, y2 = box_multi_level((x1, y1, x2, y2), image.shape, level, expansion_ratio) # Обрезаем изображение cropped = image[y1:y2, x1:x2] if cropped.size == 0: continue # Конвертируем BGR -> RGB один раз rgb_image = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB) pil_image = pad_into_square(Image.fromarray(rgb_image)) cropped_images.append(np.array(pil_image)) return cropped_images except Exception as e: print(f"Ошибка обработки кадра {frame['frame_id']}: {e}") return None def get_cropped_images_parallel(key, object_data, scene_id, preprocess, num_images=5, expansion_ratio=0.1, num_workers=4): """Параллельная обработка изображений с использованием ThreadPoolExecutor""" img_base_path = f'/home/jovyan/users/lemeshko/mmdetection3d/data/scannet/posed_images/{scene_id}' mask_base_path = f'/home/jovyan/users/lemeshko/scripts/gsam_result/scannet200/yolo/{scene_id}' frames = object_data['frames'][:num_images] # Подготавливаем аргументы для параллельной обработки args_list = [ (frame, key, scene_id, img_base_path, mask_base_path, expansion_ratio) for frame in frames ] all_cropped_images = [] all_display_images = [] # Используем ThreadPoolExecutor для I/O операций (безопасно для CUDA) with ThreadPoolExecutor(max_workers=num_workers) as executor: results = list(executor.map(process_single_frame, args_list)) # Обрабатываем результаты for result in results: if result is not None: for img_array in result: # Применяем preprocess к PIL изображению pil_img = Image.fromarray(img_array) processed_img = preprocess(pil_img) all_cropped_images.append(processed_img) # Для отображения display_img = np.array(pil_img.resize((64, 64))) all_display_images.append(display_img) if not all_cropped_images: return torch.empty(0), np.array([]) return torch.stack(all_cropped_images), np.concatenate(all_display_images, axis=1)[..., ::-1] def pad_into_square(image): """Оптимизированная функция для создания квадратного изображения""" width, height = image.size if width == height: return image new_size = max(width, height) new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255)) left = (new_size - width) // 2 top = (new_size - height) // 2 new_image.paste(image, (left, top)) return new_image def batch_encode_images(model, images, device, batch_size=32): """Пакетная обработка изображений через CLIP""" if len(images) == 0: return np.array([]) features = [] for i in range(0, len(images), batch_size): batch = images[i:i + batch_size].to(device) with torch.no_grad(): batch_features = model.encode_image(batch).float() batch_features /= batch_features.norm(dim=-1, keepdim=True) features.append(batch_features.cpu().numpy()) return np.vstack(features) if features else np.array([]) def process_objects_batch(object_items, args, scene_name, label_text_features, descriptions, label2id, total_point_num, logs_path, dataset): """Обработка пакета объектов для улучшения эффективности""" batch_results = [] for idx, (key, object_data) in object_items: try: cropped_images, saved_images = get_cropped_images_parallel( key, object_data, scene_name, args.preprocess, args.num_images, num_workers=args.image_workers ) if len(cropped_images) == 0: batch_results.append(None) continue # Пакетная обработка через CLIP features = batch_encode_images(args.model, cropped_images, args.device, batch_size=32) if features.size == 0: batch_results.append(None) continue # Вычисляем среднее по признакам object_feature = np.mean(features, axis=0, keepdims=True) # Вычисляем схожесть raw_similarity = np.dot(object_feature, label_text_features.T) exp_sim = np.exp(raw_similarity * 100) prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True) probs = np.max(prob, axis=0) max_label_id = np.argmax(probs) prob_score = probs[max_label_id] label_id = label2id[descriptions[max_label_id]] # Создаем бинарную маску point_ids = object_data['mask'] binary_mask = np.zeros(total_point_num, dtype=bool) binary_mask[list(point_ids)] = True result = { 'idx': idx, 'key': key, 'binary_mask': binary_mask, 'prob_score': prob_score, 'label_id': label_id, 'saved_images': saved_images if args.debug else None } batch_results.append(result) if args.debug and saved_images is not None: os.makedirs(logs_path, exist_ok=True) cv2.imwrite(f'{logs_path}/{str(key).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob_score:.2f}).jpg', saved_images) except Exception as e: print(f"Ошибка обработки объекта {key}: {e}") batch_results.append(None) return batch_results def main(args): scenes = np.loadtxt('splits/scannet200_subset.txt', dtype=str) scenes = scenes[args.job_id::args.num_workers] pred_dir = os.path.join('data/prediction', args.config, args.exp_name) os.makedirs(pred_dir, exist_ok=True) shutil.copy("semantics/c_open-voc_query_optimized.py", os.path.join(pred_dir, "c_open-voc_query_optimized.py")) for scene in tqdm(scenes): args.seq_name = scene if os.path.exists(f'{pred_dir}/{args.seq_name}.npz'): continue dataset = ScanNetDataset(scene) total_point_num = dataset.get_scene_points().shape[0] label_features_dict = dataset.get_label_features() label_text_features = np.stack(list(label_features_dict.values())) descriptions = list(label_features_dict.keys()) scene_name = dataset.seq_name logs_path = os.path.join('logs', args.exp_name, scene) object_dict = np.load(f'/home/jovyan/users/lemeshko/scripts/gsam_result/scannet200/yolo/{scene_name}/infos.npy', allow_pickle=True).item() label2id = dataset.get_label_id()[0] num_instance = len(object_dict) pred_dict = { "pred_masks": np.zeros((total_point_num, num_instance), dtype=bool), "pred_score": np.ones(num_instance), "pred_classes": np.zeros(num_instance, dtype=np.int32) } print(f"Обработка сцены {scene_name} с {num_instance} объектами") # Пакетная обработка объектов для лучшей эффективности if args.batch_processing and num_instance > args.batch_size: object_items = list(enumerate(object_dict.items())) # Разбиваем на пакеты for batch_start in range(0, len(object_items), args.batch_size): batch_end = min(batch_start + args.batch_size, len(object_items)) batch = object_items[batch_start:batch_end] print(f"Обработка пакета {batch_start//args.batch_size + 1}/{(len(object_items) + args.batch_size - 1)//args.batch_size}") batch_results = process_objects_batch( batch, args, scene_name, label_text_features, descriptions, label2id, total_point_num, logs_path, dataset ) # Сохраняем результаты пакета for result in batch_results: if result is not None: idx = result['idx'] pred_dict['pred_masks'][:, idx] = result['binary_mask'] pred_dict['pred_score'][idx] = result['prob_score'] pred_dict['pred_classes'][idx] = result['label_id'] else: # Последовательная обработка с оптимизированными функциями for idx, (key, object_data) in enumerate(object_dict.items()): cropped_images, saved_images = get_cropped_images_parallel( key, object_data, scene_name, args.preprocess, args.num_images, num_workers=args.image_workers ) if len(cropped_images) == 0: continue features = batch_encode_images(args.model, cropped_images, args.device, batch_size=32) object_feature = np.mean(features, axis=0, keepdims=True) raw_similarity = np.dot(object_feature, label_text_features.T) exp_sim = np.exp(raw_similarity * 100) prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True) probs = np.max(prob, axis=0) max_label_id = np.argmax(probs) prob_score = probs[max_label_id] pred_dict['pred_score'][idx] = prob_score label_id = label2id[descriptions[max_label_id]] pred_dict['pred_classes'][idx] = label_id point_ids = object_data['mask'] binary_mask = np.zeros(total_point_num, dtype=bool) binary_mask[list(point_ids)] = True pred_dict['pred_masks'][:, idx] = binary_mask if args.debug: os.makedirs(logs_path, exist_ok=True) cv2.imwrite(f'{logs_path}/{str(key).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob_score:.2f}).jpg', saved_images) np.savez(f'{pred_dir}/{args.seq_name}.npz', **pred_dict) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='scannet') parser.add_argument('--debug', action="store_true") parser.add_argument('--exp_name', type=str, default='baseline') parser.add_argument('--device', type=str, default='cuda:0') parser.add_argument('--num_images', type=int, default=5) parser.add_argument('--num_workers', '-n', type=int, default=1) parser.add_argument('--job_id', '-i', type=int, default=0) # Параметры оптимизации (безопасные для CUDA) parser.add_argument('--batch_processing', action="store_true", help="Пакетная обработка объектов") parser.add_argument('--batch_size', type=int, default=10, help="Размер пакета объектов") parser.add_argument('--image_workers', type=int, default=4, help="Количество потоков для загрузки изображений") args = parser.parse_args() args = update_args(args) model, preprocess = load_clip(args.device) args.model = model args.preprocess = preprocess date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") args.exp_name = f'{args.exp_name}' main(args)