import json import os from ImageBind.imagebind import data from ImageBind.imagebind.models import imagebind_model from ImageBind.imagebind.models.imagebind_model import ModalityType from collections import OrderedDict import torch import argparse from utils import crop_image, draw_bboxes, save_image, find_same_class, open_image_follow_symlink from ultralytics import YOLO from PIL import Image import numpy as np from models.TaskCLIP import TaskCLIP id2task_name_file = './id2task_name.json' task2prompt_file = './task20.json' if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-vlm_model', default='imagebind', help='Set front CLIP model') parser.add_argument('-od_model', default='yolox', help='Set object detection model') parser.add_argument('-device', default='cuda:0', help='Set running environment') parser.add_argument('-task_id', type=int, default=1, help='Set task id') parser.add_argument('-image_path', type=str, default='./images/demo_image_1.jpg', help='Set input image path') parser.add_argument('-activation', type=str, default='relu') parser.add_argument('-ratio_text', type=float, default=0.3) parser.add_argument('-ratio_image', type=float, default=0.3) parser.add_argument('-ratio_glob', type=float, default=0.3) parser.add_argument('-norm_before', action='store_true', default=False) parser.add_argument('-norm_after', action='store_true', default=False) parser.add_argument('-norm_range',type=str, default='10|30') parser.add_argument('-cross_attention',action='store_true', default=False) parser.add_argument('-eval_model_path',default='./test_model/decoder_epoch19.pt', help='set path for loading trained TaskCLIP model') parser.add_argument('-threshold', type=float, default=0.01, help='Set threshold for positive detection') parser.add_argument('-forward', action='store_true', default=True) parser.add_argument('-cluster', action='store_true', default=True) parser.add_argument('-forward_thre', type=float, default=0.1, help='Set threshold for positive detection during forward optimization') args = parser.parse_args() device = args.device threshold = args.threshold # prepare task name and key words with open(id2task_name_file, 'r') as f: id2task_name = json.load(f) task_id = str(args.task_id) task_name = id2task_name[task_id] # prepare input image image_path = args.image_path image_name = args.image_path.split('/')[-1].split('.')[0] image = open_image_follow_symlink(image_path).convert('RGB') # load vision-language model vlm_model_name = args.vlm_model if vlm_model_name == 'imagebind': vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(device) vlm_model.eval() # load object detection model if args.od_model == 'yolox': od_model = YOLO('./.checkpoints/yolo12x.pt') elif args.od_model == 'yolol': od_model = YOLO('./.checkpoints/tolo12l.pt') elif args.od_model == 'yolom': od_model = YOLO('./.checkpoints/tolo12m.pt') elif args.od_model == 'yolos': od_model = YOLO('./.checkpoints/tolo12s.pt') elif args.od_model == 'yolon': od_model = YOLO('./.checkpoints/tolo12n.pt') # get key words prompt with open(task2prompt_file, 'r') as f: prompt = json.load(f) prompt_use = [] for x in range(len(prompt[task_name])): prompt_use.append('The item is ' + prompt[task_name][x]) # get bbox image outputs = od_model(image_path) img = np.array(image) ocvimg = img[:, :, ::-1].copy() bbox_list = outputs[0].boxes.xyxy.tolist() classes = outputs[0].boxes.cls.tolist() names = outputs[0].names confidences = outputs[0].boxes.conf.tolist() predict_res = [] json_entry = {} json_entry['bbox'] = [] json_entry['class'] = classes json_entry['confidences'] = confidences json_entry['bbox'] = bbox_list # crop bbox images seg_dic = crop_image(ocvimg, bbox_list) seg_list = [] for id in seg_dic.keys(): seg_list.append(seg_dic[id]) if (len(seg_list) == 0): print("*"*100) print("Didn't detect any object in the image.") print("*"*100) N_seg = len(seg_list) # NOTE: test without reasoning model img_with_bbox = draw_bboxes(ocvimg, bbox_list, (0, 255, 0)) save_image(img_with_bbox, f'./res/{task_id}/{image_name}_no_reasoning.jpg') # encode bbox image and prompt keywords with torch.no_grad(): if vlm_model_name == 'imagebind': input = { ModalityType.TEXT: data.load_and_transform_text(prompt_use, device), ModalityType.VISION: data.read_and_transform_vision_data(seg_list, device), } embeddings = vlm_model(input) text_embeddings = embeddings[ModalityType.TEXT] bbox_embeddings = embeddings[ModalityType.VISION] input = { ModalityType.VISION: data.read_and_transform_vision_data([image], device), } embeddings = vlm_model(input) image_embedding = embeddings[ModalityType.VISION].squeeze(dim=0) # prepare TaskCLIP model num_layers = 8 nhead = 4 model_config = {} model_config['num_layers'] = num_layers model_config['norm'] = None model_config['return_intermediate'] = False model_config['d_model'] = image_embedding.shape[-1] model_config['nhead'] = nhead model_config['dim_feedforward'] = 2048 model_config['dropout'] = 0.1 model_config['N_words'] = text_embeddings.shape[0] model_config['activation'] = args.activation model_config['normalize_before'] = False model_config['device'] = device model_config['ratio_text'] = args.ratio_text model_config['ratio_image'] = args.ratio_image model_config['ratio_glob'] = args.ratio_glob model_config['norm_before'] = args.norm_before model_config['norm_after'] = args.norm_after model_config['MIN_VAL'] = float(args.norm_range.split('|')[0]) model_config['MAX_VAL'] = float(args.norm_range.split('|')[1]) model_config['cross_attention'] = args.cross_attention task_clip_model = TaskCLIP(model_config, normalize_before=model_config['normalize_before'], device = model_config['device']) task_clip_model.load_state_dict(torch.load(args.eval_model_path)) task_clip_model.to(device) # feed text, bbox, and image embeddings into HDC model with torch.no_grad(): task_clip_model.eval() tgt = bbox_embeddings memory = text_embeddings image_embedding = image_embedding.view(1,-1) tgt_new, memory_new, score_res, score_raw = task_clip_model(tgt, memory,image_embedding) score = score_res.view(-1) score = score.cpu().squeeze().detach().numpy().tolist() # post-processing and optimization predict_res = [] for i in range(len(bbox_list)): predict_res.append({}) predict_res[i]["category_id"] = -1 predict_res[i]["score"] = -1 predict_res[i]["class"] = int(json_entry['class'][i]) # same class forward optimization if isinstance(score, list): visited = [0]*len(score) for i, x in enumerate(score): if visited[i] == 1: continue if x > threshold: visited[i] = 1 predict_res[i]["category_id"] = 1 predict_res[i]["score"] = float(x) if args.forward: find_same_class(predict_res, score, visited, i, json_entry['class'], json_entry['confidences'], args.forward_thre) else: predict_res[i]["category_id"] = 0 predict_res[i]["score"] = 1 - float(x) else: if score > threshold: predict_res[0]["category_id"] = 1 predict_res[0]["score"] = float(score) else: predict_res[0]["category_id"] = 0 predict_res[0]["score"] = 1 - float(score) # cluster bbox optimization if args.cluster and args.forward and N_seg > 1: cluster = {} for p in predict_res: if int(p["category_id"]) == 1: if p["class"] in cluster.keys(): cluster[p["class"]].append(p["score"]) else: cluster[p["class"]] = [p["score"]] # choose one cluster if len(cluster.keys()) > 1: cluster_ave = {} for c in cluster.keys(): cluster_ave[c] = np.sum(cluster[c])/len(cluster[c]) select_class = max(cluster_ave, key=lambda k: cluster_ave[k]) # remove lower score class for p in predict_res: if p["category_id"] == 1 and p["class"] != select_class: p["category_id"] = 0 score_final = [x["category_id"] for x in predict_res] # mask = score > threshold mask = np.array(score_final) == 1 bbox_arr = np.asarray(bbox_list) bbox_select = bbox_arr[mask] img_with_bbox = draw_bboxes(ocvimg, bbox_select, (255, 0, 0)) save_image(img_with_bbox, f'./res/{task_id}/{image_name}_reasoning.jpg')