""" Jiggle Physics Simulator — HuggingFace Space backend. ZeroGPU requires demo.launch() (static @spaces.GPU detection isn't enough). But demo.launch() creates a fresh FastAPI app inside Gradio's App.create_app, discarding anything we registered on demo.app beforehand. Solution: monkey-patch gradio.routes.App.create_app so we can inject our /jiggle/* routes (and CORS middleware) into the *new* app the moment Gradio creates it, before it starts handling requests. """ import base64 import collections import io import json import sys import time from typing import Optional import gradio as gr import gradio.routes import spaces import numpy as np from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response from PIL import Image # ── In-process log ring buffer ─────────────────────────────────────────────── _log_buffer: collections.deque = collections.deque(maxlen=100) _log_cursor = 0 class _TeeStream: def __init__(self, real): self._real = real def write(self, s): self._real.write(s) if s.strip(): global _log_cursor _log_cursor += 1 _log_buffer.append({"id": _log_cursor, "t": time.time(), "msg": s.rstrip()}) def flush(self): self._real.flush() def __getattr__(self, name): return getattr(self._real, name) sys.stdout = _TeeStream(sys.stdout) def _load_image(upload: UploadFile) -> Image.Image: data = upload.file.read() img = Image.open(io.BytesIO(data)).convert("RGB") max_dim = 1024 if max(img.size) > max_dim: ratio = max_dim / max(img.size) img = img.resize( (int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS ) return img # ── GPU implementations (standalone so ZeroGPU can detect them) ────────────── @spaces.GPU def _gpu_segment(img: Image.Image, region_list: list, clicks) -> dict: from segmentation import segment_regions return segment_regions(img, region_list, clicks) @spaces.GPU def _gpu_depth(img: Image.Image) -> dict: from depth import estimate_depth return estimate_depth(img) @spaces.GPU def _gpu_reconstruct( img: Image.Image, mask: np.ndarray, bbox_list: list, use_triposr: bool, ) -> bytes: from reconstruction import reconstruct_region, depth_to_mesh from depth import estimate_depth if use_triposr: return reconstruct_region(img, mask.tolist(), bbox_list) depth_result = estimate_depth(img) return depth_to_mesh(depth_result["depth"], mask.tolist(), img) @spaces.GPU def _gpu_pose( tpose_img: Image.Image, target_img: Image.Image, region_list: list, ) -> tuple: from pose import detect_landmarks, compute_region_transform tpose_lm = detect_landmarks(tpose_img) target_lm = detect_landmarks(target_img) transforms = { region: compute_region_transform(tpose_lm, target_lm, region) for region in region_list } return transforms, tpose_lm, target_lm # ── Route registration — called against the live app inside create_app ────── def _register_jiggle_routes(app: FastAPI) -> None: @app.get("/jiggle/health") def jiggle_health(): def _warm(module_name: str, cache_attr: str) -> bool: try: import importlib mod = importlib.import_module(module_name) return getattr(mod, cache_attr) is not None except Exception: return False def _is_cached(model_id: str) -> bool: try: import os cache_root = os.path.expanduser("~/.cache/huggingface/hub") slug = "models--" + model_id.replace("/", "--") return os.path.isdir(os.path.join(cache_root, slug)) except Exception: return False try: models = { "sam2": {"warm": _warm("segmentation", "_sam2_cache"), "cached_on_disk": _is_cached("facebook/sam2-hiera-large")}, "triposr": {"warm": _warm("reconstruction", "_triposr_cache"), "cached_on_disk": _is_cached("stabilityai/TripoSR")}, "depth_pro": {"warm": _warm("depth", "_depth_cache"), "cached_on_disk": _is_cached("apple/DepthPro-hf")}, } return {"status": "ok", "models": models} except Exception as e: return {"status": "error", "detail": str(e), "models": {}} @app.get("/jiggle/logs") def get_logs(since: int = 0): lines = [e for e in _log_buffer if e["id"] > since] return {"lines": lines, "cursor": _log_cursor} @app.post("/jiggle/segment") def segment( image: UploadFile = File(...), regions: str = Form("breast_left,breast_right,buttocks"), click_points: Optional[str] = Form(None), ): img = _load_image(image) region_list = [r.strip() for r in regions.split(",") if r.strip()] clicks = json.loads(click_points) if click_points else None try: result = _gpu_segment(img, region_list, clicks) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) encoded = {} for region, data in result.items(): mask_arr = np.array(data["mask"], dtype=bool) flat = mask_arr.flatten() rle: list[int] = [] current = bool(flat[0]) count = 0 for val in flat: if bool(val) == current: count += 1 else: rle.append(count) count = 1 current = bool(val) rle.append(count) encoded[region] = { "rle": rle, "rle_start": bool(flat[0]), "shape": list(mask_arr.shape), "bbox": data["bbox"], } return {"regions": encoded, "image_size": [img.width, img.height]} @app.post("/jiggle/depth") def depth(image: UploadFile = File(...)): img = _load_image(image) try: result = _gpu_depth(img) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) arr = np.array(result["depth"], dtype=np.float32) b64 = base64.b64encode(arr.tobytes()).decode("ascii") return { "depth_b64": b64, "width": result["width"], "height": result["height"], "min": result["min"], "max": result["max"], "dtype": "float32", } @app.post("/jiggle/reconstruct") def reconstruct( image: UploadFile = File(...), mask_rle: str = Form(...), mask_shape: str = Form(...), mask_rle_start: str = Form("false"), bbox: str = Form(...), use_triposr: str = Form("true"), ): img = _load_image(image) rle = json.loads(mask_rle) shape = json.loads(mask_shape) start_val = mask_rle_start.lower() == "true" flat: list[bool] = [] current = start_val for run in rle: flat.extend([current] * run) current = not current mask = np.array(flat, dtype=bool).reshape(shape) bbox_list = json.loads(bbox) try: glb_bytes = _gpu_reconstruct(img, mask, bbox_list, use_triposr.lower() == "true") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) return Response(content=glb_bytes, media_type="application/octet-stream") @app.post("/jiggle/pose") def pose_detect( tpose_image: UploadFile = File(...), target_image: UploadFile = File(...), regions: str = Form("breast_left,breast_right,buttocks"), ): tpose_img = _load_image(tpose_image) target_img = _load_image(target_image) region_list = [r.strip() for r in regions.split(",") if r.strip()] try: transforms, tpose_lm, target_lm = _gpu_pose(tpose_img, target_img, region_list) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # orjson (Gradio's response renderer) rejects non-string dict keys — # detect_landmarks returns {int: {...}} so we stringify before returning. return { "transforms": transforms, "tpose_landmarks": {str(k): v for k, v in tpose_lm.items()}, "target_landmarks": {str(k): v for k, v in target_lm.items()}, } # Promote our routes ahead of Gradio's catch-all (GET /{path:path} SPA route) routes = app.router.routes ours = [r for r in routes if getattr(r, "path", "").startswith("/jiggle/")] others = [r for r in routes if not getattr(r, "path", "").startswith("/jiggle/")] app.router.routes = ours + others print(f"Injected {len(ours)} /jiggle/* routes into Gradio's App") # ── Monkey-patch gradio.routes.App.create_app ─────────────────────────────── # Gradio's create_app builds the FastAPI app fresh each time demo.launch() # runs. We wrap it to (1) add CORS middleware before the app starts handling # requests, and (2) register our /jiggle/* routes on the same app instance # Gradio is about to serve. _original_create_app = gradio.routes.App.create_app def _patched_create_app(*args, **kwargs): app = _original_create_app(*args, **kwargs) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) _register_jiggle_routes(app) return app gradio.routes.App.create_app = staticmethod(_patched_create_app) # ── Gradio Blocks (UI shown at /) ─────────────────────────────────────────── with gr.Blocks(title="Jiggle Physics API") as demo: gr.Markdown( "## Jiggle Physics ML API\n" "Endpoints: `/jiggle/health` `/jiggle/segment` `/jiggle/depth` " "`/jiggle/reconstruct` `/jiggle/pose`" ) if __name__ == "__main__": # ssr_mode=False: keep Gradio's Node.js SSR layer off so GET requests # reach FastAPI directly. demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)