Spaces:
Sleeping
Sleeping
| # runner.py | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, Any, List, Tuple, Optional | |
| from threading import RLock | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from ultralytics import YOLO, SAM | |
| import open_clip | |
| # --- ImageBind import: robust for both "pip install -e ImageBind" and local folder --- | |
| try: | |
| # preferred: ImageBind installed as "imagebind" | |
| from imagebind import data | |
| from imagebind.models import imagebind_model | |
| from imagebind.models.imagebind_model import ModalityType | |
| except ModuleNotFoundError: | |
| # fallback: repo has ./ImageBind/imagebind/ | |
| REPO_ROOT = Path(__file__).resolve().parents[1] # repo/ | |
| sys.path.insert(0, str(REPO_ROOT / "ImageBind")) | |
| from imagebind import data | |
| from imagebind.models import imagebind_model | |
| from imagebind.models.imagebind_model import ModalityType | |
| from models.TaskCLIP import TaskCLIP | |
| def _draw_boxes_pil( | |
| img: Image.Image, | |
| boxes_xyxy: np.ndarray, | |
| color: Tuple[int, int, int], | |
| width: int = 3, | |
| ) -> Image.Image: | |
| out = img.copy() | |
| draw = ImageDraw.Draw(out) | |
| if boxes_xyxy is None or len(boxes_xyxy) == 0: | |
| return out | |
| for (x0, y0, x1, y1) in boxes_xyxy.tolist(): | |
| draw.rectangle([x0, y0, x1, y1], outline=color, width=width) | |
| return out | |
| def _crop_pil(img: Image.Image, bbox_list: List[List[float]]) -> Tuple[List[Image.Image], List[int]]: | |
| """Return list of cropped PIL images + indices mapping back to bbox_list.""" | |
| W, H = img.size | |
| crops: List[Image.Image] = [] | |
| idxs: List[int] = [] | |
| for i, (x0, y0, x1, y1) in enumerate(bbox_list): | |
| x0 = max(0, min(W, int(x0))) | |
| y0 = max(0, min(H, int(y0))) | |
| x1 = max(0, min(W, int(x1))) | |
| y1 = max(0, min(H, int(y1))) | |
| if x1 <= x0 or y1 <= y0: | |
| continue | |
| crops.append(img.crop((x0, y0, x1, y1))) | |
| idxs.append(i) | |
| return crops, idxs | |
| def overlay_masks( | |
| img: Image.Image, | |
| masks: np.ndarray, | |
| alpha: float = 0.40, | |
| color: Tuple[int, int, int] = (255, 0, 0), | |
| ) -> Image.Image: | |
| if masks is None or len(masks) == 0: | |
| return img | |
| base = np.array(img).astype(np.float32) | |
| union = np.any(masks.astype(bool), axis=0) # (H, W) | |
| if not np.any(union): | |
| return img | |
| overlay = base.copy() | |
| overlay[union] = overlay[union] * 0.2 + np.array(color, dtype=np.float32) * 0.8 | |
| out = base * (1 - alpha) + overlay * alpha | |
| return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8)) | |
| class ModelRunner: | |
| """ | |
| WebUI runner: | |
| - YOLO detects bboxes | |
| - VLM (ImageBind or OpenCLIP) embeds text prompts and crops (+ global image) | |
| - TaskCLIP scores and selects bboxes | |
| - optionally visualize bbox or SAM masks | |
| """ | |
| def __init__( | |
| self, | |
| project_root: str, | |
| device: str = "cuda:0", | |
| yolo_ckpt: str = "./.checkpoints/yolo12x.pt", | |
| sam_ckpt: str = "./.checkpoints/sam2.1_l.pt", | |
| imagebind_ckpt: Optional[str] = None, # optional local weights path | |
| id2task_name_file: str = "./id2task_name.json", | |
| task2prompt_file: str = "./task20.json", | |
| threshold: float = 0.01, | |
| forward: bool = True, | |
| cluster: bool = True, | |
| forward_thre: float = 0.1, | |
| ): | |
| self.root = Path(project_root).resolve() | |
| self.device = device | |
| self.threshold = float(threshold) | |
| self.forward = bool(forward) | |
| self.cluster = bool(cluster) | |
| self.forward_thre = float(forward_thre) | |
| # load task metadata | |
| self.id2task_name_path = (self.root / id2task_name_file).resolve() | |
| self.task2prompt_path = (self.root / task2prompt_file).resolve() | |
| self.id2task_name = json.loads(self.id2task_name_path.read_text()) | |
| self.task2prompt = json.loads(self.task2prompt_path.read_text()) | |
| # caches | |
| self._vlm_cache: Dict[str, Dict[str, Any]] = {} | |
| self._yolo_cache: Dict[str, YOLO] = {} | |
| self._taskclip_cache: Dict[Tuple[Any, ...], TaskCLIP] = {} | |
| # default ckpt paths (not required; YOLO is cached per-run ckpt) | |
| self.yolo_ckpt_path = (self.root / yolo_ckpt).resolve() if str(yolo_ckpt).startswith(".") else Path(yolo_ckpt) | |
| # SAM loaded once | |
| sam_ckpt_path = (self.root / sam_ckpt).resolve() if str(sam_ckpt).startswith(".") else Path(sam_ckpt) | |
| self.sam = SAM(str(sam_ckpt_path)) | |
| # ImageBind weights path (optional) | |
| self.imagebind_ckpt = imagebind_ckpt | |
| # lock for single-GPU servers | |
| self._lock = RLock() | |
| def _get_yolo(self, ckpt_path: str) -> YOLO: | |
| ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path) | |
| if ckpt_abs not in self._yolo_cache: | |
| self._yolo_cache[ckpt_abs] = YOLO(ckpt_abs) | |
| return self._yolo_cache[ckpt_abs] | |
| def _load_imagebind(self) -> Any: | |
| """ | |
| Load ImageBind once and cache it. | |
| - If self.imagebind_ckpt provided and exists: load pretrained=False then load_state_dict | |
| - Else: pretrained=True (may download) | |
| """ | |
| if "imagebind" in self._vlm_cache: | |
| return self._vlm_cache["imagebind"]["model"] | |
| if self.imagebind_ckpt: | |
| ckpt_path = (self.root / self.imagebind_ckpt).resolve() if str(self.imagebind_ckpt).startswith(".") else Path(self.imagebind_ckpt) | |
| if ckpt_path.exists(): | |
| m = imagebind_model.imagebind_huge(pretrained=False).to(self.device).eval() | |
| state = torch.load(str(ckpt_path), map_location="cpu") | |
| # common wrappers | |
| if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict): | |
| state = state["model"] | |
| if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict): | |
| state = state["state_dict"] | |
| m.load_state_dict(state, strict=False) | |
| self._vlm_cache["imagebind"] = {"kind": "imagebind", "model": m} | |
| return m | |
| m = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval() | |
| self._vlm_cache["imagebind"] = {"kind": "imagebind", "model": m} | |
| return m | |
| def _get_vlm(self, vlm_model: str) -> Dict[str, Any]: | |
| if vlm_model in self._vlm_cache and vlm_model != "imagebind": | |
| return self._vlm_cache[vlm_model] | |
| if vlm_model == "imagebind": | |
| m = self._load_imagebind() | |
| return {"kind": "imagebind", "model": m} | |
| if vlm_model == "vit-b": | |
| m, _, preprocess = open_clip.create_model_and_transforms( | |
| "ViT-B-32", pretrained="laion2b_s34b_b79k" | |
| ) | |
| m = m.to(self.device).eval() | |
| tokenizer = open_clip.get_tokenizer("ViT-B-32") | |
| pack = {"kind": "openclip", "model": m, "preprocess": preprocess, "tokenizer": tokenizer} | |
| self._vlm_cache[vlm_model] = pack | |
| return pack | |
| if vlm_model == "vit-l": | |
| m, _, preprocess = open_clip.create_model_and_transforms( | |
| "ViT-L-14", pretrained="laion2b_s32b_b82k" | |
| ) | |
| m = m.to(self.device).eval() | |
| tokenizer = open_clip.get_tokenizer("ViT-L-14") | |
| pack = {"kind": "openclip", "model": m, "preprocess": preprocess, "tokenizer": tokenizer} | |
| self._vlm_cache[vlm_model] = pack | |
| return pack | |
| raise ValueError(f"Unknown vlm_model: {vlm_model}") | |
| def _encode_vlm( | |
| self, | |
| vlm_model: str, | |
| prompt_use: List[str], | |
| seg_list: List[Image.Image], | |
| full_img_pil: Image.Image, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| pack = self._get_vlm(vlm_model) | |
| with torch.inference_mode(): | |
| if pack["kind"] == "imagebind": | |
| m = pack["model"] | |
| input_pack = { | |
| ModalityType.TEXT: data.load_and_transform_text(prompt_use, self.device), | |
| ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device), | |
| } | |
| emb = m(input_pack) | |
| text_embeddings = emb[ModalityType.TEXT] | |
| bbox_embeddings = emb[ModalityType.VISION] | |
| input_pack2 = {ModalityType.VISION: data.read_and_transform_vision_data([full_img_pil], self.device)} | |
| emb2 = m(input_pack2) | |
| image_embedding = emb2[ModalityType.VISION].squeeze(0) | |
| return text_embeddings, bbox_embeddings, image_embedding | |
| # openclip branch | |
| m = pack["model"] | |
| preprocess = pack["preprocess"] | |
| tokenizer = pack["tokenizer"] | |
| # text | |
| text = tokenizer(prompt_use).to(self.device) | |
| text_embeddings = m.encode_text(text).float() | |
| text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) | |
| # bbox crops | |
| crop_tensors = [preprocess(im) for im in seg_list] | |
| crop_batch = torch.stack(crop_tensors, dim=0).to(self.device) | |
| bbox_embeddings = m.encode_image(crop_batch).float() | |
| bbox_embeddings = bbox_embeddings / bbox_embeddings.norm(dim=-1, keepdim=True) | |
| # global image | |
| img_tensor = preprocess(full_img_pil).unsqueeze(0).to(self.device) | |
| image_embedding = m.encode_image(img_tensor).float().squeeze(0) | |
| image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True) | |
| return text_embeddings, bbox_embeddings, image_embedding | |
| def list_task_ids(self) -> List[int]: | |
| ids: List[int] = [] | |
| for k in self.id2task_name.keys(): | |
| try: | |
| ids.append(int(k)) | |
| except Exception: | |
| pass | |
| return sorted(ids) | |
| def _unwrap_state_dict(obj: Any) -> Dict[str, torch.Tensor]: | |
| # supports {"state_dict": ...} style checkpoints | |
| if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict): | |
| return obj["state_dict"] | |
| if isinstance(obj, dict): | |
| return obj | |
| raise TypeError(f"Unsupported checkpoint format: {type(obj)}") | |
| def _infer_ckpt_flags(self, state: Dict[str, torch.Tensor]) -> Tuple[bool, bool, int]: | |
| """ | |
| Infer: | |
| - is_hdc: whether checkpoint contains HDC submodule keys | |
| - has_cross_attention: whether checkpoint contains cross-attn keys | |
| - ckpt_d_model: best-effort inferred d_model from weights | |
| """ | |
| keys = list(state.keys()) | |
| is_hdc = any(k.startswith("ScoreFunction.HDReason.") for k in keys) | |
| # NOTE: adjust this if your TaskCLIP names cross-attn differently | |
| has_cross = any("cross_attn" in k or "cross_attn_text" in k for k in keys) | |
| if "decoder_norm.weight" in state: | |
| ckpt_d_model = int(state["decoder_norm.weight"].shape[0]) | |
| elif "ScoreFunction.norm.weight" in state: | |
| ckpt_d_model = int(state["ScoreFunction.norm.weight"].shape[0]) | |
| else: | |
| ckpt_d_model = -1 | |
| return is_hdc, has_cross, ckpt_d_model | |
| def _get_taskclip( | |
| self, | |
| ckpt_path: str, | |
| d_model: int, | |
| n_words: int, | |
| score_function: str, | |
| hdv_dim: int, | |
| cross_attention: bool, | |
| ) -> TaskCLIP: | |
| ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path) | |
| if not Path(ckpt_abs).exists(): | |
| raise FileNotFoundError(f"TaskCLIP checkpoint not found: {ckpt_abs}") | |
| eff_hdv_dim = int(hdv_dim) if score_function == "HDC" else 0 | |
| # cache key must include cross_attention + score_function + dimensions | |
| key = (ckpt_abs, int(d_model), int(n_words), str(score_function), int(eff_hdv_dim), bool(cross_attention)) | |
| if key in self._taskclip_cache: | |
| return self._taskclip_cache[key] | |
| state_raw = torch.load(ckpt_abs, map_location="cpu") | |
| state = self._unwrap_state_dict(state_raw) | |
| ckpt_is_hdc, ckpt_has_cross, ckpt_d_model = self._infer_ckpt_flags(state) | |
| # Validate score_function against checkpoint | |
| if score_function == "HDC" and not ckpt_is_hdc: | |
| raise RuntimeError(f"Checkpoint is NOT HDC but score_function=HDC was selected. ckpt={ckpt_abs}") | |
| if score_function != "HDC" and ckpt_is_hdc: | |
| raise RuntimeError(f"Checkpoint IS HDC but score_function=default was selected. ckpt={ckpt_abs}") | |
| # Validate d_model against checkpoint (if inferred) | |
| if ckpt_d_model != -1 and int(d_model) != int(ckpt_d_model): | |
| raise RuntimeError( | |
| f"d_model mismatch: VLM produced d_model={int(d_model)} but checkpoint expects d_model={int(ckpt_d_model)}. ckpt={ckpt_abs}" | |
| ) | |
| model_config = { | |
| "num_layers": 8, | |
| "norm": None, | |
| "return_intermediate": False, | |
| "d_model": int(d_model), | |
| "nhead": 4, | |
| "dim_feedforward": 2048, | |
| "dropout": 0.1, | |
| "N_words": int(n_words), | |
| "activation": "gelu", | |
| "normalize_before": False, | |
| "device": self.device, | |
| "ratio_text": 0.3, | |
| "ratio_image": 0.3, | |
| "ratio_glob": 0.3, | |
| "norm_before": True, | |
| "norm_after": False, | |
| "MIN_VAL": 10.0, | |
| "MAX_VAL": 30.0, | |
| "cross_attention": bool(cross_attention), | |
| "score_function": "HDC" if score_function == "HDC" else "default", | |
| "HDV_D": int(eff_hdv_dim), | |
| } | |
| m = TaskCLIP(model_config, normalize_before=model_config["normalize_before"], device=model_config["device"]) | |
| m.load_state_dict(state, strict=True) | |
| m = m.to(self.device).eval() | |
| self._taskclip_cache[key] = m | |
| return m | |
| def _find_same_class(self, predict_res, score, visited, i, classes, confs, forward_thre): | |
| cls_i = classes[i] | |
| for j in range(len(score)): | |
| if visited[j] == 1: | |
| continue | |
| if classes[j] == cls_i and float(score[j]) > forward_thre: | |
| visited[j] = 1 | |
| predict_res[j]["category_id"] = 1 | |
| predict_res[j]["score"] = float(score[j]) | |
| def _sam_masks_from_bboxes(self, image_path: str, bbox_list: List[List[float]], img_h: int, img_w: int) -> np.ndarray: | |
| if not bbox_list: | |
| return np.zeros((0, img_h, img_w), dtype=bool) | |
| bboxes = [[float(x0), float(y0), float(x1), float(y1)] for x0, y0, x1, y1 in bbox_list] | |
| try: | |
| res = self.sam(image_path, bboxes=bboxes) | |
| r0 = res[0] | |
| if r0.masks is None: | |
| return np.zeros((0, img_h, img_w), dtype=bool) | |
| masks = r0.masks.data.detach().cpu().numpy().astype(bool) | |
| return masks | |
| except Exception: | |
| # fallback per-box | |
| masks_list = [] | |
| for bb in bboxes: | |
| rr = self.sam(image_path, bboxes=bb)[0] | |
| if rr.masks is None: | |
| continue | |
| m = rr.masks.data.detach().cpu().numpy().astype(bool) | |
| masks_list.append(m[0]) | |
| if len(masks_list) == 0: | |
| return np.zeros((0, img_h, img_w), dtype=bool) | |
| return np.stack(masks_list, axis=0) | |
| def run( | |
| self, | |
| image_path: str, | |
| task_id: int, | |
| vlm_model: str = "imagebind", | |
| od_model: str = "yolo", | |
| yolo_ckpt: str = "./.checkpoints/yolo12x.pt", | |
| score_function: str = "default", | |
| hdv_dim: int = 256, | |
| taskclip_ckpt: str = "./test_model/default/decoder.pt", | |
| viz_mode: str = "bbox", | |
| hw_noise_dist: str = "none", | |
| hw_noise_width: int = 0, | |
| hw_noise_strength: int = 0, | |
| hdc_bits: int = 32 | |
| ) -> Dict[str, Any]: | |
| if vlm_model not in ["imagebind", "vit-b", "vit-l"]: | |
| raise ValueError(f"Unknown vlm_model: {vlm_model}") | |
| if od_model != "yolo": | |
| raise ValueError("Currently only od_model='yolo' is supported.") | |
| if viz_mode not in ["bbox", "mask"]: | |
| raise ValueError(f"Unknown viz_mode={viz_mode}") | |
| # Training convention you stated: | |
| # - default => cross_attention True | |
| # - HDC => cross_attention False | |
| # If your actual training differs, change this rule OR pass it from app.py. | |
| cross_attention = (score_function != "HDC") | |
| with self._lock: | |
| img = Image.open(image_path).convert("RGB") | |
| task_name = self.id2task_name[str(task_id)] | |
| prompt_words = self.task2prompt[task_name] | |
| prompt_use = ["The item is " + w for w in prompt_words] | |
| # YOLO detect | |
| yolo = self._get_yolo(yolo_ckpt) | |
| outputs = yolo(image_path) | |
| bbox_list = outputs[0].boxes.xyxy.tolist() | |
| classes = outputs[0].boxes.cls.tolist() | |
| confidences = outputs[0].boxes.conf.tolist() | |
| H, W = img.size[1], img.size[0] | |
| all_boxes = np.asarray(bbox_list, dtype=np.float32) | |
| # visualize all detections | |
| if viz_mode == "bbox": | |
| img_yolo = _draw_boxes_pil(img, all_boxes, color=(0, 255, 0), width=3) | |
| all_masks = None | |
| else: | |
| all_masks = self._sam_masks_from_bboxes(image_path, bbox_list, img_h=H, img_w=W) | |
| img_yolo = overlay_masks(img, all_masks, alpha=0.35, color=(0, 255, 0)) | |
| # crop bboxes | |
| seg_list, _ = _crop_pil(img, bbox_list) | |
| if len(seg_list) == 0: | |
| return { | |
| "task_id": task_id, | |
| "task_name": task_name, | |
| "bbox_list": bbox_list, | |
| "selected_indices": [], | |
| "images": {"original": img, "yolo": img_yolo, "selected": img.copy()}, | |
| } | |
| # VLM embeddings | |
| text_embeddings, bbox_embeddings, image_embedding = self._encode_vlm( | |
| vlm_model=vlm_model, | |
| prompt_use=prompt_use, | |
| seg_list=seg_list, | |
| full_img_pil=img, | |
| ) | |
| # Ensure dims are consistent | |
| if int(bbox_embeddings.shape[-1]) != int(image_embedding.shape[-1]): | |
| raise RuntimeError( | |
| f"Embedding dim mismatch: bbox_embeddings dim={bbox_embeddings.shape[-1]} " | |
| f"vs image_embedding dim={image_embedding.shape[-1]}" | |
| ) | |
| # IMPORTANT: d_model should come from bbox_embeddings (tgt), not global image | |
| d_model = int(bbox_embeddings.shape[-1]) | |
| n_words = int(text_embeddings.shape[0]) | |
| # TaskCLIP (load correct arch) | |
| taskclip = self._get_taskclip( | |
| ckpt_path=taskclip_ckpt, | |
| d_model=d_model, | |
| n_words=n_words, | |
| score_function=score_function, | |
| hdv_dim=hdv_dim, | |
| cross_attention=cross_attention, | |
| ) | |
| # Score | |
| with torch.inference_mode(): | |
| tgt = bbox_embeddings | |
| memory = text_embeddings | |
| image_embedding_2d = image_embedding.view(1, -1) | |
| _, _, score_res, _ = taskclip( | |
| tgt, | |
| memory, | |
| image_embedding_2d, | |
| hw_noise_dist=hw_noise_dist, | |
| hw_noise_width=int(hw_noise_width), | |
| hw_noise_strength=int(hw_noise_strength), | |
| hdc_bits=hdc_bits) | |
| score = score_res.view(-1).detach().cpu().numpy().tolist() | |
| # post-process | |
| predict_res = [] | |
| for i in range(len(bbox_list)): | |
| predict_res.append({"category_id": -1, "score": -1, "class": int(classes[i])}) | |
| visited = [0] * len(score) | |
| for i, x in enumerate(score): | |
| if visited[i] == 1: | |
| continue | |
| if float(x) > self.threshold: | |
| visited[i] = 1 | |
| predict_res[i]["category_id"] = 1 | |
| predict_res[i]["score"] = float(x) | |
| if self.forward: | |
| self._find_same_class(predict_res, score, visited, i, classes, confidences, self.forward_thre) | |
| else: | |
| predict_res[i]["category_id"] = 0 | |
| predict_res[i]["score"] = 1.0 - float(x) | |
| # cluster optimization | |
| if self.cluster and self.forward and len(seg_list) > 1: | |
| cluster_scores: Dict[int, List[float]] = {} | |
| for p in predict_res: | |
| if int(p["category_id"]) == 1: | |
| c = int(p["class"]) | |
| cluster_scores.setdefault(c, []).append(float(p["score"])) | |
| if len(cluster_scores) > 1: | |
| cluster_ave = {c: float(np.mean(v)) for c, v in cluster_scores.items()} | |
| select_class = max(cluster_ave, key=lambda k: cluster_ave[k]) | |
| for p in predict_res: | |
| if int(p["category_id"]) == 1 and int(p["class"]) != int(select_class): | |
| p["category_id"] = 0 | |
| selected_indices = [i for i, p in enumerate(predict_res) if int(p["category_id"]) == 1] | |
| selected_boxes = ( | |
| all_boxes[selected_indices] if len(selected_indices) > 0 else np.zeros((0, 4), dtype=np.float32) | |
| ) | |
| # visualize selected | |
| if viz_mode == "bbox": | |
| img_selected = _draw_boxes_pil(img, selected_boxes, color=(255, 0, 0), width=4) | |
| else: | |
| if all_masks is not None and all_masks.shape[0] > 0 and len(selected_indices) > 0: | |
| sel_masks = all_masks[selected_indices] | |
| else: | |
| sel_masks = np.zeros((0, H, W), dtype=bool) | |
| img_selected = overlay_masks(img, sel_masks, alpha=0.45, color=(255, 0, 0)) | |
| return { | |
| "task_id": task_id, | |
| "task_name": task_name, | |
| "bbox_list": bbox_list, | |
| "classes": classes, | |
| "confidences": confidences, | |
| "scores": score, | |
| "selected_indices": selected_indices, | |
| "images": {"original": img, "yolo": img_yolo, "selected": img_selected}, | |
| } |