jiggle-physics / app.py
Justin Wood
Stringify landmark dict keys before pose response
fa899d5
"""
Jiggle Physics Simulator β€” HuggingFace Space backend.
ZeroGPU requires demo.launch() (static @spaces.GPU detection isn't enough).
But demo.launch() creates a fresh FastAPI app inside Gradio's App.create_app,
discarding anything we registered on demo.app beforehand.
Solution: monkey-patch gradio.routes.App.create_app so we can inject our
/jiggle/* routes (and CORS middleware) into the *new* app the moment
Gradio creates it, before it starts handling requests.
"""
import base64
import collections
import io
import json
import sys
import time
from typing import Optional
import gradio as gr
import gradio.routes
import spaces
import numpy as np
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from PIL import Image
# ── In-process log ring buffer ───────────────────────────────────────────────
_log_buffer: collections.deque = collections.deque(maxlen=100)
_log_cursor = 0
class _TeeStream:
def __init__(self, real):
self._real = real
def write(self, s):
self._real.write(s)
if s.strip():
global _log_cursor
_log_cursor += 1
_log_buffer.append({"id": _log_cursor, "t": time.time(), "msg": s.rstrip()})
def flush(self):
self._real.flush()
def __getattr__(self, name):
return getattr(self._real, name)
sys.stdout = _TeeStream(sys.stdout)
def _load_image(upload: UploadFile) -> Image.Image:
data = upload.file.read()
img = Image.open(io.BytesIO(data)).convert("RGB")
max_dim = 1024
if max(img.size) > max_dim:
ratio = max_dim / max(img.size)
img = img.resize(
(int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS
)
return img
# ── GPU implementations (standalone so ZeroGPU can detect them) ──────────────
@spaces.GPU
def _gpu_segment(img: Image.Image, region_list: list, clicks) -> dict:
from segmentation import segment_regions
return segment_regions(img, region_list, clicks)
@spaces.GPU
def _gpu_depth(img: Image.Image) -> dict:
from depth import estimate_depth
return estimate_depth(img)
@spaces.GPU
def _gpu_reconstruct(
img: Image.Image,
mask: np.ndarray,
bbox_list: list,
use_triposr: bool,
) -> bytes:
from reconstruction import reconstruct_region, depth_to_mesh
from depth import estimate_depth
if use_triposr:
return reconstruct_region(img, mask.tolist(), bbox_list)
depth_result = estimate_depth(img)
return depth_to_mesh(depth_result["depth"], mask.tolist(), img)
@spaces.GPU
def _gpu_pose(
tpose_img: Image.Image,
target_img: Image.Image,
region_list: list,
) -> tuple:
from pose import detect_landmarks, compute_region_transform
tpose_lm = detect_landmarks(tpose_img)
target_lm = detect_landmarks(target_img)
transforms = {
region: compute_region_transform(tpose_lm, target_lm, region)
for region in region_list
}
return transforms, tpose_lm, target_lm
# ── Route registration β€” called against the live app inside create_app ──────
def _register_jiggle_routes(app: FastAPI) -> None:
@app.get("/jiggle/health")
def jiggle_health():
def _warm(module_name: str, cache_attr: str) -> bool:
try:
import importlib
mod = importlib.import_module(module_name)
return getattr(mod, cache_attr) is not None
except Exception:
return False
def _is_cached(model_id: str) -> bool:
try:
import os
cache_root = os.path.expanduser("~/.cache/huggingface/hub")
slug = "models--" + model_id.replace("/", "--")
return os.path.isdir(os.path.join(cache_root, slug))
except Exception:
return False
try:
models = {
"sam2": {"warm": _warm("segmentation", "_sam2_cache"), "cached_on_disk": _is_cached("facebook/sam2-hiera-large")},
"triposr": {"warm": _warm("reconstruction", "_triposr_cache"), "cached_on_disk": _is_cached("stabilityai/TripoSR")},
"depth_pro": {"warm": _warm("depth", "_depth_cache"), "cached_on_disk": _is_cached("apple/DepthPro-hf")},
}
return {"status": "ok", "models": models}
except Exception as e:
return {"status": "error", "detail": str(e), "models": {}}
@app.get("/jiggle/logs")
def get_logs(since: int = 0):
lines = [e for e in _log_buffer if e["id"] > since]
return {"lines": lines, "cursor": _log_cursor}
@app.post("/jiggle/segment")
def segment(
image: UploadFile = File(...),
regions: str = Form("breast_left,breast_right,buttocks"),
click_points: Optional[str] = Form(None),
):
img = _load_image(image)
region_list = [r.strip() for r in regions.split(",") if r.strip()]
clicks = json.loads(click_points) if click_points else None
try:
result = _gpu_segment(img, region_list, clicks)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
encoded = {}
for region, data in result.items():
mask_arr = np.array(data["mask"], dtype=bool)
flat = mask_arr.flatten()
rle: list[int] = []
current = bool(flat[0])
count = 0
for val in flat:
if bool(val) == current:
count += 1
else:
rle.append(count)
count = 1
current = bool(val)
rle.append(count)
encoded[region] = {
"rle": rle,
"rle_start": bool(flat[0]),
"shape": list(mask_arr.shape),
"bbox": data["bbox"],
}
return {"regions": encoded, "image_size": [img.width, img.height]}
@app.post("/jiggle/depth")
def depth(image: UploadFile = File(...)):
img = _load_image(image)
try:
result = _gpu_depth(img)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
arr = np.array(result["depth"], dtype=np.float32)
b64 = base64.b64encode(arr.tobytes()).decode("ascii")
return {
"depth_b64": b64,
"width": result["width"],
"height": result["height"],
"min": result["min"],
"max": result["max"],
"dtype": "float32",
}
@app.post("/jiggle/reconstruct")
def reconstruct(
image: UploadFile = File(...),
mask_rle: str = Form(...),
mask_shape: str = Form(...),
mask_rle_start: str = Form("false"),
bbox: str = Form(...),
use_triposr: str = Form("true"),
):
img = _load_image(image)
rle = json.loads(mask_rle)
shape = json.loads(mask_shape)
start_val = mask_rle_start.lower() == "true"
flat: list[bool] = []
current = start_val
for run in rle:
flat.extend([current] * run)
current = not current
mask = np.array(flat, dtype=bool).reshape(shape)
bbox_list = json.loads(bbox)
try:
glb_bytes = _gpu_reconstruct(img, mask, bbox_list, use_triposr.lower() == "true")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return Response(content=glb_bytes, media_type="application/octet-stream")
@app.post("/jiggle/pose")
def pose_detect(
tpose_image: UploadFile = File(...),
target_image: UploadFile = File(...),
regions: str = Form("breast_left,breast_right,buttocks"),
):
tpose_img = _load_image(tpose_image)
target_img = _load_image(target_image)
region_list = [r.strip() for r in regions.split(",") if r.strip()]
try:
transforms, tpose_lm, target_lm = _gpu_pose(tpose_img, target_img, region_list)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# orjson (Gradio's response renderer) rejects non-string dict keys β€”
# detect_landmarks returns {int: {...}} so we stringify before returning.
return {
"transforms": transforms,
"tpose_landmarks": {str(k): v for k, v in tpose_lm.items()},
"target_landmarks": {str(k): v for k, v in target_lm.items()},
}
# Promote our routes ahead of Gradio's catch-all (GET /{path:path} SPA route)
routes = app.router.routes
ours = [r for r in routes if getattr(r, "path", "").startswith("/jiggle/")]
others = [r for r in routes if not getattr(r, "path", "").startswith("/jiggle/")]
app.router.routes = ours + others
print(f"Injected {len(ours)} /jiggle/* routes into Gradio's App")
# ── Monkey-patch gradio.routes.App.create_app ───────────────────────────────
# Gradio's create_app builds the FastAPI app fresh each time demo.launch()
# runs. We wrap it to (1) add CORS middleware before the app starts handling
# requests, and (2) register our /jiggle/* routes on the same app instance
# Gradio is about to serve.
_original_create_app = gradio.routes.App.create_app
def _patched_create_app(*args, **kwargs):
app = _original_create_app(*args, **kwargs)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
_register_jiggle_routes(app)
return app
gradio.routes.App.create_app = staticmethod(_patched_create_app)
# ── Gradio Blocks (UI shown at /) ───────────────────────────────────────────
with gr.Blocks(title="Jiggle Physics API") as demo:
gr.Markdown(
"## Jiggle Physics ML API\n"
"Endpoints: `/jiggle/health` `/jiggle/segment` `/jiggle/depth` "
"`/jiggle/reconstruct` `/jiggle/pose`"
)
if __name__ == "__main__":
# ssr_mode=False: keep Gradio's Node.js SSR layer off so GET requests
# reach FastAPI directly.
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)