zoo3d / exts /ov_features.py
drozdgk's picture
init
4eeefd1
'''
This script extracts open-vocabulary visual features for each mask following OpenMask3D.
For each mask, we crop the image with CROP_SCALES=3 scales based on the mask.
Then we extract the visual features using CLIP model and average these features as the mask feature.
'''
import open_clip
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import cv2
import argparse
import json
import sys
WD = None
_MK_PATH = None
def load(MK_PATH: str) -> None:
global _MK_PATH
_MK_PATH = MK_PATH
if MK_PATH not in sys.path:
sys.path.insert(0, MK_PATH)
from dataset.scannet import WildDataset
global WD
WD = WildDataset
def get_dataset(seq_name, root):
dataset = WD(seq_name, root=root)
return dataset
CROP_SCALES = 3 # follow OpenMask3D
class CroppedImageDataset(Dataset):
def __init__(self, seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess):
'''
Given a list of masks, we calculate the open-vocabulary features for each mask.
Args:
seq_name_list: sequence name for each mask
frame_id_list: frame id for each mask
mask_id_list: mask id for each mask
rgb_path_list: rgb path for each mask
segmentation_path_list: segmentation path for each mask
preprocess: image preprocessing function
'''
self.seq_name_list = seq_name_list
self.frame_id_list = frame_id_list
self.mask_id_list = mask_id_list
self.preprocess = preprocess
self.rgb_path_list = rgb_path_list
self.segmentation_path_list = segmentation_path_list
def __len__(self):
return len(self.mask_id_list)
def __getitem__(self, idx):
def get_cropped_image(mask, rgb):
'''
Given a mask and an rgb image, we crop the image with CROP_SCALES scales based on the mask.
'''
def mask2box_multi_level(mask, level, expansion_ratio):
pos = np.where(mask)
top = np.min(pos[0])
bottom = np.max(pos[0])
left = np.min(pos[1])
right = np.max(pos[1])
if level == 0:
return left, top, right , bottom
shape = mask.shape
x_exp = int(abs(right - left)*expansion_ratio) * level
y_exp = int(abs(bottom - top)*expansion_ratio) * level
return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp)
def crop_image(rgb, mask):
multiscale_cropped_images = []
for level in range(CROP_SCALES):
left, top, right, bottom = mask2box_multi_level(mask, level, 0.1)
cropped_image = rgb[top:bottom, left:right].copy()
multiscale_cropped_images.append(cropped_image)
return multiscale_cropped_images
mask = cv2.resize(mask.astype(np.uint8), (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
multiscale_cropped_images = crop_image(rgb, mask)
return multiscale_cropped_images
def pad_into_square(image):
width, height = image.size
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), (255,255,255))
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(image, (left, top))
return new_image
seq_name = self.seq_name_list[idx]
frame_id = self.frame_id_list[idx]
mask_id = self.mask_id_list[idx]
rgb_path = self.rgb_path_list[idx]
segmentation_path = self.segmentation_path_list[idx]
rgb_image = cv2.imread(rgb_path)
rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
segmentation_image = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED)
mask = (segmentation_image == mask_id)
cropped_images = get_cropped_image(mask, np.array(rgb_image))
input_images = [self.preprocess(pad_into_square(Image.fromarray(cropped_image))) for cropped_image in cropped_images]
input_images = torch.stack(input_images)
return input_images, seq_name, frame_id, mask_id
def main(model, preprocess, seq_name, root):
seq_name_list = [seq_name]
seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list = [], [], [], [], []
feature_dict = {}
for seq_name in seq_name_list:
dataset = get_dataset(seq_name, root)
if not os.path.exists(os.path.join(dataset.object_dict_dir, 'wild', f'object_dict.npy')):
continue
object_dict = np.load(f'{dataset.object_dict_dir}/wild/object_dict.npy', allow_pickle=True).item()
for key, value in object_dict.items():
mask_list = value['repre_mask_list']
if len(mask_list) == 0:
continue
for mask_info in mask_list:
seq_name_list.append(seq_name)
frame_id = mask_info[0]
frame_id_list.append(frame_id)
mask_id_list.append(mask_info[1])
rgb_path, segmentation_path = dataset.get_frame_path(frame_id)
rgb_path_list.append(rgb_path)
segmentation_path_list.append(segmentation_path)
feature_dict[seq_name] = {}
dataloader = DataLoader(CroppedImageDataset(seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess), batch_size=64, shuffle=False, num_workers=16)
print('[INFO] extracting features')
for images, seq_names, frame_ids, mask_ids in tqdm(dataloader):
images = images.reshape(-1, 3, 224, 224)
image_input = images.cuda()
with torch.no_grad():
image_features = model.encode_image(image_input).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().numpy()
for i in range(len(image_features) // CROP_SCALES):
feature_dict[seq_names[i]][f'{frame_ids[i]}_{mask_ids[i]}'] = image_features[CROP_SCALES*i:CROP_SCALES*(i+1)].mean(axis=0)
print('[INFO] finish extracting features')
for seq_name in seq_name_list:
dataset = get_dataset(seq_name, root)
if seq_name in feature_dict:
np.save(os.path.join(dataset.object_dict_dir, 'wild', f'open-vocabulary_features.npy'), feature_dict[seq_name])