Spaces:
Sleeping
Sleeping
File size: 9,265 Bytes
f2f112a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import json
import os
from ImageBind.imagebind import data
from ImageBind.imagebind.models import imagebind_model
from ImageBind.imagebind.models.imagebind_model import ModalityType
from collections import OrderedDict
import torch
import argparse
from utils import crop_image, draw_bboxes, save_image, find_same_class, open_image_follow_symlink
from ultralytics import YOLO
from PIL import Image
import numpy as np
from models.TaskCLIP import TaskCLIP
id2task_name_file = './id2task_name.json'
task2prompt_file = './task20.json'
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-vlm_model', default='imagebind', help='Set front CLIP model')
parser.add_argument('-od_model', default='yolox', help='Set object detection model')
parser.add_argument('-device', default='cuda:0', help='Set running environment')
parser.add_argument('-task_id', type=int, default=1, help='Set task id')
parser.add_argument('-image_path', type=str, default='./images/demo_image_1.jpg', help='Set input image path')
parser.add_argument('-activation', type=str, default='relu')
parser.add_argument('-ratio_text', type=float, default=0.3)
parser.add_argument('-ratio_image', type=float, default=0.3)
parser.add_argument('-ratio_glob', type=float, default=0.3)
parser.add_argument('-norm_before', action='store_true', default=False)
parser.add_argument('-norm_after', action='store_true', default=False)
parser.add_argument('-norm_range',type=str, default='10|30')
parser.add_argument('-cross_attention',action='store_true', default=False)
parser.add_argument('-eval_model_path',default='./test_model/decoder_epoch19.pt', help='set path for loading trained TaskCLIP model')
parser.add_argument('-threshold', type=float, default=0.01, help='Set threshold for positive detection')
parser.add_argument('-forward', action='store_true', default=True)
parser.add_argument('-cluster', action='store_true', default=True)
parser.add_argument('-forward_thre', type=float, default=0.1, help='Set threshold for positive detection during forward optimization')
args = parser.parse_args()
device = args.device
threshold = args.threshold
# prepare task name and key words
with open(id2task_name_file, 'r') as f:
id2task_name = json.load(f)
task_id = str(args.task_id)
task_name = id2task_name[task_id]
# prepare input image
image_path = args.image_path
image_name = args.image_path.split('/')[-1].split('.')[0]
image = open_image_follow_symlink(image_path).convert('RGB')
# load vision-language model
vlm_model_name = args.vlm_model
if vlm_model_name == 'imagebind':
vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(device)
vlm_model.eval()
# load object detection model
if args.od_model == 'yolox':
od_model = YOLO('./.checkpoints/yolo12x.pt')
elif args.od_model == 'yolol':
od_model = YOLO('./.checkpoints/tolo12l.pt')
elif args.od_model == 'yolom':
od_model = YOLO('./.checkpoints/tolo12m.pt')
elif args.od_model == 'yolos':
od_model = YOLO('./.checkpoints/tolo12s.pt')
elif args.od_model == 'yolon':
od_model = YOLO('./.checkpoints/tolo12n.pt')
# get key words prompt
with open(task2prompt_file, 'r') as f:
prompt = json.load(f)
prompt_use = []
for x in range(len(prompt[task_name])):
prompt_use.append('The item is ' + prompt[task_name][x])
# get bbox image
outputs = od_model(image_path)
img = np.array(image)
ocvimg = img[:, :, ::-1].copy()
bbox_list = outputs[0].boxes.xyxy.tolist()
classes = outputs[0].boxes.cls.tolist()
names = outputs[0].names
confidences = outputs[0].boxes.conf.tolist()
predict_res = []
json_entry = {}
json_entry['bbox'] = []
json_entry['class'] = classes
json_entry['confidences'] = confidences
json_entry['bbox'] = bbox_list
# crop bbox images
seg_dic = crop_image(ocvimg, bbox_list)
seg_list = []
for id in seg_dic.keys():
seg_list.append(seg_dic[id])
if (len(seg_list) == 0):
print("*"*100)
print("Didn't detect any object in the image.")
print("*"*100)
N_seg = len(seg_list)
# NOTE: test without reasoning model
img_with_bbox = draw_bboxes(ocvimg, bbox_list, (0, 255, 0))
save_image(img_with_bbox, f'./res/{task_id}/{image_name}_no_reasoning.jpg')
# encode bbox image and prompt keywords
with torch.no_grad():
if vlm_model_name == 'imagebind':
input = {
ModalityType.TEXT: data.load_and_transform_text(prompt_use, device),
ModalityType.VISION: data.read_and_transform_vision_data(seg_list, device),
}
embeddings = vlm_model(input)
text_embeddings = embeddings[ModalityType.TEXT]
bbox_embeddings = embeddings[ModalityType.VISION]
input = {
ModalityType.VISION: data.read_and_transform_vision_data([image], device),
}
embeddings = vlm_model(input)
image_embedding = embeddings[ModalityType.VISION].squeeze(dim=0)
# prepare TaskCLIP model
num_layers = 8
nhead = 4
model_config = {}
model_config['num_layers'] = num_layers
model_config['norm'] = None
model_config['return_intermediate'] = False
model_config['d_model'] = image_embedding.shape[-1]
model_config['nhead'] = nhead
model_config['dim_feedforward'] = 2048
model_config['dropout'] = 0.1
model_config['N_words'] = text_embeddings.shape[0]
model_config['activation'] = args.activation
model_config['normalize_before'] = False
model_config['device'] = device
model_config['ratio_text'] = args.ratio_text
model_config['ratio_image'] = args.ratio_image
model_config['ratio_glob'] = args.ratio_glob
model_config['norm_before'] = args.norm_before
model_config['norm_after'] = args.norm_after
model_config['MIN_VAL'] = float(args.norm_range.split('|')[0])
model_config['MAX_VAL'] = float(args.norm_range.split('|')[1])
model_config['cross_attention'] = args.cross_attention
task_clip_model = TaskCLIP(model_config, normalize_before=model_config['normalize_before'], device = model_config['device'])
task_clip_model.load_state_dict(torch.load(args.eval_model_path))
task_clip_model.to(device)
# feed text, bbox, and image embeddings into HDC model
with torch.no_grad():
task_clip_model.eval()
tgt = bbox_embeddings
memory = text_embeddings
image_embedding = image_embedding.view(1,-1)
tgt_new, memory_new, score_res, score_raw = task_clip_model(tgt, memory,image_embedding)
score = score_res.view(-1)
score = score.cpu().squeeze().detach().numpy().tolist()
# post-processing and optimization
predict_res = []
for i in range(len(bbox_list)):
predict_res.append({})
predict_res[i]["category_id"] = -1
predict_res[i]["score"] = -1
predict_res[i]["class"] = int(json_entry['class'][i])
# same class forward optimization
if isinstance(score, list):
visited = [0]*len(score)
for i, x in enumerate(score):
if visited[i] == 1:
continue
if x > threshold:
visited[i] = 1
predict_res[i]["category_id"] = 1
predict_res[i]["score"] = float(x)
if args.forward:
find_same_class(predict_res, score, visited, i, json_entry['class'], json_entry['confidences'], args.forward_thre)
else:
predict_res[i]["category_id"] = 0
predict_res[i]["score"] = 1 - float(x)
else:
if score > threshold:
predict_res[0]["category_id"] = 1
predict_res[0]["score"] = float(score)
else:
predict_res[0]["category_id"] = 0
predict_res[0]["score"] = 1 - float(score)
# cluster bbox optimization
if args.cluster and args.forward and N_seg > 1:
cluster = {}
for p in predict_res:
if int(p["category_id"]) == 1:
if p["class"] in cluster.keys():
cluster[p["class"]].append(p["score"])
else:
cluster[p["class"]] = [p["score"]]
# choose one cluster
if len(cluster.keys()) > 1:
cluster_ave = {}
for c in cluster.keys():
cluster_ave[c] = np.sum(cluster[c])/len(cluster[c])
select_class = max(cluster_ave, key=lambda k: cluster_ave[k])
# remove lower score class
for p in predict_res:
if p["category_id"] == 1 and p["class"] != select_class:
p["category_id"] = 0
score_final = [x["category_id"] for x in predict_res]
# mask = score > threshold
mask = np.array(score_final) == 1
bbox_arr = np.asarray(bbox_list)
bbox_select = bbox_arr[mask]
img_with_bbox = draw_bboxes(ocvimg, bbox_select, (255, 0, 0))
save_image(img_with_bbox, f'./res/{task_id}/{image_name}_reasoning.jpg') |