TaskCLIP / webui /runner.py
HanningChen
Download weights from HF model repo and use cached paths
6feb3b2
raw
history blame
9.81 kB
import json
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import torch
from PIL import Image, ImageDraw
from ultralytics import YOLO, SAM
from ImageBind.imagebind import data
from ImageBind.imagebind.models import imagebind_model
from ImageBind.imagebind.models.imagebind_model import ModalityType
from models.TaskCLIP import TaskCLIP
# ... keep your helper funcs _draw_boxes_pil/_crop_pil/overlay_masks ...
class ModelRunner:
def __init__(
self,
project_root: str,
device: str = "cuda:0",
yolo_ckpt: str = "./.checkpoints/yolo12x.pt",
sam_ckpt: str = "./.checkpoints/sam2.1_l.pt",
imagebind_ckpt: Optional[str] = None, # NEW
id2task_name_file: str = "./id2task_name.json",
task2prompt_file: str = "./task20.json",
threshold: float = 0.01,
forward: bool = True,
cluster: bool = True,
forward_thre: float = 0.1,
):
self.root = Path(project_root).resolve()
self.device = device
self.threshold = float(threshold)
self.forward = bool(forward)
self.cluster = bool(cluster)
self.forward_thre = float(forward_thre)
# metadata
self.id2task_name_path = (self.root / id2task_name_file).resolve()
self.task2prompt_path = (self.root / task2prompt_file).resolve()
self.id2task_name = json.loads(self.id2task_name_path.read_text())
self.task2prompt = json.loads(self.task2prompt_path.read_text())
# caches
self._yolo_cache = {}
self._taskclip_cache = {}
# YOLO path (kept for reference; actual YOLO models are cached per ckpt in _get_yolo)
self.yolo_ckpt_path = (self.root / yolo_ckpt).resolve() if str(yolo_ckpt).startswith(".") else Path(yolo_ckpt)
# ---- SAM load ONCE (from absolute or repo-relative path) ----
sam_ckpt_path = (self.root / sam_ckpt).resolve() if str(sam_ckpt).startswith(".") else Path(sam_ckpt)
self.sam = SAM(str(sam_ckpt_path))
# ---- ImageBind load ONCE ----
# If you provide imagebind_huge.pth from weights repo, use it.
# Otherwise fall back to pretrained=True behavior.
self.vlm_model = imagebind_model.imagebind_huge(pretrained=False).to(self.device).eval()
if imagebind_ckpt:
ckpt_path = (self.root / imagebind_ckpt).resolve() if str(imagebind_ckpt).startswith(".") else Path(imagebind_ckpt)
if ckpt_path.exists():
state = torch.load(str(ckpt_path), map_location="cpu")
# robust handling of different checkpoint formats
if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
state = state["model"]
elif isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
state = state["state_dict"]
self.vlm_model.load_state_dict(state, strict=False)
else:
# fallback if file missing
self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
else:
self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
self._lock = torch.multiprocessing.RLock()
def _get_yolo(self, ckpt_path: str):
ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
if ckpt_abs not in self._yolo_cache:
self._yolo_cache[ckpt_abs] = YOLO(ckpt_abs)
return self._yolo_cache[ckpt_abs]
def list_task_ids(self) -> List[int]:
ids = []
for k in self.id2task_name.keys():
try:
ids.append(int(k))
except Exception:
pass
return sorted(ids)
def _get_taskclip(self, ckpt_path: str, d_model: int, n_words: int, score_function: str, hdv_dim: int):
ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
if not Path(ckpt_abs).exists():
raise FileNotFoundError(f"TaskCLIP checkpoint not found: {ckpt_abs}")
eff_hdv_dim = int(hdv_dim) if score_function == "HDC" else 0
key = (ckpt_abs, int(d_model), int(n_words), str(score_function), eff_hdv_dim)
if key in self._taskclip_cache:
return self._taskclip_cache[key]
model_config = {
"num_layers": 8,
"norm": None,
"return_intermediate": False,
"d_model": int(d_model),
"nhead": 4,
"dim_feedforward": 2048,
"dropout": 0.1,
"N_words": int(n_words),
"activation": "gelu",
"normalize_before": False,
"device": self.device,
"ratio_text": 0.3,
"ratio_image": 0.3,
"ratio_glob": 0.3,
"norm_before": True,
"norm_after": False,
"MIN_VAL": 10.0,
"MAX_VAL": 30.0,
"cross_attention": True, # keep consistent with how your checkpoint was trained
"score_function": "HDC" if score_function == "HDC" else "default",
"HDV_D": int(eff_hdv_dim),
}
m = TaskCLIP(model_config, normalize_before=model_config["normalize_before"], device=model_config["device"])
state = torch.load(ckpt_abs, map_location="cpu")
m.load_state_dict(state, strict=True)
m = m.to(self.device).eval()
self._taskclip_cache[key] = m
return m
def _sam_masks_from_bboxes(self, image_path: str, bbox_list: List[List[float]], img_h: int, img_w: int) -> np.ndarray:
if not bbox_list:
return np.zeros((0, img_h, img_w), dtype=bool)
bboxes = [[float(x0), float(y0), float(x1), float(y1)] for x0, y0, x1, y1 in bbox_list]
# multi-box call
res = self.sam(image_path, bboxes=bboxes)
r0 = res[0]
if r0.masks is None:
return np.zeros((0, img_h, img_w), dtype=bool)
return r0.masks.data.detach().cpu().numpy().astype(bool)
def run(
self,
image_path: str,
task_id: int,
vlm_model: str = "imagebind",
od_model: str = "yolo",
yolo_ckpt: str = "./.checkpoints/yolo12x.pt",
score_function: str = "default",
hdv_dim: int = 256,
taskclip_ckpt: str = "./test_model/default/decoder.pt",
viz_mode: str = "bbox",
) -> Dict[str, Any]:
if vlm_model != "imagebind":
raise ValueError("This runner.py currently implements ImageBind only (your OpenCLIP version was in the other runner).")
if od_model != "yolo":
raise ValueError("Only od_model='yolo' supported.")
with self._lock:
img = Image.open(image_path).convert("RGB")
task_name = self.id2task_name[str(task_id)]
prompt_words = self.task2prompt[task_name]
prompt_use = ["The item is " + w for w in prompt_words]
# YOLO
yolo = self._get_yolo(yolo_ckpt)
outputs = yolo(image_path)
bbox_list = outputs[0].boxes.xyxy.tolist()
classes = outputs[0].boxes.cls.tolist()
confidences = outputs[0].boxes.conf.tolist()
all_boxes = np.asarray(bbox_list, dtype=np.float32)
H = img.size[1]
W = img.size[0]
# IMPORTANT: only run SAM if viz_mode == mask
all_masks = None
if viz_mode == "bbox":
img_yolo = _draw_boxes_pil(img, all_boxes, color=(0, 255, 0), width=3)
elif viz_mode == "mask":
all_masks = self._sam_masks_from_bboxes(image_path, bbox_list, img_h=H, img_w=W)
img_yolo = overlay_masks(img, all_masks, alpha=0.35, color=(0, 255, 0))
else:
raise ValueError(f"Unknown viz_mode={viz_mode}")
# crops
seg_list, _ = _crop_pil(img, bbox_list)
if len(seg_list) == 0:
return {
"task_id": task_id,
"task_name": task_name,
"bbox_list": bbox_list,
"selected_indices": [],
"images": {"original": img, "yolo": img_yolo, "selected": img.copy()},
}
# ImageBind embeddings
with torch.no_grad():
input_pack = {
ModalityType.TEXT: data.load_and_transform_text(prompt_use, self.device),
ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device),
}
emb = self.vlm_model(input_pack)
text_embeddings = emb[ModalityType.TEXT]
bbox_embeddings = emb[ModalityType.VISION]
input_pack2 = {ModalityType.VISION: data.read_and_transform_vision_data([img], self.device)}
emb2 = self.vlm_model(input_pack2)
image_embedding = emb2[ModalityType.VISION].squeeze(0)
# TaskCLIP
taskclip = self._get_taskclip(
ckpt_path=taskclip_ckpt,
d_model=int(image_embedding.shape[-1]),
n_words=int(text_embeddings.shape[0]),
score_function=score_function,
hdv_dim=hdv_dim,
)
with torch.no_grad():
_, _, score_res, _ = taskclip(bbox_embeddings, text_embeddings, image_embedding.view(1, -1))
score = score_res.view(-1).detach().cpu().numpy().tolist()
# ... keep your postprocess/selection logic unchanged ...
# (use your existing code below this point)
# return dict unchanged