TaskCLIP / models /models.py
HanningChen
Initial HF Space: FastAPI + HTML (no weights yet)
f2f112a
raw
history blame
8.36 kB
import torch
import detectron2
import cv2
import numpy as np
import glob
import os
import json
from PIL import Image
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from ImageBind.imagebind import data
from ImageBind.imagebind.models import imagebind_model
from ImageBind.imagebind.models.imagebind_model import ModalityType
#from .CoCoTask_Model import CoCoTask_Model
from .test_model2 import CoCoTask_Model
class TriStageModel(torch.nn.Module):
def __init__(self,model_path) -> None:
super().__init__()
self.fast_rcnn_path = "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"
#self.fast_rcnn_path = "Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x.yaml"
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(self.fast_rcnn_path))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.fast_rcnn_path)
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.num_layers = 4
self.d_model = 1024
self.nhead = 4
self.dim_feedforward=2048
self.dropout=0.1
self.activation="relu"
self.normalize_before=False
self.return_intermediate = False
self.FrontEnd = DefaultPredictor(cfg)
self.MiddleEnd = imagebind_model.imagebind_huge(pretrained=True)
self.BackEnd = self.Construct_BackEnd(self.num_layers,
self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.device,
self.normalize_before,
self.return_intermediate)
self.BackEnd.load_state_dict(torch.load('/home/hanningchen/IJCAI24/models/saved_models/decoder_score_task1_epoch14.pt'))
def Construct_BackEnd(self,
num_layers,
d_model,
nhead,
dim_feedforward,
dropout,
activation,
device,
normalize_before=False,
return_intermediate=False):
return CoCoTask_Model(num_layers=num_layers,
norm=None,
return_intermediate=return_intermediate,
d_model = d_model,
nhead = nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
device = device)
def crop_image(self,input_image,bbx_list,results,img_id):
seg_list = []
for i in range(len(bbx_list)):
bbx_id = i
bbx_tensor = bbx_list[i].tensor.squeeze()
x0 = int(bbx_tensor[0])
y0 = int(bbx_tensor[1])
x1 = int(bbx_tensor[2])
y1 = int(bbx_tensor[3])
if x0 >= 0 and x1 <= input_image.shape[1] and y0 >= 0 and y1 <= input_image.shape[0]:
pil_image = Image.fromarray(cv2.cvtColor(input_image[y0:y1,x0:x1,:], cv2.COLOR_BGR2RGB))
seg_list.append(pil_image)
x = float(bbx_tensor[0])
y = float(bbx_tensor[1])
w = float(bbx_tensor[2]) - float(bbx_tensor[0])
h = float(bbx_tensor[3]) - float(bbx_tensor[1])
results.append({"image_id":img_id,
"bbox":[x,y,w,h],
"score":-1,
"category_id": -1})
else:
print("************************")
print("The bbx exceed the image")
print("************************")
if x0 < 0:
x0 = 0
if x1 > input_image.shape[1]:
x1 = int(input_image.shape[1])
if y0 < 0:
y0 = 0
if y1 > input_image.shape[0]:
y1 = int(input_image.shape[0])
pil_image = Image.fromarray(cv2.cvtColor(input_image[y0:y1,x0:x1,:], cv2.COLOR_BGR2RGB))
seg_list.append(pil_image)
x = float(bbx_tensor[0])
y = float(bbx_tensor[1])
w = float(bbx_tensor[2]) - float(bbx_tensor[0])
h = float(bbx_tensor[3]) - float(bbx_tensor[1])
results.append({"image_id":img_id,
"bbox":[x,y,w,h],
"score":-1,
"category_id": -1})
return seg_list
def forward(self,inputs,img_id,reason_path):
predict_res = []
self.MiddleEnd.eval()
self.MiddleEnd.to(self.device)
self.BackEnd.eval()
self.BackEnd.to(self.device)
img = np.array(inputs)
ocvimg = img[:, :, ::-1].copy()
outputs = self.FrontEnd(ocvimg)
List_bbx = outputs["instances"].pred_boxes
List_class = outputs["instances"].pred_classes.cpu().tolist()
List_score = outputs["instances"].scores.cpu().tolist()
seg_list = self.crop_image(ocvimg,List_bbx,predict_res,img_id)
if len(seg_list) == 0:
print("*******************")
print("Detecron didn't find object in image {}".format(img_id))
print("*******************")
return []
#NOTE: Prepare reason list
#prompt_file = glob.glob(os.path.join(reason_path,"*.json"))
prompt_file = reason_path
reason_list = []
with open(prompt_file) as f:
prompt = json.load(f)['visual_features']
for x in range(len(prompt)):
prompt[x] = 'The item is ' + prompt[x]
for i in range(10):
reason_list.append(prompt[i])
"""
for json_file in prompt_file:
with open(json_file,'r') as f:
prompt = json.load(f)
for object_ent in prompt['reasons']:
reason_list.append(object_ent['description'])
"""
#NOTE: Here data.read_and_transform_vision_data is modified by Hanning
middle_input = {
ModalityType.TEXT: data.load_and_transform_text(reason_list, self.device),
ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device),
}
with torch.no_grad():
embeddings = self.MiddleEnd(middle_input)
tgt = embeddings[ModalityType.VISION]
memory = embeddings[ModalityType.TEXT]
_, _, score, _ = self.BackEnd(tgt, memory)
score = score.cpu().squeeze().detach().numpy().tolist()
if isinstance(score,list):
visited = [0]*len(score)
for i, x in enumerate(score):
if visited[i] == 1:
continue
if x >= self.BackEnd.threshold:
visited[i] = 1
predict_res[i]["category_id"] = 1
predict_res[i]["score"] = float(x)
#NOTE: Chek the same class
# self.find_same_class(predict_res,score,visited,i, List_class,List_score)
else:
predict_res[i]["category_id"] = 0
predict_res[i]["score"] = float(1- x)
else:
if score >= self.BackEnd.threshold:
predict_res[0]["category_id"] = 1
predict_res[0]["score"] = float(score)
else:
predict_res[0]["category_id"] = 0
predict_res[0]["score"] = float(1- score)
return predict_res