Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import io | |
| import traceback | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageFilter | |
| from fastapi import FastAPI, Request, UploadFile, File, Form | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from webui.runner import ModelRunner | |
| from webui.weights import get_weights_dir | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app = FastAPI() | |
| # CORS for local frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:8000", "http://127.0.0.1:8000", "null"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] # repo root | |
| WEBUI_DIR = Path(__file__).resolve().parent | |
| UPLOAD_DIR = WEBUI_DIR / "uploads" | |
| RESULT_DIR = WEBUI_DIR / "results" | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| RESULT_DIR.mkdir(parents=True, exist_ok=True) | |
| app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results") | |
| def health(): | |
| return {"ok": True} | |
| # ---- weights repo ---- | |
| WEIGHTS_REPO = os.getenv("TASKCLIP_WEIGHTS_REPO", "BiasLab2025/taskclip-weights") | |
| WEIGHTS_DIR = get_weights_dir(WEIGHTS_REPO) | |
| CKPT_DIR = WEIGHTS_DIR / "checkpoints" | |
| DECODER_DIR = WEIGHTS_DIR / "test_model" | |
| VLM_CHOICES = [ | |
| {"label": "imagebind", "value": "imagebind", "folder": "imagebind"}, | |
| {"label": "ViT-B", "value": "vit-b", "folder": "ViT-B"}, | |
| {"label": "ViT-L", "value": "vit-l", "folder": "ViT-L"}, | |
| ] | |
| VLM_VALUE_TO_FOLDER = {x["value"]: x["folder"] for x in VLM_CHOICES} | |
| SCORE_FUNCS = ["default", "HDC"] | |
| HDV_DIMS = [128, 256, 512, 1024] | |
| DEFAULT_VLM = "imagebind" | |
| DEFAULT_HDV = 256 | |
| DEFAULT_SCORE_FUNC = "default" | |
| DEFAULT_TASKCLIP_CKPT = str(DECODER_DIR / "default" / "decoder.pt") | |
| OD_CHOICES = [ | |
| {"label": "nano", "value": "nano", "ckpt": str(CKPT_DIR / "yolo12n.pt")}, | |
| {"label": "small", "value": "small", "ckpt": str(CKPT_DIR / "yolo12s.pt")}, | |
| {"label": "median", "value": "median", "ckpt": str(CKPT_DIR / "yolo12m.pt")}, | |
| {"label": "large", "value": "large", "ckpt": str(CKPT_DIR / "yolo12l.pt")}, | |
| {"label": "xlarge", "value": "xlarge", "ckpt": str(CKPT_DIR / "yolo12x.pt")}, | |
| ] | |
| OD_VALUE_TO_CKPT = {x["value"]: x["ckpt"] for x in OD_CHOICES} | |
| DEFAULT_OD = "xlarge" | |
| DEFAULT_SAM_CKPT = str(CKPT_DIR / "sam2.1_l.pt") | |
| DEFAULT_IMAGEBIND_CKPT = str(CKPT_DIR / "imagebind_huge.pth") # optional but recommended | |
| def _clamp_int(x, lo=0, hi=100) -> int: | |
| try: | |
| v = int(x) | |
| except Exception: | |
| v = 0 | |
| return max(lo, min(hi, v)) | |
| def apply_noise_pil(img: Image.Image, noise_type: str, strength_0_100: int) -> Image.Image: | |
| """ | |
| Simple input-noise layer applied before running YOLO/TaskCLIP. | |
| strength_0_100: 0..100 | |
| """ | |
| strength = _clamp_int(strength_0_100, 0, 100) | |
| t = (noise_type or "none").lower() | |
| if strength == 0 or t in ["none", "default", "off"]: | |
| return img | |
| arr = np.asarray(img).astype(np.float32) | |
| if t == "gaussian": | |
| # sigma in [0, 25] roughly | |
| sigma = (strength / 100.0) * 25.0 | |
| noise = np.random.normal(0.0, sigma, size=arr.shape).astype(np.float32) | |
| out = np.clip(arr + noise, 0, 255).astype(np.uint8) | |
| return Image.fromarray(out) | |
| if t == "linear": | |
| # simple brightness/contrast-like linear shift | |
| alpha = 1.0 + (strength / 100.0) * 0.6 # 1.0 -> 1.6 | |
| beta = (strength / 100.0) * 20.0 # 0 -> 20 | |
| out = np.clip(arr * alpha + beta, 0, 255).astype(np.uint8) | |
| return Image.fromarray(out) | |
| # adversarial-ish synthetic corruptions (fast, deterministic-ish) | |
| if t in ["adv", "adv_rand_sign"]: | |
| amp = (strength / 100.0) * 18.0 | |
| sign = np.random.choice([-1.0, 1.0], size=arr.shape).astype(np.float32) | |
| out = np.clip(arr + sign * amp, 0, 255).astype(np.uint8) | |
| return Image.fromarray(out) | |
| if t == "adv_edge_sign": | |
| # edge sign from Laplacian filter, then apply sign perturbation | |
| gray = img.convert("L").filter(ImageFilter.FIND_EDGES) | |
| g = np.asarray(gray).astype(np.float32) / 255.0 | |
| sign2d = np.where(g > 0.2, 1.0, -1.0).astype(np.float32) # crude edge mask | |
| amp = (strength / 100.0) * 18.0 | |
| sign = np.repeat(sign2d[..., None], 3, axis=2) | |
| out = np.clip(arr + sign * amp, 0, 255).astype(np.uint8) | |
| return Image.fromarray(out) | |
| if t == "adv_patch": | |
| # random square occlusion / noise patch | |
| out = arr.copy() | |
| w, h = img.size | |
| s = int(min(w, h) * (0.10 + 0.30 * (strength / 100.0))) # 10% -> 40% | |
| x0 = np.random.randint(0, max(1, w - s)) | |
| y0 = np.random.randint(0, max(1, h - s)) | |
| patch = np.random.uniform(0, 255, size=(s, s, 3)).astype(np.float32) | |
| out[y0:y0 + s, x0:x0 + s, :] = patch | |
| return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8)) | |
| if t == "adv_stripes": | |
| out = arr.copy() | |
| h, w = out.shape[0], out.shape[1] | |
| period = max(4, int(40 - 30 * (strength / 100.0))) # 40 -> 10 | |
| amp = (strength / 100.0) * 35.0 | |
| for x in range(0, w, period): | |
| out[:, x:x+2, :] = np.clip(out[:, x:x+2, :] + amp, 0, 255) | |
| return Image.fromarray(out.astype(np.uint8)) | |
| if t == "adv_jpeg": | |
| # JPEG compression artifacts | |
| quality = int(95 - (strength / 100.0) * 75) # 95 -> 20 | |
| quality = max(10, min(95, quality)) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=quality) | |
| buf.seek(0) | |
| return Image.open(buf).convert("RGB") | |
| # fallback: no-op | |
| return img | |
| # ---- Load runner ONCE at startup ---- | |
| device_env = os.getenv("DEVICE", "").strip() | |
| if device_env: | |
| device = device_env | |
| else: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| runner = ModelRunner( | |
| project_root=str(PROJECT_ROOT), | |
| device=device, | |
| yolo_ckpt=OD_VALUE_TO_CKPT[DEFAULT_OD], | |
| sam_ckpt=DEFAULT_SAM_CKPT, | |
| imagebind_ckpt=DEFAULT_IMAGEBIND_CKPT, | |
| id2task_name_file="./id2task_name.json", | |
| task2prompt_file="./task20.json", | |
| threshold=0.01, | |
| forward=True, | |
| cluster=True, | |
| forward_thre=0.1, | |
| ) | |
| """ | |
| @app.get("/", response_class=HTMLResponse) | |
| def index(request: Request): | |
| task_ids = runner.list_task_ids() | |
| task_items = [(tid, runner.id2task_name.get(str(tid), f"task_{tid}")) for tid in task_ids] | |
| return templates.TemplateResponse( | |
| "index.html", | |
| { | |
| "request": request, | |
| "vlm_choices": VLM_CHOICES, | |
| "default_vlm": DEFAULT_VLM, | |
| "score_funcs": SCORE_FUNCS, | |
| "default_score_func": DEFAULT_SCORE_FUNC, | |
| "hdv_dims": HDV_DIMS, | |
| "default_hdv_dim": DEFAULT_HDV, | |
| "od_choices": OD_CHOICES, | |
| "default_od": DEFAULT_OD, | |
| "task_ids": runner.list_task_ids(), | |
| "task_items": task_items | |
| }, | |
| ) | |
| """ | |
| def root(): | |
| return {"ok": True, "message": "Backend is running. Use POST /api/run and open /docs."} | |
| def api_meta(): | |
| task_ids = runner.list_task_ids() | |
| task_items = [(tid, runner.id2task_name.get(str(tid), f"task_{tid}")) for tid in task_ids] | |
| return { | |
| "vlm_choices": VLM_CHOICES, | |
| "od_choices": OD_CHOICES, | |
| "hdv_dims": HDV_DIMS, | |
| "score_funcs": SCORE_FUNCS, | |
| "defaults": { | |
| "vlm": DEFAULT_VLM, | |
| "od": DEFAULT_OD, | |
| "hdv_dim": DEFAULT_HDV, | |
| "score_func": DEFAULT_SCORE_FUNC, | |
| }, | |
| "task_items": task_items, | |
| } | |
| async def api_run( | |
| request: Request, | |
| vlm_model: str = Form(DEFAULT_VLM), | |
| od_model: str = Form(DEFAULT_OD), | |
| task_id: int = Form(1), | |
| score_function: str = Form(DEFAULT_SCORE_FUNC), | |
| hdv_dim: int = Form(DEFAULT_HDV), | |
| viz_mode: str = Form("bbox"), | |
| upload: UploadFile = File(...), | |
| noise_type: str = Form("none"), | |
| noise_strength: int = Form(0), | |
| hw_noise_dist: str = Form("none"), | |
| hw_noise_width: int = Form(0), | |
| hw_noise_strength: int = Form(0), | |
| hdc_bits: int = Form(32), | |
| ): | |
| # validate + pick decoder | |
| if score_function not in SCORE_FUNCS: | |
| return JSONResponse({"ok": False, "error": f"Unknown score_function: {score_function}"}, status_code=400) | |
| if score_function == "HDC": | |
| if hdv_dim not in HDV_DIMS: | |
| return JSONResponse({"ok": False, "error": f"Unsupported hdv_dim: {hdv_dim}"}, status_code=400) | |
| vlm_folder = VLM_VALUE_TO_FOLDER.get(vlm_model) | |
| if not vlm_folder: | |
| return JSONResponse({"ok": False, "error": f"Unknown vlm_model: {vlm_model}"}, status_code=400) | |
| taskclip_ckpt = str(DECODER_DIR / vlm_folder / f"8Layer_4Head_HDV_{hdv_dim}" / "decoder.pt") | |
| else: | |
| taskclip_ckpt = DEFAULT_TASKCLIP_CKPT | |
| # pick yolo ckpt | |
| yolo_ckpt = OD_VALUE_TO_CKPT.get(od_model) | |
| if not yolo_ckpt: | |
| return JSONResponse({"ok": False, "error": f"Unknown od_model size: {od_model}"}, status_code=400) | |
| # save upload (apply noise first) | |
| job_id = uuid.uuid4().hex | |
| suffix = Path(upload.filename).suffix or ".jpg" | |
| upload_path = UPLOAD_DIR / f"{job_id}{suffix}" | |
| raw = await upload.read() | |
| try: | |
| img = Image.open(io.BytesIO(raw)).convert("RGB") | |
| except Exception: | |
| return JSONResponse({"ok": False, "error": "Failed to decode image upload"}, status_code=400) | |
| img = apply_noise_pil(img, noise_type=noise_type, strength_0_100=noise_strength) | |
| img.save(upload_path, quality=95) | |
| # run | |
| try: | |
| out = runner.run( | |
| image_path=str(upload_path), | |
| task_id=int(task_id), | |
| vlm_model=vlm_model, | |
| od_model="yolo", | |
| yolo_ckpt=yolo_ckpt, | |
| score_function=score_function, | |
| hdv_dim=int(hdv_dim), | |
| taskclip_ckpt=taskclip_ckpt, | |
| viz_mode=viz_mode, | |
| hw_noise_dist=hw_noise_dist, | |
| hw_noise_width=int(hw_noise_width), | |
| hw_noise_strength=int(hw_noise_strength), | |
| hdc_bits=hdc_bits | |
| ) | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| print(tb) | |
| return JSONResponse({"ok": False, "error": str(e), "traceback": tb}, status_code=500) | |
| # save results | |
| job_dir = RESULT_DIR / job_id | |
| job_dir.mkdir(parents=True, exist_ok=True) | |
| p_in = job_dir / "input.jpg" | |
| p_yolo = job_dir / "yolo.jpg" | |
| p_sel = job_dir / "selected.jpg" | |
| out["images"]["original"].save(p_in, quality=95) | |
| out["images"]["yolo"].save(p_yolo, quality=95) | |
| out["images"]["selected"].save(p_sel, quality=95) | |
| base = str(request.base_url).rstrip("/") | |
| return { | |
| "ok": True, | |
| "job_id": job_id, | |
| "task_id": out["task_id"], | |
| "task_name": out["task_name"], | |
| "selected_indices": out["selected_indices"], | |
| "image_urls": { | |
| "input": f"{base}/results/{job_id}/input.jpg", | |
| "yolo": f"{base}/results/{job_id}/yolo.jpg", | |
| "selected": f"{base}/results/{job_id}/selected.jpg", | |
| }, | |
| } | |