File size: 6,738 Bytes
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
'''
    This script extracts open-vocabulary visual features for each mask following OpenMask3D.
    For each mask, we crop the image with CROP_SCALES=3 scales based on the mask.
    Then we extract the visual features using CLIP model and average these features as the mask feature.
'''

import open_clip
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import cv2
import argparse
import json
import sys

WD = None
_MK_PATH = None
def load(MK_PATH: str) -> None:
    global _MK_PATH
    _MK_PATH = MK_PATH
    if MK_PATH not in sys.path:
        sys.path.insert(0, MK_PATH)
    
    from dataset.scannet import WildDataset
    global WD
    WD = WildDataset

def get_dataset(seq_name, root):
    dataset = WD(seq_name, root=root)
    return dataset



CROP_SCALES = 3 # follow OpenMask3D

class CroppedImageDataset(Dataset):
    def __init__(self, seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess):
        '''
            Given a list of masks, we calculate the open-vocabulary features for each mask.

            Args:
                seq_name_list: sequence name for each mask
                frame_id_list: frame id for each mask
                mask_id_list: mask id for each mask
                rgb_path_list: rgb path for each mask
                segmentation_path_list: segmentation path for each mask
                preprocess: image preprocessing function
        '''
        self.seq_name_list = seq_name_list
        self.frame_id_list = frame_id_list
        self.mask_id_list = mask_id_list
        self.preprocess = preprocess
        self.rgb_path_list = rgb_path_list
        self.segmentation_path_list = segmentation_path_list

    def __len__(self):
        return len(self.mask_id_list)

    def __getitem__(self, idx):
        def get_cropped_image(mask, rgb):
            '''
                Given a mask and an rgb image, we crop the image with CROP_SCALES scales based on the mask.
            '''
            def mask2box_multi_level(mask, level, expansion_ratio):
                pos = np.where(mask)
                top = np.min(pos[0])
                bottom = np.max(pos[0])
                left = np.min(pos[1])
                right = np.max(pos[1])

                if level == 0:
                    return left, top, right , bottom
                shape = mask.shape
                x_exp = int(abs(right - left)*expansion_ratio) * level
                y_exp = int(abs(bottom - top)*expansion_ratio) * level
                return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp)

            def crop_image(rgb, mask):
                multiscale_cropped_images = []
                for level in range(CROP_SCALES):
                    left, top, right, bottom = mask2box_multi_level(mask, level, 0.1)
                    cropped_image = rgb[top:bottom, left:right].copy()
                    multiscale_cropped_images.append(cropped_image)
                return multiscale_cropped_images

            mask = cv2.resize(mask.astype(np.uint8), (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
            multiscale_cropped_images = crop_image(rgb, mask)
            return multiscale_cropped_images
        
        def pad_into_square(image):
            width, height = image.size
            new_size = max(width, height)
            new_image = Image.new("RGB", (new_size, new_size), (255,255,255))
            left = (new_size - width) // 2
            top = (new_size - height) // 2
            new_image.paste(image, (left, top))
            return new_image

        seq_name = self.seq_name_list[idx]
        frame_id = self.frame_id_list[idx]
        mask_id = self.mask_id_list[idx]
        rgb_path = self.rgb_path_list[idx]
        segmentation_path = self.segmentation_path_list[idx]

        rgb_image = cv2.imread(rgb_path)
        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
        
        segmentation_image = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED)
        mask = (segmentation_image == mask_id)
        cropped_images = get_cropped_image(mask, np.array(rgb_image))

        input_images = [self.preprocess(pad_into_square(Image.fromarray(cropped_image))) for cropped_image in cropped_images]
        input_images = torch.stack(input_images)
        return input_images, seq_name, frame_id, mask_id


def main(model, preprocess, seq_name, root):
    seq_name_list = [seq_name]

    seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list = [], [], [], [], []
    feature_dict = {}
    for seq_name in seq_name_list:
        dataset = get_dataset(seq_name, root)
        if not os.path.exists(os.path.join(dataset.object_dict_dir, 'wild', f'object_dict.npy')):
            continue
        object_dict = np.load(f'{dataset.object_dict_dir}/wild/object_dict.npy', allow_pickle=True).item()
        for key, value in object_dict.items():
            mask_list = value['repre_mask_list']
            if len(mask_list) == 0:
                continue
            for mask_info in mask_list:
                seq_name_list.append(seq_name)
                frame_id = mask_info[0]
                frame_id_list.append(frame_id)
                mask_id_list.append(mask_info[1])
                rgb_path, segmentation_path = dataset.get_frame_path(frame_id)
                rgb_path_list.append(rgb_path)
                segmentation_path_list.append(segmentation_path)
        feature_dict[seq_name] = {}

    dataloader = DataLoader(CroppedImageDataset(seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess), batch_size=64, shuffle=False, num_workers=16)
    
    print('[INFO] extracting features')
    for images, seq_names, frame_ids, mask_ids in tqdm(dataloader):
        images = images.reshape(-1, 3, 224, 224)
        image_input = images.cuda()
        with torch.no_grad():
            image_features = model.encode_image(image_input).float()
            image_features /= image_features.norm(dim=-1, keepdim=True)
            image_features = image_features.cpu().numpy()
        for i in range(len(image_features) // CROP_SCALES):
            feature_dict[seq_names[i]][f'{frame_ids[i]}_{mask_ids[i]}'] = image_features[CROP_SCALES*i:CROP_SCALES*(i+1)].mean(axis=0)
    print('[INFO] finish extracting features')

    for seq_name in seq_name_list:
        dataset = get_dataset(seq_name, root)
        if seq_name in feature_dict:
            np.save(os.path.join(dataset.object_dict_dir, 'wild', f'open-vocabulary_features.npy'), feature_dict[seq_name])