Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| from typing import Dict, Any, List, Tuple, Optional | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from ultralytics import YOLO, SAM | |
| from ImageBind.imagebind import data | |
| from ImageBind.imagebind.models import imagebind_model | |
| from ImageBind.imagebind.models.imagebind_model import ModalityType | |
| from models.TaskCLIP import TaskCLIP | |
| # ... keep your helper funcs _draw_boxes_pil/_crop_pil/overlay_masks ... | |
| class ModelRunner: | |
| def __init__( | |
| self, | |
| project_root: str, | |
| device: str = "cuda:0", | |
| yolo_ckpt: str = "./.checkpoints/yolo12x.pt", | |
| sam_ckpt: str = "./.checkpoints/sam2.1_l.pt", | |
| imagebind_ckpt: Optional[str] = None, # NEW | |
| id2task_name_file: str = "./id2task_name.json", | |
| task2prompt_file: str = "./task20.json", | |
| threshold: float = 0.01, | |
| forward: bool = True, | |
| cluster: bool = True, | |
| forward_thre: float = 0.1, | |
| ): | |
| self.root = Path(project_root).resolve() | |
| self.device = device | |
| self.threshold = float(threshold) | |
| self.forward = bool(forward) | |
| self.cluster = bool(cluster) | |
| self.forward_thre = float(forward_thre) | |
| # metadata | |
| self.id2task_name_path = (self.root / id2task_name_file).resolve() | |
| self.task2prompt_path = (self.root / task2prompt_file).resolve() | |
| self.id2task_name = json.loads(self.id2task_name_path.read_text()) | |
| self.task2prompt = json.loads(self.task2prompt_path.read_text()) | |
| # caches | |
| self._yolo_cache = {} | |
| self._taskclip_cache = {} | |
| # YOLO path (kept for reference; actual YOLO models are cached per ckpt in _get_yolo) | |
| self.yolo_ckpt_path = (self.root / yolo_ckpt).resolve() if str(yolo_ckpt).startswith(".") else Path(yolo_ckpt) | |
| # ---- SAM load ONCE (from absolute or repo-relative path) ---- | |
| sam_ckpt_path = (self.root / sam_ckpt).resolve() if str(sam_ckpt).startswith(".") else Path(sam_ckpt) | |
| self.sam = SAM(str(sam_ckpt_path)) | |
| # ---- ImageBind load ONCE ---- | |
| # If you provide imagebind_huge.pth from weights repo, use it. | |
| # Otherwise fall back to pretrained=True behavior. | |
| self.vlm_model = imagebind_model.imagebind_huge(pretrained=False).to(self.device).eval() | |
| if imagebind_ckpt: | |
| ckpt_path = (self.root / imagebind_ckpt).resolve() if str(imagebind_ckpt).startswith(".") else Path(imagebind_ckpt) | |
| if ckpt_path.exists(): | |
| state = torch.load(str(ckpt_path), map_location="cpu") | |
| # robust handling of different checkpoint formats | |
| if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict): | |
| state = state["model"] | |
| elif isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict): | |
| state = state["state_dict"] | |
| self.vlm_model.load_state_dict(state, strict=False) | |
| else: | |
| # fallback if file missing | |
| self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval() | |
| else: | |
| self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval() | |
| self._lock = torch.multiprocessing.RLock() | |
| def _get_yolo(self, ckpt_path: str): | |
| ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path) | |
| if ckpt_abs not in self._yolo_cache: | |
| self._yolo_cache[ckpt_abs] = YOLO(ckpt_abs) | |
| return self._yolo_cache[ckpt_abs] | |
| def list_task_ids(self) -> List[int]: | |
| ids = [] | |
| for k in self.id2task_name.keys(): | |
| try: | |
| ids.append(int(k)) | |
| except Exception: | |
| pass | |
| return sorted(ids) | |
| def _get_taskclip(self, ckpt_path: str, d_model: int, n_words: int, score_function: str, hdv_dim: int): | |
| ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path) | |
| if not Path(ckpt_abs).exists(): | |
| raise FileNotFoundError(f"TaskCLIP checkpoint not found: {ckpt_abs}") | |
| eff_hdv_dim = int(hdv_dim) if score_function == "HDC" else 0 | |
| key = (ckpt_abs, int(d_model), int(n_words), str(score_function), eff_hdv_dim) | |
| if key in self._taskclip_cache: | |
| return self._taskclip_cache[key] | |
| model_config = { | |
| "num_layers": 8, | |
| "norm": None, | |
| "return_intermediate": False, | |
| "d_model": int(d_model), | |
| "nhead": 4, | |
| "dim_feedforward": 2048, | |
| "dropout": 0.1, | |
| "N_words": int(n_words), | |
| "activation": "gelu", | |
| "normalize_before": False, | |
| "device": self.device, | |
| "ratio_text": 0.3, | |
| "ratio_image": 0.3, | |
| "ratio_glob": 0.3, | |
| "norm_before": True, | |
| "norm_after": False, | |
| "MIN_VAL": 10.0, | |
| "MAX_VAL": 30.0, | |
| "cross_attention": True, # keep consistent with how your checkpoint was trained | |
| "score_function": "HDC" if score_function == "HDC" else "default", | |
| "HDV_D": int(eff_hdv_dim), | |
| } | |
| m = TaskCLIP(model_config, normalize_before=model_config["normalize_before"], device=model_config["device"]) | |
| state = torch.load(ckpt_abs, map_location="cpu") | |
| m.load_state_dict(state, strict=True) | |
| m = m.to(self.device).eval() | |
| self._taskclip_cache[key] = m | |
| return m | |
| def _sam_masks_from_bboxes(self, image_path: str, bbox_list: List[List[float]], img_h: int, img_w: int) -> np.ndarray: | |
| if not bbox_list: | |
| return np.zeros((0, img_h, img_w), dtype=bool) | |
| bboxes = [[float(x0), float(y0), float(x1), float(y1)] for x0, y0, x1, y1 in bbox_list] | |
| # multi-box call | |
| res = self.sam(image_path, bboxes=bboxes) | |
| r0 = res[0] | |
| if r0.masks is None: | |
| return np.zeros((0, img_h, img_w), dtype=bool) | |
| return r0.masks.data.detach().cpu().numpy().astype(bool) | |
| def run( | |
| self, | |
| image_path: str, | |
| task_id: int, | |
| vlm_model: str = "imagebind", | |
| od_model: str = "yolo", | |
| yolo_ckpt: str = "./.checkpoints/yolo12x.pt", | |
| score_function: str = "default", | |
| hdv_dim: int = 256, | |
| taskclip_ckpt: str = "./test_model/default/decoder.pt", | |
| viz_mode: str = "bbox", | |
| ) -> Dict[str, Any]: | |
| if vlm_model != "imagebind": | |
| raise ValueError("This runner.py currently implements ImageBind only (your OpenCLIP version was in the other runner).") | |
| if od_model != "yolo": | |
| raise ValueError("Only od_model='yolo' supported.") | |
| with self._lock: | |
| img = Image.open(image_path).convert("RGB") | |
| task_name = self.id2task_name[str(task_id)] | |
| prompt_words = self.task2prompt[task_name] | |
| prompt_use = ["The item is " + w for w in prompt_words] | |
| # YOLO | |
| yolo = self._get_yolo(yolo_ckpt) | |
| outputs = yolo(image_path) | |
| bbox_list = outputs[0].boxes.xyxy.tolist() | |
| classes = outputs[0].boxes.cls.tolist() | |
| confidences = outputs[0].boxes.conf.tolist() | |
| all_boxes = np.asarray(bbox_list, dtype=np.float32) | |
| H = img.size[1] | |
| W = img.size[0] | |
| # IMPORTANT: only run SAM if viz_mode == mask | |
| all_masks = None | |
| if viz_mode == "bbox": | |
| img_yolo = _draw_boxes_pil(img, all_boxes, color=(0, 255, 0), width=3) | |
| elif viz_mode == "mask": | |
| all_masks = self._sam_masks_from_bboxes(image_path, bbox_list, img_h=H, img_w=W) | |
| img_yolo = overlay_masks(img, all_masks, alpha=0.35, color=(0, 255, 0)) | |
| else: | |
| raise ValueError(f"Unknown viz_mode={viz_mode}") | |
| # crops | |
| seg_list, _ = _crop_pil(img, bbox_list) | |
| if len(seg_list) == 0: | |
| return { | |
| "task_id": task_id, | |
| "task_name": task_name, | |
| "bbox_list": bbox_list, | |
| "selected_indices": [], | |
| "images": {"original": img, "yolo": img_yolo, "selected": img.copy()}, | |
| } | |
| # ImageBind embeddings | |
| with torch.no_grad(): | |
| input_pack = { | |
| ModalityType.TEXT: data.load_and_transform_text(prompt_use, self.device), | |
| ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device), | |
| } | |
| emb = self.vlm_model(input_pack) | |
| text_embeddings = emb[ModalityType.TEXT] | |
| bbox_embeddings = emb[ModalityType.VISION] | |
| input_pack2 = {ModalityType.VISION: data.read_and_transform_vision_data([img], self.device)} | |
| emb2 = self.vlm_model(input_pack2) | |
| image_embedding = emb2[ModalityType.VISION].squeeze(0) | |
| # TaskCLIP | |
| taskclip = self._get_taskclip( | |
| ckpt_path=taskclip_ckpt, | |
| d_model=int(image_embedding.shape[-1]), | |
| n_words=int(text_embeddings.shape[0]), | |
| score_function=score_function, | |
| hdv_dim=hdv_dim, | |
| ) | |
| with torch.no_grad(): | |
| _, _, score_res, _ = taskclip(bbox_embeddings, text_embeddings, image_embedding.view(1, -1)) | |
| score = score_res.view(-1).detach().cpu().numpy().tolist() | |
| # ... keep your postprocess/selection logic unchanged ... | |
| # (use your existing code below this point) | |
| # return dict unchanged |