Spaces:
Sleeping
Sleeping
HanningChen
commited on
Commit
·
c5d818e
1
Parent(s):
629ac00
Fix runner bug
Browse files- webui/runner.py +349 -76
webui/runner.py
CHANGED
|
@@ -1,19 +1,35 @@
|
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Dict, Any, List, Tuple, Optional
|
| 4 |
|
|
|
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
from PIL import Image, ImageDraw
|
| 8 |
|
| 9 |
from ultralytics import YOLO, SAM
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from models.TaskCLIP import TaskCLIP
|
| 16 |
|
|
|
|
| 17 |
def _draw_boxes_pil(
|
| 18 |
img: Image.Image,
|
| 19 |
boxes_xyxy: np.ndarray,
|
|
@@ -32,8 +48,8 @@ def _draw_boxes_pil(
|
|
| 32 |
def _crop_pil(img: Image.Image, bbox_list: List[List[float]]) -> Tuple[List[Image.Image], List[int]]:
|
| 33 |
"""Return list of cropped PIL images + indices mapping back to bbox_list."""
|
| 34 |
W, H = img.size
|
| 35 |
-
crops = []
|
| 36 |
-
idxs = []
|
| 37 |
for i, (x0, y0, x1, y1) in enumerate(bbox_list):
|
| 38 |
x0 = max(0, min(W, int(x0)))
|
| 39 |
y0 = max(0, min(H, int(y0)))
|
|
@@ -65,14 +81,23 @@ def overlay_masks(
|
|
| 65 |
out = base * (1 - alpha) + overlay * alpha
|
| 66 |
return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8))
|
| 67 |
|
|
|
|
| 68 |
class ModelRunner:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def __init__(
|
| 70 |
self,
|
| 71 |
project_root: str,
|
| 72 |
device: str = "cuda:0",
|
| 73 |
yolo_ckpt: str = "./.checkpoints/yolo12x.pt",
|
| 74 |
sam_ckpt: str = "./.checkpoints/sam2.1_l.pt",
|
| 75 |
-
imagebind_ckpt: Optional[str] = None,
|
| 76 |
id2task_name_file: str = "./id2task_name.json",
|
| 77 |
task2prompt_file: str = "./task20.json",
|
| 78 |
threshold: float = 0.01,
|
|
@@ -87,53 +112,143 @@ class ModelRunner:
|
|
| 87 |
self.cluster = bool(cluster)
|
| 88 |
self.forward_thre = float(forward_thre)
|
| 89 |
|
| 90 |
-
# metadata
|
| 91 |
self.id2task_name_path = (self.root / id2task_name_file).resolve()
|
| 92 |
self.task2prompt_path = (self.root / task2prompt_file).resolve()
|
| 93 |
self.id2task_name = json.loads(self.id2task_name_path.read_text())
|
| 94 |
self.task2prompt = json.loads(self.task2prompt_path.read_text())
|
| 95 |
|
| 96 |
# caches
|
| 97 |
-
self.
|
| 98 |
-
self.
|
|
|
|
| 99 |
|
| 100 |
-
#
|
| 101 |
self.yolo_ckpt_path = (self.root / yolo_ckpt).resolve() if str(yolo_ckpt).startswith(".") else Path(yolo_ckpt)
|
| 102 |
|
| 103 |
-
#
|
| 104 |
sam_ckpt_path = (self.root / sam_ckpt).resolve() if str(sam_ckpt).startswith(".") else Path(sam_ckpt)
|
| 105 |
self.sam = SAM(str(sam_ckpt_path))
|
| 106 |
|
| 107 |
-
#
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
if ckpt_path.exists():
|
|
|
|
| 114 |
state = torch.load(str(ckpt_path), map_location="cpu")
|
| 115 |
-
#
|
| 116 |
if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
|
| 117 |
state = state["model"]
|
| 118 |
-
|
| 119 |
state = state["state_dict"]
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
|
| 124 |
-
else:
|
| 125 |
-
self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
|
| 126 |
|
| 127 |
-
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
def
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
def list_task_ids(self) -> List[int]:
|
| 136 |
-
ids = []
|
| 137 |
for k in self.id2task_name.keys():
|
| 138 |
try:
|
| 139 |
ids.append(int(k))
|
|
@@ -141,16 +256,81 @@ class ModelRunner:
|
|
| 141 |
pass
|
| 142 |
return sorted(ids)
|
| 143 |
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
|
| 146 |
if not Path(ckpt_abs).exists():
|
| 147 |
raise FileNotFoundError(f"TaskCLIP checkpoint not found: {ckpt_abs}")
|
| 148 |
|
| 149 |
eff_hdv_dim = int(hdv_dim) if score_function == "HDC" else 0
|
| 150 |
-
|
|
|
|
|
|
|
| 151 |
if key in self._taskclip_cache:
|
| 152 |
return self._taskclip_cache[key]
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
model_config = {
|
| 155 |
"num_layers": 8,
|
| 156 |
"norm": None,
|
|
@@ -170,31 +350,53 @@ class ModelRunner:
|
|
| 170 |
"norm_after": False,
|
| 171 |
"MIN_VAL": 10.0,
|
| 172 |
"MAX_VAL": 30.0,
|
| 173 |
-
"cross_attention":
|
| 174 |
"score_function": "HDC" if score_function == "HDC" else "default",
|
| 175 |
"HDV_D": int(eff_hdv_dim),
|
| 176 |
}
|
| 177 |
|
| 178 |
m = TaskCLIP(model_config, normalize_before=model_config["normalize_before"], device=model_config["device"])
|
| 179 |
-
state = torch.load(ckpt_abs, map_location="cpu")
|
| 180 |
m.load_state_dict(state, strict=True)
|
| 181 |
m = m.to(self.device).eval()
|
| 182 |
|
| 183 |
self._taskclip_cache[key] = m
|
| 184 |
return m
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
def _sam_masks_from_bboxes(self, image_path: str, bbox_list: List[List[float]], img_h: int, img_w: int) -> np.ndarray:
|
| 187 |
if not bbox_list:
|
| 188 |
return np.zeros((0, img_h, img_w), dtype=bool)
|
| 189 |
|
| 190 |
bboxes = [[float(x0), float(y0), float(x1), float(y1)] for x0, y0, x1, y1 in bbox_list]
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
def run(
|
| 200 |
self,
|
|
@@ -208,40 +410,45 @@ class ModelRunner:
|
|
| 208 |
taskclip_ckpt: str = "./test_model/default/decoder.pt",
|
| 209 |
viz_mode: str = "bbox",
|
| 210 |
) -> Dict[str, Any]:
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
raise ValueError("This runner.py currently implements ImageBind only (your OpenCLIP version was in the other runner).")
|
| 214 |
if od_model != "yolo":
|
| 215 |
-
raise ValueError("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
with self._lock:
|
| 218 |
img = Image.open(image_path).convert("RGB")
|
|
|
|
| 219 |
task_name = self.id2task_name[str(task_id)]
|
| 220 |
prompt_words = self.task2prompt[task_name]
|
| 221 |
prompt_use = ["The item is " + w for w in prompt_words]
|
| 222 |
|
| 223 |
-
# YOLO
|
| 224 |
yolo = self._get_yolo(yolo_ckpt)
|
| 225 |
outputs = yolo(image_path)
|
| 226 |
bbox_list = outputs[0].boxes.xyxy.tolist()
|
| 227 |
classes = outputs[0].boxes.cls.tolist()
|
| 228 |
confidences = outputs[0].boxes.conf.tolist()
|
| 229 |
|
|
|
|
| 230 |
all_boxes = np.asarray(bbox_list, dtype=np.float32)
|
| 231 |
-
H = img.size[1]
|
| 232 |
-
W = img.size[0]
|
| 233 |
|
| 234 |
-
#
|
| 235 |
-
all_masks = None
|
| 236 |
if viz_mode == "bbox":
|
| 237 |
img_yolo = _draw_boxes_pil(img, all_boxes, color=(0, 255, 0), width=3)
|
| 238 |
-
|
|
|
|
| 239 |
all_masks = self._sam_masks_from_bboxes(image_path, bbox_list, img_h=H, img_w=W)
|
| 240 |
img_yolo = overlay_masks(img, all_masks, alpha=0.35, color=(0, 255, 0))
|
| 241 |
-
else:
|
| 242 |
-
raise ValueError(f"Unknown viz_mode={viz_mode}")
|
| 243 |
|
| 244 |
-
#
|
| 245 |
seg_list, _ = _crop_pil(img, bbox_list)
|
| 246 |
if len(seg_list) == 0:
|
| 247 |
return {
|
|
@@ -252,33 +459,99 @@ class ModelRunner:
|
|
| 252 |
"images": {"original": img, "yolo": img_yolo, "selected": img.copy()},
|
| 253 |
}
|
| 254 |
|
| 255 |
-
#
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
text_embeddings = emb[ModalityType.TEXT]
|
| 263 |
-
bbox_embeddings = emb[ModalityType.VISION]
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
-
# TaskCLIP
|
| 270 |
taskclip = self._get_taskclip(
|
| 271 |
ckpt_path=taskclip_ckpt,
|
| 272 |
-
d_model=
|
| 273 |
-
n_words=
|
| 274 |
score_function=score_function,
|
| 275 |
hdv_dim=hdv_dim,
|
|
|
|
| 276 |
)
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
score = score_res.view(-1).detach().cpu().numpy().tolist()
|
| 281 |
|
| 282 |
-
#
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# runner.py
|
| 2 |
import json
|
| 3 |
+
import sys
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Dict, Any, List, Tuple, Optional
|
| 6 |
|
| 7 |
+
from threading import RLock
|
| 8 |
+
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
from PIL import Image, ImageDraw
|
| 12 |
|
| 13 |
from ultralytics import YOLO, SAM
|
| 14 |
+
import open_clip
|
| 15 |
+
|
| 16 |
+
# --- ImageBind import: robust for both "pip install -e ImageBind" and local folder ---
|
| 17 |
+
try:
|
| 18 |
+
# preferred: ImageBind installed as "imagebind"
|
| 19 |
+
from imagebind import data
|
| 20 |
+
from imagebind.models import imagebind_model
|
| 21 |
+
from imagebind.models.imagebind_model import ModalityType
|
| 22 |
+
except ModuleNotFoundError:
|
| 23 |
+
# fallback: repo has ./ImageBind/imagebind/
|
| 24 |
+
REPO_ROOT = Path(__file__).resolve().parents[1] # repo/
|
| 25 |
+
sys.path.insert(0, str(REPO_ROOT / "ImageBind"))
|
| 26 |
+
from imagebind import data
|
| 27 |
+
from imagebind.models import imagebind_model
|
| 28 |
+
from imagebind.models.imagebind_model import ModalityType
|
| 29 |
|
| 30 |
from models.TaskCLIP import TaskCLIP
|
| 31 |
|
| 32 |
+
|
| 33 |
def _draw_boxes_pil(
|
| 34 |
img: Image.Image,
|
| 35 |
boxes_xyxy: np.ndarray,
|
|
|
|
| 48 |
def _crop_pil(img: Image.Image, bbox_list: List[List[float]]) -> Tuple[List[Image.Image], List[int]]:
|
| 49 |
"""Return list of cropped PIL images + indices mapping back to bbox_list."""
|
| 50 |
W, H = img.size
|
| 51 |
+
crops: List[Image.Image] = []
|
| 52 |
+
idxs: List[int] = []
|
| 53 |
for i, (x0, y0, x1, y1) in enumerate(bbox_list):
|
| 54 |
x0 = max(0, min(W, int(x0)))
|
| 55 |
y0 = max(0, min(H, int(y0)))
|
|
|
|
| 81 |
out = base * (1 - alpha) + overlay * alpha
|
| 82 |
return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8))
|
| 83 |
|
| 84 |
+
|
| 85 |
class ModelRunner:
|
| 86 |
+
"""
|
| 87 |
+
WebUI runner:
|
| 88 |
+
- YOLO detects bboxes
|
| 89 |
+
- VLM (ImageBind or OpenCLIP) embeds text prompts and crops (+ global image)
|
| 90 |
+
- TaskCLIP scores and selects bboxes
|
| 91 |
+
- optionally visualize bbox or SAM masks
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
def __init__(
|
| 95 |
self,
|
| 96 |
project_root: str,
|
| 97 |
device: str = "cuda:0",
|
| 98 |
yolo_ckpt: str = "./.checkpoints/yolo12x.pt",
|
| 99 |
sam_ckpt: str = "./.checkpoints/sam2.1_l.pt",
|
| 100 |
+
imagebind_ckpt: Optional[str] = None, # optional local weights path
|
| 101 |
id2task_name_file: str = "./id2task_name.json",
|
| 102 |
task2prompt_file: str = "./task20.json",
|
| 103 |
threshold: float = 0.01,
|
|
|
|
| 112 |
self.cluster = bool(cluster)
|
| 113 |
self.forward_thre = float(forward_thre)
|
| 114 |
|
| 115 |
+
# load task metadata
|
| 116 |
self.id2task_name_path = (self.root / id2task_name_file).resolve()
|
| 117 |
self.task2prompt_path = (self.root / task2prompt_file).resolve()
|
| 118 |
self.id2task_name = json.loads(self.id2task_name_path.read_text())
|
| 119 |
self.task2prompt = json.loads(self.task2prompt_path.read_text())
|
| 120 |
|
| 121 |
# caches
|
| 122 |
+
self._vlm_cache: Dict[str, Dict[str, Any]] = {}
|
| 123 |
+
self._yolo_cache: Dict[str, YOLO] = {}
|
| 124 |
+
self._taskclip_cache: Dict[Tuple[Any, ...], TaskCLIP] = {}
|
| 125 |
|
| 126 |
+
# default ckpt paths (not required; YOLO is cached per-run ckpt)
|
| 127 |
self.yolo_ckpt_path = (self.root / yolo_ckpt).resolve() if str(yolo_ckpt).startswith(".") else Path(yolo_ckpt)
|
| 128 |
|
| 129 |
+
# SAM loaded once
|
| 130 |
sam_ckpt_path = (self.root / sam_ckpt).resolve() if str(sam_ckpt).startswith(".") else Path(sam_ckpt)
|
| 131 |
self.sam = SAM(str(sam_ckpt_path))
|
| 132 |
|
| 133 |
+
# ImageBind weights path (optional)
|
| 134 |
+
self.imagebind_ckpt = imagebind_ckpt
|
| 135 |
+
|
| 136 |
+
# lock for single-GPU servers
|
| 137 |
+
self._lock = RLock()
|
| 138 |
+
|
| 139 |
+
def _get_yolo(self, ckpt_path: str) -> YOLO:
|
| 140 |
+
ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
|
| 141 |
+
if ckpt_abs not in self._yolo_cache:
|
| 142 |
+
self._yolo_cache[ckpt_abs] = YOLO(ckpt_abs)
|
| 143 |
+
return self._yolo_cache[ckpt_abs]
|
| 144 |
+
|
| 145 |
+
def _load_imagebind(self) -> Any:
|
| 146 |
+
"""
|
| 147 |
+
Load ImageBind once and cache it.
|
| 148 |
+
- If self.imagebind_ckpt provided and exists: load pretrained=False then load_state_dict
|
| 149 |
+
- Else: pretrained=True (may download)
|
| 150 |
+
"""
|
| 151 |
+
if "imagebind" in self._vlm_cache:
|
| 152 |
+
return self._vlm_cache["imagebind"]["model"]
|
| 153 |
+
|
| 154 |
+
if self.imagebind_ckpt:
|
| 155 |
+
ckpt_path = (self.root / self.imagebind_ckpt).resolve() if str(self.imagebind_ckpt).startswith(".") else Path(self.imagebind_ckpt)
|
| 156 |
if ckpt_path.exists():
|
| 157 |
+
m = imagebind_model.imagebind_huge(pretrained=False).to(self.device).eval()
|
| 158 |
state = torch.load(str(ckpt_path), map_location="cpu")
|
| 159 |
+
# common wrappers
|
| 160 |
if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
|
| 161 |
state = state["model"]
|
| 162 |
+
if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
|
| 163 |
state = state["state_dict"]
|
| 164 |
+
m.load_state_dict(state, strict=False)
|
| 165 |
+
self._vlm_cache["imagebind"] = {"kind": "imagebind", "model": m}
|
| 166 |
+
return m
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
+
m = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
|
| 169 |
+
self._vlm_cache["imagebind"] = {"kind": "imagebind", "model": m}
|
| 170 |
+
return m
|
| 171 |
|
| 172 |
+
def _get_vlm(self, vlm_model: str) -> Dict[str, Any]:
|
| 173 |
+
if vlm_model in self._vlm_cache and vlm_model != "imagebind":
|
| 174 |
+
return self._vlm_cache[vlm_model]
|
| 175 |
+
|
| 176 |
+
if vlm_model == "imagebind":
|
| 177 |
+
m = self._load_imagebind()
|
| 178 |
+
return {"kind": "imagebind", "model": m}
|
| 179 |
+
|
| 180 |
+
if vlm_model == "vit-b":
|
| 181 |
+
m, _, preprocess = open_clip.create_model_and_transforms(
|
| 182 |
+
"ViT-B-32", pretrained="laion2b_s34b_b79k"
|
| 183 |
+
)
|
| 184 |
+
m = m.to(self.device).eval()
|
| 185 |
+
tokenizer = open_clip.get_tokenizer("ViT-B-32")
|
| 186 |
+
pack = {"kind": "openclip", "model": m, "preprocess": preprocess, "tokenizer": tokenizer}
|
| 187 |
+
self._vlm_cache[vlm_model] = pack
|
| 188 |
+
return pack
|
| 189 |
+
|
| 190 |
+
if vlm_model == "vit-l":
|
| 191 |
+
m, _, preprocess = open_clip.create_model_and_transforms(
|
| 192 |
+
"ViT-L-14", pretrained="laion2b_s32b_b82k"
|
| 193 |
+
)
|
| 194 |
+
m = m.to(self.device).eval()
|
| 195 |
+
tokenizer = open_clip.get_tokenizer("ViT-L-14")
|
| 196 |
+
pack = {"kind": "openclip", "model": m, "preprocess": preprocess, "tokenizer": tokenizer}
|
| 197 |
+
self._vlm_cache[vlm_model] = pack
|
| 198 |
+
return pack
|
| 199 |
+
|
| 200 |
+
raise ValueError(f"Unknown vlm_model: {vlm_model}")
|
| 201 |
+
|
| 202 |
+
def _encode_vlm(
|
| 203 |
+
self,
|
| 204 |
+
vlm_model: str,
|
| 205 |
+
prompt_use: List[str],
|
| 206 |
+
seg_list: List[Image.Image],
|
| 207 |
+
full_img_pil: Image.Image,
|
| 208 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 209 |
+
pack = self._get_vlm(vlm_model)
|
| 210 |
+
|
| 211 |
+
with torch.inference_mode():
|
| 212 |
+
if pack["kind"] == "imagebind":
|
| 213 |
+
m = pack["model"]
|
| 214 |
+
input_pack = {
|
| 215 |
+
ModalityType.TEXT: data.load_and_transform_text(prompt_use, self.device),
|
| 216 |
+
ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device),
|
| 217 |
+
}
|
| 218 |
+
emb = m(input_pack)
|
| 219 |
+
text_embeddings = emb[ModalityType.TEXT]
|
| 220 |
+
bbox_embeddings = emb[ModalityType.VISION]
|
| 221 |
+
|
| 222 |
+
input_pack2 = {ModalityType.VISION: data.read_and_transform_vision_data([full_img_pil], self.device)}
|
| 223 |
+
emb2 = m(input_pack2)
|
| 224 |
+
image_embedding = emb2[ModalityType.VISION].squeeze(0)
|
| 225 |
+
return text_embeddings, bbox_embeddings, image_embedding
|
| 226 |
+
|
| 227 |
+
# openclip branch
|
| 228 |
+
m = pack["model"]
|
| 229 |
+
preprocess = pack["preprocess"]
|
| 230 |
+
tokenizer = pack["tokenizer"]
|
| 231 |
+
|
| 232 |
+
# text
|
| 233 |
+
text = tokenizer(prompt_use).to(self.device)
|
| 234 |
+
text_embeddings = m.encode_text(text).float()
|
| 235 |
+
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
|
| 236 |
+
|
| 237 |
+
# bbox crops
|
| 238 |
+
crop_tensors = [preprocess(im) for im in seg_list]
|
| 239 |
+
crop_batch = torch.stack(crop_tensors, dim=0).to(self.device)
|
| 240 |
+
bbox_embeddings = m.encode_image(crop_batch).float()
|
| 241 |
+
bbox_embeddings = bbox_embeddings / bbox_embeddings.norm(dim=-1, keepdim=True)
|
| 242 |
+
|
| 243 |
+
# global image
|
| 244 |
+
img_tensor = preprocess(full_img_pil).unsqueeze(0).to(self.device)
|
| 245 |
+
image_embedding = m.encode_image(img_tensor).float().squeeze(0)
|
| 246 |
+
image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
|
| 247 |
+
|
| 248 |
+
return text_embeddings, bbox_embeddings, image_embedding
|
| 249 |
|
| 250 |
def list_task_ids(self) -> List[int]:
|
| 251 |
+
ids: List[int] = []
|
| 252 |
for k in self.id2task_name.keys():
|
| 253 |
try:
|
| 254 |
ids.append(int(k))
|
|
|
|
| 256 |
pass
|
| 257 |
return sorted(ids)
|
| 258 |
|
| 259 |
+
@staticmethod
|
| 260 |
+
def _unwrap_state_dict(obj: Any) -> Dict[str, torch.Tensor]:
|
| 261 |
+
# supports {"state_dict": ...} style checkpoints
|
| 262 |
+
if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict):
|
| 263 |
+
return obj["state_dict"]
|
| 264 |
+
if isinstance(obj, dict):
|
| 265 |
+
return obj
|
| 266 |
+
raise TypeError(f"Unsupported checkpoint format: {type(obj)}")
|
| 267 |
+
|
| 268 |
+
def _infer_ckpt_flags(self, state: Dict[str, torch.Tensor]) -> Tuple[bool, bool, int]:
|
| 269 |
+
"""
|
| 270 |
+
Infer:
|
| 271 |
+
- is_hdc: whether checkpoint contains HDC submodule keys
|
| 272 |
+
- has_cross_attention: whether checkpoint contains cross-attn keys
|
| 273 |
+
- ckpt_d_model: best-effort inferred d_model from weights
|
| 274 |
+
"""
|
| 275 |
+
keys = list(state.keys())
|
| 276 |
+
is_hdc = any(k.startswith("ScoreFunction.HDReason.") for k in keys)
|
| 277 |
+
# NOTE: adjust this if your TaskCLIP names cross-attn differently
|
| 278 |
+
has_cross = any("cross_attn" in k or "cross_attn_text" in k for k in keys)
|
| 279 |
+
|
| 280 |
+
if "decoder_norm.weight" in state:
|
| 281 |
+
ckpt_d_model = int(state["decoder_norm.weight"].shape[0])
|
| 282 |
+
elif "ScoreFunction.norm.weight" in state:
|
| 283 |
+
ckpt_d_model = int(state["ScoreFunction.norm.weight"].shape[0])
|
| 284 |
+
else:
|
| 285 |
+
ckpt_d_model = -1
|
| 286 |
+
|
| 287 |
+
return is_hdc, has_cross, ckpt_d_model
|
| 288 |
+
|
| 289 |
+
def _get_taskclip(
|
| 290 |
+
self,
|
| 291 |
+
ckpt_path: str,
|
| 292 |
+
d_model: int,
|
| 293 |
+
n_words: int,
|
| 294 |
+
score_function: str,
|
| 295 |
+
hdv_dim: int,
|
| 296 |
+
cross_attention: bool,
|
| 297 |
+
) -> TaskCLIP:
|
| 298 |
ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
|
| 299 |
if not Path(ckpt_abs).exists():
|
| 300 |
raise FileNotFoundError(f"TaskCLIP checkpoint not found: {ckpt_abs}")
|
| 301 |
|
| 302 |
eff_hdv_dim = int(hdv_dim) if score_function == "HDC" else 0
|
| 303 |
+
|
| 304 |
+
# cache key must include cross_attention + score_function + dimensions
|
| 305 |
+
key = (ckpt_abs, int(d_model), int(n_words), str(score_function), int(eff_hdv_dim), bool(cross_attention))
|
| 306 |
if key in self._taskclip_cache:
|
| 307 |
return self._taskclip_cache[key]
|
| 308 |
|
| 309 |
+
state_raw = torch.load(ckpt_abs, map_location="cpu")
|
| 310 |
+
state = self._unwrap_state_dict(state_raw)
|
| 311 |
+
|
| 312 |
+
ckpt_is_hdc, ckpt_has_cross, ckpt_d_model = self._infer_ckpt_flags(state)
|
| 313 |
+
|
| 314 |
+
# Validate score_function against checkpoint
|
| 315 |
+
if score_function == "HDC" and not ckpt_is_hdc:
|
| 316 |
+
raise RuntimeError(f"Checkpoint is NOT HDC but score_function=HDC was selected. ckpt={ckpt_abs}")
|
| 317 |
+
if score_function != "HDC" and ckpt_is_hdc:
|
| 318 |
+
raise RuntimeError(f"Checkpoint IS HDC but score_function=default was selected. ckpt={ckpt_abs}")
|
| 319 |
+
|
| 320 |
+
# Validate cross_attention against checkpoint (if we can infer it)
|
| 321 |
+
# If your checkpoints don't contain cross-attn keys, ckpt_has_cross may be False even when the arch uses cross-attn.
|
| 322 |
+
# In that case, either update inference or remove this validation.
|
| 323 |
+
if bool(cross_attention) != bool(ckpt_has_cross):
|
| 324 |
+
raise RuntimeError(
|
| 325 |
+
f"cross_attention mismatch: runtime={cross_attention} but checkpoint has_cross_attention={ckpt_has_cross}. ckpt={ckpt_abs}"
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Validate d_model against checkpoint (if inferred)
|
| 329 |
+
if ckpt_d_model != -1 and int(d_model) != int(ckpt_d_model):
|
| 330 |
+
raise RuntimeError(
|
| 331 |
+
f"d_model mismatch: VLM produced d_model={int(d_model)} but checkpoint expects d_model={int(ckpt_d_model)}. ckpt={ckpt_abs}"
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
model_config = {
|
| 335 |
"num_layers": 8,
|
| 336 |
"norm": None,
|
|
|
|
| 350 |
"norm_after": False,
|
| 351 |
"MIN_VAL": 10.0,
|
| 352 |
"MAX_VAL": 30.0,
|
| 353 |
+
"cross_attention": bool(cross_attention),
|
| 354 |
"score_function": "HDC" if score_function == "HDC" else "default",
|
| 355 |
"HDV_D": int(eff_hdv_dim),
|
| 356 |
}
|
| 357 |
|
| 358 |
m = TaskCLIP(model_config, normalize_before=model_config["normalize_before"], device=model_config["device"])
|
|
|
|
| 359 |
m.load_state_dict(state, strict=True)
|
| 360 |
m = m.to(self.device).eval()
|
| 361 |
|
| 362 |
self._taskclip_cache[key] = m
|
| 363 |
return m
|
| 364 |
|
| 365 |
+
def _find_same_class(self, predict_res, score, visited, i, classes, confs, forward_thre):
|
| 366 |
+
cls_i = classes[i]
|
| 367 |
+
for j in range(len(score)):
|
| 368 |
+
if visited[j] == 1:
|
| 369 |
+
continue
|
| 370 |
+
if classes[j] == cls_i and float(score[j]) > forward_thre:
|
| 371 |
+
visited[j] = 1
|
| 372 |
+
predict_res[j]["category_id"] = 1
|
| 373 |
+
predict_res[j]["score"] = float(score[j])
|
| 374 |
+
|
| 375 |
def _sam_masks_from_bboxes(self, image_path: str, bbox_list: List[List[float]], img_h: int, img_w: int) -> np.ndarray:
|
| 376 |
if not bbox_list:
|
| 377 |
return np.zeros((0, img_h, img_w), dtype=bool)
|
| 378 |
|
| 379 |
bboxes = [[float(x0), float(y0), float(x1), float(y1)] for x0, y0, x1, y1 in bbox_list]
|
| 380 |
|
| 381 |
+
try:
|
| 382 |
+
res = self.sam(image_path, bboxes=bboxes)
|
| 383 |
+
r0 = res[0]
|
| 384 |
+
if r0.masks is None:
|
| 385 |
+
return np.zeros((0, img_h, img_w), dtype=bool)
|
| 386 |
+
masks = r0.masks.data.detach().cpu().numpy().astype(bool)
|
| 387 |
+
return masks
|
| 388 |
+
except Exception:
|
| 389 |
+
# fallback per-box
|
| 390 |
+
masks_list = []
|
| 391 |
+
for bb in bboxes:
|
| 392 |
+
rr = self.sam(image_path, bboxes=bb)[0]
|
| 393 |
+
if rr.masks is None:
|
| 394 |
+
continue
|
| 395 |
+
m = rr.masks.data.detach().cpu().numpy().astype(bool)
|
| 396 |
+
masks_list.append(m[0])
|
| 397 |
+
if len(masks_list) == 0:
|
| 398 |
+
return np.zeros((0, img_h, img_w), dtype=bool)
|
| 399 |
+
return np.stack(masks_list, axis=0)
|
| 400 |
|
| 401 |
def run(
|
| 402 |
self,
|
|
|
|
| 410 |
taskclip_ckpt: str = "./test_model/default/decoder.pt",
|
| 411 |
viz_mode: str = "bbox",
|
| 412 |
) -> Dict[str, Any]:
|
| 413 |
+
if vlm_model not in ["imagebind", "vit-b", "vit-l"]:
|
| 414 |
+
raise ValueError(f"Unknown vlm_model: {vlm_model}")
|
|
|
|
| 415 |
if od_model != "yolo":
|
| 416 |
+
raise ValueError("Currently only od_model='yolo' is supported.")
|
| 417 |
+
if viz_mode not in ["bbox", "mask"]:
|
| 418 |
+
raise ValueError(f"Unknown viz_mode={viz_mode}")
|
| 419 |
+
|
| 420 |
+
# Training convention you stated:
|
| 421 |
+
# - default => cross_attention True
|
| 422 |
+
# - HDC => cross_attention False
|
| 423 |
+
# If your actual training differs, change this rule OR pass it from app.py.
|
| 424 |
+
cross_attention = (score_function != "HDC")
|
| 425 |
|
| 426 |
with self._lock:
|
| 427 |
img = Image.open(image_path).convert("RGB")
|
| 428 |
+
|
| 429 |
task_name = self.id2task_name[str(task_id)]
|
| 430 |
prompt_words = self.task2prompt[task_name]
|
| 431 |
prompt_use = ["The item is " + w for w in prompt_words]
|
| 432 |
|
| 433 |
+
# YOLO detect
|
| 434 |
yolo = self._get_yolo(yolo_ckpt)
|
| 435 |
outputs = yolo(image_path)
|
| 436 |
bbox_list = outputs[0].boxes.xyxy.tolist()
|
| 437 |
classes = outputs[0].boxes.cls.tolist()
|
| 438 |
confidences = outputs[0].boxes.conf.tolist()
|
| 439 |
|
| 440 |
+
H, W = img.size[1], img.size[0]
|
| 441 |
all_boxes = np.asarray(bbox_list, dtype=np.float32)
|
|
|
|
|
|
|
| 442 |
|
| 443 |
+
# visualize all detections
|
|
|
|
| 444 |
if viz_mode == "bbox":
|
| 445 |
img_yolo = _draw_boxes_pil(img, all_boxes, color=(0, 255, 0), width=3)
|
| 446 |
+
all_masks = None
|
| 447 |
+
else:
|
| 448 |
all_masks = self._sam_masks_from_bboxes(image_path, bbox_list, img_h=H, img_w=W)
|
| 449 |
img_yolo = overlay_masks(img, all_masks, alpha=0.35, color=(0, 255, 0))
|
|
|
|
|
|
|
| 450 |
|
| 451 |
+
# crop bboxes
|
| 452 |
seg_list, _ = _crop_pil(img, bbox_list)
|
| 453 |
if len(seg_list) == 0:
|
| 454 |
return {
|
|
|
|
| 459 |
"images": {"original": img, "yolo": img_yolo, "selected": img.copy()},
|
| 460 |
}
|
| 461 |
|
| 462 |
+
# VLM embeddings
|
| 463 |
+
text_embeddings, bbox_embeddings, image_embedding = self._encode_vlm(
|
| 464 |
+
vlm_model=vlm_model,
|
| 465 |
+
prompt_use=prompt_use,
|
| 466 |
+
seg_list=seg_list,
|
| 467 |
+
full_img_pil=img,
|
| 468 |
+
)
|
|
|
|
|
|
|
| 469 |
|
| 470 |
+
# Ensure dims are consistent
|
| 471 |
+
if int(bbox_embeddings.shape[-1]) != int(image_embedding.shape[-1]):
|
| 472 |
+
raise RuntimeError(
|
| 473 |
+
f"Embedding dim mismatch: bbox_embeddings dim={bbox_embeddings.shape[-1]} "
|
| 474 |
+
f"vs image_embedding dim={image_embedding.shape[-1]}"
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# IMPORTANT: d_model should come from bbox_embeddings (tgt), not global image
|
| 478 |
+
d_model = int(bbox_embeddings.shape[-1])
|
| 479 |
+
n_words = int(text_embeddings.shape[0])
|
| 480 |
|
| 481 |
+
# TaskCLIP (load correct arch)
|
| 482 |
taskclip = self._get_taskclip(
|
| 483 |
ckpt_path=taskclip_ckpt,
|
| 484 |
+
d_model=d_model,
|
| 485 |
+
n_words=n_words,
|
| 486 |
score_function=score_function,
|
| 487 |
hdv_dim=hdv_dim,
|
| 488 |
+
cross_attention=cross_attention,
|
| 489 |
)
|
| 490 |
|
| 491 |
+
# Score
|
| 492 |
+
with torch.inference_mode():
|
| 493 |
+
tgt = bbox_embeddings
|
| 494 |
+
memory = text_embeddings
|
| 495 |
+
image_embedding_2d = image_embedding.view(1, -1)
|
| 496 |
+
_, _, score_res, _ = taskclip(tgt, memory, image_embedding_2d)
|
| 497 |
score = score_res.view(-1).detach().cpu().numpy().tolist()
|
| 498 |
|
| 499 |
+
# post-process
|
| 500 |
+
predict_res = []
|
| 501 |
+
for i in range(len(bbox_list)):
|
| 502 |
+
predict_res.append({"category_id": -1, "score": -1, "class": int(classes[i])})
|
| 503 |
+
|
| 504 |
+
visited = [0] * len(score)
|
| 505 |
+
for i, x in enumerate(score):
|
| 506 |
+
if visited[i] == 1:
|
| 507 |
+
continue
|
| 508 |
+
if float(x) > self.threshold:
|
| 509 |
+
visited[i] = 1
|
| 510 |
+
predict_res[i]["category_id"] = 1
|
| 511 |
+
predict_res[i]["score"] = float(x)
|
| 512 |
+
if self.forward:
|
| 513 |
+
self._find_same_class(predict_res, score, visited, i, classes, confidences, self.forward_thre)
|
| 514 |
+
else:
|
| 515 |
+
predict_res[i]["category_id"] = 0
|
| 516 |
+
predict_res[i]["score"] = 1.0 - float(x)
|
| 517 |
+
|
| 518 |
+
# cluster optimization
|
| 519 |
+
if self.cluster and self.forward and len(seg_list) > 1:
|
| 520 |
+
cluster_scores: Dict[int, List[float]] = {}
|
| 521 |
+
for p in predict_res:
|
| 522 |
+
if int(p["category_id"]) == 1:
|
| 523 |
+
c = int(p["class"])
|
| 524 |
+
cluster_scores.setdefault(c, []).append(float(p["score"]))
|
| 525 |
+
|
| 526 |
+
if len(cluster_scores) > 1:
|
| 527 |
+
cluster_ave = {c: float(np.mean(v)) for c, v in cluster_scores.items()}
|
| 528 |
+
select_class = max(cluster_ave, key=lambda k: cluster_ave[k])
|
| 529 |
+
for p in predict_res:
|
| 530 |
+
if int(p["category_id"]) == 1 and int(p["class"]) != int(select_class):
|
| 531 |
+
p["category_id"] = 0
|
| 532 |
+
|
| 533 |
+
selected_indices = [i for i, p in enumerate(predict_res) if int(p["category_id"]) == 1]
|
| 534 |
+
selected_boxes = (
|
| 535 |
+
all_boxes[selected_indices] if len(selected_indices) > 0 else np.zeros((0, 4), dtype=np.float32)
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# visualize selected
|
| 539 |
+
if viz_mode == "bbox":
|
| 540 |
+
img_selected = _draw_boxes_pil(img, selected_boxes, color=(255, 0, 0), width=4)
|
| 541 |
+
else:
|
| 542 |
+
if all_masks is not None and all_masks.shape[0] > 0 and len(selected_indices) > 0:
|
| 543 |
+
sel_masks = all_masks[selected_indices]
|
| 544 |
+
else:
|
| 545 |
+
sel_masks = np.zeros((0, H, W), dtype=bool)
|
| 546 |
+
img_selected = overlay_masks(img, sel_masks, alpha=0.45, color=(255, 0, 0))
|
| 547 |
+
|
| 548 |
+
return {
|
| 549 |
+
"task_id": task_id,
|
| 550 |
+
"task_name": task_name,
|
| 551 |
+
"bbox_list": bbox_list,
|
| 552 |
+
"classes": classes,
|
| 553 |
+
"confidences": confidences,
|
| 554 |
+
"scores": score,
|
| 555 |
+
"selected_indices": selected_indices,
|
| 556 |
+
"images": {"original": img, "yolo": img_yolo, "selected": img_selected},
|
| 557 |
+
}
|