TaskCLIP / models /models.py
HanningChen
Initial HF Space: FastAPI + HTML (no weights yet)
f2f112a
import torch
import detectron2
import cv2
import numpy as np
import glob
import os
import json
from PIL import Image
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from ImageBind.imagebind import data
from ImageBind.imagebind.models import imagebind_model
from ImageBind.imagebind.models.imagebind_model import ModalityType
#from .CoCoTask_Model import CoCoTask_Model
from .test_model2 import CoCoTask_Model
class TriStageModel(torch.nn.Module):
def __init__(self,model_path) -> None:
super().__init__()
self.fast_rcnn_path = "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"
#self.fast_rcnn_path = "Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x.yaml"
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(self.fast_rcnn_path))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.fast_rcnn_path)
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.num_layers = 4
self.d_model = 1024
self.nhead = 4
self.dim_feedforward=2048
self.dropout=0.1
self.activation="relu"
self.normalize_before=False
self.return_intermediate = False
self.FrontEnd = DefaultPredictor(cfg)
self.MiddleEnd = imagebind_model.imagebind_huge(pretrained=True)
self.BackEnd = self.Construct_BackEnd(self.num_layers,
self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.device,
self.normalize_before,
self.return_intermediate)
self.BackEnd.load_state_dict(torch.load('/home/hanningchen/IJCAI24/models/saved_models/decoder_score_task1_epoch14.pt'))
def Construct_BackEnd(self,
num_layers,
d_model,
nhead,
dim_feedforward,
dropout,
activation,
device,
normalize_before=False,
return_intermediate=False):
return CoCoTask_Model(num_layers=num_layers,
norm=None,
return_intermediate=return_intermediate,
d_model = d_model,
nhead = nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
device = device)
def crop_image(self,input_image,bbx_list,results,img_id):
seg_list = []
for i in range(len(bbx_list)):
bbx_id = i
bbx_tensor = bbx_list[i].tensor.squeeze()
x0 = int(bbx_tensor[0])
y0 = int(bbx_tensor[1])
x1 = int(bbx_tensor[2])
y1 = int(bbx_tensor[3])
if x0 >= 0 and x1 <= input_image.shape[1] and y0 >= 0 and y1 <= input_image.shape[0]:
pil_image = Image.fromarray(cv2.cvtColor(input_image[y0:y1,x0:x1,:], cv2.COLOR_BGR2RGB))
seg_list.append(pil_image)
x = float(bbx_tensor[0])
y = float(bbx_tensor[1])
w = float(bbx_tensor[2]) - float(bbx_tensor[0])
h = float(bbx_tensor[3]) - float(bbx_tensor[1])
results.append({"image_id":img_id,
"bbox":[x,y,w,h],
"score":-1,
"category_id": -1})
else:
print("************************")
print("The bbx exceed the image")
print("************************")
if x0 < 0:
x0 = 0
if x1 > input_image.shape[1]:
x1 = int(input_image.shape[1])
if y0 < 0:
y0 = 0
if y1 > input_image.shape[0]:
y1 = int(input_image.shape[0])
pil_image = Image.fromarray(cv2.cvtColor(input_image[y0:y1,x0:x1,:], cv2.COLOR_BGR2RGB))
seg_list.append(pil_image)
x = float(bbx_tensor[0])
y = float(bbx_tensor[1])
w = float(bbx_tensor[2]) - float(bbx_tensor[0])
h = float(bbx_tensor[3]) - float(bbx_tensor[1])
results.append({"image_id":img_id,
"bbox":[x,y,w,h],
"score":-1,
"category_id": -1})
return seg_list
def forward(self,inputs,img_id,reason_path):
predict_res = []
self.MiddleEnd.eval()
self.MiddleEnd.to(self.device)
self.BackEnd.eval()
self.BackEnd.to(self.device)
img = np.array(inputs)
ocvimg = img[:, :, ::-1].copy()
outputs = self.FrontEnd(ocvimg)
List_bbx = outputs["instances"].pred_boxes
List_class = outputs["instances"].pred_classes.cpu().tolist()
List_score = outputs["instances"].scores.cpu().tolist()
seg_list = self.crop_image(ocvimg,List_bbx,predict_res,img_id)
if len(seg_list) == 0:
print("*******************")
print("Detecron didn't find object in image {}".format(img_id))
print("*******************")
return []
#NOTE: Prepare reason list
#prompt_file = glob.glob(os.path.join(reason_path,"*.json"))
prompt_file = reason_path
reason_list = []
with open(prompt_file) as f:
prompt = json.load(f)['visual_features']
for x in range(len(prompt)):
prompt[x] = 'The item is ' + prompt[x]
for i in range(10):
reason_list.append(prompt[i])
"""
for json_file in prompt_file:
with open(json_file,'r') as f:
prompt = json.load(f)
for object_ent in prompt['reasons']:
reason_list.append(object_ent['description'])
"""
#NOTE: Here data.read_and_transform_vision_data is modified by Hanning
middle_input = {
ModalityType.TEXT: data.load_and_transform_text(reason_list, self.device),
ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device),
}
with torch.no_grad():
embeddings = self.MiddleEnd(middle_input)
tgt = embeddings[ModalityType.VISION]
memory = embeddings[ModalityType.TEXT]
_, _, score, _ = self.BackEnd(tgt, memory)
score = score.cpu().squeeze().detach().numpy().tolist()
if isinstance(score,list):
visited = [0]*len(score)
for i, x in enumerate(score):
if visited[i] == 1:
continue
if x >= self.BackEnd.threshold:
visited[i] = 1
predict_res[i]["category_id"] = 1
predict_res[i]["score"] = float(x)
#NOTE: Chek the same class
# self.find_same_class(predict_res,score,visited,i, List_class,List_score)
else:
predict_res[i]["category_id"] = 0
predict_res[i]["score"] = float(1- x)
else:
if score >= self.BackEnd.threshold:
predict_res[0]["category_id"] = 1
predict_res[0]["score"] = float(score)
else:
predict_res[0]["category_id"] = 0
predict_res[0]["score"] = float(1- score)
return predict_res