Spaces:
Sleeping
Sleeping
| """ | |
| Jiggle Physics Simulator β HuggingFace Space backend. | |
| ZeroGPU requires demo.launch() (static @spaces.GPU detection isn't enough). | |
| But demo.launch() creates a fresh FastAPI app inside Gradio's App.create_app, | |
| discarding anything we registered on demo.app beforehand. | |
| Solution: monkey-patch gradio.routes.App.create_app so we can inject our | |
| /jiggle/* routes (and CORS middleware) into the *new* app the moment | |
| Gradio creates it, before it starts handling requests. | |
| """ | |
| import base64 | |
| import collections | |
| import io | |
| import json | |
| import sys | |
| import time | |
| from typing import Optional | |
| import gradio as gr | |
| import gradio.routes | |
| import spaces | |
| import numpy as np | |
| from fastapi import FastAPI, File, Form, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import Response | |
| from PIL import Image | |
| # ββ In-process log ring buffer βββββββββββββββββββββββββββββββββββββββββββββββ | |
| _log_buffer: collections.deque = collections.deque(maxlen=100) | |
| _log_cursor = 0 | |
| class _TeeStream: | |
| def __init__(self, real): | |
| self._real = real | |
| def write(self, s): | |
| self._real.write(s) | |
| if s.strip(): | |
| global _log_cursor | |
| _log_cursor += 1 | |
| _log_buffer.append({"id": _log_cursor, "t": time.time(), "msg": s.rstrip()}) | |
| def flush(self): | |
| self._real.flush() | |
| def __getattr__(self, name): | |
| return getattr(self._real, name) | |
| sys.stdout = _TeeStream(sys.stdout) | |
| def _load_image(upload: UploadFile) -> Image.Image: | |
| data = upload.file.read() | |
| img = Image.open(io.BytesIO(data)).convert("RGB") | |
| max_dim = 1024 | |
| if max(img.size) > max_dim: | |
| ratio = max_dim / max(img.size) | |
| img = img.resize( | |
| (int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS | |
| ) | |
| return img | |
| # ββ GPU implementations (standalone so ZeroGPU can detect them) ββββββββββββββ | |
| def _gpu_segment(img: Image.Image, region_list: list, clicks) -> dict: | |
| from segmentation import segment_regions | |
| return segment_regions(img, region_list, clicks) | |
| def _gpu_depth(img: Image.Image) -> dict: | |
| from depth import estimate_depth | |
| return estimate_depth(img) | |
| def _gpu_reconstruct( | |
| img: Image.Image, | |
| mask: np.ndarray, | |
| bbox_list: list, | |
| use_triposr: bool, | |
| ) -> bytes: | |
| from reconstruction import reconstruct_region, depth_to_mesh | |
| from depth import estimate_depth | |
| if use_triposr: | |
| return reconstruct_region(img, mask.tolist(), bbox_list) | |
| depth_result = estimate_depth(img) | |
| return depth_to_mesh(depth_result["depth"], mask.tolist(), img) | |
| def _gpu_pose( | |
| tpose_img: Image.Image, | |
| target_img: Image.Image, | |
| region_list: list, | |
| ) -> tuple: | |
| from pose import detect_landmarks, compute_region_transform | |
| tpose_lm = detect_landmarks(tpose_img) | |
| target_lm = detect_landmarks(target_img) | |
| transforms = { | |
| region: compute_region_transform(tpose_lm, target_lm, region) | |
| for region in region_list | |
| } | |
| return transforms, tpose_lm, target_lm | |
| # ββ Route registration β called against the live app inside create_app ββββββ | |
| def _register_jiggle_routes(app: FastAPI) -> None: | |
| def jiggle_health(): | |
| def _warm(module_name: str, cache_attr: str) -> bool: | |
| try: | |
| import importlib | |
| mod = importlib.import_module(module_name) | |
| return getattr(mod, cache_attr) is not None | |
| except Exception: | |
| return False | |
| def _is_cached(model_id: str) -> bool: | |
| try: | |
| import os | |
| cache_root = os.path.expanduser("~/.cache/huggingface/hub") | |
| slug = "models--" + model_id.replace("/", "--") | |
| return os.path.isdir(os.path.join(cache_root, slug)) | |
| except Exception: | |
| return False | |
| try: | |
| models = { | |
| "sam2": {"warm": _warm("segmentation", "_sam2_cache"), "cached_on_disk": _is_cached("facebook/sam2-hiera-large")}, | |
| "triposr": {"warm": _warm("reconstruction", "_triposr_cache"), "cached_on_disk": _is_cached("stabilityai/TripoSR")}, | |
| "depth_pro": {"warm": _warm("depth", "_depth_cache"), "cached_on_disk": _is_cached("apple/DepthPro-hf")}, | |
| } | |
| return {"status": "ok", "models": models} | |
| except Exception as e: | |
| return {"status": "error", "detail": str(e), "models": {}} | |
| def get_logs(since: int = 0): | |
| lines = [e for e in _log_buffer if e["id"] > since] | |
| return {"lines": lines, "cursor": _log_cursor} | |
| def segment( | |
| image: UploadFile = File(...), | |
| regions: str = Form("breast_left,breast_right,buttocks"), | |
| click_points: Optional[str] = Form(None), | |
| ): | |
| img = _load_image(image) | |
| region_list = [r.strip() for r in regions.split(",") if r.strip()] | |
| clicks = json.loads(click_points) if click_points else None | |
| try: | |
| result = _gpu_segment(img, region_list, clicks) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| encoded = {} | |
| for region, data in result.items(): | |
| mask_arr = np.array(data["mask"], dtype=bool) | |
| flat = mask_arr.flatten() | |
| rle: list[int] = [] | |
| current = bool(flat[0]) | |
| count = 0 | |
| for val in flat: | |
| if bool(val) == current: | |
| count += 1 | |
| else: | |
| rle.append(count) | |
| count = 1 | |
| current = bool(val) | |
| rle.append(count) | |
| encoded[region] = { | |
| "rle": rle, | |
| "rle_start": bool(flat[0]), | |
| "shape": list(mask_arr.shape), | |
| "bbox": data["bbox"], | |
| } | |
| return {"regions": encoded, "image_size": [img.width, img.height]} | |
| def depth(image: UploadFile = File(...)): | |
| img = _load_image(image) | |
| try: | |
| result = _gpu_depth(img) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| arr = np.array(result["depth"], dtype=np.float32) | |
| b64 = base64.b64encode(arr.tobytes()).decode("ascii") | |
| return { | |
| "depth_b64": b64, | |
| "width": result["width"], | |
| "height": result["height"], | |
| "min": result["min"], | |
| "max": result["max"], | |
| "dtype": "float32", | |
| } | |
| def reconstruct( | |
| image: UploadFile = File(...), | |
| mask_rle: str = Form(...), | |
| mask_shape: str = Form(...), | |
| mask_rle_start: str = Form("false"), | |
| bbox: str = Form(...), | |
| use_triposr: str = Form("true"), | |
| ): | |
| img = _load_image(image) | |
| rle = json.loads(mask_rle) | |
| shape = json.loads(mask_shape) | |
| start_val = mask_rle_start.lower() == "true" | |
| flat: list[bool] = [] | |
| current = start_val | |
| for run in rle: | |
| flat.extend([current] * run) | |
| current = not current | |
| mask = np.array(flat, dtype=bool).reshape(shape) | |
| bbox_list = json.loads(bbox) | |
| try: | |
| glb_bytes = _gpu_reconstruct(img, mask, bbox_list, use_triposr.lower() == "true") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| return Response(content=glb_bytes, media_type="application/octet-stream") | |
| def pose_detect( | |
| tpose_image: UploadFile = File(...), | |
| target_image: UploadFile = File(...), | |
| regions: str = Form("breast_left,breast_right,buttocks"), | |
| ): | |
| tpose_img = _load_image(tpose_image) | |
| target_img = _load_image(target_image) | |
| region_list = [r.strip() for r in regions.split(",") if r.strip()] | |
| try: | |
| transforms, tpose_lm, target_lm = _gpu_pose(tpose_img, target_img, region_list) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # orjson (Gradio's response renderer) rejects non-string dict keys β | |
| # detect_landmarks returns {int: {...}} so we stringify before returning. | |
| return { | |
| "transforms": transforms, | |
| "tpose_landmarks": {str(k): v for k, v in tpose_lm.items()}, | |
| "target_landmarks": {str(k): v for k, v in target_lm.items()}, | |
| } | |
| # Promote our routes ahead of Gradio's catch-all (GET /{path:path} SPA route) | |
| routes = app.router.routes | |
| ours = [r for r in routes if getattr(r, "path", "").startswith("/jiggle/")] | |
| others = [r for r in routes if not getattr(r, "path", "").startswith("/jiggle/")] | |
| app.router.routes = ours + others | |
| print(f"Injected {len(ours)} /jiggle/* routes into Gradio's App") | |
| # ββ Monkey-patch gradio.routes.App.create_app βββββββββββββββββββββββββββββββ | |
| # Gradio's create_app builds the FastAPI app fresh each time demo.launch() | |
| # runs. We wrap it to (1) add CORS middleware before the app starts handling | |
| # requests, and (2) register our /jiggle/* routes on the same app instance | |
| # Gradio is about to serve. | |
| _original_create_app = gradio.routes.App.create_app | |
| def _patched_create_app(*args, **kwargs): | |
| app = _original_create_app(*args, **kwargs) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| _register_jiggle_routes(app) | |
| return app | |
| gradio.routes.App.create_app = staticmethod(_patched_create_app) | |
| # ββ Gradio Blocks (UI shown at /) βββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Jiggle Physics API") as demo: | |
| gr.Markdown( | |
| "## Jiggle Physics ML API\n" | |
| "Endpoints: `/jiggle/health` `/jiggle/segment` `/jiggle/depth` " | |
| "`/jiggle/reconstruct` `/jiggle/pose`" | |
| ) | |
| if __name__ == "__main__": | |
| # ssr_mode=False: keep Gradio's Node.js SSR layer off so GET requests | |
| # reach FastAPI directly. | |
| demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False) | |