Spaces:
Sleeping
Sleeping
HanningChen
commited on
Commit
·
6feb3b2
1
Parent(s):
eb9b67e
Download weights from HF model repo and use cached paths
Browse files- requirements.txt +3 -0
- webui/app.py +38 -36
- webui/runner.py +76 -341
- webui/weights.py +14 -0
requirements.txt
CHANGED
|
@@ -23,3 +23,6 @@ opencv-python-headless==4.10.0.84
|
|
| 23 |
# --- Models / inference ---
|
| 24 |
ultralytics==8.4.3
|
| 25 |
open-clip-torch==2.24.0
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# --- Models / inference ---
|
| 24 |
ultralytics==8.4.3
|
| 25 |
open-clip-torch==2.24.0
|
| 26 |
+
|
| 27 |
+
huggingface_hub>=0.24.0
|
| 28 |
+
pytorchvideo==0.1.5
|
webui/app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
|
|
| 1 |
import uuid
|
| 2 |
from pathlib import Path
|
| 3 |
-
from typing import Optional
|
| 4 |
|
| 5 |
from fastapi import FastAPI, Request, UploadFile, File, Form
|
| 6 |
from fastapi.responses import HTMLResponse, JSONResponse
|
|
@@ -8,8 +8,9 @@ from fastapi.staticfiles import StaticFiles
|
|
| 8 |
from fastapi.templating import Jinja2Templates
|
| 9 |
|
| 10 |
from webui.runner import ModelRunner
|
|
|
|
| 11 |
|
| 12 |
-
PROJECT_ROOT = Path(__file__).resolve().parents[1] #
|
| 13 |
WEBUI_DIR = Path(__file__).resolve().parent
|
| 14 |
UPLOAD_DIR = WEBUI_DIR / "uploads"
|
| 15 |
RESULT_DIR = WEBUI_DIR / "results"
|
|
@@ -22,6 +23,13 @@ templates = Jinja2Templates(directory=str(WEBUI_DIR / "templates"))
|
|
| 22 |
app.mount("/static", StaticFiles(directory=str(WEBUI_DIR / "static")), name="static")
|
| 23 |
app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
VLM_CHOICES = [
|
| 26 |
{"label": "imagebind", "value": "imagebind", "folder": "imagebind"},
|
| 27 |
{"label": "ViT-B", "value": "vit-b", "folder": "ViT-B"},
|
|
@@ -34,24 +42,29 @@ HDV_DIMS = [128, 256, 512, 1024]
|
|
| 34 |
|
| 35 |
DEFAULT_VLM = "imagebind"
|
| 36 |
DEFAULT_HDV = 256
|
| 37 |
-
DEFAULT_TASKCLIP_CKPT = "./test_model/default/decoder.pt"
|
| 38 |
DEFAULT_SCORE_FUNC = "default"
|
|
|
|
| 39 |
|
| 40 |
OD_CHOICES = [
|
| 41 |
-
{"label": "nano", "value": "nano", "ckpt": "
|
| 42 |
-
{"label": "small", "value": "small", "ckpt": "
|
| 43 |
-
{"label": "median", "value": "median", "ckpt": "
|
| 44 |
-
{"label": "large", "value": "large", "ckpt": "
|
| 45 |
-
{"label": "xlarge", "value": "xlarge", "ckpt": "
|
| 46 |
]
|
| 47 |
OD_VALUE_TO_CKPT = {x["value"]: x["ckpt"] for x in OD_CHOICES}
|
| 48 |
DEFAULT_OD = "xlarge"
|
| 49 |
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
| 51 |
runner = ModelRunner(
|
| 52 |
project_root=str(PROJECT_ROOT),
|
| 53 |
-
device="cuda:0",
|
| 54 |
-
yolo_ckpt=
|
|
|
|
|
|
|
| 55 |
id2task_name_file="./id2task_name.json",
|
| 56 |
task2prompt_file="./task20.json",
|
| 57 |
threshold=0.01,
|
|
@@ -78,7 +91,6 @@ def index(request: Request):
|
|
| 78 |
},
|
| 79 |
)
|
| 80 |
|
| 81 |
-
|
| 82 |
@app.post("/api/run")
|
| 83 |
async def api_run(
|
| 84 |
vlm_model: str = Form(DEFAULT_VLM),
|
|
@@ -89,58 +101,48 @@ async def api_run(
|
|
| 89 |
viz_mode: str = Form("bbox"),
|
| 90 |
upload: UploadFile = File(...),
|
| 91 |
):
|
| 92 |
-
#
|
| 93 |
if score_function not in SCORE_FUNCS:
|
| 94 |
return JSONResponse({"ok": False, "error": f"Unknown score_function: {score_function}"}, status_code=400)
|
| 95 |
-
|
| 96 |
if score_function == "HDC":
|
| 97 |
if hdv_dim not in HDV_DIMS:
|
| 98 |
return JSONResponse({"ok": False, "error": f"Unsupported hdv_dim: {hdv_dim}"}, status_code=400)
|
| 99 |
-
|
| 100 |
vlm_folder = VLM_VALUE_TO_FOLDER.get(vlm_model)
|
| 101 |
if not vlm_folder:
|
| 102 |
return JSONResponse({"ok": False, "error": f"Unknown vlm_model: {vlm_model}"}, status_code=400)
|
| 103 |
-
|
| 104 |
-
taskclip_ckpt = f"./test_model/{vlm_folder}/8Layer_4Head_HDV_{hdv_dim}/decoder.pt"
|
| 105 |
else:
|
| 106 |
taskclip_ckpt = DEFAULT_TASKCLIP_CKPT
|
| 107 |
|
| 108 |
-
|
| 109 |
-
return JSONResponse(
|
| 110 |
-
{"ok": False, "error": "score_function=default only supports vlm_model=imagebind. Use HDC for vit-b/vit-l."},
|
| 111 |
-
status_code=400
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
# get yolo checkpoint
|
| 115 |
yolo_ckpt = OD_VALUE_TO_CKPT.get(od_model)
|
| 116 |
if not yolo_ckpt:
|
| 117 |
return JSONResponse({"ok": False, "error": f"Unknown od_model size: {od_model}"}, status_code=400)
|
| 118 |
|
| 119 |
-
#
|
| 120 |
suffix = Path(upload.filename).suffix or ".jpg"
|
| 121 |
job_id = uuid.uuid4().hex
|
| 122 |
upload_path = UPLOAD_DIR / f"{job_id}{suffix}"
|
| 123 |
upload_path.write_bytes(await upload.read())
|
| 124 |
|
| 125 |
-
#
|
| 126 |
try:
|
| 127 |
-
# print("[API] vlm_model", vlm_model, "score_function", score_function, "hdv_dim", hdv_dim, "taskclip_ckpt", taskclip_ckpt)
|
| 128 |
out = runner.run(
|
| 129 |
-
image_path=str(upload_path),
|
| 130 |
-
task_id=int(task_id),
|
| 131 |
-
vlm_model=vlm_model,
|
| 132 |
-
od_model=
|
| 133 |
yolo_ckpt=yolo_ckpt,
|
| 134 |
score_function=score_function,
|
| 135 |
-
hdv_dim=hdv_dim,
|
| 136 |
-
taskclip_ckpt=taskclip_ckpt,
|
| 137 |
viz_mode=viz_mode,
|
| 138 |
)
|
| 139 |
except Exception as e:
|
| 140 |
return JSONResponse({"ok": False, "error": repr(e)}, status_code=500)
|
| 141 |
|
| 142 |
-
|
| 143 |
-
# Save 3 images to results/<job_id>/
|
| 144 |
job_dir = RESULT_DIR / job_id
|
| 145 |
job_dir.mkdir(parents=True, exist_ok=True)
|
| 146 |
|
|
@@ -163,4 +165,4 @@ async def api_run(
|
|
| 163 |
"yolo": f"/results/{job_id}/yolo.jpg",
|
| 164 |
"selected": f"/results/{job_id}/selected.jpg",
|
| 165 |
},
|
| 166 |
-
}
|
|
|
|
| 1 |
+
import os
|
| 2 |
import uuid
|
| 3 |
from pathlib import Path
|
|
|
|
| 4 |
|
| 5 |
from fastapi import FastAPI, Request, UploadFile, File, Form
|
| 6 |
from fastapi.responses import HTMLResponse, JSONResponse
|
|
|
|
| 8 |
from fastapi.templating import Jinja2Templates
|
| 9 |
|
| 10 |
from webui.runner import ModelRunner
|
| 11 |
+
from webui.weights import get_weights_dir
|
| 12 |
|
| 13 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1] # repo root
|
| 14 |
WEBUI_DIR = Path(__file__).resolve().parent
|
| 15 |
UPLOAD_DIR = WEBUI_DIR / "uploads"
|
| 16 |
RESULT_DIR = WEBUI_DIR / "results"
|
|
|
|
| 23 |
app.mount("/static", StaticFiles(directory=str(WEBUI_DIR / "static")), name="static")
|
| 24 |
app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
|
| 25 |
|
| 26 |
+
# ---- weights repo ----
|
| 27 |
+
WEIGHTS_REPO = os.getenv("TASKCLIP_WEIGHTS_REPO", "BiasLab2025/YOUR-WEIGHTS-REPO") # <-- change default
|
| 28 |
+
WEIGHTS_DIR = get_weights_dir(WEIGHTS_REPO)
|
| 29 |
+
|
| 30 |
+
CKPT_DIR = WEIGHTS_DIR / "checkpoints"
|
| 31 |
+
DECODER_DIR = WEIGHTS_DIR / "test_model"
|
| 32 |
+
|
| 33 |
VLM_CHOICES = [
|
| 34 |
{"label": "imagebind", "value": "imagebind", "folder": "imagebind"},
|
| 35 |
{"label": "ViT-B", "value": "vit-b", "folder": "ViT-B"},
|
|
|
|
| 42 |
|
| 43 |
DEFAULT_VLM = "imagebind"
|
| 44 |
DEFAULT_HDV = 256
|
|
|
|
| 45 |
DEFAULT_SCORE_FUNC = "default"
|
| 46 |
+
DEFAULT_TASKCLIP_CKPT = str(DECODER_DIR / "default" / "decoder.pt")
|
| 47 |
|
| 48 |
OD_CHOICES = [
|
| 49 |
+
{"label": "nano", "value": "nano", "ckpt": str(CKPT_DIR / "yolo12n.pt")},
|
| 50 |
+
{"label": "small", "value": "small", "ckpt": str(CKPT_DIR / "yolo12s.pt")},
|
| 51 |
+
{"label": "median", "value": "median", "ckpt": str(CKPT_DIR / "yolo12m.pt")},
|
| 52 |
+
{"label": "large", "value": "large", "ckpt": str(CKPT_DIR / "yolo12l.pt")},
|
| 53 |
+
{"label": "xlarge", "value": "xlarge", "ckpt": str(CKPT_DIR / "yolo12x.pt")},
|
| 54 |
]
|
| 55 |
OD_VALUE_TO_CKPT = {x["value"]: x["ckpt"] for x in OD_CHOICES}
|
| 56 |
DEFAULT_OD = "xlarge"
|
| 57 |
|
| 58 |
+
DEFAULT_SAM_CKPT = str(CKPT_DIR / "sam2.1_l.pt")
|
| 59 |
+
DEFAULT_IMAGEBIND_CKPT = str(CKPT_DIR / "imagebind_huge.pth") # optional but recommended
|
| 60 |
+
|
| 61 |
+
# ---- Load runner ONCE at startup ----
|
| 62 |
runner = ModelRunner(
|
| 63 |
project_root=str(PROJECT_ROOT),
|
| 64 |
+
device=os.getenv("DEVICE", "cuda:0"),
|
| 65 |
+
yolo_ckpt=OD_VALUE_TO_CKPT[DEFAULT_OD],
|
| 66 |
+
sam_ckpt=DEFAULT_SAM_CKPT,
|
| 67 |
+
imagebind_ckpt=DEFAULT_IMAGEBIND_CKPT, # if missing, runner can fall back to pretrained=True
|
| 68 |
id2task_name_file="./id2task_name.json",
|
| 69 |
task2prompt_file="./task20.json",
|
| 70 |
threshold=0.01,
|
|
|
|
| 91 |
},
|
| 92 |
)
|
| 93 |
|
|
|
|
| 94 |
@app.post("/api/run")
|
| 95 |
async def api_run(
|
| 96 |
vlm_model: str = Form(DEFAULT_VLM),
|
|
|
|
| 101 |
viz_mode: str = Form("bbox"),
|
| 102 |
upload: UploadFile = File(...),
|
| 103 |
):
|
| 104 |
+
# validate + pick decoder
|
| 105 |
if score_function not in SCORE_FUNCS:
|
| 106 |
return JSONResponse({"ok": False, "error": f"Unknown score_function: {score_function}"}, status_code=400)
|
| 107 |
+
|
| 108 |
if score_function == "HDC":
|
| 109 |
if hdv_dim not in HDV_DIMS:
|
| 110 |
return JSONResponse({"ok": False, "error": f"Unsupported hdv_dim: {hdv_dim}"}, status_code=400)
|
|
|
|
| 111 |
vlm_folder = VLM_VALUE_TO_FOLDER.get(vlm_model)
|
| 112 |
if not vlm_folder:
|
| 113 |
return JSONResponse({"ok": False, "error": f"Unknown vlm_model: {vlm_model}"}, status_code=400)
|
| 114 |
+
taskclip_ckpt = str(DECODER_DIR / vlm_folder / f"8Layer_4Head_HDV_{hdv_dim}" / "decoder.pt")
|
|
|
|
| 115 |
else:
|
| 116 |
taskclip_ckpt = DEFAULT_TASKCLIP_CKPT
|
| 117 |
|
| 118 |
+
# pick yolo ckpt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
yolo_ckpt = OD_VALUE_TO_CKPT.get(od_model)
|
| 120 |
if not yolo_ckpt:
|
| 121 |
return JSONResponse({"ok": False, "error": f"Unknown od_model size: {od_model}"}, status_code=400)
|
| 122 |
|
| 123 |
+
# save upload
|
| 124 |
suffix = Path(upload.filename).suffix or ".jpg"
|
| 125 |
job_id = uuid.uuid4().hex
|
| 126 |
upload_path = UPLOAD_DIR / f"{job_id}{suffix}"
|
| 127 |
upload_path.write_bytes(await upload.read())
|
| 128 |
|
| 129 |
+
# run
|
| 130 |
try:
|
|
|
|
| 131 |
out = runner.run(
|
| 132 |
+
image_path=str(upload_path),
|
| 133 |
+
task_id=int(task_id),
|
| 134 |
+
vlm_model=vlm_model,
|
| 135 |
+
od_model="yolo",
|
| 136 |
yolo_ckpt=yolo_ckpt,
|
| 137 |
score_function=score_function,
|
| 138 |
+
hdv_dim=int(hdv_dim),
|
| 139 |
+
taskclip_ckpt=taskclip_ckpt,
|
| 140 |
viz_mode=viz_mode,
|
| 141 |
)
|
| 142 |
except Exception as e:
|
| 143 |
return JSONResponse({"ok": False, "error": repr(e)}, status_code=500)
|
| 144 |
|
| 145 |
+
# save results
|
|
|
|
| 146 |
job_dir = RESULT_DIR / job_id
|
| 147 |
job_dir.mkdir(parents=True, exist_ok=True)
|
| 148 |
|
|
|
|
| 165 |
"yolo": f"/results/{job_id}/yolo.jpg",
|
| 166 |
"selected": f"/results/{job_id}/selected.jpg",
|
| 167 |
},
|
| 168 |
+
}
|
webui/runner.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
from pathlib import Path
|
| 3 |
-
from typing import Dict, Any, List, Tuple
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
|
@@ -8,91 +8,22 @@ from PIL import Image, ImageDraw
|
|
| 8 |
|
| 9 |
from ultralytics import YOLO, SAM
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
import sys
|
| 15 |
-
from pathlib import Path
|
| 16 |
-
|
| 17 |
-
REPO_ROOT = Path(__file__).resolve().parents[1] # repo/
|
| 18 |
-
sys.path.insert(0, str(REPO_ROOT / "ImageBind")) # so "import imagebind" works
|
| 19 |
-
|
| 20 |
-
from imagebind import data
|
| 21 |
-
from imagebind.models import imagebind_model
|
| 22 |
-
from imagebind.models.imagebind_model import ModalityType
|
| 23 |
-
|
| 24 |
-
import open_clip
|
| 25 |
|
| 26 |
from models.TaskCLIP import TaskCLIP
|
| 27 |
|
| 28 |
-
|
| 29 |
-
def _draw_boxes_pil(
|
| 30 |
-
img: Image.Image,
|
| 31 |
-
boxes_xyxy: np.ndarray,
|
| 32 |
-
color: Tuple[int, int, int],
|
| 33 |
-
width: int = 3,
|
| 34 |
-
) -> Image.Image:
|
| 35 |
-
out = img.copy()
|
| 36 |
-
draw = ImageDraw.Draw(out)
|
| 37 |
-
if boxes_xyxy is None or len(boxes_xyxy) == 0:
|
| 38 |
-
return out
|
| 39 |
-
for (x0, y0, x1, y1) in boxes_xyxy.tolist():
|
| 40 |
-
draw.rectangle([x0, y0, x1, y1], outline=color, width=width)
|
| 41 |
-
return out
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def _crop_pil(img: Image.Image, bbox_list: List[List[float]]) -> Tuple[List[Image.Image], List[int]]:
|
| 45 |
-
"""Return list of cropped PIL images + indices mapping back to bbox_list."""
|
| 46 |
-
W, H = img.size
|
| 47 |
-
crops = []
|
| 48 |
-
idxs = []
|
| 49 |
-
for i, (x0, y0, x1, y1) in enumerate(bbox_list):
|
| 50 |
-
x0 = max(0, min(W, int(x0)))
|
| 51 |
-
y0 = max(0, min(H, int(y0)))
|
| 52 |
-
x1 = max(0, min(W, int(x1)))
|
| 53 |
-
y1 = max(0, min(H, int(y1)))
|
| 54 |
-
if x1 <= x0 or y1 <= y0:
|
| 55 |
-
continue
|
| 56 |
-
crops.append(img.crop((x0, y0, x1, y1)))
|
| 57 |
-
idxs.append(i)
|
| 58 |
-
return crops, idxs
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def overlay_masks(
|
| 62 |
-
img: Image.Image,
|
| 63 |
-
masks: np.ndarray,
|
| 64 |
-
alpha: float = 0.40,
|
| 65 |
-
color: Tuple[int, int, int] = (255, 0, 0),
|
| 66 |
-
) -> Image.Image:
|
| 67 |
-
if masks is None or len(masks) == 0:
|
| 68 |
-
return img
|
| 69 |
-
|
| 70 |
-
base = np.array(img).astype(np.float32)
|
| 71 |
-
union = np.any(masks.astype(bool), axis=0) # (H, W)
|
| 72 |
-
if not np.any(union):
|
| 73 |
-
return img
|
| 74 |
-
|
| 75 |
-
overlay = base.copy()
|
| 76 |
-
overlay[union] = overlay[union] * 0.2 + np.array(color, dtype=np.float32) * 0.8
|
| 77 |
-
out = base * (1 - alpha) + overlay * alpha
|
| 78 |
-
return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8))
|
| 79 |
-
|
| 80 |
|
| 81 |
class ModelRunner:
|
| 82 |
-
"""
|
| 83 |
-
WebUI runner:
|
| 84 |
-
- YOLO detects bboxes
|
| 85 |
-
- VLM (ImageBind or OpenCLIP) embeds text prompts and crops (+ global image)
|
| 86 |
-
- TaskCLIP scores and selects bboxes
|
| 87 |
-
- optionally visualize bbox or SAM masks
|
| 88 |
-
"""
|
| 89 |
-
|
| 90 |
def __init__(
|
| 91 |
self,
|
| 92 |
project_root: str,
|
| 93 |
device: str = "cuda:0",
|
| 94 |
yolo_ckpt: str = "./.checkpoints/yolo12x.pt",
|
| 95 |
sam_ckpt: str = "./.checkpoints/sam2.1_l.pt",
|
|
|
|
| 96 |
id2task_name_file: str = "./id2task_name.json",
|
| 97 |
task2prompt_file: str = "./task20.json",
|
| 98 |
threshold: float = 0.01,
|
|
@@ -107,101 +38,51 @@ class ModelRunner:
|
|
| 107 |
self.cluster = bool(cluster)
|
| 108 |
self.forward_thre = float(forward_thre)
|
| 109 |
|
| 110 |
-
#
|
| 111 |
self.id2task_name_path = (self.root / id2task_name_file).resolve()
|
| 112 |
self.task2prompt_path = (self.root / task2prompt_file).resolve()
|
| 113 |
-
self.yolo_ckpt_path = (self.root / yolo_ckpt).resolve()
|
| 114 |
-
|
| 115 |
-
# load task metadata
|
| 116 |
self.id2task_name = json.loads(self.id2task_name_path.read_text())
|
| 117 |
self.task2prompt = json.loads(self.task2prompt_path.read_text())
|
| 118 |
|
| 119 |
# caches
|
| 120 |
-
self._vlm_cache = {}
|
| 121 |
self._yolo_cache = {}
|
| 122 |
self._taskclip_cache = {}
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
sam_ckpt_path = (self.root / sam_ckpt).resolve() if str(sam_ckpt).startswith(".") else Path(sam_ckpt)
|
| 125 |
self.sam = SAM(str(sam_ckpt_path))
|
| 126 |
|
| 127 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
self._lock = torch.multiprocessing.RLock()
|
| 129 |
|
| 130 |
def _get_yolo(self, ckpt_path: str):
|
| 131 |
-
ckpt_abs = str((self.root / ckpt_path).resolve()) if ckpt_path.startswith(".") else ckpt_path
|
| 132 |
if ckpt_abs not in self._yolo_cache:
|
| 133 |
self._yolo_cache[ckpt_abs] = YOLO(ckpt_abs)
|
| 134 |
return self._yolo_cache[ckpt_abs]
|
| 135 |
|
| 136 |
-
def _get_vlm(self, vlm_model: str):
|
| 137 |
-
if vlm_model in self._vlm_cache:
|
| 138 |
-
return self._vlm_cache[vlm_model]
|
| 139 |
-
|
| 140 |
-
if vlm_model == "imagebind":
|
| 141 |
-
m = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
|
| 142 |
-
pack = {"kind": "imagebind", "model": m}
|
| 143 |
-
elif vlm_model == "vit-b":
|
| 144 |
-
m, _, preprocess = open_clip.create_model_and_transforms(
|
| 145 |
-
"ViT-B-32", pretrained="laion2b_s34b_b79k"
|
| 146 |
-
)
|
| 147 |
-
m = m.to(self.device).eval()
|
| 148 |
-
tokenizer = open_clip.get_tokenizer("ViT-B-32")
|
| 149 |
-
pack = {"kind": "openclip", "model": m, "preprocess": preprocess, "tokenizer": tokenizer}
|
| 150 |
-
elif vlm_model == "vit-l":
|
| 151 |
-
m, _, preprocess = open_clip.create_model_and_transforms(
|
| 152 |
-
"ViT-L-14", pretrained="laion2b_s32b_b82k"
|
| 153 |
-
)
|
| 154 |
-
m = m.to(self.device).eval()
|
| 155 |
-
tokenizer = open_clip.get_tokenizer("ViT-L-14")
|
| 156 |
-
pack = {"kind": "openclip", "model": m, "preprocess": preprocess, "tokenizer": tokenizer}
|
| 157 |
-
else:
|
| 158 |
-
raise ValueError(f"Unknown vlm_model: {vlm_model}")
|
| 159 |
-
|
| 160 |
-
self._vlm_cache[vlm_model] = pack
|
| 161 |
-
return pack
|
| 162 |
-
|
| 163 |
-
def _encode_vlm(self, vlm_model: str, prompt_use, seg_list, full_img_pil):
|
| 164 |
-
pack = self._get_vlm(vlm_model)
|
| 165 |
-
|
| 166 |
-
with torch.inference_mode():
|
| 167 |
-
if pack["kind"] == "imagebind":
|
| 168 |
-
input_pack = {
|
| 169 |
-
ModalityType.TEXT: data.load_and_transform_text(prompt_use, self.device),
|
| 170 |
-
ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device),
|
| 171 |
-
}
|
| 172 |
-
emb = pack["model"](input_pack)
|
| 173 |
-
text_embeddings = emb[ModalityType.TEXT]
|
| 174 |
-
bbox_embeddings = emb[ModalityType.VISION]
|
| 175 |
-
|
| 176 |
-
input_pack2 = {ModalityType.VISION: data.read_and_transform_vision_data([full_img_pil], self.device)}
|
| 177 |
-
emb2 = pack["model"](input_pack2)
|
| 178 |
-
image_embedding = emb2[ModalityType.VISION].squeeze(0)
|
| 179 |
-
|
| 180 |
-
return text_embeddings, bbox_embeddings, image_embedding
|
| 181 |
-
|
| 182 |
-
# openclip branch
|
| 183 |
-
m = pack["model"]
|
| 184 |
-
preprocess = pack["preprocess"]
|
| 185 |
-
tokenizer = pack["tokenizer"]
|
| 186 |
-
|
| 187 |
-
# text
|
| 188 |
-
text = tokenizer(prompt_use).to(self.device)
|
| 189 |
-
text_embeddings = m.encode_text(text).float()
|
| 190 |
-
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
|
| 191 |
-
|
| 192 |
-
# bbox crops
|
| 193 |
-
crop_tensors = [preprocess(im) for im in seg_list]
|
| 194 |
-
crop_batch = torch.stack(crop_tensors, dim=0).to(self.device)
|
| 195 |
-
bbox_embeddings = m.encode_image(crop_batch).float()
|
| 196 |
-
bbox_embeddings = bbox_embeddings / bbox_embeddings.norm(dim=-1, keepdim=True)
|
| 197 |
-
|
| 198 |
-
# global image
|
| 199 |
-
img_tensor = preprocess(full_img_pil).unsqueeze(0).to(self.device)
|
| 200 |
-
image_embedding = m.encode_image(img_tensor).float().squeeze(0)
|
| 201 |
-
image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
|
| 202 |
-
|
| 203 |
-
return text_embeddings, bbox_embeddings, image_embedding
|
| 204 |
-
|
| 205 |
def list_task_ids(self) -> List[int]:
|
| 206 |
ids = []
|
| 207 |
for k in self.id2task_name.keys():
|
|
@@ -211,73 +92,16 @@ class ModelRunner:
|
|
| 211 |
pass
|
| 212 |
return sorted(ids)
|
| 213 |
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
# supports {"state_dict": ...} style checkpoints
|
| 217 |
-
if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict):
|
| 218 |
-
return obj["state_dict"]
|
| 219 |
-
if isinstance(obj, dict):
|
| 220 |
-
return obj
|
| 221 |
-
raise TypeError(f"Unsupported checkpoint format: {type(obj)}")
|
| 222 |
-
|
| 223 |
-
def _infer_ckpt_flags(self, state: Dict[str, torch.Tensor]) -> Tuple[bool, bool, int]:
|
| 224 |
-
# infer (is_hdc, has_cross_attention, ckpt_d_model)
|
| 225 |
-
keys = list(state.keys())
|
| 226 |
-
is_hdc = any(k.startswith("ScoreFunction.HDReason.") for k in keys)
|
| 227 |
-
has_cross = any("cross_attn_text" in k for k in keys)
|
| 228 |
-
|
| 229 |
-
if "decoder_norm.weight" in state:
|
| 230 |
-
ckpt_d_model = int(state["decoder_norm.weight"].shape[0])
|
| 231 |
-
elif "ScoreFunction.norm.weight" in state:
|
| 232 |
-
ckpt_d_model = int(state["ScoreFunction.norm.weight"].shape[0])
|
| 233 |
-
else:
|
| 234 |
-
ckpt_d_model = -1
|
| 235 |
-
|
| 236 |
-
return is_hdc, has_cross, ckpt_d_model
|
| 237 |
-
|
| 238 |
-
def _get_taskclip(
|
| 239 |
-
self,
|
| 240 |
-
ckpt_path: str,
|
| 241 |
-
d_model: int,
|
| 242 |
-
n_words: int,
|
| 243 |
-
score_function: str,
|
| 244 |
-
hdv_dim: int,
|
| 245 |
-
cross_attention: bool,
|
| 246 |
-
):
|
| 247 |
-
ckpt_abs = str((self.root / ckpt_path).resolve()) if ckpt_path.startswith(".") else ckpt_path
|
| 248 |
if not Path(ckpt_abs).exists():
|
| 249 |
raise FileNotFoundError(f"TaskCLIP checkpoint not found: {ckpt_abs}")
|
| 250 |
|
| 251 |
eff_hdv_dim = int(hdv_dim) if score_function == "HDC" else 0
|
| 252 |
-
|
| 253 |
-
# IMPORTANT: cache key must include cross_attention + score_function
|
| 254 |
-
key = (ckpt_abs, int(d_model), int(n_words), str(score_function), int(eff_hdv_dim), bool(cross_attention))
|
| 255 |
if key in self._taskclip_cache:
|
| 256 |
return self._taskclip_cache[key]
|
| 257 |
|
| 258 |
-
state_raw = torch.load(ckpt_abs, map_location="cpu")
|
| 259 |
-
state = self._unwrap_state_dict(state_raw)
|
| 260 |
-
|
| 261 |
-
ckpt_is_hdc, ckpt_has_cross, ckpt_d_model = self._infer_ckpt_flags(state)
|
| 262 |
-
|
| 263 |
-
# Validate score_function against checkpoint
|
| 264 |
-
if score_function == "HDC" and not ckpt_is_hdc:
|
| 265 |
-
raise RuntimeError(f"Checkpoint is NOT HDC but score_function=HDC was selected. ckpt={ckpt_abs}")
|
| 266 |
-
if score_function != "HDC" and ckpt_is_hdc:
|
| 267 |
-
raise RuntimeError(f"Checkpoint IS HDC but score_function=default was selected. ckpt={ckpt_abs}")
|
| 268 |
-
|
| 269 |
-
# Validate cross_attention against checkpoint (your training differs by family)
|
| 270 |
-
if bool(cross_attention) != bool(ckpt_has_cross):
|
| 271 |
-
raise RuntimeError(
|
| 272 |
-
f"cross_attention mismatch: runtime={cross_attention} but checkpoint has_cross_attention={ckpt_has_cross}. ckpt={ckpt_abs}"
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
# Validate d_model against checkpoint
|
| 276 |
-
if ckpt_d_model != -1 and int(d_model) != int(ckpt_d_model):
|
| 277 |
-
raise RuntimeError(
|
| 278 |
-
f"d_model mismatch: VLM produced d_model={int(d_model)} but checkpoint expects d_model={int(ckpt_d_model)}. ckpt={ckpt_abs}"
|
| 279 |
-
)
|
| 280 |
-
|
| 281 |
model_config = {
|
| 282 |
"num_layers": 8,
|
| 283 |
"norm": None,
|
|
@@ -297,52 +121,31 @@ class ModelRunner:
|
|
| 297 |
"norm_after": False,
|
| 298 |
"MIN_VAL": 10.0,
|
| 299 |
"MAX_VAL": 30.0,
|
| 300 |
-
"cross_attention":
|
| 301 |
"score_function": "HDC" if score_function == "HDC" else "default",
|
| 302 |
"HDV_D": int(eff_hdv_dim),
|
| 303 |
}
|
| 304 |
|
| 305 |
m = TaskCLIP(model_config, normalize_before=model_config["normalize_before"], device=model_config["device"])
|
|
|
|
| 306 |
m.load_state_dict(state, strict=True)
|
| 307 |
m = m.to(self.device).eval()
|
| 308 |
|
| 309 |
self._taskclip_cache[key] = m
|
| 310 |
return m
|
| 311 |
|
| 312 |
-
def _find_same_class(self, predict_res, score, visited, i, classes, confs, forward_thre):
|
| 313 |
-
cls_i = classes[i]
|
| 314 |
-
for j in range(len(score)):
|
| 315 |
-
if visited[j] == 1:
|
| 316 |
-
continue
|
| 317 |
-
if classes[j] == cls_i and float(score[j]) > forward_thre:
|
| 318 |
-
visited[j] = 1
|
| 319 |
-
predict_res[j]["category_id"] = 1
|
| 320 |
-
predict_res[j]["score"] = float(score[j])
|
| 321 |
-
|
| 322 |
def _sam_masks_from_bboxes(self, image_path: str, bbox_list: List[List[float]], img_h: int, img_w: int) -> np.ndarray:
|
| 323 |
if not bbox_list:
|
| 324 |
return np.zeros((0, img_h, img_w), dtype=bool)
|
| 325 |
|
| 326 |
bboxes = [[float(x0), float(y0), float(x1), float(y1)] for x0, y0, x1, y1 in bbox_list]
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
return masks
|
| 335 |
-
except Exception:
|
| 336 |
-
masks_list = []
|
| 337 |
-
for bb in bboxes:
|
| 338 |
-
rr = self.sam(image_path, bboxes=bb)[0]
|
| 339 |
-
if rr.masks is None:
|
| 340 |
-
continue
|
| 341 |
-
m = rr.masks.data.detach().cpu().numpy().astype(bool)
|
| 342 |
-
masks_list.append(m[0])
|
| 343 |
-
if len(masks_list) == 0:
|
| 344 |
-
return np.zeros((0, img_h, img_w), dtype=bool)
|
| 345 |
-
return np.stack(masks_list, axis=0)
|
| 346 |
|
| 347 |
def run(
|
| 348 |
self,
|
|
@@ -356,47 +159,41 @@ class ModelRunner:
|
|
| 356 |
taskclip_ckpt: str = "./test_model/default/decoder.pt",
|
| 357 |
viz_mode: str = "bbox",
|
| 358 |
) -> Dict[str, Any]:
|
| 359 |
-
if vlm_model not in ["imagebind", "vit-b", "vit-l"]:
|
| 360 |
-
raise ValueError(f"Unknown vlm_model: {vlm_model}")
|
| 361 |
|
|
|
|
|
|
|
| 362 |
if od_model != "yolo":
|
| 363 |
-
raise ValueError("
|
| 364 |
-
|
| 365 |
-
if viz_mode not in ["bbox", "mask"]:
|
| 366 |
-
raise ValueError(f"Unknown viz_mode={viz_mode}")
|
| 367 |
-
|
| 368 |
-
# training truth:
|
| 369 |
-
# - default used cross_attention=True
|
| 370 |
-
# - HDC used cross_attention=False
|
| 371 |
-
cross_attention = (score_function != "HDC")
|
| 372 |
|
| 373 |
with self._lock:
|
| 374 |
img = Image.open(image_path).convert("RGB")
|
| 375 |
-
|
| 376 |
task_name = self.id2task_name[str(task_id)]
|
| 377 |
prompt_words = self.task2prompt[task_name]
|
| 378 |
prompt_use = ["The item is " + w for w in prompt_words]
|
| 379 |
|
| 380 |
-
# YOLO
|
| 381 |
yolo = self._get_yolo(yolo_ckpt)
|
| 382 |
outputs = yolo(image_path)
|
| 383 |
bbox_list = outputs[0].boxes.xyxy.tolist()
|
| 384 |
classes = outputs[0].boxes.cls.tolist()
|
| 385 |
confidences = outputs[0].boxes.conf.tolist()
|
| 386 |
|
| 387 |
-
H, W = img.size[1], img.size[0]
|
| 388 |
all_boxes = np.asarray(bbox_list, dtype=np.float32)
|
|
|
|
|
|
|
| 389 |
|
| 390 |
-
#
|
|
|
|
| 391 |
if viz_mode == "bbox":
|
| 392 |
img_yolo = _draw_boxes_pil(img, all_boxes, color=(0, 255, 0), width=3)
|
| 393 |
-
|
| 394 |
-
else:
|
| 395 |
all_masks = self._sam_masks_from_bboxes(image_path, bbox_list, img_h=H, img_w=W)
|
| 396 |
img_yolo = overlay_masks(img, all_masks, alpha=0.35, color=(0, 255, 0))
|
|
|
|
|
|
|
| 397 |
|
| 398 |
-
#
|
| 399 |
-
seg_list,
|
| 400 |
if len(seg_list) == 0:
|
| 401 |
return {
|
| 402 |
"task_id": task_id,
|
|
@@ -406,95 +203,33 @@ class ModelRunner:
|
|
| 406 |
"images": {"original": img, "yolo": img_yolo, "selected": img.copy()},
|
| 407 |
}
|
| 408 |
|
| 409 |
-
#
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
if int(bbox_embeddings.shape[-1]) != int(image_embedding.shape[-1]):
|
| 419 |
-
raise RuntimeError(
|
| 420 |
-
f"Embedding dim mismatch: bbox_embeddings dim={bbox_embeddings.shape[-1]} vs image_embedding dim={image_embedding.shape[-1]}"
|
| 421 |
-
)
|
| 422 |
|
| 423 |
-
|
| 424 |
-
|
|
|
|
| 425 |
|
| 426 |
-
# TaskCLIP
|
| 427 |
taskclip = self._get_taskclip(
|
| 428 |
ckpt_path=taskclip_ckpt,
|
| 429 |
-
d_model=
|
| 430 |
-
n_words=
|
| 431 |
score_function=score_function,
|
| 432 |
hdv_dim=hdv_dim,
|
| 433 |
-
cross_attention=cross_attention,
|
| 434 |
)
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
tgt = bbox_embeddings
|
| 439 |
-
memory = text_embeddings
|
| 440 |
-
image_embedding_2d = image_embedding.view(1, -1)
|
| 441 |
-
_, _, score_res, _ = taskclip(tgt, memory, image_embedding_2d)
|
| 442 |
score = score_res.view(-1).detach().cpu().numpy().tolist()
|
| 443 |
|
| 444 |
-
#
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
predict_res.append({"category_id": -1, "score": -1, "class": int(classes[i])})
|
| 448 |
-
|
| 449 |
-
visited = [0] * len(score)
|
| 450 |
-
for i, x in enumerate(score):
|
| 451 |
-
if visited[i] == 1:
|
| 452 |
-
continue
|
| 453 |
-
if float(x) > self.threshold:
|
| 454 |
-
visited[i] = 1
|
| 455 |
-
predict_res[i]["category_id"] = 1
|
| 456 |
-
predict_res[i]["score"] = float(x)
|
| 457 |
-
if self.forward:
|
| 458 |
-
self._find_same_class(predict_res, score, visited, i, classes, confidences, self.forward_thre)
|
| 459 |
-
else:
|
| 460 |
-
predict_res[i]["category_id"] = 0
|
| 461 |
-
predict_res[i]["score"] = 1.0 - float(x)
|
| 462 |
-
|
| 463 |
-
# cluster optimization
|
| 464 |
-
if self.cluster and self.forward and len(seg_list) > 1:
|
| 465 |
-
cluster_scores = {}
|
| 466 |
-
for p in predict_res:
|
| 467 |
-
if int(p["category_id"]) == 1:
|
| 468 |
-
c = p["class"]
|
| 469 |
-
cluster_scores.setdefault(c, []).append(p["score"])
|
| 470 |
-
|
| 471 |
-
if len(cluster_scores) > 1:
|
| 472 |
-
cluster_ave = {c: float(np.mean(v)) for c, v in cluster_scores.items()}
|
| 473 |
-
select_class = max(cluster_ave, key=lambda k: cluster_ave[k])
|
| 474 |
-
for p in predict_res:
|
| 475 |
-
if p["category_id"] == 1 and p["class"] != select_class:
|
| 476 |
-
p["category_id"] = 0
|
| 477 |
-
|
| 478 |
-
selected_indices = [i for i, p in enumerate(predict_res) if int(p["category_id"]) == 1]
|
| 479 |
-
selected_boxes = all_boxes[selected_indices] if len(selected_indices) > 0 else np.zeros((0, 4), dtype=np.float32)
|
| 480 |
-
|
| 481 |
-
# visualize selected
|
| 482 |
-
if viz_mode == "bbox":
|
| 483 |
-
img_selected = _draw_boxes_pil(img, selected_boxes, color=(255, 0, 0), width=4)
|
| 484 |
-
else:
|
| 485 |
-
if all_masks is not None and all_masks.shape[0] > 0 and len(selected_indices) > 0:
|
| 486 |
-
sel_masks = all_masks[selected_indices]
|
| 487 |
-
else:
|
| 488 |
-
sel_masks = np.zeros((0, H, W), dtype=bool)
|
| 489 |
-
img_selected = overlay_masks(img, sel_masks, alpha=0.45, color=(255, 0, 0))
|
| 490 |
-
|
| 491 |
-
return {
|
| 492 |
-
"task_id": task_id,
|
| 493 |
-
"task_name": task_name,
|
| 494 |
-
"bbox_list": bbox_list,
|
| 495 |
-
"classes": classes,
|
| 496 |
-
"confidences": confidences,
|
| 497 |
-
"scores": score,
|
| 498 |
-
"selected_indices": selected_indices,
|
| 499 |
-
"images": {"original": img, "yolo": img_yolo, "selected": img_selected},
|
| 500 |
-
}
|
|
|
|
| 1 |
import json
|
| 2 |
from pathlib import Path
|
| 3 |
+
from typing import Dict, Any, List, Tuple, Optional
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
|
|
|
| 8 |
|
| 9 |
from ultralytics import YOLO, SAM
|
| 10 |
|
| 11 |
+
from ImageBind.imagebind import data
|
| 12 |
+
from ImageBind.imagebind.models import imagebind_model
|
| 13 |
+
from ImageBind.imagebind.models.imagebind_model import ModalityType
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from models.TaskCLIP import TaskCLIP
|
| 16 |
|
| 17 |
+
# ... keep your helper funcs _draw_boxes_pil/_crop_pil/overlay_masks ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
class ModelRunner:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def __init__(
|
| 21 |
self,
|
| 22 |
project_root: str,
|
| 23 |
device: str = "cuda:0",
|
| 24 |
yolo_ckpt: str = "./.checkpoints/yolo12x.pt",
|
| 25 |
sam_ckpt: str = "./.checkpoints/sam2.1_l.pt",
|
| 26 |
+
imagebind_ckpt: Optional[str] = None, # NEW
|
| 27 |
id2task_name_file: str = "./id2task_name.json",
|
| 28 |
task2prompt_file: str = "./task20.json",
|
| 29 |
threshold: float = 0.01,
|
|
|
|
| 38 |
self.cluster = bool(cluster)
|
| 39 |
self.forward_thre = float(forward_thre)
|
| 40 |
|
| 41 |
+
# metadata
|
| 42 |
self.id2task_name_path = (self.root / id2task_name_file).resolve()
|
| 43 |
self.task2prompt_path = (self.root / task2prompt_file).resolve()
|
|
|
|
|
|
|
|
|
|
| 44 |
self.id2task_name = json.loads(self.id2task_name_path.read_text())
|
| 45 |
self.task2prompt = json.loads(self.task2prompt_path.read_text())
|
| 46 |
|
| 47 |
# caches
|
|
|
|
| 48 |
self._yolo_cache = {}
|
| 49 |
self._taskclip_cache = {}
|
| 50 |
|
| 51 |
+
# YOLO path (kept for reference; actual YOLO models are cached per ckpt in _get_yolo)
|
| 52 |
+
self.yolo_ckpt_path = (self.root / yolo_ckpt).resolve() if str(yolo_ckpt).startswith(".") else Path(yolo_ckpt)
|
| 53 |
+
|
| 54 |
+
# ---- SAM load ONCE (from absolute or repo-relative path) ----
|
| 55 |
sam_ckpt_path = (self.root / sam_ckpt).resolve() if str(sam_ckpt).startswith(".") else Path(sam_ckpt)
|
| 56 |
self.sam = SAM(str(sam_ckpt_path))
|
| 57 |
|
| 58 |
+
# ---- ImageBind load ONCE ----
|
| 59 |
+
# If you provide imagebind_huge.pth from weights repo, use it.
|
| 60 |
+
# Otherwise fall back to pretrained=True behavior.
|
| 61 |
+
self.vlm_model = imagebind_model.imagebind_huge(pretrained=False).to(self.device).eval()
|
| 62 |
+
if imagebind_ckpt:
|
| 63 |
+
ckpt_path = (self.root / imagebind_ckpt).resolve() if str(imagebind_ckpt).startswith(".") else Path(imagebind_ckpt)
|
| 64 |
+
if ckpt_path.exists():
|
| 65 |
+
state = torch.load(str(ckpt_path), map_location="cpu")
|
| 66 |
+
# robust handling of different checkpoint formats
|
| 67 |
+
if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
|
| 68 |
+
state = state["model"]
|
| 69 |
+
elif isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
|
| 70 |
+
state = state["state_dict"]
|
| 71 |
+
self.vlm_model.load_state_dict(state, strict=False)
|
| 72 |
+
else:
|
| 73 |
+
# fallback if file missing
|
| 74 |
+
self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
|
| 75 |
+
else:
|
| 76 |
+
self.vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(self.device).eval()
|
| 77 |
+
|
| 78 |
self._lock = torch.multiprocessing.RLock()
|
| 79 |
|
| 80 |
def _get_yolo(self, ckpt_path: str):
|
| 81 |
+
ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
|
| 82 |
if ckpt_abs not in self._yolo_cache:
|
| 83 |
self._yolo_cache[ckpt_abs] = YOLO(ckpt_abs)
|
| 84 |
return self._yolo_cache[ckpt_abs]
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
def list_task_ids(self) -> List[int]:
|
| 87 |
ids = []
|
| 88 |
for k in self.id2task_name.keys():
|
|
|
|
| 92 |
pass
|
| 93 |
return sorted(ids)
|
| 94 |
|
| 95 |
+
def _get_taskclip(self, ckpt_path: str, d_model: int, n_words: int, score_function: str, hdv_dim: int):
|
| 96 |
+
ckpt_abs = str((self.root / ckpt_path).resolve()) if str(ckpt_path).startswith(".") else str(ckpt_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
if not Path(ckpt_abs).exists():
|
| 98 |
raise FileNotFoundError(f"TaskCLIP checkpoint not found: {ckpt_abs}")
|
| 99 |
|
| 100 |
eff_hdv_dim = int(hdv_dim) if score_function == "HDC" else 0
|
| 101 |
+
key = (ckpt_abs, int(d_model), int(n_words), str(score_function), eff_hdv_dim)
|
|
|
|
|
|
|
| 102 |
if key in self._taskclip_cache:
|
| 103 |
return self._taskclip_cache[key]
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
model_config = {
|
| 106 |
"num_layers": 8,
|
| 107 |
"norm": None,
|
|
|
|
| 121 |
"norm_after": False,
|
| 122 |
"MIN_VAL": 10.0,
|
| 123 |
"MAX_VAL": 30.0,
|
| 124 |
+
"cross_attention": True, # keep consistent with how your checkpoint was trained
|
| 125 |
"score_function": "HDC" if score_function == "HDC" else "default",
|
| 126 |
"HDV_D": int(eff_hdv_dim),
|
| 127 |
}
|
| 128 |
|
| 129 |
m = TaskCLIP(model_config, normalize_before=model_config["normalize_before"], device=model_config["device"])
|
| 130 |
+
state = torch.load(ckpt_abs, map_location="cpu")
|
| 131 |
m.load_state_dict(state, strict=True)
|
| 132 |
m = m.to(self.device).eval()
|
| 133 |
|
| 134 |
self._taskclip_cache[key] = m
|
| 135 |
return m
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def _sam_masks_from_bboxes(self, image_path: str, bbox_list: List[List[float]], img_h: int, img_w: int) -> np.ndarray:
|
| 138 |
if not bbox_list:
|
| 139 |
return np.zeros((0, img_h, img_w), dtype=bool)
|
| 140 |
|
| 141 |
bboxes = [[float(x0), float(y0), float(x1), float(y1)] for x0, y0, x1, y1 in bbox_list]
|
| 142 |
|
| 143 |
+
# multi-box call
|
| 144 |
+
res = self.sam(image_path, bboxes=bboxes)
|
| 145 |
+
r0 = res[0]
|
| 146 |
+
if r0.masks is None:
|
| 147 |
+
return np.zeros((0, img_h, img_w), dtype=bool)
|
| 148 |
+
return r0.masks.data.detach().cpu().numpy().astype(bool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
def run(
|
| 151 |
self,
|
|
|
|
| 159 |
taskclip_ckpt: str = "./test_model/default/decoder.pt",
|
| 160 |
viz_mode: str = "bbox",
|
| 161 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
if vlm_model != "imagebind":
|
| 164 |
+
raise ValueError("This runner.py currently implements ImageBind only (your OpenCLIP version was in the other runner).")
|
| 165 |
if od_model != "yolo":
|
| 166 |
+
raise ValueError("Only od_model='yolo' supported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
with self._lock:
|
| 169 |
img = Image.open(image_path).convert("RGB")
|
|
|
|
| 170 |
task_name = self.id2task_name[str(task_id)]
|
| 171 |
prompt_words = self.task2prompt[task_name]
|
| 172 |
prompt_use = ["The item is " + w for w in prompt_words]
|
| 173 |
|
| 174 |
+
# YOLO
|
| 175 |
yolo = self._get_yolo(yolo_ckpt)
|
| 176 |
outputs = yolo(image_path)
|
| 177 |
bbox_list = outputs[0].boxes.xyxy.tolist()
|
| 178 |
classes = outputs[0].boxes.cls.tolist()
|
| 179 |
confidences = outputs[0].boxes.conf.tolist()
|
| 180 |
|
|
|
|
| 181 |
all_boxes = np.asarray(bbox_list, dtype=np.float32)
|
| 182 |
+
H = img.size[1]
|
| 183 |
+
W = img.size[0]
|
| 184 |
|
| 185 |
+
# IMPORTANT: only run SAM if viz_mode == mask
|
| 186 |
+
all_masks = None
|
| 187 |
if viz_mode == "bbox":
|
| 188 |
img_yolo = _draw_boxes_pil(img, all_boxes, color=(0, 255, 0), width=3)
|
| 189 |
+
elif viz_mode == "mask":
|
|
|
|
| 190 |
all_masks = self._sam_masks_from_bboxes(image_path, bbox_list, img_h=H, img_w=W)
|
| 191 |
img_yolo = overlay_masks(img, all_masks, alpha=0.35, color=(0, 255, 0))
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(f"Unknown viz_mode={viz_mode}")
|
| 194 |
|
| 195 |
+
# crops
|
| 196 |
+
seg_list, _ = _crop_pil(img, bbox_list)
|
| 197 |
if len(seg_list) == 0:
|
| 198 |
return {
|
| 199 |
"task_id": task_id,
|
|
|
|
| 203 |
"images": {"original": img, "yolo": img_yolo, "selected": img.copy()},
|
| 204 |
}
|
| 205 |
|
| 206 |
+
# ImageBind embeddings
|
| 207 |
+
with torch.no_grad():
|
| 208 |
+
input_pack = {
|
| 209 |
+
ModalityType.TEXT: data.load_and_transform_text(prompt_use, self.device),
|
| 210 |
+
ModalityType.VISION: data.read_and_transform_vision_data(seg_list, self.device),
|
| 211 |
+
}
|
| 212 |
+
emb = self.vlm_model(input_pack)
|
| 213 |
+
text_embeddings = emb[ModalityType.TEXT]
|
| 214 |
+
bbox_embeddings = emb[ModalityType.VISION]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
+
input_pack2 = {ModalityType.VISION: data.read_and_transform_vision_data([img], self.device)}
|
| 217 |
+
emb2 = self.vlm_model(input_pack2)
|
| 218 |
+
image_embedding = emb2[ModalityType.VISION].squeeze(0)
|
| 219 |
|
| 220 |
+
# TaskCLIP
|
| 221 |
taskclip = self._get_taskclip(
|
| 222 |
ckpt_path=taskclip_ckpt,
|
| 223 |
+
d_model=int(image_embedding.shape[-1]),
|
| 224 |
+
n_words=int(text_embeddings.shape[0]),
|
| 225 |
score_function=score_function,
|
| 226 |
hdv_dim=hdv_dim,
|
|
|
|
| 227 |
)
|
| 228 |
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
_, _, score_res, _ = taskclip(bbox_embeddings, text_embeddings, image_embedding.view(1, -1))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
score = score_res.view(-1).detach().cpu().numpy().tolist()
|
| 232 |
|
| 233 |
+
# ... keep your postprocess/selection logic unchanged ...
|
| 234 |
+
# (use your existing code below this point)
|
| 235 |
+
# return dict unchanged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
webui/weights.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# webui/weights.py
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from huggingface_hub import snapshot_download
|
| 5 |
+
|
| 6 |
+
def get_weights_dir(repo_id: str) -> Path:
|
| 7 |
+
token = os.getenv("HF_TOKEN") # only needed if repo is private
|
| 8 |
+
p = snapshot_download(
|
| 9 |
+
repo_id=repo_id,
|
| 10 |
+
local_dir="weights_cache",
|
| 11 |
+
local_dir_use_symlinks=False,
|
| 12 |
+
token=token,
|
| 13 |
+
)
|
| 14 |
+
return Path(p).resolve()
|