''' This script extracts open-vocabulary visual features for each mask following OpenMask3D. For each mask, we crop the image with CROP_SCALES=3 scales based on the mask. Then we extract the visual features using CLIP model and average these features as the mask feature. ''' import open_clip import os from PIL import Image import numpy as np import torch from torch.utils.data import Dataset from torch.utils.data import DataLoader from tqdm import tqdm import cv2 import argparse import json import sys WD = None _MK_PATH = None def load(MK_PATH: str) -> None: global _MK_PATH _MK_PATH = MK_PATH if MK_PATH not in sys.path: sys.path.insert(0, MK_PATH) from dataset.scannet import WildDataset global WD WD = WildDataset def get_dataset(seq_name, root): dataset = WD(seq_name, root=root) return dataset CROP_SCALES = 3 # follow OpenMask3D class CroppedImageDataset(Dataset): def __init__(self, seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess): ''' Given a list of masks, we calculate the open-vocabulary features for each mask. Args: seq_name_list: sequence name for each mask frame_id_list: frame id for each mask mask_id_list: mask id for each mask rgb_path_list: rgb path for each mask segmentation_path_list: segmentation path for each mask preprocess: image preprocessing function ''' self.seq_name_list = seq_name_list self.frame_id_list = frame_id_list self.mask_id_list = mask_id_list self.preprocess = preprocess self.rgb_path_list = rgb_path_list self.segmentation_path_list = segmentation_path_list def __len__(self): return len(self.mask_id_list) def __getitem__(self, idx): def get_cropped_image(mask, rgb): ''' Given a mask and an rgb image, we crop the image with CROP_SCALES scales based on the mask. ''' def mask2box_multi_level(mask, level, expansion_ratio): pos = np.where(mask) top = np.min(pos[0]) bottom = np.max(pos[0]) left = np.min(pos[1]) right = np.max(pos[1]) if level == 0: return left, top, right , bottom shape = mask.shape x_exp = int(abs(right - left)*expansion_ratio) * level y_exp = int(abs(bottom - top)*expansion_ratio) * level return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp) def crop_image(rgb, mask): multiscale_cropped_images = [] for level in range(CROP_SCALES): left, top, right, bottom = mask2box_multi_level(mask, level, 0.1) cropped_image = rgb[top:bottom, left:right].copy() multiscale_cropped_images.append(cropped_image) return multiscale_cropped_images mask = cv2.resize(mask.astype(np.uint8), (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST) multiscale_cropped_images = crop_image(rgb, mask) return multiscale_cropped_images def pad_into_square(image): width, height = image.size new_size = max(width, height) new_image = Image.new("RGB", (new_size, new_size), (255,255,255)) left = (new_size - width) // 2 top = (new_size - height) // 2 new_image.paste(image, (left, top)) return new_image seq_name = self.seq_name_list[idx] frame_id = self.frame_id_list[idx] mask_id = self.mask_id_list[idx] rgb_path = self.rgb_path_list[idx] segmentation_path = self.segmentation_path_list[idx] rgb_image = cv2.imread(rgb_path) rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB) segmentation_image = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED) mask = (segmentation_image == mask_id) cropped_images = get_cropped_image(mask, np.array(rgb_image)) input_images = [self.preprocess(pad_into_square(Image.fromarray(cropped_image))) for cropped_image in cropped_images] input_images = torch.stack(input_images) return input_images, seq_name, frame_id, mask_id def main(model, preprocess, seq_name, root): seq_name_list = [seq_name] seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list = [], [], [], [], [] feature_dict = {} for seq_name in seq_name_list: dataset = get_dataset(seq_name, root) if not os.path.exists(os.path.join(dataset.object_dict_dir, 'wild', f'object_dict.npy')): continue object_dict = np.load(f'{dataset.object_dict_dir}/wild/object_dict.npy', allow_pickle=True).item() for key, value in object_dict.items(): mask_list = value['repre_mask_list'] if len(mask_list) == 0: continue for mask_info in mask_list: seq_name_list.append(seq_name) frame_id = mask_info[0] frame_id_list.append(frame_id) mask_id_list.append(mask_info[1]) rgb_path, segmentation_path = dataset.get_frame_path(frame_id) rgb_path_list.append(rgb_path) segmentation_path_list.append(segmentation_path) feature_dict[seq_name] = {} dataloader = DataLoader(CroppedImageDataset(seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess), batch_size=64, shuffle=False, num_workers=16) print('[INFO] extracting features') for images, seq_names, frame_ids, mask_ids in tqdm(dataloader): images = images.reshape(-1, 3, 224, 224) image_input = images.cuda() with torch.no_grad(): image_features = model.encode_image(image_input).float() image_features /= image_features.norm(dim=-1, keepdim=True) image_features = image_features.cpu().numpy() for i in range(len(image_features) // CROP_SCALES): feature_dict[seq_names[i]][f'{frame_ids[i]}_{mask_ids[i]}'] = image_features[CROP_SCALES*i:CROP_SCALES*(i+1)].mean(axis=0) print('[INFO] finish extracting features') for seq_name in seq_name_list: dataset = get_dataset(seq_name, root) if seq_name in feature_dict: np.save(os.path.join(dataset.object_dict_dir, 'wild', f'open-vocabulary_features.npy'), feature_dict[seq_name])