''' This script is used to generate the semantic labels for the objects in the scene. ''' from utils.config import get_dataset, get_args from dataset.scannet import ScanNetDataset import os import numpy as np import open_clip import cv2 from PIL import Image import torch from torch.utils.data import TensorDataset, DataLoader from tqdm import tqdm from utils.config import update_args import argparse LEVELS = 3 def load_clip(device): print(f'[INFO] loading CLIP model...') model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k") model.to(device) model.eval() print(f'[INFO]', ' finish loading CLIP model...') return model, preprocess def box_multi_level(bbox, shape, level, expansion_ratio): left, top, right, bottom = bbox if level == 0: return left, top, right , bottom x_exp = int(abs(right - left)*expansion_ratio) * level y_exp = int(abs(bottom - top)*expansion_ratio) * level return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp) def get_cropped_images(object, scene_id, preprocess, num_images=5, expansion_ratio=0.1, preloaded_images=None): croped_images = [] images = [] for frame in object['frames'][:num_images]: image = preloaded_images[frame['frame_id']] x1, y1, x2, y2 = frame['bbox'] for level in range(LEVELS): x1, y1, x2, y2 = box_multi_level((x1, y1, x2, y2), np.asarray(image).shape, level, expansion_ratio) pil_image = pad_into_square(Image.fromarray(cv2.cvtColor(image[y1:y2, x1:x2], cv2.COLOR_BGR2RGB))) croped_images.append(preprocess(pil_image)) images.append(np.asarray(pil_image.resize((64, 64)))) return torch.stack(croped_images), np.concatenate(images, axis=1)[..., ::-1] def pad_into_square(image): width, height = image.size new_size = max(width, height) new_image = Image.new("RGB", (new_size, new_size), (255,255,255)) left = (new_size - width) // 2 top = (new_size - height) // 2 new_image.paste(image, (left, top)) return new_image def preload_images(scene, frame_ids): preloaded_images = {} for frame_id in frame_ids: img_path = f'/home/jovyan/users/lemeshko/mmdetection3d/data/scannet/posed_images/{scene}/{str(frame_id).zfill(5)}.jpg' image = cv2.imread(img_path) preloaded_images[frame_id] = image return preloaded_images def custom_probs(feature, label_text_features): object_feature = feature #np.mean(feature, axis=0, keepdims=True) print(object_feature.shape, np.mean(feature, axis=0, keepdims=True).shape) raw_similarity = object_feature @ label_text_features.T raw_similarity = np.sum(raw_similarity, axis=0, keepdims=True) exp_sim = np.exp(raw_similarity) prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True) probs = np.max(prob, axis=0) max_label_id = np.argmax(probs) prob = probs[max_label_id] return prob, max_label_id def main(args): scenes = np.loadtxt('/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/splits/scannet200_subset.txt', dtype=str) scenes = scenes[args.job_id::args.num_workers] # scenes = scenes[:1] for scene in tqdm(scenes): frame_ids = set() args.seq_name = scene pred_dir = os.path.join('data/prediction', args.config, args.exp_name) if os.path.exists(f'{pred_dir}/{args.seq_name}.npz'): continue dataset = ScanNetDataset(scene) total_point_num = dataset.get_scene_points().shape[0] label_features_dict = dataset.get_label_features() label_text_features = np.stack(list(label_features_dict.values())) descriptions = list(label_features_dict.keys()) scene_name = dataset.seq_name logs_path = f'logs/{args.exp_name}/{scene}' object_dict = np.load(os.path.join(args.path_to_predictions, scene_name, 'infos.npy'), allow_pickle=True).item() label2id = dataset.get_label_id()[0] print(label2id) print(dataset.get_label_id()[1]) os.makedirs(pred_dir, exist_ok=True) num_instance = len(object_dict) pred_dict = { "pred_masks": np.zeros((total_point_num, num_instance), dtype=bool), "pred_score": np.ones(num_instance), "pred_classes" : np.zeros(num_instance, dtype=np.int32) } for key, object in object_dict.items(): frame_ids.update([a['frame_id'] for a in object['frames']]) frame_ids = list(frame_ids) preloaded_images = preload_images(scene, frame_ids) print(scene_name) for idx, (key, object) in enumerate(object_dict.items()): croped_images, saved_images = get_cropped_images(object, scene_name, args.preprocess, args.num_images, preloaded_images=preloaded_images) bs = 32 chunks = torch.chunk(croped_images, max(1, len(croped_images) // bs)) features = [] for images in chunks: # images = images[0] images = images.to(args.device) with torch.no_grad(): image_features = args.model.encode_image(images).float() image_features /= image_features.norm(dim=-1, keepdim=True) image_features = image_features.cpu().numpy() for f in image_features: features.append(f) feature = np.stack(features) object_feature = np.mean(feature, axis=0, keepdims=True) raw_similarity = np.dot(object_feature, label_text_features.T) exp_sim = np.exp(raw_similarity * 100) prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True) probs = np.max(prob, axis=0) max_label_id = np.argmax(probs) prob = probs[max_label_id] pred_dict['pred_score'][idx] = prob label_id = label2id[descriptions[max_label_id]] pred_dict['pred_classes'][idx] = label_id point_ids = object['mask'] binary_mask = np.zeros(total_point_num, dtype=bool) binary_mask[list(point_ids)] = True pred_dict['pred_masks'][:, idx] = binary_mask if args.debug: os.makedirs(logs_path, exist_ok=True) cv2.imwrite(f'{logs_path}/{str(key).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob:.2f}).jpg', saved_images) print(key, label_id, dataset.get_label_id()[1][label_id], "confidence:", prob) pred_classes = pred_dict['pred_classes'] pred_classes = [dataset.get_label_id()[1][i] for i in pred_classes] # remove classes with label == 0 np.savez(f'{pred_dir}/{args.seq_name}.npz', **pred_dict) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='scannet') parser.add_argument('--debug', action="store_true") parser.add_argument('--exp_name', type=str, default='baseline') parser.add_argument('--device', type=str, default='cuda:0') parser.add_argument('--num_images', type=int, default=5) parser.add_argument('--num_workers', '-n', type=int, default=1) parser.add_argument('--path_to_predictions', type=str, default='/home/jovyan/users/bulat/workspace/3drec/Indoor/Grounded-SAM-2/results/gsam_result/scannet200/yolo/memory_150_classes_198') parser.add_argument('--job_id', '-i', type=int, default=0) args = parser.parse_args() args = update_args(args) model, preprocess = load_clip(args.device) args.model = model args.preprocess = preprocess main(args)