import torch import detectron2 import cv2 import numpy as np import glob import os import json from PIL import Image from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from ImageBind.imagebind import data from ImageBind.imagebind.models import imagebind_model from ImageBind.imagebind.models.imagebind_model import ModalityType #from .CoCoTask_Model import CoCoTask_Model from .test_model2 import CoCoTask_Model class TriStageModel(torch.nn.Module): def __init__(self,model_path) -> None: super().__init__() self.fast_rcnn_path = "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml" #self.fast_rcnn_path = "Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x.yaml" cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file(self.fast_rcnn_path)) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.fast_rcnn_path) self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.num_layers = 4 self.d_model = 1024 self.nhead = 4 self.dim_feedforward=2048 self.dropout=0.1 self.activation="relu" self.normalize_before=False self.return_intermediate = False self.FrontEnd = DefaultPredictor(cfg) self.MiddleEnd = imagebind_model.imagebind_huge(pretrained=True) self.BackEnd = self.Construct_BackEnd(self.num_layers, self.d_model, self.nhead, self.dim_feedforward, self.dropout, self.activation, self.device, self.normalize_before, self.return_intermediate) self.BackEnd.load_state_dict(torch.load('/home/hanningchen/IJCAI24/models/saved_models/decoder_score_task1_epoch14.pt')) def Construct_BackEnd(self, num_layers, d_model, nhead, dim_feedforward, dropout, activation, device, normalize_before=False, return_intermediate=False): return CoCoTask_Model(num_layers=num_layers, norm=None, return_intermediate=return_intermediate, d_model = d_model, nhead = nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before, device = device) def crop_image(self,input_image,bbx_list,results,img_id): seg_list = [] for i in range(len(bbx_list)): bbx_id = i bbx_tensor = bbx_list[i].tensor.squeeze() x0 = int(bbx_tensor[0]) y0 = int(bbx_tensor[1]) x1 = int(bbx_tensor[2]) y1 = int(bbx_tensor[3]) if x0 >= 0 and x1 <= input_image.shape[1] and y0 >= 0 and y1 <= input_image.shape[0]: pil_image = Image.fromarray(cv2.cvtColor(input_image[y0:y1,x0:x1,:], cv2.COLOR_BGR2RGB)) seg_list.append(pil_image) x = float(bbx_tensor[0]) y = float(bbx_tensor[1]) w = float(bbx_tensor[2]) - float(bbx_tensor[0]) h = float(bbx_tensor[3]) - float(bbx_tensor[1]) results.append({"image_id":img_id, "bbox":[x,y,w,h], "score":-1, "category_id": -1}) else: print("************************") print("The bbx exceed the image") print("************************") if x0 < 0: x0 = 0 if x1 > input_image.shape[1]: x1 = int(input_image.shape[1]) if y0 < 0: y0 = 0 if y1 > input_image.shape[0]: y1 = int(input_image.shape[0]) pil_image = Image.fromarray(cv2.cvtColor(input_image[y0:y1,x0:x1,:], cv2.COLOR_BGR2RGB)) seg_list.append(pil_image) x = float(bbx_tensor[0]) y = float(bbx_tensor[1]) w = float(bbx_tensor[2]) - float(bbx_tensor[0]) h = float(bbx_tensor[3]) - float(bbx_tensor[1]) results.append({"image_id":img_id, "bbox":[x,y,w,h], "score":-1, "category_id": -1}) return seg_list def forward(self,inputs,img_id,reason_path): predict_res = [] self.MiddleEnd.eval() self.MiddleEnd.to(self.device) self.BackEnd.eval() self.BackEnd.to(self.device) img = np.array(inputs) ocvimg = img[:, :, ::-1].copy() outputs = self.FrontEnd(ocvimg) List_bbx = outputs["instances"].pred_boxes List_class = outputs["instances"].pred_classes.cpu().tolist() List_score = outputs["instances"].scores.cpu().tolist() seg_list = self.crop_image(ocvimg,List_bbx,predict_res,img_id) if len(seg_list) == 0: print("*******************") print("Detecron didn't find object in image {}".format(img_id)) print("*******************") return [] #NOTE: Prepare reason list #prompt_file = glob.glob(os.path.join(reason_path,"*.json")) prompt_file = reason_path reason_list = [] with open(prompt_file) as f: prompt = json.load(f)['visual_features'] for x in range(len(prompt)): prompt[x] = 'The item is ' + prompt[x] for i in range(10): reason_list.append(prompt[i]) """ for json_file in prompt_file: with open(json_file,'r') as f: prompt = json.load(f) for object_ent in prompt['reasons']: reason_list.append(object_ent['description']) """ #NOTE: Here data.read_and_transform_vision_data is modified by Hanning middle_input = { ModalityType.TEXT: data.load_and_transform_text(reason_list, self.device), ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device), } with torch.no_grad(): embeddings = self.MiddleEnd(middle_input) tgt = embeddings[ModalityType.VISION] memory = embeddings[ModalityType.TEXT] _, _, score, _ = self.BackEnd(tgt, memory) score = score.cpu().squeeze().detach().numpy().tolist() if isinstance(score,list): visited = [0]*len(score) for i, x in enumerate(score): if visited[i] == 1: continue if x >= self.BackEnd.threshold: visited[i] = 1 predict_res[i]["category_id"] = 1 predict_res[i]["score"] = float(x) #NOTE: Chek the same class # self.find_same_class(predict_res,score,visited,i, List_class,List_score) else: predict_res[i]["category_id"] = 0 predict_res[i]["score"] = float(1- x) else: if score >= self.BackEnd.threshold: predict_res[0]["category_id"] = 1 predict_res[0]["score"] = float(score) else: predict_res[0]["category_id"] = 0 predict_res[0]["score"] = float(1- score) return predict_res