"""HTTP service for the Princess dress-up game. Serves both the API (background removal, auto-segment) and the static frontend from a single FastAPI process. Everything is deployed together to a Hugging Face Docker Space — one commit, one cold start, no CORS. Design notes: - Both the rembg session and the SAM pipeline are created once at import time. HF Spaces keeps the container warm between requests, so model weights stay resident in process memory. First request after a cold start is slow because the container itself needs to boot (~45s with torch+SAM), not because of model loading. - CORS stays wide open for now even though frontend and backend share an origin. It's harmless and lets you hit the API from a second client (curl, another deploy) without surprise. - We accept multipart form upload with a `file` field rather than raw bytes so the client can use standard FormData from the browser without any custom headers. - StaticFiles is mounted LAST at "/", after every explicit API route. FastAPI evaluates explicit routes before mounts, so `/remove-bg` hits the handler below rather than looking for a file named `remove-bg` in the static dir. """ import base64 import io import logging from pathlib import Path import numpy as np from fastapi import FastAPI, File, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response from fastapi.staticfiles import StaticFiles from PIL import Image from rembg import new_session, remove from transformers import pipeline MODEL_NAME = "isnet-general-use" SAM_MODEL = "Zigeng/SlimSAM-uniform-77" logging.basicConfig(level=logging.INFO) log = logging.getLogger("princess") # 20 MB — plenty for a downscaled photo from the client (which already # caps at ~1280 on the long edge). Rejecting bigger uploads protects us # from accidental full-res uploads burning CPU on the free tier. MAX_UPLOAD_BYTES = 20 * 1024 * 1024 # SAM is quadratically sensitive to input resolution. 512 on the long # edge gives good-quality masks on kid drawings while keeping a full # pass under ~5s on the free CPU tier. SAM_INPUT_DIM = 512 # Masks are returned to the client at this resolution. Segment-review # uses them for tap-hit-testing and polygon adjust, neither of which # needs pixel-perfect alignment — 128x128 keeps the JSON payload small # (~16KB per mask pre-base64) without visibly hurting the overlay. MASK_OUT_DIM = 128 # Drop tiny masks (noise) and very large masks (full image / bg). SAM_MIN_AREA_FRAC = 0.005 SAM_MAX_AREA_FRAC = 0.85 # Non-max suppression IoU threshold — masks overlapping more than this # are treated as duplicates and the lower-scoring one gets dropped. SAM_NMS_IOU = 0.7 # Static assets live next to app.py inside the container. The Dockerfile # copies index.html, style.css, and js/ into /app alongside this file. STATIC_ROOT = Path(__file__).parent log.info("Loading rembg model: %s", MODEL_NAME) _bg_session = new_session(MODEL_NAME) log.info("rembg ready") log.info("Loading SAM pipeline: %s", SAM_MODEL) _sam_pipeline = pipeline("mask-generation", model=SAM_MODEL, device=-1) log.info("SAM ready") app = FastAPI(title="Princess", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["GET", "POST", "OPTIONS"], allow_headers=["*"], ) @app.get("/healthz") def health() -> JSONResponse: return JSONResponse( { "status": "ok", "service": "princess", "bg_model": MODEL_NAME, } ) @app.post("/remove-bg") async def remove_bg(file: UploadFile = File(...)) -> Response: data = await file.read() if not data: raise HTTPException(status_code=400, detail="empty upload") if len(data) > MAX_UPLOAD_BYTES: raise HTTPException( status_code=413, detail=f"upload too large ({len(data)} bytes, max {MAX_UPLOAD_BYTES})", ) try: output_bytes = remove(data, session=_bg_session) except Exception as err: # noqa: BLE001 — we want to surface anything log.exception("rembg failed") raise HTTPException(status_code=500, detail=f"rembg failed: {err}") # rembg returns a PNG (bytes) with the alpha channel already applied. return Response( content=output_bytes, media_type="image/png", headers={"Cache-Control": "no-store"}, ) def _resize_long_edge(img: Image.Image, target: int) -> Image.Image: w, h = img.size long_edge = max(w, h) if long_edge <= target: return img ratio = target / long_edge return img.resize((round(w * ratio), round(h * ratio)), Image.LANCZOS) def _downsample_mask(mask: np.ndarray, out_w: int, out_h: int) -> np.ndarray: """Nearest-neighbor downsample a bool mask to (out_h, out_w). We don't need anti-aliasing — segment-review just hit-tests pixels and draws them as rects. """ h, w = mask.shape ys = (np.linspace(0, h - 1, out_h)).astype(np.int32) xs = (np.linspace(0, w - 1, out_w)).astype(np.int32) return mask[ys[:, None], xs[None, :]] def _mask_iou(a: np.ndarray, b: np.ndarray) -> float: inter = np.logical_and(a, b).sum() union = np.logical_or(a, b).sum() return float(inter) / float(union) if union else 0.0 def _nms(masks, scores, iou_thresh: float): """Greedy NMS. `masks` is a list of (small) bool arrays, aligned with `scores`. Returns the list of kept indices in score order.""" order = np.argsort(scores)[::-1] kept: list[int] = [] for idx in order: dominated = False for k in kept: if _mask_iou(masks[idx], masks[k]) > iou_thresh: dominated = True break if not dominated: kept.append(int(idx)) return kept def _crop_with_mask( rgba: Image.Image, mask_full: np.ndarray, ) -> tuple[bytes, dict, int]: """Crop the image to the mask's bbox and apply the mask as alpha. Returns (png_bytes, normalized_bbox, pixel_area). Uses the full-res RGBA image so the output looks sharp — the low-res mask is upsampled via nearest-neighbor to match. """ H, W = mask_full.shape ys, xs = np.where(mask_full) if ys.size == 0: raise ValueError("empty mask") pad = 4 y0 = max(0, int(ys.min()) - pad) y1 = min(H, int(ys.max()) + 1 + pad) x0 = max(0, int(xs.min()) - pad) x1 = min(W, int(xs.max()) + 1 + pad) cropped = rgba.crop((x0, y0, x1, y1)).convert("RGBA") crop_mask = mask_full[y0:y1, x0:x1] arr = np.array(cropped, dtype=np.uint8) arr[..., 3] = (arr[..., 3].astype(np.uint16) * crop_mask.astype(np.uint16)).astype(np.uint8) out_img = Image.fromarray(arr, mode="RGBA") buf = io.BytesIO() out_img.save(buf, format="PNG", optimize=False) return ( buf.getvalue(), { "x": x0 / W, "y": y0 / H, "w": (x1 - x0) / W, "h": (y1 - y0) / H, }, int(ys.size), ) @app.post("/auto-segment") async def auto_segment(file: UploadFile = File(...)) -> JSONResponse: data = await file.read() if not data: raise HTTPException(status_code=400, detail="empty upload") if len(data) > MAX_UPLOAD_BYTES: raise HTTPException( status_code=413, detail=f"upload too large ({len(data)} bytes, max {MAX_UPLOAD_BYTES})", ) try: src = Image.open(io.BytesIO(data)).convert("RGBA") except Exception as err: # noqa: BLE001 raise HTTPException(status_code=400, detail=f"bad image: {err}") full_w, full_h = src.size # SAM wants a 3-channel image; we pass the downscaled RGB view. work = _resize_long_edge(src, SAM_INPUT_DIM) work_rgb = work.convert("RGB") # points_per_side=8 → 64 grid points. Default 32 gives 1024 which # is ~16× slower and overkill for a drawing with maybe 5-10 parts. try: sam_out = _sam_pipeline( work_rgb, points_per_side=8, pred_iou_thresh=0.85, stability_score_thresh=0.85, ) except Exception as err: # noqa: BLE001 log.exception("SAM failed") raise HTTPException(status_code=500, detail=f"sam failed: {err}") raw_masks = sam_out.get("masks", []) raw_scores = sam_out.get("scores", []) if not raw_masks: return JSONResponse({"segments": []}) # The pipeline returns masks at the *input* (downscaled) resolution. # Upsample them to full-res once so crops are sharp, and keep a # small copy for NMS + client-side hit-testing. work_w, work_h = work_rgb.size min_area_px = int(full_w * full_h * SAM_MIN_AREA_FRAC) max_area_px = int(full_w * full_h * SAM_MAX_AREA_FRAC) candidates = [] for mask, score in zip(raw_masks, raw_scores): mask_arr = np.asarray(mask, dtype=bool) if mask_arr.shape != (work_h, work_w): # Some pipelines return (H, W) at the *original* size; handle both. if mask_arr.shape == (full_h, full_w): full_mask = mask_arr else: continue else: # Nearest-neighbor upsample to full res. ys = (np.linspace(0, work_h - 1, full_h)).astype(np.int32) xs = (np.linspace(0, work_w - 1, full_w)).astype(np.int32) full_mask = mask_arr[ys[:, None], xs[None, :]] area = int(full_mask.sum()) if area < min_area_px or area > max_area_px: continue small = _downsample_mask(full_mask, MASK_OUT_DIM, MASK_OUT_DIM) candidates.append( { "full_mask": full_mask, "small_mask": small, "score": float(score), "area": area, } ) if not candidates: return JSONResponse({"segments": []}) small_masks = [c["small_mask"] for c in candidates] scores_arr = np.array([c["score"] for c in candidates], dtype=np.float32) kept_idx = _nms(small_masks, scores_arr, SAM_NMS_IOU) segments = [] for seg_i, idx in enumerate(kept_idx): c = candidates[idx] try: png_bytes, bbox, _ = _crop_with_mask(src, c["full_mask"]) except Exception as err: # noqa: BLE001 log.warning("crop failed for seg %d: %s", seg_i, err) continue segments.append( { "id": f"seg-{seg_i}", "score": c["score"], "area": c["area"], "bbox": bbox, "maskW": MASK_OUT_DIM, "maskH": MASK_OUT_DIM, # Pack bool mask as 1 byte per pixel, base64 for JSON transport. "mask": base64.b64encode( c["small_mask"].astype(np.uint8).tobytes() ).decode("ascii"), "croppedPng": base64.b64encode(png_bytes).decode("ascii"), } ) # Largest first — matches segment-review's princess-selection heuristic. segments.sort(key=lambda s: s["area"], reverse=True) return JSONResponse({"segments": segments}) # Mount last so API routes above take precedence. html=True makes "/" # serve index.html automatically. app.mount("/", StaticFiles(directory=STATIC_ROOT, html=True), name="static")