|
|
''' |
|
|
Оптимизированная версия скрипта для генерации семантических меток объектов в сцене. |
|
|
Оптимизации: |
|
|
- Многопоточная обработка I/O операций |
|
|
- Кэширование загруженных изображений |
|
|
- Пакетная обработка |
|
|
- Оптимизированные операции с NumPy |
|
|
''' |
|
|
from utils.config import get_dataset, get_args |
|
|
from dataset.scannet import ScanNetDataset |
|
|
import os |
|
|
import numpy as np |
|
|
import open_clip |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torch.utils.data import TensorDataset, DataLoader |
|
|
from tqdm import tqdm |
|
|
from utils.config import update_args |
|
|
import argparse |
|
|
from datetime import datetime |
|
|
import shutil |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from functools import lru_cache |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
LEVELS = 3 |
|
|
|
|
|
|
|
|
CACHE_SIZE = 1000 |
|
|
|
|
|
def load_clip(device): |
|
|
print(f'[INFO] loading CLIP model...') |
|
|
model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
print(f'[INFO] finish loading CLIP model...') |
|
|
return model, preprocess |
|
|
|
|
|
def box_multi_level(bbox, shape, level, expansion_ratio): |
|
|
left, top, right, bottom = bbox |
|
|
if level == 0: |
|
|
return left, top, right, bottom |
|
|
|
|
|
x_exp = int(abs(right - left) * expansion_ratio) * level |
|
|
y_exp = int(abs(bottom - top) * expansion_ratio) * level |
|
|
return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp) |
|
|
|
|
|
@lru_cache(maxsize=CACHE_SIZE) |
|
|
def load_image_cached(img_path: str) -> np.ndarray: |
|
|
"""Кэшированная загрузка изображений""" |
|
|
if not os.path.exists(img_path): |
|
|
return None |
|
|
return cv2.imread(img_path) |
|
|
|
|
|
@lru_cache(maxsize=CACHE_SIZE) |
|
|
def load_mask_cached(mask_path: str) -> np.ndarray: |
|
|
"""Кэшированная загрузка масок""" |
|
|
if not os.path.exists(mask_path): |
|
|
return None |
|
|
return np.load(mask_path, allow_pickle=True) |
|
|
|
|
|
def process_single_frame(args_tuple): |
|
|
"""Обработка одного кадра - функция для многопоточности I/O""" |
|
|
frame, key, scene_id, img_base_path, mask_base_path, expansion_ratio = args_tuple |
|
|
|
|
|
img_path = f'{img_base_path}/{str(frame["frame_id"]).zfill(5)}.jpg' |
|
|
mask_path = os.path.join(mask_base_path, frame['mask_path']) |
|
|
|
|
|
try: |
|
|
image = load_image_cached(img_path) |
|
|
mask = load_mask_cached(mask_path) |
|
|
|
|
|
if image is None or mask is None: |
|
|
return None |
|
|
|
|
|
|
|
|
mask_indices = mask == key |
|
|
if not np.any(mask_indices): |
|
|
return None |
|
|
|
|
|
image = image.copy() |
|
|
image[mask_indices] = (image[mask_indices] * 0.8).astype(np.uint8) |
|
|
|
|
|
x1, y1, x2, y2 = frame['bbox'] |
|
|
cropped_images = [] |
|
|
|
|
|
for level in range(LEVELS): |
|
|
x1, y1, x2, y2 = box_multi_level((x1, y1, x2, y2), image.shape, level, expansion_ratio) |
|
|
|
|
|
|
|
|
cropped = image[y1:y2, x1:x2] |
|
|
if cropped.size == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
rgb_image = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB) |
|
|
pil_image = pad_into_square(Image.fromarray(rgb_image)) |
|
|
|
|
|
cropped_images.append(np.array(pil_image)) |
|
|
|
|
|
return cropped_images |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Ошибка обработки кадра {frame['frame_id']}: {e}") |
|
|
return None |
|
|
|
|
|
def get_cropped_images_parallel(key, object_data, scene_id, preprocess, num_images=5, expansion_ratio=0.1, num_workers=4): |
|
|
"""Параллельная обработка изображений с использованием ThreadPoolExecutor""" |
|
|
img_base_path = f'/home/jovyan/users/lemeshko/mmdetection3d/data/scannet/posed_images/{scene_id}' |
|
|
mask_base_path = f'/home/jovyan/users/lemeshko/scripts/gsam_result/scannet200/yolo/{scene_id}' |
|
|
|
|
|
frames = object_data['frames'][:num_images] |
|
|
|
|
|
|
|
|
args_list = [ |
|
|
(frame, key, scene_id, img_base_path, mask_base_path, expansion_ratio) |
|
|
for frame in frames |
|
|
] |
|
|
|
|
|
all_cropped_images = [] |
|
|
all_display_images = [] |
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=num_workers) as executor: |
|
|
results = list(executor.map(process_single_frame, args_list)) |
|
|
|
|
|
|
|
|
for result in results: |
|
|
if result is not None: |
|
|
for img_array in result: |
|
|
|
|
|
pil_img = Image.fromarray(img_array) |
|
|
processed_img = preprocess(pil_img) |
|
|
all_cropped_images.append(processed_img) |
|
|
|
|
|
|
|
|
display_img = np.array(pil_img.resize((64, 64))) |
|
|
all_display_images.append(display_img) |
|
|
|
|
|
if not all_cropped_images: |
|
|
return torch.empty(0), np.array([]) |
|
|
|
|
|
return torch.stack(all_cropped_images), np.concatenate(all_display_images, axis=1)[..., ::-1] |
|
|
|
|
|
def pad_into_square(image): |
|
|
"""Оптимизированная функция для создания квадратного изображения""" |
|
|
width, height = image.size |
|
|
if width == height: |
|
|
return image |
|
|
|
|
|
new_size = max(width, height) |
|
|
new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255)) |
|
|
left = (new_size - width) // 2 |
|
|
top = (new_size - height) // 2 |
|
|
new_image.paste(image, (left, top)) |
|
|
return new_image |
|
|
|
|
|
def batch_encode_images(model, images, device, batch_size=32): |
|
|
"""Пакетная обработка изображений через CLIP""" |
|
|
if len(images) == 0: |
|
|
return np.array([]) |
|
|
|
|
|
features = [] |
|
|
for i in range(0, len(images), batch_size): |
|
|
batch = images[i:i + batch_size].to(device) |
|
|
with torch.no_grad(): |
|
|
batch_features = model.encode_image(batch).float() |
|
|
batch_features /= batch_features.norm(dim=-1, keepdim=True) |
|
|
features.append(batch_features.cpu().numpy()) |
|
|
|
|
|
return np.vstack(features) if features else np.array([]) |
|
|
|
|
|
def process_objects_batch(object_items, args, scene_name, label_text_features, descriptions, label2id, total_point_num, logs_path, dataset): |
|
|
"""Обработка пакета объектов для улучшения эффективности""" |
|
|
batch_results = [] |
|
|
|
|
|
for idx, (key, object_data) in object_items: |
|
|
try: |
|
|
cropped_images, saved_images = get_cropped_images_parallel( |
|
|
key, object_data, scene_name, args.preprocess, args.num_images, |
|
|
num_workers=args.image_workers |
|
|
) |
|
|
|
|
|
if len(cropped_images) == 0: |
|
|
batch_results.append(None) |
|
|
continue |
|
|
|
|
|
|
|
|
features = batch_encode_images(args.model, cropped_images, args.device, batch_size=32) |
|
|
|
|
|
if features.size == 0: |
|
|
batch_results.append(None) |
|
|
continue |
|
|
|
|
|
|
|
|
object_feature = np.mean(features, axis=0, keepdims=True) |
|
|
|
|
|
|
|
|
raw_similarity = np.dot(object_feature, label_text_features.T) |
|
|
exp_sim = np.exp(raw_similarity * 100) |
|
|
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True) |
|
|
probs = np.max(prob, axis=0) |
|
|
max_label_id = np.argmax(probs) |
|
|
prob_score = probs[max_label_id] |
|
|
|
|
|
label_id = label2id[descriptions[max_label_id]] |
|
|
|
|
|
|
|
|
point_ids = object_data['mask'] |
|
|
binary_mask = np.zeros(total_point_num, dtype=bool) |
|
|
binary_mask[list(point_ids)] = True |
|
|
|
|
|
result = { |
|
|
'idx': idx, |
|
|
'key': key, |
|
|
'binary_mask': binary_mask, |
|
|
'prob_score': prob_score, |
|
|
'label_id': label_id, |
|
|
'saved_images': saved_images if args.debug else None |
|
|
} |
|
|
|
|
|
batch_results.append(result) |
|
|
|
|
|
if args.debug and saved_images is not None: |
|
|
os.makedirs(logs_path, exist_ok=True) |
|
|
cv2.imwrite(f'{logs_path}/{str(key).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob_score:.2f}).jpg', saved_images) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Ошибка обработки объекта {key}: {e}") |
|
|
batch_results.append(None) |
|
|
|
|
|
return batch_results |
|
|
|
|
|
def main(args): |
|
|
scenes = np.loadtxt('splits/scannet200_subset.txt', dtype=str) |
|
|
scenes = scenes[args.job_id::args.num_workers] |
|
|
pred_dir = os.path.join('data/prediction', args.config, args.exp_name) |
|
|
os.makedirs(pred_dir, exist_ok=True) |
|
|
shutil.copy("semantics/c_open-voc_query_optimized.py", os.path.join(pred_dir, "c_open-voc_query_optimized.py")) |
|
|
|
|
|
for scene in tqdm(scenes): |
|
|
args.seq_name = scene |
|
|
|
|
|
if os.path.exists(f'{pred_dir}/{args.seq_name}.npz'): |
|
|
continue |
|
|
|
|
|
dataset = ScanNetDataset(scene) |
|
|
total_point_num = dataset.get_scene_points().shape[0] |
|
|
|
|
|
label_features_dict = dataset.get_label_features() |
|
|
label_text_features = np.stack(list(label_features_dict.values())) |
|
|
descriptions = list(label_features_dict.keys()) |
|
|
|
|
|
scene_name = dataset.seq_name |
|
|
logs_path = os.path.join('logs', args.exp_name, scene) |
|
|
object_dict = np.load(f'/home/jovyan/users/lemeshko/scripts/gsam_result/scannet200/yolo/{scene_name}/infos.npy', allow_pickle=True).item() |
|
|
|
|
|
label2id = dataset.get_label_id()[0] |
|
|
num_instance = len(object_dict) |
|
|
|
|
|
pred_dict = { |
|
|
"pred_masks": np.zeros((total_point_num, num_instance), dtype=bool), |
|
|
"pred_score": np.ones(num_instance), |
|
|
"pred_classes": np.zeros(num_instance, dtype=np.int32) |
|
|
} |
|
|
|
|
|
print(f"Обработка сцены {scene_name} с {num_instance} объектами") |
|
|
|
|
|
|
|
|
if args.batch_processing and num_instance > args.batch_size: |
|
|
object_items = list(enumerate(object_dict.items())) |
|
|
|
|
|
|
|
|
for batch_start in range(0, len(object_items), args.batch_size): |
|
|
batch_end = min(batch_start + args.batch_size, len(object_items)) |
|
|
batch = object_items[batch_start:batch_end] |
|
|
|
|
|
print(f"Обработка пакета {batch_start//args.batch_size + 1}/{(len(object_items) + args.batch_size - 1)//args.batch_size}") |
|
|
|
|
|
batch_results = process_objects_batch( |
|
|
batch, args, scene_name, label_text_features, descriptions, |
|
|
label2id, total_point_num, logs_path, dataset |
|
|
) |
|
|
|
|
|
|
|
|
for result in batch_results: |
|
|
if result is not None: |
|
|
idx = result['idx'] |
|
|
pred_dict['pred_masks'][:, idx] = result['binary_mask'] |
|
|
pred_dict['pred_score'][idx] = result['prob_score'] |
|
|
pred_dict['pred_classes'][idx] = result['label_id'] |
|
|
else: |
|
|
|
|
|
for idx, (key, object_data) in enumerate(object_dict.items()): |
|
|
cropped_images, saved_images = get_cropped_images_parallel( |
|
|
key, object_data, scene_name, args.preprocess, args.num_images, |
|
|
num_workers=args.image_workers |
|
|
) |
|
|
|
|
|
if len(cropped_images) == 0: |
|
|
continue |
|
|
|
|
|
features = batch_encode_images(args.model, cropped_images, args.device, batch_size=32) |
|
|
object_feature = np.mean(features, axis=0, keepdims=True) |
|
|
|
|
|
raw_similarity = np.dot(object_feature, label_text_features.T) |
|
|
exp_sim = np.exp(raw_similarity * 100) |
|
|
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True) |
|
|
probs = np.max(prob, axis=0) |
|
|
max_label_id = np.argmax(probs) |
|
|
prob_score = probs[max_label_id] |
|
|
|
|
|
pred_dict['pred_score'][idx] = prob_score |
|
|
label_id = label2id[descriptions[max_label_id]] |
|
|
pred_dict['pred_classes'][idx] = label_id |
|
|
|
|
|
point_ids = object_data['mask'] |
|
|
binary_mask = np.zeros(total_point_num, dtype=bool) |
|
|
binary_mask[list(point_ids)] = True |
|
|
pred_dict['pred_masks'][:, idx] = binary_mask |
|
|
|
|
|
if args.debug: |
|
|
os.makedirs(logs_path, exist_ok=True) |
|
|
cv2.imwrite(f'{logs_path}/{str(key).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob_score:.2f}).jpg', saved_images) |
|
|
|
|
|
np.savez(f'{pred_dir}/{args.seq_name}.npz', **pred_dict) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--config', type=str, default='scannet') |
|
|
parser.add_argument('--debug', action="store_true") |
|
|
parser.add_argument('--exp_name', type=str, default='baseline') |
|
|
parser.add_argument('--device', type=str, default='cuda:0') |
|
|
parser.add_argument('--num_images', type=int, default=5) |
|
|
parser.add_argument('--num_workers', '-n', type=int, default=1) |
|
|
parser.add_argument('--job_id', '-i', type=int, default=0) |
|
|
|
|
|
parser.add_argument('--batch_processing', action="store_true", help="Пакетная обработка объектов") |
|
|
parser.add_argument('--batch_size', type=int, default=10, help="Размер пакета объектов") |
|
|
parser.add_argument('--image_workers', type=int, default=4, help="Количество потоков для загрузки изображений") |
|
|
|
|
|
args = parser.parse_args() |
|
|
args = update_args(args) |
|
|
|
|
|
model, preprocess = load_clip(args.device) |
|
|
args.model = model |
|
|
args.preprocess = preprocess |
|
|
|
|
|
date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
args.exp_name = f'{args.exp_name}' |
|
|
|
|
|
main(args) |