HanningChen commited on
Commit
c5d818e
·
1 Parent(s): 629ac00

Fix runner bug

Browse files
Files changed (1) hide show
  1. webui/runner.py +349 -76
webui/runner.py CHANGED
@@ -1,19 +1,35 @@
 
1
  import json
 
2
  from pathlib import Path
3
  from typing import Dict, Any, List, Tuple, Optional
4
 
 
 
5
  import numpy as np
6
  import torch
7
  from PIL import Image, ImageDraw
8
 
9
  from ultralytics import YOLO, SAM
10
-
11
- from ImageBind.imagebind import data
12
- from ImageBind.imagebind.models import imagebind_model
13
- from ImageBind.imagebind.models.imagebind_model import ModalityType
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  from models.TaskCLIP import TaskCLIP
16
 
 
17
  def _draw_boxes_pil(
18
  img: Image.Image,
19
  boxes_xyxy: np.ndarray,
@@ -32,8 +48,8 @@ def _draw_boxes_pil(
32
  def _crop_pil(img: Image.Image, bbox_list: List[List[float]]) -> Tuple[List[Image.Image], List[int]]:
33
  """Return list of cropped PIL images + indices mapping back to bbox_list."""
34
  W, H = img.size
35
- crops = []
36
- idxs = []
37
  for i, (x0, y0, x1, y1) in enumerate(bbox_list):
38
  x0 = max(0, min(W, int(x0)))
39
  y0 = max(0, min(H, int(y0)))
@@ -65,14 +81,23 @@ def overlay_masks(
65
  out = base * (1 - alpha) + overlay * alpha
66
  return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8))
67
 
 
68
  class ModelRunner:
 
 
 
 
 
 
 
 
69
  def __init__(
70
  self,
71
  project_root: str,
72
  device: str = "cuda:0",
73
  yolo_ckpt: str = "./.checkpoints/yolo12x.pt",
74
  sam_ckpt: str = "./.checkpoints/sam2.1_l.pt",
75
- imagebind_ckpt: Optional[str] = None, # NEW
76
  id2task_name_file: str = "./id2task_name.json",
77
  task2prompt_file: str = "./task20.json",
78
  threshold: float = 0.01,
@@ -87,53 +112,143 @@ class ModelRunner:
87
  self.cluster = bool(cluster)
88
  self.forward_thre = float(forward_thre)
89
 
90
- # metadata
91
  self.id2task_name_path = (self.root / id2task_name_file).resolve()
92
  self.task2prompt_path = (self.root / task2prompt_file).resolve()
93
  self.id2task_name = json.loads(self.id2task_name_path.read_text())
94
  self.task2prompt = json.loads(self.task2prompt_path.read_text())
95
 
96
  # caches
97
- self._yolo_cache = {}
98
- self._taskclip_cache = {}
 
99
 
100
- # YOLO path (kept for reference; actual YOLO models are cached per ckpt in _get_yolo)
101
  self.yolo_ckpt_path = (self.root / yolo_ckpt).resolve() if str(yolo_ckpt).startswith(".") else Path(yolo_ckpt)
102
 
103
- # ---- SAM load ONCE (from absolute or repo-relative path) ----
104
  sam_ckpt_path = (self.root / sam_ckpt).resolve() if str(sam_ckpt).startswith(".") else Path(sam_ckpt)
105
  self.sam = SAM(str(sam_ckpt_path))
106
 
107
- # ---- ImageBind load ONCE ----
108
- # If you provide imagebind_huge.pth from weights repo, use it.
109
- # Otherwise fall back to pretrained=True behavior.
110
- self.vlm_model = imagebind_model.imagebind_huge(pretrained=False).to(self.device).eval()
111
- if imagebind_ckpt:
112
- ckpt_path = (self.root / imagebind_ckpt).resolve() if str(imagebind_ckpt).startswith(".") else Path(imagebind_ckpt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  if ckpt_path.exists():
 
114
  state = torch.load(str(ckpt_path), map_location="cpu")
115
- # robust handling of different checkpoint formats
116
  if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
117
  state = state["model"]
118
- elif isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
119
  state = state["state_dict"]
120
- self.vlm_model.load_state_dict(state, strict=False)
121
- else:
122
- # fallback if file missing
123
- self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
124
- else:
125
- self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
126
 
127
- self._lock = torch.multiprocessing.RLock()
 
 
128
 
129
- def _get_yolo(self, ckpt_path: str):
130
- ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
131
- if ckpt_abs not in self._yolo_cache:
132
- self._yolo_cache[ckpt_abs] = YOLO(ckpt_abs)
133
- return self._yolo_cache[ckpt_abs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  def list_task_ids(self) -> List[int]:
136
- ids = []
137
  for k in self.id2task_name.keys():
138
  try:
139
  ids.append(int(k))
@@ -141,16 +256,81 @@ class ModelRunner:
141
  pass
142
  return sorted(ids)
143
 
144
- def _get_taskclip(self, ckpt_path: str, d_model: int, n_words: int, score_function: str, hdv_dim: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
146
  if not Path(ckpt_abs).exists():
147
  raise FileNotFoundError(f"TaskCLIP checkpoint not found: {ckpt_abs}")
148
 
149
  eff_hdv_dim = int(hdv_dim) if score_function == "HDC" else 0
150
- key = (ckpt_abs, int(d_model), int(n_words), str(score_function), eff_hdv_dim)
 
 
151
  if key in self._taskclip_cache:
152
  return self._taskclip_cache[key]
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  model_config = {
155
  "num_layers": 8,
156
  "norm": None,
@@ -170,31 +350,53 @@ class ModelRunner:
170
  "norm_after": False,
171
  "MIN_VAL": 10.0,
172
  "MAX_VAL": 30.0,
173
- "cross_attention": True, # keep consistent with how your checkpoint was trained
174
  "score_function": "HDC" if score_function == "HDC" else "default",
175
  "HDV_D": int(eff_hdv_dim),
176
  }
177
 
178
  m = TaskCLIP(model_config, normalize_before=model_config["normalize_before"], device=model_config["device"])
179
- state = torch.load(ckpt_abs, map_location="cpu")
180
  m.load_state_dict(state, strict=True)
181
  m = m.to(self.device).eval()
182
 
183
  self._taskclip_cache[key] = m
184
  return m
185
 
 
 
 
 
 
 
 
 
 
 
186
  def _sam_masks_from_bboxes(self, image_path: str, bbox_list: List[List[float]], img_h: int, img_w: int) -> np.ndarray:
187
  if not bbox_list:
188
  return np.zeros((0, img_h, img_w), dtype=bool)
189
 
190
  bboxes = [[float(x0), float(y0), float(x1), float(y1)] for x0, y0, x1, y1 in bbox_list]
191
 
192
- # multi-box call
193
- res = self.sam(image_path, bboxes=bboxes)
194
- r0 = res[0]
195
- if r0.masks is None:
196
- return np.zeros((0, img_h, img_w), dtype=bool)
197
- return r0.masks.data.detach().cpu().numpy().astype(bool)
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  def run(
200
  self,
@@ -208,40 +410,45 @@ class ModelRunner:
208
  taskclip_ckpt: str = "./test_model/default/decoder.pt",
209
  viz_mode: str = "bbox",
210
  ) -> Dict[str, Any]:
211
-
212
- if vlm_model != "imagebind":
213
- raise ValueError("This runner.py currently implements ImageBind only (your OpenCLIP version was in the other runner).")
214
  if od_model != "yolo":
215
- raise ValueError("Only od_model='yolo' supported.")
 
 
 
 
 
 
 
 
216
 
217
  with self._lock:
218
  img = Image.open(image_path).convert("RGB")
 
219
  task_name = self.id2task_name[str(task_id)]
220
  prompt_words = self.task2prompt[task_name]
221
  prompt_use = ["The item is " + w for w in prompt_words]
222
 
223
- # YOLO
224
  yolo = self._get_yolo(yolo_ckpt)
225
  outputs = yolo(image_path)
226
  bbox_list = outputs[0].boxes.xyxy.tolist()
227
  classes = outputs[0].boxes.cls.tolist()
228
  confidences = outputs[0].boxes.conf.tolist()
229
 
 
230
  all_boxes = np.asarray(bbox_list, dtype=np.float32)
231
- H = img.size[1]
232
- W = img.size[0]
233
 
234
- # IMPORTANT: only run SAM if viz_mode == mask
235
- all_masks = None
236
  if viz_mode == "bbox":
237
  img_yolo = _draw_boxes_pil(img, all_boxes, color=(0, 255, 0), width=3)
238
- elif viz_mode == "mask":
 
239
  all_masks = self._sam_masks_from_bboxes(image_path, bbox_list, img_h=H, img_w=W)
240
  img_yolo = overlay_masks(img, all_masks, alpha=0.35, color=(0, 255, 0))
241
- else:
242
- raise ValueError(f"Unknown viz_mode={viz_mode}")
243
 
244
- # crops
245
  seg_list, _ = _crop_pil(img, bbox_list)
246
  if len(seg_list) == 0:
247
  return {
@@ -252,33 +459,99 @@ class ModelRunner:
252
  "images": {"original": img, "yolo": img_yolo, "selected": img.copy()},
253
  }
254
 
255
- # ImageBind embeddings
256
- with torch.no_grad():
257
- input_pack = {
258
- ModalityType.TEXT: data.load_and_transform_text(prompt_use, self.device),
259
- ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device),
260
- }
261
- emb = self.vlm_model(input_pack)
262
- text_embeddings = emb[ModalityType.TEXT]
263
- bbox_embeddings = emb[ModalityType.VISION]
264
 
265
- input_pack2 = {ModalityType.VISION: data.read_and_transform_vision_data([img], self.device)}
266
- emb2 = self.vlm_model(input_pack2)
267
- image_embedding = emb2[ModalityType.VISION].squeeze(0)
 
 
 
 
 
 
 
268
 
269
- # TaskCLIP
270
  taskclip = self._get_taskclip(
271
  ckpt_path=taskclip_ckpt,
272
- d_model=int(image_embedding.shape[-1]),
273
- n_words=int(text_embeddings.shape[0]),
274
  score_function=score_function,
275
  hdv_dim=hdv_dim,
 
276
  )
277
 
278
- with torch.no_grad():
279
- _, _, score_res, _ = taskclip(bbox_embeddings, text_embeddings, image_embedding.view(1, -1))
 
 
 
 
280
  score = score_res.view(-1).detach().cpu().numpy().tolist()
281
 
282
- # ... keep your postprocess/selection logic unchanged ...
283
- # (use your existing code below this point)
284
- # return dict unchanged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # runner.py
2
  import json
3
+ import sys
4
  from pathlib import Path
5
  from typing import Dict, Any, List, Tuple, Optional
6
 
7
+ from threading import RLock
8
+
9
  import numpy as np
10
  import torch
11
  from PIL import Image, ImageDraw
12
 
13
  from ultralytics import YOLO, SAM
14
+ import open_clip
15
+
16
+ # --- ImageBind import: robust for both "pip install -e ImageBind" and local folder ---
17
+ try:
18
+ # preferred: ImageBind installed as "imagebind"
19
+ from imagebind import data
20
+ from imagebind.models import imagebind_model
21
+ from imagebind.models.imagebind_model import ModalityType
22
+ except ModuleNotFoundError:
23
+ # fallback: repo has ./ImageBind/imagebind/
24
+ REPO_ROOT = Path(__file__).resolve().parents[1] # repo/
25
+ sys.path.insert(0, str(REPO_ROOT / "ImageBind"))
26
+ from imagebind import data
27
+ from imagebind.models import imagebind_model
28
+ from imagebind.models.imagebind_model import ModalityType
29
 
30
  from models.TaskCLIP import TaskCLIP
31
 
32
+
33
  def _draw_boxes_pil(
34
  img: Image.Image,
35
  boxes_xyxy: np.ndarray,
 
48
  def _crop_pil(img: Image.Image, bbox_list: List[List[float]]) -> Tuple[List[Image.Image], List[int]]:
49
  """Return list of cropped PIL images + indices mapping back to bbox_list."""
50
  W, H = img.size
51
+ crops: List[Image.Image] = []
52
+ idxs: List[int] = []
53
  for i, (x0, y0, x1, y1) in enumerate(bbox_list):
54
  x0 = max(0, min(W, int(x0)))
55
  y0 = max(0, min(H, int(y0)))
 
81
  out = base * (1 - alpha) + overlay * alpha
82
  return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8))
83
 
84
+
85
  class ModelRunner:
86
+ """
87
+ WebUI runner:
88
+ - YOLO detects bboxes
89
+ - VLM (ImageBind or OpenCLIP) embeds text prompts and crops (+ global image)
90
+ - TaskCLIP scores and selects bboxes
91
+ - optionally visualize bbox or SAM masks
92
+ """
93
+
94
  def __init__(
95
  self,
96
  project_root: str,
97
  device: str = "cuda:0",
98
  yolo_ckpt: str = "./.checkpoints/yolo12x.pt",
99
  sam_ckpt: str = "./.checkpoints/sam2.1_l.pt",
100
+ imagebind_ckpt: Optional[str] = None, # optional local weights path
101
  id2task_name_file: str = "./id2task_name.json",
102
  task2prompt_file: str = "./task20.json",
103
  threshold: float = 0.01,
 
112
  self.cluster = bool(cluster)
113
  self.forward_thre = float(forward_thre)
114
 
115
+ # load task metadata
116
  self.id2task_name_path = (self.root / id2task_name_file).resolve()
117
  self.task2prompt_path = (self.root / task2prompt_file).resolve()
118
  self.id2task_name = json.loads(self.id2task_name_path.read_text())
119
  self.task2prompt = json.loads(self.task2prompt_path.read_text())
120
 
121
  # caches
122
+ self._vlm_cache: Dict[str, Dict[str, Any]] = {}
123
+ self._yolo_cache: Dict[str, YOLO] = {}
124
+ self._taskclip_cache: Dict[Tuple[Any, ...], TaskCLIP] = {}
125
 
126
+ # default ckpt paths (not required; YOLO is cached per-run ckpt)
127
  self.yolo_ckpt_path = (self.root / yolo_ckpt).resolve() if str(yolo_ckpt).startswith(".") else Path(yolo_ckpt)
128
 
129
+ # SAM loaded once
130
  sam_ckpt_path = (self.root / sam_ckpt).resolve() if str(sam_ckpt).startswith(".") else Path(sam_ckpt)
131
  self.sam = SAM(str(sam_ckpt_path))
132
 
133
+ # ImageBind weights path (optional)
134
+ self.imagebind_ckpt = imagebind_ckpt
135
+
136
+ # lock for single-GPU servers
137
+ self._lock = RLock()
138
+
139
+ def _get_yolo(self, ckpt_path: str) -> YOLO:
140
+ ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
141
+ if ckpt_abs not in self._yolo_cache:
142
+ self._yolo_cache[ckpt_abs] = YOLO(ckpt_abs)
143
+ return self._yolo_cache[ckpt_abs]
144
+
145
+ def _load_imagebind(self) -> Any:
146
+ """
147
+ Load ImageBind once and cache it.
148
+ - If self.imagebind_ckpt provided and exists: load pretrained=False then load_state_dict
149
+ - Else: pretrained=True (may download)
150
+ """
151
+ if "imagebind" in self._vlm_cache:
152
+ return self._vlm_cache["imagebind"]["model"]
153
+
154
+ if self.imagebind_ckpt:
155
+ ckpt_path = (self.root / self.imagebind_ckpt).resolve() if str(self.imagebind_ckpt).startswith(".") else Path(self.imagebind_ckpt)
156
  if ckpt_path.exists():
157
+ m = imagebind_model.imagebind_huge(pretrained=False).to(self.device).eval()
158
  state = torch.load(str(ckpt_path), map_location="cpu")
159
+ # common wrappers
160
  if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
161
  state = state["model"]
162
+ if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
163
  state = state["state_dict"]
164
+ m.load_state_dict(state, strict=False)
165
+ self._vlm_cache["imagebind"] = {"kind": "imagebind", "model": m}
166
+ return m
 
 
 
167
 
168
+ m = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
169
+ self._vlm_cache["imagebind"] = {"kind": "imagebind", "model": m}
170
+ return m
171
 
172
+ def _get_vlm(self, vlm_model: str) -> Dict[str, Any]:
173
+ if vlm_model in self._vlm_cache and vlm_model != "imagebind":
174
+ return self._vlm_cache[vlm_model]
175
+
176
+ if vlm_model == "imagebind":
177
+ m = self._load_imagebind()
178
+ return {"kind": "imagebind", "model": m}
179
+
180
+ if vlm_model == "vit-b":
181
+ m, _, preprocess = open_clip.create_model_and_transforms(
182
+ "ViT-B-32", pretrained="laion2b_s34b_b79k"
183
+ )
184
+ m = m.to(self.device).eval()
185
+ tokenizer = open_clip.get_tokenizer("ViT-B-32")
186
+ pack = {"kind": "openclip", "model": m, "preprocess": preprocess, "tokenizer": tokenizer}
187
+ self._vlm_cache[vlm_model] = pack
188
+ return pack
189
+
190
+ if vlm_model == "vit-l":
191
+ m, _, preprocess = open_clip.create_model_and_transforms(
192
+ "ViT-L-14", pretrained="laion2b_s32b_b82k"
193
+ )
194
+ m = m.to(self.device).eval()
195
+ tokenizer = open_clip.get_tokenizer("ViT-L-14")
196
+ pack = {"kind": "openclip", "model": m, "preprocess": preprocess, "tokenizer": tokenizer}
197
+ self._vlm_cache[vlm_model] = pack
198
+ return pack
199
+
200
+ raise ValueError(f"Unknown vlm_model: {vlm_model}")
201
+
202
+ def _encode_vlm(
203
+ self,
204
+ vlm_model: str,
205
+ prompt_use: List[str],
206
+ seg_list: List[Image.Image],
207
+ full_img_pil: Image.Image,
208
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
209
+ pack = self._get_vlm(vlm_model)
210
+
211
+ with torch.inference_mode():
212
+ if pack["kind"] == "imagebind":
213
+ m = pack["model"]
214
+ input_pack = {
215
+ ModalityType.TEXT: data.load_and_transform_text(prompt_use, self.device),
216
+ ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device),
217
+ }
218
+ emb = m(input_pack)
219
+ text_embeddings = emb[ModalityType.TEXT]
220
+ bbox_embeddings = emb[ModalityType.VISION]
221
+
222
+ input_pack2 = {ModalityType.VISION: data.read_and_transform_vision_data([full_img_pil], self.device)}
223
+ emb2 = m(input_pack2)
224
+ image_embedding = emb2[ModalityType.VISION].squeeze(0)
225
+ return text_embeddings, bbox_embeddings, image_embedding
226
+
227
+ # openclip branch
228
+ m = pack["model"]
229
+ preprocess = pack["preprocess"]
230
+ tokenizer = pack["tokenizer"]
231
+
232
+ # text
233
+ text = tokenizer(prompt_use).to(self.device)
234
+ text_embeddings = m.encode_text(text).float()
235
+ text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
236
+
237
+ # bbox crops
238
+ crop_tensors = [preprocess(im) for im in seg_list]
239
+ crop_batch = torch.stack(crop_tensors, dim=0).to(self.device)
240
+ bbox_embeddings = m.encode_image(crop_batch).float()
241
+ bbox_embeddings = bbox_embeddings / bbox_embeddings.norm(dim=-1, keepdim=True)
242
+
243
+ # global image
244
+ img_tensor = preprocess(full_img_pil).unsqueeze(0).to(self.device)
245
+ image_embedding = m.encode_image(img_tensor).float().squeeze(0)
246
+ image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
247
+
248
+ return text_embeddings, bbox_embeddings, image_embedding
249
 
250
  def list_task_ids(self) -> List[int]:
251
+ ids: List[int] = []
252
  for k in self.id2task_name.keys():
253
  try:
254
  ids.append(int(k))
 
256
  pass
257
  return sorted(ids)
258
 
259
+ @staticmethod
260
+ def _unwrap_state_dict(obj: Any) -> Dict[str, torch.Tensor]:
261
+ # supports {"state_dict": ...} style checkpoints
262
+ if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict):
263
+ return obj["state_dict"]
264
+ if isinstance(obj, dict):
265
+ return obj
266
+ raise TypeError(f"Unsupported checkpoint format: {type(obj)}")
267
+
268
+ def _infer_ckpt_flags(self, state: Dict[str, torch.Tensor]) -> Tuple[bool, bool, int]:
269
+ """
270
+ Infer:
271
+ - is_hdc: whether checkpoint contains HDC submodule keys
272
+ - has_cross_attention: whether checkpoint contains cross-attn keys
273
+ - ckpt_d_model: best-effort inferred d_model from weights
274
+ """
275
+ keys = list(state.keys())
276
+ is_hdc = any(k.startswith("ScoreFunction.HDReason.") for k in keys)
277
+ # NOTE: adjust this if your TaskCLIP names cross-attn differently
278
+ has_cross = any("cross_attn" in k or "cross_attn_text" in k for k in keys)
279
+
280
+ if "decoder_norm.weight" in state:
281
+ ckpt_d_model = int(state["decoder_norm.weight"].shape[0])
282
+ elif "ScoreFunction.norm.weight" in state:
283
+ ckpt_d_model = int(state["ScoreFunction.norm.weight"].shape[0])
284
+ else:
285
+ ckpt_d_model = -1
286
+
287
+ return is_hdc, has_cross, ckpt_d_model
288
+
289
+ def _get_taskclip(
290
+ self,
291
+ ckpt_path: str,
292
+ d_model: int,
293
+ n_words: int,
294
+ score_function: str,
295
+ hdv_dim: int,
296
+ cross_attention: bool,
297
+ ) -> TaskCLIP:
298
  ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
299
  if not Path(ckpt_abs).exists():
300
  raise FileNotFoundError(f"TaskCLIP checkpoint not found: {ckpt_abs}")
301
 
302
  eff_hdv_dim = int(hdv_dim) if score_function == "HDC" else 0
303
+
304
+ # cache key must include cross_attention + score_function + dimensions
305
+ key = (ckpt_abs, int(d_model), int(n_words), str(score_function), int(eff_hdv_dim), bool(cross_attention))
306
  if key in self._taskclip_cache:
307
  return self._taskclip_cache[key]
308
 
309
+ state_raw = torch.load(ckpt_abs, map_location="cpu")
310
+ state = self._unwrap_state_dict(state_raw)
311
+
312
+ ckpt_is_hdc, ckpt_has_cross, ckpt_d_model = self._infer_ckpt_flags(state)
313
+
314
+ # Validate score_function against checkpoint
315
+ if score_function == "HDC" and not ckpt_is_hdc:
316
+ raise RuntimeError(f"Checkpoint is NOT HDC but score_function=HDC was selected. ckpt={ckpt_abs}")
317
+ if score_function != "HDC" and ckpt_is_hdc:
318
+ raise RuntimeError(f"Checkpoint IS HDC but score_function=default was selected. ckpt={ckpt_abs}")
319
+
320
+ # Validate cross_attention against checkpoint (if we can infer it)
321
+ # If your checkpoints don't contain cross-attn keys, ckpt_has_cross may be False even when the arch uses cross-attn.
322
+ # In that case, either update inference or remove this validation.
323
+ if bool(cross_attention) != bool(ckpt_has_cross):
324
+ raise RuntimeError(
325
+ f"cross_attention mismatch: runtime={cross_attention} but checkpoint has_cross_attention={ckpt_has_cross}. ckpt={ckpt_abs}"
326
+ )
327
+
328
+ # Validate d_model against checkpoint (if inferred)
329
+ if ckpt_d_model != -1 and int(d_model) != int(ckpt_d_model):
330
+ raise RuntimeError(
331
+ f"d_model mismatch: VLM produced d_model={int(d_model)} but checkpoint expects d_model={int(ckpt_d_model)}. ckpt={ckpt_abs}"
332
+ )
333
+
334
  model_config = {
335
  "num_layers": 8,
336
  "norm": None,
 
350
  "norm_after": False,
351
  "MIN_VAL": 10.0,
352
  "MAX_VAL": 30.0,
353
+ "cross_attention": bool(cross_attention),
354
  "score_function": "HDC" if score_function == "HDC" else "default",
355
  "HDV_D": int(eff_hdv_dim),
356
  }
357
 
358
  m = TaskCLIP(model_config, normalize_before=model_config["normalize_before"], device=model_config["device"])
 
359
  m.load_state_dict(state, strict=True)
360
  m = m.to(self.device).eval()
361
 
362
  self._taskclip_cache[key] = m
363
  return m
364
 
365
+ def _find_same_class(self, predict_res, score, visited, i, classes, confs, forward_thre):
366
+ cls_i = classes[i]
367
+ for j in range(len(score)):
368
+ if visited[j] == 1:
369
+ continue
370
+ if classes[j] == cls_i and float(score[j]) > forward_thre:
371
+ visited[j] = 1
372
+ predict_res[j]["category_id"] = 1
373
+ predict_res[j]["score"] = float(score[j])
374
+
375
  def _sam_masks_from_bboxes(self, image_path: str, bbox_list: List[List[float]], img_h: int, img_w: int) -> np.ndarray:
376
  if not bbox_list:
377
  return np.zeros((0, img_h, img_w), dtype=bool)
378
 
379
  bboxes = [[float(x0), float(y0), float(x1), float(y1)] for x0, y0, x1, y1 in bbox_list]
380
 
381
+ try:
382
+ res = self.sam(image_path, bboxes=bboxes)
383
+ r0 = res[0]
384
+ if r0.masks is None:
385
+ return np.zeros((0, img_h, img_w), dtype=bool)
386
+ masks = r0.masks.data.detach().cpu().numpy().astype(bool)
387
+ return masks
388
+ except Exception:
389
+ # fallback per-box
390
+ masks_list = []
391
+ for bb in bboxes:
392
+ rr = self.sam(image_path, bboxes=bb)[0]
393
+ if rr.masks is None:
394
+ continue
395
+ m = rr.masks.data.detach().cpu().numpy().astype(bool)
396
+ masks_list.append(m[0])
397
+ if len(masks_list) == 0:
398
+ return np.zeros((0, img_h, img_w), dtype=bool)
399
+ return np.stack(masks_list, axis=0)
400
 
401
  def run(
402
  self,
 
410
  taskclip_ckpt: str = "./test_model/default/decoder.pt",
411
  viz_mode: str = "bbox",
412
  ) -> Dict[str, Any]:
413
+ if vlm_model not in ["imagebind", "vit-b", "vit-l"]:
414
+ raise ValueError(f"Unknown vlm_model: {vlm_model}")
 
415
  if od_model != "yolo":
416
+ raise ValueError("Currently only od_model='yolo' is supported.")
417
+ if viz_mode not in ["bbox", "mask"]:
418
+ raise ValueError(f"Unknown viz_mode={viz_mode}")
419
+
420
+ # Training convention you stated:
421
+ # - default => cross_attention True
422
+ # - HDC => cross_attention False
423
+ # If your actual training differs, change this rule OR pass it from app.py.
424
+ cross_attention = (score_function != "HDC")
425
 
426
  with self._lock:
427
  img = Image.open(image_path).convert("RGB")
428
+
429
  task_name = self.id2task_name[str(task_id)]
430
  prompt_words = self.task2prompt[task_name]
431
  prompt_use = ["The item is " + w for w in prompt_words]
432
 
433
+ # YOLO detect
434
  yolo = self._get_yolo(yolo_ckpt)
435
  outputs = yolo(image_path)
436
  bbox_list = outputs[0].boxes.xyxy.tolist()
437
  classes = outputs[0].boxes.cls.tolist()
438
  confidences = outputs[0].boxes.conf.tolist()
439
 
440
+ H, W = img.size[1], img.size[0]
441
  all_boxes = np.asarray(bbox_list, dtype=np.float32)
 
 
442
 
443
+ # visualize all detections
 
444
  if viz_mode == "bbox":
445
  img_yolo = _draw_boxes_pil(img, all_boxes, color=(0, 255, 0), width=3)
446
+ all_masks = None
447
+ else:
448
  all_masks = self._sam_masks_from_bboxes(image_path, bbox_list, img_h=H, img_w=W)
449
  img_yolo = overlay_masks(img, all_masks, alpha=0.35, color=(0, 255, 0))
 
 
450
 
451
+ # crop bboxes
452
  seg_list, _ = _crop_pil(img, bbox_list)
453
  if len(seg_list) == 0:
454
  return {
 
459
  "images": {"original": img, "yolo": img_yolo, "selected": img.copy()},
460
  }
461
 
462
+ # VLM embeddings
463
+ text_embeddings, bbox_embeddings, image_embedding = self._encode_vlm(
464
+ vlm_model=vlm_model,
465
+ prompt_use=prompt_use,
466
+ seg_list=seg_list,
467
+ full_img_pil=img,
468
+ )
 
 
469
 
470
+ # Ensure dims are consistent
471
+ if int(bbox_embeddings.shape[-1]) != int(image_embedding.shape[-1]):
472
+ raise RuntimeError(
473
+ f"Embedding dim mismatch: bbox_embeddings dim={bbox_embeddings.shape[-1]} "
474
+ f"vs image_embedding dim={image_embedding.shape[-1]}"
475
+ )
476
+
477
+ # IMPORTANT: d_model should come from bbox_embeddings (tgt), not global image
478
+ d_model = int(bbox_embeddings.shape[-1])
479
+ n_words = int(text_embeddings.shape[0])
480
 
481
+ # TaskCLIP (load correct arch)
482
  taskclip = self._get_taskclip(
483
  ckpt_path=taskclip_ckpt,
484
+ d_model=d_model,
485
+ n_words=n_words,
486
  score_function=score_function,
487
  hdv_dim=hdv_dim,
488
+ cross_attention=cross_attention,
489
  )
490
 
491
+ # Score
492
+ with torch.inference_mode():
493
+ tgt = bbox_embeddings
494
+ memory = text_embeddings
495
+ image_embedding_2d = image_embedding.view(1, -1)
496
+ _, _, score_res, _ = taskclip(tgt, memory, image_embedding_2d)
497
  score = score_res.view(-1).detach().cpu().numpy().tolist()
498
 
499
+ # post-process
500
+ predict_res = []
501
+ for i in range(len(bbox_list)):
502
+ predict_res.append({"category_id": -1, "score": -1, "class": int(classes[i])})
503
+
504
+ visited = [0] * len(score)
505
+ for i, x in enumerate(score):
506
+ if visited[i] == 1:
507
+ continue
508
+ if float(x) > self.threshold:
509
+ visited[i] = 1
510
+ predict_res[i]["category_id"] = 1
511
+ predict_res[i]["score"] = float(x)
512
+ if self.forward:
513
+ self._find_same_class(predict_res, score, visited, i, classes, confidences, self.forward_thre)
514
+ else:
515
+ predict_res[i]["category_id"] = 0
516
+ predict_res[i]["score"] = 1.0 - float(x)
517
+
518
+ # cluster optimization
519
+ if self.cluster and self.forward and len(seg_list) > 1:
520
+ cluster_scores: Dict[int, List[float]] = {}
521
+ for p in predict_res:
522
+ if int(p["category_id"]) == 1:
523
+ c = int(p["class"])
524
+ cluster_scores.setdefault(c, []).append(float(p["score"]))
525
+
526
+ if len(cluster_scores) > 1:
527
+ cluster_ave = {c: float(np.mean(v)) for c, v in cluster_scores.items()}
528
+ select_class = max(cluster_ave, key=lambda k: cluster_ave[k])
529
+ for p in predict_res:
530
+ if int(p["category_id"]) == 1 and int(p["class"]) != int(select_class):
531
+ p["category_id"] = 0
532
+
533
+ selected_indices = [i for i, p in enumerate(predict_res) if int(p["category_id"]) == 1]
534
+ selected_boxes = (
535
+ all_boxes[selected_indices] if len(selected_indices) > 0 else np.zeros((0, 4), dtype=np.float32)
536
+ )
537
+
538
+ # visualize selected
539
+ if viz_mode == "bbox":
540
+ img_selected = _draw_boxes_pil(img, selected_boxes, color=(255, 0, 0), width=4)
541
+ else:
542
+ if all_masks is not None and all_masks.shape[0] > 0 and len(selected_indices) > 0:
543
+ sel_masks = all_masks[selected_indices]
544
+ else:
545
+ sel_masks = np.zeros((0, H, W), dtype=bool)
546
+ img_selected = overlay_masks(img, sel_masks, alpha=0.45, color=(255, 0, 0))
547
+
548
+ return {
549
+ "task_id": task_id,
550
+ "task_name": task_name,
551
+ "bbox_list": bbox_list,
552
+ "classes": classes,
553
+ "confidences": confidences,
554
+ "scores": score,
555
+ "selected_indices": selected_indices,
556
+ "images": {"original": img, "yolo": img_yolo, "selected": img_selected},
557
+ }