File size: 6,687 Bytes
55e58d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
'''
    This script extracts open-vocabulary visual features for each mask following OpenMask3D.
    For each mask, we crop the image with CROP_SCALES=3 scales based on the mask.
    Then we extract the visual features using CLIP model and average these features as the mask feature.
'''

import open_clip
import os
from PIL import Image
import numpy as np
import torch
from utils.config import get_args, get_dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import cv2

CROP_SCALES = 3 # follow OpenMask3D
args = get_args()

class CroppedImageDataset(Dataset):
    def __init__(self, seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess):
        '''
            Given a list of masks, we calculate the open-vocabulary features for each mask.

            Args:
                seq_name_list: sequence name for each mask
                frame_id_list: frame id for each mask
                mask_id_list: mask id for each mask
                rgb_path_list: rgb path for each mask
                segmentation_path_list: segmentation path for each mask
                preprocess: image preprocessing function
        '''
        self.seq_name_list = seq_name_list
        self.frame_id_list = frame_id_list
        self.mask_id_list = mask_id_list
        self.preprocess = preprocess
        self.rgb_path_list = rgb_path_list
        self.segmentation_path_list = segmentation_path_list

    def __len__(self):
        return len(self.mask_id_list)

    def __getitem__(self, idx):
        def get_cropped_image(mask, rgb):
            '''
                Given a mask and an rgb image, we crop the image with CROP_SCALES scales based on the mask.
            '''
            def mask2box_multi_level(mask, level, expansion_ratio):
                pos = np.where(mask)
                top = np.min(pos[0])
                bottom = np.max(pos[0])
                left = np.min(pos[1])
                right = np.max(pos[1])

                if level == 0:
                    return left, top, right , bottom
                shape = mask.shape
                x_exp = int(abs(right - left)*expansion_ratio) * level
                y_exp = int(abs(bottom - top)*expansion_ratio) * level
                return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp)

            def crop_image(rgb, mask):
                multiscale_cropped_images = []
                for level in range(CROP_SCALES):
                    left, top, right, bottom = mask2box_multi_level(mask, level, 0.1)
                    cropped_image = rgb[top:bottom, left:right].copy()
                    multiscale_cropped_images.append(cropped_image)
                return multiscale_cropped_images

            mask = cv2.resize(mask.astype(np.uint8), (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
            multiscale_cropped_images = crop_image(rgb, mask)
            return multiscale_cropped_images
        
        def pad_into_square(image):
            width, height = image.size
            new_size = max(width, height)
            new_image = Image.new("RGB", (new_size, new_size), (255,255,255))
            left = (new_size - width) // 2
            top = (new_size - height) // 2
            new_image.paste(image, (left, top))
            return new_image

        seq_name = self.seq_name_list[idx]
        frame_id = self.frame_id_list[idx]
        mask_id = self.mask_id_list[idx]
        rgb_path = self.rgb_path_list[idx]
        segmentation_path = self.segmentation_path_list[idx]

        rgb_image = cv2.imread(rgb_path)
        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
        
        segmentation_image = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED)
        mask = (segmentation_image == mask_id)
        cropped_images = get_cropped_image(mask, np.array(rgb_image))

        input_images = [self.preprocess(pad_into_square(Image.fromarray(cropped_image))) for cropped_image in cropped_images]
        input_images = torch.stack(input_images)
        return input_images, seq_name, frame_id, mask_id

def load_clip():
    print(f'[INFO] loading CLIP model...')
    model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
    model.cuda()
    model.eval()
    print(f'[INFO]', ' finish loading CLIP model...')
    return model, preprocess

def main():
    model, preprocess = load_clip()

    seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list = [], [], [], [], []
    feature_dict = {}
    for seq_name in args.seq_name_list.split('+'):
        args.seq_name = seq_name
        dataset = get_dataset(args)
        object_dict = np.load(f'{dataset.object_dict_dir}/{args.config}/object_dict.npy', allow_pickle=True).item()
        for key, value in object_dict.items():
            mask_list = value['repre_mask_list']
            if len(mask_list) == 0:
                continue
            for mask_info in mask_list:
                seq_name_list.append(seq_name)
                frame_id = mask_info[0]
                frame_id_list.append(frame_id)
                mask_id_list.append(mask_info[1])
                rgb_path, segmentation_path = dataset.get_frame_path(frame_id)
                rgb_path_list.append(rgb_path)
                segmentation_path_list.append(segmentation_path)
        feature_dict[seq_name] = {}

    dataloader = DataLoader(CroppedImageDataset(seq_name_list, frame_id_list, mask_id_list, rgb_path_list, segmentation_path_list, preprocess), batch_size=64, shuffle=False, num_workers=16)
    
    print('[INFO] extracting features')
    for images, seq_names, frame_ids, mask_ids in tqdm(dataloader):
        images = images.reshape(-1, 3, 224, 224)
        image_input = images.cuda()
        with torch.no_grad():
            image_features = model.encode_image(image_input).float()
            image_features /= image_features.norm(dim=-1, keepdim=True)
            image_features = image_features.cpu().numpy()
        for i in range(len(image_features) // CROP_SCALES):
            feature_dict[seq_names[i]][f'{frame_ids[i]}_{mask_ids[i]}'] = image_features[CROP_SCALES*i:CROP_SCALES*(i+1)].mean(axis=0)
    print('[INFO] finish extracting features')

    for seq_name in args.seq_name_list.split('+'):
        args.seq_name = seq_name
        dataset = get_dataset(args)
        if seq_name in feature_dict:
            np.save(os.path.join(dataset.object_dict_dir, f'{args.config}/open-vocabulary_features.npy'), feature_dict[seq_name])

if __name__ == '__main__':
    main()