Spaces:
Sleeping
Sleeping
| """HTTP service for the Princess dress-up game. | |
| Serves both the API (background removal, auto-segment) and the static | |
| frontend from a single FastAPI process. Everything is deployed together | |
| to a Hugging Face Docker Space — one commit, one cold start, no CORS. | |
| Design notes: | |
| - Both the rembg session and the SAM pipeline are created once at | |
| import time. HF Spaces keeps the container warm between requests, | |
| so model weights stay resident in process memory. First request | |
| after a cold start is slow because the container itself needs to | |
| boot (~45s with torch+SAM), not because of model loading. | |
| - CORS stays wide open for now even though frontend and backend share | |
| an origin. It's harmless and lets you hit the API from a second | |
| client (curl, another deploy) without surprise. | |
| - We accept multipart form upload with a `file` field rather than raw | |
| bytes so the client can use standard FormData from the browser | |
| without any custom headers. | |
| - StaticFiles is mounted LAST at "/", after every explicit API route. | |
| FastAPI evaluates explicit routes before mounts, so `/remove-bg` | |
| hits the handler below rather than looking for a file named | |
| `remove-bg` in the static dir. | |
| """ | |
| import base64 | |
| import io | |
| import logging | |
| from pathlib import Path | |
| import numpy as np | |
| from fastapi import FastAPI, File, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, Response | |
| from fastapi.staticfiles import StaticFiles | |
| from PIL import Image | |
| from rembg import new_session, remove | |
| from transformers import pipeline | |
| MODEL_NAME = "isnet-general-use" | |
| SAM_MODEL = "Zigeng/SlimSAM-uniform-77" | |
| logging.basicConfig(level=logging.INFO) | |
| log = logging.getLogger("princess") | |
| # 20 MB — plenty for a downscaled photo from the client (which already | |
| # caps at ~1280 on the long edge). Rejecting bigger uploads protects us | |
| # from accidental full-res uploads burning CPU on the free tier. | |
| MAX_UPLOAD_BYTES = 20 * 1024 * 1024 | |
| # SAM is quadratically sensitive to input resolution. 512 on the long | |
| # edge gives good-quality masks on kid drawings while keeping a full | |
| # pass under ~5s on the free CPU tier. | |
| SAM_INPUT_DIM = 512 | |
| # Masks are returned to the client at this resolution. Segment-review | |
| # uses them for tap-hit-testing and polygon adjust, neither of which | |
| # needs pixel-perfect alignment — 128x128 keeps the JSON payload small | |
| # (~16KB per mask pre-base64) without visibly hurting the overlay. | |
| MASK_OUT_DIM = 128 | |
| # Drop tiny masks (noise) and very large masks (full image / bg). | |
| SAM_MIN_AREA_FRAC = 0.005 | |
| SAM_MAX_AREA_FRAC = 0.85 | |
| # Non-max suppression IoU threshold — masks overlapping more than this | |
| # are treated as duplicates and the lower-scoring one gets dropped. | |
| SAM_NMS_IOU = 0.7 | |
| # Static assets live next to app.py inside the container. The Dockerfile | |
| # copies index.html, style.css, and js/ into /app alongside this file. | |
| STATIC_ROOT = Path(__file__).parent | |
| log.info("Loading rembg model: %s", MODEL_NAME) | |
| _bg_session = new_session(MODEL_NAME) | |
| log.info("rembg ready") | |
| log.info("Loading SAM pipeline: %s", SAM_MODEL) | |
| _sam_pipeline = pipeline("mask-generation", model=SAM_MODEL, device=-1) | |
| log.info("SAM ready") | |
| app = FastAPI(title="Princess", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["*"], | |
| ) | |
| def health() -> JSONResponse: | |
| return JSONResponse( | |
| { | |
| "status": "ok", | |
| "service": "princess", | |
| "bg_model": MODEL_NAME, | |
| } | |
| ) | |
| async def remove_bg(file: UploadFile = File(...)) -> Response: | |
| data = await file.read() | |
| if not data: | |
| raise HTTPException(status_code=400, detail="empty upload") | |
| if len(data) > MAX_UPLOAD_BYTES: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"upload too large ({len(data)} bytes, max {MAX_UPLOAD_BYTES})", | |
| ) | |
| try: | |
| output_bytes = remove(data, session=_bg_session) | |
| except Exception as err: # noqa: BLE001 — we want to surface anything | |
| log.exception("rembg failed") | |
| raise HTTPException(status_code=500, detail=f"rembg failed: {err}") | |
| # rembg returns a PNG (bytes) with the alpha channel already applied. | |
| return Response( | |
| content=output_bytes, | |
| media_type="image/png", | |
| headers={"Cache-Control": "no-store"}, | |
| ) | |
| def _resize_long_edge(img: Image.Image, target: int) -> Image.Image: | |
| w, h = img.size | |
| long_edge = max(w, h) | |
| if long_edge <= target: | |
| return img | |
| ratio = target / long_edge | |
| return img.resize((round(w * ratio), round(h * ratio)), Image.LANCZOS) | |
| def _downsample_mask(mask: np.ndarray, out_w: int, out_h: int) -> np.ndarray: | |
| """Nearest-neighbor downsample a bool mask to (out_h, out_w). | |
| We don't need anti-aliasing — segment-review just hit-tests pixels | |
| and draws them as rects. | |
| """ | |
| h, w = mask.shape | |
| ys = (np.linspace(0, h - 1, out_h)).astype(np.int32) | |
| xs = (np.linspace(0, w - 1, out_w)).astype(np.int32) | |
| return mask[ys[:, None], xs[None, :]] | |
| def _mask_iou(a: np.ndarray, b: np.ndarray) -> float: | |
| inter = np.logical_and(a, b).sum() | |
| union = np.logical_or(a, b).sum() | |
| return float(inter) / float(union) if union else 0.0 | |
| def _nms(masks, scores, iou_thresh: float): | |
| """Greedy NMS. `masks` is a list of (small) bool arrays, aligned | |
| with `scores`. Returns the list of kept indices in score order.""" | |
| order = np.argsort(scores)[::-1] | |
| kept: list[int] = [] | |
| for idx in order: | |
| dominated = False | |
| for k in kept: | |
| if _mask_iou(masks[idx], masks[k]) > iou_thresh: | |
| dominated = True | |
| break | |
| if not dominated: | |
| kept.append(int(idx)) | |
| return kept | |
| def _crop_with_mask( | |
| rgba: Image.Image, | |
| mask_full: np.ndarray, | |
| ) -> tuple[bytes, dict, int]: | |
| """Crop the image to the mask's bbox and apply the mask as alpha. | |
| Returns (png_bytes, normalized_bbox, pixel_area). Uses the full-res | |
| RGBA image so the output looks sharp — the low-res mask is upsampled | |
| via nearest-neighbor to match. | |
| """ | |
| H, W = mask_full.shape | |
| ys, xs = np.where(mask_full) | |
| if ys.size == 0: | |
| raise ValueError("empty mask") | |
| pad = 4 | |
| y0 = max(0, int(ys.min()) - pad) | |
| y1 = min(H, int(ys.max()) + 1 + pad) | |
| x0 = max(0, int(xs.min()) - pad) | |
| x1 = min(W, int(xs.max()) + 1 + pad) | |
| cropped = rgba.crop((x0, y0, x1, y1)).convert("RGBA") | |
| crop_mask = mask_full[y0:y1, x0:x1] | |
| arr = np.array(cropped, dtype=np.uint8) | |
| arr[..., 3] = (arr[..., 3].astype(np.uint16) * crop_mask.astype(np.uint16)).astype(np.uint8) | |
| out_img = Image.fromarray(arr, mode="RGBA") | |
| buf = io.BytesIO() | |
| out_img.save(buf, format="PNG", optimize=False) | |
| return ( | |
| buf.getvalue(), | |
| { | |
| "x": x0 / W, | |
| "y": y0 / H, | |
| "w": (x1 - x0) / W, | |
| "h": (y1 - y0) / H, | |
| }, | |
| int(ys.size), | |
| ) | |
| async def auto_segment(file: UploadFile = File(...)) -> JSONResponse: | |
| data = await file.read() | |
| if not data: | |
| raise HTTPException(status_code=400, detail="empty upload") | |
| if len(data) > MAX_UPLOAD_BYTES: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"upload too large ({len(data)} bytes, max {MAX_UPLOAD_BYTES})", | |
| ) | |
| try: | |
| src = Image.open(io.BytesIO(data)).convert("RGBA") | |
| except Exception as err: # noqa: BLE001 | |
| raise HTTPException(status_code=400, detail=f"bad image: {err}") | |
| full_w, full_h = src.size | |
| # SAM wants a 3-channel image; we pass the downscaled RGB view. | |
| work = _resize_long_edge(src, SAM_INPUT_DIM) | |
| work_rgb = work.convert("RGB") | |
| # points_per_side=8 → 64 grid points. Default 32 gives 1024 which | |
| # is ~16× slower and overkill for a drawing with maybe 5-10 parts. | |
| try: | |
| sam_out = _sam_pipeline( | |
| work_rgb, | |
| points_per_side=8, | |
| pred_iou_thresh=0.85, | |
| stability_score_thresh=0.85, | |
| ) | |
| except Exception as err: # noqa: BLE001 | |
| log.exception("SAM failed") | |
| raise HTTPException(status_code=500, detail=f"sam failed: {err}") | |
| raw_masks = sam_out.get("masks", []) | |
| raw_scores = sam_out.get("scores", []) | |
| if not raw_masks: | |
| return JSONResponse({"segments": []}) | |
| # The pipeline returns masks at the *input* (downscaled) resolution. | |
| # Upsample them to full-res once so crops are sharp, and keep a | |
| # small copy for NMS + client-side hit-testing. | |
| work_w, work_h = work_rgb.size | |
| min_area_px = int(full_w * full_h * SAM_MIN_AREA_FRAC) | |
| max_area_px = int(full_w * full_h * SAM_MAX_AREA_FRAC) | |
| candidates = [] | |
| for mask, score in zip(raw_masks, raw_scores): | |
| mask_arr = np.asarray(mask, dtype=bool) | |
| if mask_arr.shape != (work_h, work_w): | |
| # Some pipelines return (H, W) at the *original* size; handle both. | |
| if mask_arr.shape == (full_h, full_w): | |
| full_mask = mask_arr | |
| else: | |
| continue | |
| else: | |
| # Nearest-neighbor upsample to full res. | |
| ys = (np.linspace(0, work_h - 1, full_h)).astype(np.int32) | |
| xs = (np.linspace(0, work_w - 1, full_w)).astype(np.int32) | |
| full_mask = mask_arr[ys[:, None], xs[None, :]] | |
| area = int(full_mask.sum()) | |
| if area < min_area_px or area > max_area_px: | |
| continue | |
| small = _downsample_mask(full_mask, MASK_OUT_DIM, MASK_OUT_DIM) | |
| candidates.append( | |
| { | |
| "full_mask": full_mask, | |
| "small_mask": small, | |
| "score": float(score), | |
| "area": area, | |
| } | |
| ) | |
| if not candidates: | |
| return JSONResponse({"segments": []}) | |
| small_masks = [c["small_mask"] for c in candidates] | |
| scores_arr = np.array([c["score"] for c in candidates], dtype=np.float32) | |
| kept_idx = _nms(small_masks, scores_arr, SAM_NMS_IOU) | |
| segments = [] | |
| for seg_i, idx in enumerate(kept_idx): | |
| c = candidates[idx] | |
| try: | |
| png_bytes, bbox, _ = _crop_with_mask(src, c["full_mask"]) | |
| except Exception as err: # noqa: BLE001 | |
| log.warning("crop failed for seg %d: %s", seg_i, err) | |
| continue | |
| segments.append( | |
| { | |
| "id": f"seg-{seg_i}", | |
| "score": c["score"], | |
| "area": c["area"], | |
| "bbox": bbox, | |
| "maskW": MASK_OUT_DIM, | |
| "maskH": MASK_OUT_DIM, | |
| # Pack bool mask as 1 byte per pixel, base64 for JSON transport. | |
| "mask": base64.b64encode( | |
| c["small_mask"].astype(np.uint8).tobytes() | |
| ).decode("ascii"), | |
| "croppedPng": base64.b64encode(png_bytes).decode("ascii"), | |
| } | |
| ) | |
| # Largest first — matches segment-review's princess-selection heuristic. | |
| segments.sort(key=lambda s: s["area"], reverse=True) | |
| return JSONResponse({"segments": segments}) | |
| # Mount last so API routes above take precedence. html=True makes "/" | |
| # serve index.html automatically. | |
| app.mount("/", StaticFiles(directory=STATIC_ROOT, html=True), name="static") | |