zoo3d / MaskClustering /semantics /get_open-voc_features.py
bulatko's picture
adding real MK
55e58d1
raw
history blame
6.69 kB
'''
This script extracts open-vocabulary visual features for each mask following OpenMask3D.
For each mask, we crop the image with CROP_SCALES=3 scales based on the mask.
Then we extract the visual features using CLIP model and average these features as the mask feature.
'''
import open_clip
import os
from PIL import Image
import numpy as np
import torch
from utils.config import get_args, get_dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import cv2
CROP_SCALES = 3 # follow OpenMask3D
args = get_args()
class CroppedImageDataset(Dataset):
def __init__(self, seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess):
'''
Given a list of masks, we calculate the open-vocabulary features for each mask.
Args:
seq_name_list: sequence name for each mask
frame_id_list: frame id for each mask
mask_id_list: mask id for each mask
rgb_path_list: rgb path for each mask
segmentation_path_list: segmentation path for each mask
preprocess: image preprocessing function
'''
self.seq_name_list = seq_name_list
self.frame_id_list = frame_id_list
self.mask_id_list = mask_id_list
self.preprocess = preprocess
self.rgb_path_list = rgb_path_list
self.segmentation_path_list = segmentation_path_list
def __len__(self):
return len(self.mask_id_list)
def __getitem__(self, idx):
def get_cropped_image(mask, rgb):
'''
Given a mask and an rgb image, we crop the image with CROP_SCALES scales based on the mask.
'''
def mask2box_multi_level(mask, level, expansion_ratio):
pos = np.where(mask)
top = np.min(pos[0])
bottom = np.max(pos[0])
left = np.min(pos[1])
right = np.max(pos[1])
if level == 0:
return left, top, right , bottom
shape = mask.shape
x_exp = int(abs(right - left)*expansion_ratio) * level
y_exp = int(abs(bottom - top)*expansion_ratio) * level
return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp)
def crop_image(rgb, mask):
multiscale_cropped_images = []
for level in range(CROP_SCALES):
left, top, right, bottom = mask2box_multi_level(mask, level, 0.1)
cropped_image = rgb[top:bottom, left:right].copy()
multiscale_cropped_images.append(cropped_image)
return multiscale_cropped_images
mask = cv2.resize(mask.astype(np.uint8), (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
multiscale_cropped_images = crop_image(rgb, mask)
return multiscale_cropped_images
def pad_into_square(image):
width, height = image.size
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), (255,255,255))
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(image, (left, top))
return new_image
seq_name = self.seq_name_list[idx]
frame_id = self.frame_id_list[idx]
mask_id = self.mask_id_list[idx]
rgb_path = self.rgb_path_list[idx]
segmentation_path = self.segmentation_path_list[idx]
rgb_image = cv2.imread(rgb_path)
rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
segmentation_image = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED)
mask = (segmentation_image == mask_id)
cropped_images = get_cropped_image(mask, np.array(rgb_image))
input_images = [self.preprocess(pad_into_square(Image.fromarray(cropped_image))) for cropped_image in cropped_images]
input_images = torch.stack(input_images)
return input_images, seq_name, frame_id, mask_id
def load_clip():
print(f'[INFO] loading CLIP model...')
model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
model.cuda()
model.eval()
print(f'[INFO]', ' finish loading CLIP model...')
return model, preprocess
def main():
model, preprocess = load_clip()
seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list = [], [], [], [], []
feature_dict = {}
for seq_name in args.seq_name_list.split('+'):
args.seq_name = seq_name
dataset = get_dataset(args)
object_dict = np.load(f'{dataset.object_dict_dir}/{args.config}/object_dict.npy', allow_pickle=True).item()
for key, value in object_dict.items():
mask_list = value['repre_mask_list']
if len(mask_list) == 0:
continue
for mask_info in mask_list:
seq_name_list.append(seq_name)
frame_id = mask_info[0]
frame_id_list.append(frame_id)
mask_id_list.append(mask_info[1])
rgb_path, segmentation_path = dataset.get_frame_path(frame_id)
rgb_path_list.append(rgb_path)
segmentation_path_list.append(segmentation_path)
feature_dict[seq_name] = {}
dataloader = DataLoader(CroppedImageDataset(seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess), batch_size=64, shuffle=False, num_workers=16)
print('[INFO] extracting features')
for images, seq_names, frame_ids, mask_ids in tqdm(dataloader):
images = images.reshape(-1, 3, 224, 224)
image_input = images.cuda()
with torch.no_grad():
image_features = model.encode_image(image_input).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().numpy()
for i in range(len(image_features) // CROP_SCALES):
feature_dict[seq_names[i]][f'{frame_ids[i]}_{mask_ids[i]}'] = image_features[CROP_SCALES*i:CROP_SCALES*(i+1)].mean(axis=0)
print('[INFO] finish extracting features')
for seq_name in args.seq_name_list.split('+'):
args.seq_name = seq_name
dataset = get_dataset(args)
if seq_name in feature_dict:
np.save(os.path.join(dataset.object_dict_dir, f'{args.config}/open-vocabulary_features.npy'), feature_dict[seq_name])
if __name__ == '__main__':
main()