Spaces:
Running
on
A10G
Running
on
A10G
| import torch | |
| import detectron2 | |
| import cv2 | |
| import numpy as np | |
| import glob | |
| import os | |
| import json | |
| from PIL import Image | |
| from detectron2 import model_zoo | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.utils.visualizer import Visualizer | |
| from ImageBind.imagebind import data | |
| from ImageBind.imagebind.models import imagebind_model | |
| from ImageBind.imagebind.models.imagebind_model import ModalityType | |
| #from .CoCoTask_Model import CoCoTask_Model | |
| from .test_model2 import CoCoTask_Model | |
| class TriStageModel(torch.nn.Module): | |
| def __init__(self,model_path) -> None: | |
| super().__init__() | |
| self.fast_rcnn_path = "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml" | |
| #self.fast_rcnn_path = "Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x.yaml" | |
| cfg = get_cfg() | |
| cfg.merge_from_file(model_zoo.get_config_file(self.fast_rcnn_path)) | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.fast_rcnn_path) | |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| self.num_layers = 4 | |
| self.d_model = 1024 | |
| self.nhead = 4 | |
| self.dim_feedforward=2048 | |
| self.dropout=0.1 | |
| self.activation="relu" | |
| self.normalize_before=False | |
| self.return_intermediate = False | |
| self.FrontEnd = DefaultPredictor(cfg) | |
| self.MiddleEnd = imagebind_model.imagebind_huge(pretrained=True) | |
| self.BackEnd = self.Construct_BackEnd(self.num_layers, | |
| self.d_model, | |
| self.nhead, | |
| self.dim_feedforward, | |
| self.dropout, | |
| self.activation, | |
| self.device, | |
| self.normalize_before, | |
| self.return_intermediate) | |
| self.BackEnd.load_state_dict(torch.load('/home/hanningchen/IJCAI24/models/saved_models/decoder_score_task1_epoch14.pt')) | |
| def Construct_BackEnd(self, | |
| num_layers, | |
| d_model, | |
| nhead, | |
| dim_feedforward, | |
| dropout, | |
| activation, | |
| device, | |
| normalize_before=False, | |
| return_intermediate=False): | |
| return CoCoTask_Model(num_layers=num_layers, | |
| norm=None, | |
| return_intermediate=return_intermediate, | |
| d_model = d_model, | |
| nhead = nhead, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| activation=activation, | |
| normalize_before=normalize_before, | |
| device = device) | |
| def crop_image(self,input_image,bbx_list,results,img_id): | |
| seg_list = [] | |
| for i in range(len(bbx_list)): | |
| bbx_id = i | |
| bbx_tensor = bbx_list[i].tensor.squeeze() | |
| x0 = int(bbx_tensor[0]) | |
| y0 = int(bbx_tensor[1]) | |
| x1 = int(bbx_tensor[2]) | |
| y1 = int(bbx_tensor[3]) | |
| if x0 >= 0 and x1 <= input_image.shape[1] and y0 >= 0 and y1 <= input_image.shape[0]: | |
| pil_image = Image.fromarray(cv2.cvtColor(input_image[y0:y1,x0:x1,:], cv2.COLOR_BGR2RGB)) | |
| seg_list.append(pil_image) | |
| x = float(bbx_tensor[0]) | |
| y = float(bbx_tensor[1]) | |
| w = float(bbx_tensor[2]) - float(bbx_tensor[0]) | |
| h = float(bbx_tensor[3]) - float(bbx_tensor[1]) | |
| results.append({"image_id":img_id, | |
| "bbox":[x,y,w,h], | |
| "score":-1, | |
| "category_id": -1}) | |
| else: | |
| print("************************") | |
| print("The bbx exceed the image") | |
| print("************************") | |
| if x0 < 0: | |
| x0 = 0 | |
| if x1 > input_image.shape[1]: | |
| x1 = int(input_image.shape[1]) | |
| if y0 < 0: | |
| y0 = 0 | |
| if y1 > input_image.shape[0]: | |
| y1 = int(input_image.shape[0]) | |
| pil_image = Image.fromarray(cv2.cvtColor(input_image[y0:y1,x0:x1,:], cv2.COLOR_BGR2RGB)) | |
| seg_list.append(pil_image) | |
| x = float(bbx_tensor[0]) | |
| y = float(bbx_tensor[1]) | |
| w = float(bbx_tensor[2]) - float(bbx_tensor[0]) | |
| h = float(bbx_tensor[3]) - float(bbx_tensor[1]) | |
| results.append({"image_id":img_id, | |
| "bbox":[x,y,w,h], | |
| "score":-1, | |
| "category_id": -1}) | |
| return seg_list | |
| def forward(self,inputs,img_id,reason_path): | |
| predict_res = [] | |
| self.MiddleEnd.eval() | |
| self.MiddleEnd.to(self.device) | |
| self.BackEnd.eval() | |
| self.BackEnd.to(self.device) | |
| img = np.array(inputs) | |
| ocvimg = img[:, :, ::-1].copy() | |
| outputs = self.FrontEnd(ocvimg) | |
| List_bbx = outputs["instances"].pred_boxes | |
| List_class = outputs["instances"].pred_classes.cpu().tolist() | |
| List_score = outputs["instances"].scores.cpu().tolist() | |
| seg_list = self.crop_image(ocvimg,List_bbx,predict_res,img_id) | |
| if len(seg_list) == 0: | |
| print("*******************") | |
| print("Detecron didn't find object in image {}".format(img_id)) | |
| print("*******************") | |
| return [] | |
| #NOTE: Prepare reason list | |
| #prompt_file = glob.glob(os.path.join(reason_path,"*.json")) | |
| prompt_file = reason_path | |
| reason_list = [] | |
| with open(prompt_file) as f: | |
| prompt = json.load(f)['visual_features'] | |
| for x in range(len(prompt)): | |
| prompt[x] = 'The item is ' + prompt[x] | |
| for i in range(10): | |
| reason_list.append(prompt[i]) | |
| """ | |
| for json_file in prompt_file: | |
| with open(json_file,'r') as f: | |
| prompt = json.load(f) | |
| for object_ent in prompt['reasons']: | |
| reason_list.append(object_ent['description']) | |
| """ | |
| #NOTE: Here data.read_and_transform_vision_data is modified by Hanning | |
| middle_input = { | |
| ModalityType.TEXT: data.load_and_transform_text(reason_list, self.device), | |
| ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device), | |
| } | |
| with torch.no_grad(): | |
| embeddings = self.MiddleEnd(middle_input) | |
| tgt = embeddings[ModalityType.VISION] | |
| memory = embeddings[ModalityType.TEXT] | |
| _, _, score, _ = self.BackEnd(tgt, memory) | |
| score = score.cpu().squeeze().detach().numpy().tolist() | |
| if isinstance(score,list): | |
| visited = [0]*len(score) | |
| for i, x in enumerate(score): | |
| if visited[i] == 1: | |
| continue | |
| if x >= self.BackEnd.threshold: | |
| visited[i] = 1 | |
| predict_res[i]["category_id"] = 1 | |
| predict_res[i]["score"] = float(x) | |
| #NOTE: Chek the same class | |
| # self.find_same_class(predict_res,score,visited,i, List_class,List_score) | |
| else: | |
| predict_res[i]["category_id"] = 0 | |
| predict_res[i]["score"] = float(1- x) | |
| else: | |
| if score >= self.BackEnd.threshold: | |
| predict_res[0]["category_id"] = 1 | |
| predict_res[0]["score"] = float(score) | |
| else: | |
| predict_res[0]["category_id"] = 0 | |
| predict_res[0]["score"] = float(1- score) | |
| return predict_res |