File size: 9,265 Bytes
f2f112a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import json
import os
from ImageBind.imagebind import data
from ImageBind.imagebind.models import imagebind_model
from ImageBind.imagebind.models.imagebind_model import ModalityType
from collections import OrderedDict
import torch
import argparse
from utils import crop_image, draw_bboxes, save_image, find_same_class, open_image_follow_symlink
from ultralytics import YOLO
from PIL import Image
import numpy as np
from models.TaskCLIP import TaskCLIP

id2task_name_file = './id2task_name.json'
task2prompt_file = './task20.json'

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-vlm_model', default='imagebind',	help='Set front CLIP model')
    parser.add_argument('-od_model', default='yolox',	help='Set object detection model')
    parser.add_argument('-device', default='cuda:0',	help='Set running environment')
    parser.add_argument('-task_id', type=int, default=1, help='Set task id')
    parser.add_argument('-image_path', type=str, default='./images/demo_image_1.jpg', help='Set input image path')
    parser.add_argument('-activation', type=str, default='relu')
    parser.add_argument('-ratio_text', type=float, default=0.3)
    parser.add_argument('-ratio_image', type=float, default=0.3)
    parser.add_argument('-ratio_glob', type=float, default=0.3)
    parser.add_argument('-norm_before', action='store_true',  default=False)
    parser.add_argument('-norm_after', action='store_true',  default=False)
    parser.add_argument('-norm_range',type=str, default='10|30')
    parser.add_argument('-cross_attention',action='store_true',  default=False)
    parser.add_argument('-eval_model_path',default='./test_model/decoder_epoch19.pt', help='set path for loading trained TaskCLIP model')
    parser.add_argument('-threshold', type=float, default=0.01, help='Set threshold for positive detection')
    parser.add_argument('-forward', action='store_true', default=True)
    parser.add_argument('-cluster', action='store_true', default=True)
    parser.add_argument('-forward_thre', type=float, default=0.1, help='Set threshold for positive detection during forward optimization')
    args = parser.parse_args()
    device = args.device
    threshold = args.threshold

    # prepare task name and key words
    with open(id2task_name_file, 'r') as f:
        id2task_name = json.load(f)
    task_id = str(args.task_id)
    task_name = id2task_name[task_id]

    # prepare input image
    image_path = args.image_path
    image_name = args.image_path.split('/')[-1].split('.')[0]
    image = open_image_follow_symlink(image_path).convert('RGB')

    # load vision-language model
    vlm_model_name = args.vlm_model
    if vlm_model_name == 'imagebind':
        vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(device)
    vlm_model.eval()
    
    # load object detection model
    if args.od_model == 'yolox':
        od_model = YOLO('./.checkpoints/yolo12x.pt')
    elif args.od_model == 'yolol':
        od_model = YOLO('./.checkpoints/tolo12l.pt')
    elif args.od_model == 'yolom':
        od_model = YOLO('./.checkpoints/tolo12m.pt')
    elif args.od_model == 'yolos':
        od_model = YOLO('./.checkpoints/tolo12s.pt')
    elif args.od_model == 'yolon':
        od_model = YOLO('./.checkpoints/tolo12n.pt')

    # get key words prompt
    with open(task2prompt_file, 'r') as f:
        prompt = json.load(f)
    prompt_use = []
    for x in range(len(prompt[task_name])):
        prompt_use.append('The item is ' + prompt[task_name][x])
    
    # get bbox image
    outputs = od_model(image_path)
    img = np.array(image)
    ocvimg = img[:, :, ::-1].copy()
    bbox_list = outputs[0].boxes.xyxy.tolist()
    classes = outputs[0].boxes.cls.tolist()
    names = outputs[0].names
    confidences = outputs[0].boxes.conf.tolist()
    predict_res = []    
    json_entry = {}
    json_entry['bbox'] = []
    json_entry['class'] = classes
    json_entry['confidences'] = confidences
    json_entry['bbox'] = bbox_list
    
    # crop bbox images
    seg_dic = crop_image(ocvimg, bbox_list)
    seg_list = []
    for id in seg_dic.keys():
        seg_list.append(seg_dic[id])
    if (len(seg_list) == 0):
        print("*"*100)
        print("Didn't detect any object in the image.")
        print("*"*100)
    N_seg = len(seg_list)

    # NOTE: test without reasoning model
    img_with_bbox = draw_bboxes(ocvimg, bbox_list, (0, 255, 0))
    save_image(img_with_bbox, f'./res/{task_id}/{image_name}_no_reasoning.jpg')


    # encode bbox image and prompt keywords
    with torch.no_grad():
        if vlm_model_name == 'imagebind':
            input = {
                ModalityType.TEXT: data.load_and_transform_text(prompt_use, device),
                ModalityType.VISION: data.read_and_transform_vision_data(seg_list, device),
            }
            embeddings = vlm_model(input)
            text_embeddings = embeddings[ModalityType.TEXT]
            bbox_embeddings = embeddings[ModalityType.VISION]
            input = {
                ModalityType.VISION: data.read_and_transform_vision_data([image], device),
            }
            embeddings = vlm_model(input)
            image_embedding = embeddings[ModalityType.VISION].squeeze(dim=0)

    # prepare TaskCLIP model
    num_layers = 8
    nhead = 4
    model_config = {}
    model_config['num_layers'] = num_layers
    model_config['norm'] = None
    model_config['return_intermediate'] = False
    model_config['d_model'] = image_embedding.shape[-1]
    model_config['nhead'] = nhead
    model_config['dim_feedforward'] = 2048
    model_config['dropout'] = 0.1
    model_config['N_words'] = text_embeddings.shape[0]
    model_config['activation'] = args.activation
    model_config['normalize_before'] = False
    model_config['device'] = device
    model_config['ratio_text'] = args.ratio_text
    model_config['ratio_image'] = args.ratio_image
    model_config['ratio_glob'] = args.ratio_glob
    model_config['norm_before'] = args.norm_before
    model_config['norm_after'] = args.norm_after
    model_config['MIN_VAL'] = float(args.norm_range.split('|')[0])
    model_config['MAX_VAL'] = float(args.norm_range.split('|')[1])
    model_config['cross_attention'] = args.cross_attention

    task_clip_model = TaskCLIP(model_config, normalize_before=model_config['normalize_before'], device = model_config['device'])
    task_clip_model.load_state_dict(torch.load(args.eval_model_path))
    task_clip_model.to(device)

    # feed text, bbox, and image embeddings into HDC model
    with torch.no_grad():
        task_clip_model.eval()
        tgt = bbox_embeddings
        memory = text_embeddings
        image_embedding = image_embedding.view(1,-1)
        tgt_new, memory_new, score_res, score_raw  = task_clip_model(tgt, memory,image_embedding)
        score = score_res.view(-1)
    
    score = score.cpu().squeeze().detach().numpy().tolist()
    # post-processing and optimization
    predict_res = []
    for i in range(len(bbox_list)):
        predict_res.append({})
        predict_res[i]["category_id"] = -1
        predict_res[i]["score"] = -1
        predict_res[i]["class"] = int(json_entry['class'][i])
    
    # same class forward optimization
    if isinstance(score, list):
        visited = [0]*len(score)
        for i, x in enumerate(score):
            if visited[i] == 1:
                continue
            if x > threshold:
                visited[i] = 1
                predict_res[i]["category_id"] = 1
                predict_res[i]["score"] = float(x)
                if args.forward:
                    find_same_class(predict_res, score, visited, i, json_entry['class'], json_entry['confidences'], args.forward_thre)
            else:
                predict_res[i]["category_id"] = 0
                predict_res[i]["score"] = 1 - float(x)
    else:
        if score > threshold:
            predict_res[0]["category_id"] = 1
            predict_res[0]["score"] = float(score)
        else:
            predict_res[0]["category_id"] = 0
            predict_res[0]["score"] = 1 - float(score)
    
    # cluster bbox optimization
    if args.cluster and args.forward and N_seg > 1:
        cluster = {}
        for p in predict_res:
            if int(p["category_id"]) == 1:
                if p["class"] in cluster.keys():
                    cluster[p["class"]].append(p["score"])
                else:
                    cluster[p["class"]] = [p["score"]]
        
        # choose one cluster
        if len(cluster.keys()) > 1:
            cluster_ave = {}
            for c in cluster.keys():
                cluster_ave[c] = np.sum(cluster[c])/len(cluster[c])
            select_class = max(cluster_ave, key=lambda k: cluster_ave[k])
            
            # remove lower score class
            for p in predict_res:
                if p["category_id"] == 1 and p["class"] != select_class:
                    p["category_id"] = 0
    
    score_final = [x["category_id"] for x in predict_res]
    # mask = score > threshold
    mask = np.array(score_final) == 1
    bbox_arr = np.asarray(bbox_list) 
    bbox_select = bbox_arr[mask]
    img_with_bbox = draw_bboxes(ocvimg, bbox_select, (255, 0, 0))
    save_image(img_with_bbox, f'./res/{task_id}/{image_name}_reasoning.jpg')