sdragly's picture
Segmentation model
f952998
"""HTTP service for the Princess dress-up game.
Serves both the API (background removal, auto-segment) and the static
frontend from a single FastAPI process. Everything is deployed together
to a Hugging Face Docker Space — one commit, one cold start, no CORS.
Design notes:
- Both the rembg session and the SAM pipeline are created once at
import time. HF Spaces keeps the container warm between requests,
so model weights stay resident in process memory. First request
after a cold start is slow because the container itself needs to
boot (~45s with torch+SAM), not because of model loading.
- CORS stays wide open for now even though frontend and backend share
an origin. It's harmless and lets you hit the API from a second
client (curl, another deploy) without surprise.
- We accept multipart form upload with a `file` field rather than raw
bytes so the client can use standard FormData from the browser
without any custom headers.
- StaticFiles is mounted LAST at "/", after every explicit API route.
FastAPI evaluates explicit routes before mounts, so `/remove-bg`
hits the handler below rather than looking for a file named
`remove-bg` in the static dir.
"""
import base64
import io
import logging
from pathlib import Path
import numpy as np
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response
from fastapi.staticfiles import StaticFiles
from PIL import Image
from rembg import new_session, remove
from transformers import pipeline
MODEL_NAME = "isnet-general-use"
SAM_MODEL = "Zigeng/SlimSAM-uniform-77"
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("princess")
# 20 MB — plenty for a downscaled photo from the client (which already
# caps at ~1280 on the long edge). Rejecting bigger uploads protects us
# from accidental full-res uploads burning CPU on the free tier.
MAX_UPLOAD_BYTES = 20 * 1024 * 1024
# SAM is quadratically sensitive to input resolution. 512 on the long
# edge gives good-quality masks on kid drawings while keeping a full
# pass under ~5s on the free CPU tier.
SAM_INPUT_DIM = 512
# Masks are returned to the client at this resolution. Segment-review
# uses them for tap-hit-testing and polygon adjust, neither of which
# needs pixel-perfect alignment — 128x128 keeps the JSON payload small
# (~16KB per mask pre-base64) without visibly hurting the overlay.
MASK_OUT_DIM = 128
# Drop tiny masks (noise) and very large masks (full image / bg).
SAM_MIN_AREA_FRAC = 0.005
SAM_MAX_AREA_FRAC = 0.85
# Non-max suppression IoU threshold — masks overlapping more than this
# are treated as duplicates and the lower-scoring one gets dropped.
SAM_NMS_IOU = 0.7
# Static assets live next to app.py inside the container. The Dockerfile
# copies index.html, style.css, and js/ into /app alongside this file.
STATIC_ROOT = Path(__file__).parent
log.info("Loading rembg model: %s", MODEL_NAME)
_bg_session = new_session(MODEL_NAME)
log.info("rembg ready")
log.info("Loading SAM pipeline: %s", SAM_MODEL)
_sam_pipeline = pipeline("mask-generation", model=SAM_MODEL, device=-1)
log.info("SAM ready")
app = FastAPI(title="Princess", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
@app.get("/healthz")
def health() -> JSONResponse:
return JSONResponse(
{
"status": "ok",
"service": "princess",
"bg_model": MODEL_NAME,
}
)
@app.post("/remove-bg")
async def remove_bg(file: UploadFile = File(...)) -> Response:
data = await file.read()
if not data:
raise HTTPException(status_code=400, detail="empty upload")
if len(data) > MAX_UPLOAD_BYTES:
raise HTTPException(
status_code=413,
detail=f"upload too large ({len(data)} bytes, max {MAX_UPLOAD_BYTES})",
)
try:
output_bytes = remove(data, session=_bg_session)
except Exception as err: # noqa: BLE001 — we want to surface anything
log.exception("rembg failed")
raise HTTPException(status_code=500, detail=f"rembg failed: {err}")
# rembg returns a PNG (bytes) with the alpha channel already applied.
return Response(
content=output_bytes,
media_type="image/png",
headers={"Cache-Control": "no-store"},
)
def _resize_long_edge(img: Image.Image, target: int) -> Image.Image:
w, h = img.size
long_edge = max(w, h)
if long_edge <= target:
return img
ratio = target / long_edge
return img.resize((round(w * ratio), round(h * ratio)), Image.LANCZOS)
def _downsample_mask(mask: np.ndarray, out_w: int, out_h: int) -> np.ndarray:
"""Nearest-neighbor downsample a bool mask to (out_h, out_w).
We don't need anti-aliasing — segment-review just hit-tests pixels
and draws them as rects.
"""
h, w = mask.shape
ys = (np.linspace(0, h - 1, out_h)).astype(np.int32)
xs = (np.linspace(0, w - 1, out_w)).astype(np.int32)
return mask[ys[:, None], xs[None, :]]
def _mask_iou(a: np.ndarray, b: np.ndarray) -> float:
inter = np.logical_and(a, b).sum()
union = np.logical_or(a, b).sum()
return float(inter) / float(union) if union else 0.0
def _nms(masks, scores, iou_thresh: float):
"""Greedy NMS. `masks` is a list of (small) bool arrays, aligned
with `scores`. Returns the list of kept indices in score order."""
order = np.argsort(scores)[::-1]
kept: list[int] = []
for idx in order:
dominated = False
for k in kept:
if _mask_iou(masks[idx], masks[k]) > iou_thresh:
dominated = True
break
if not dominated:
kept.append(int(idx))
return kept
def _crop_with_mask(
rgba: Image.Image,
mask_full: np.ndarray,
) -> tuple[bytes, dict, int]:
"""Crop the image to the mask's bbox and apply the mask as alpha.
Returns (png_bytes, normalized_bbox, pixel_area). Uses the full-res
RGBA image so the output looks sharp — the low-res mask is upsampled
via nearest-neighbor to match.
"""
H, W = mask_full.shape
ys, xs = np.where(mask_full)
if ys.size == 0:
raise ValueError("empty mask")
pad = 4
y0 = max(0, int(ys.min()) - pad)
y1 = min(H, int(ys.max()) + 1 + pad)
x0 = max(0, int(xs.min()) - pad)
x1 = min(W, int(xs.max()) + 1 + pad)
cropped = rgba.crop((x0, y0, x1, y1)).convert("RGBA")
crop_mask = mask_full[y0:y1, x0:x1]
arr = np.array(cropped, dtype=np.uint8)
arr[..., 3] = (arr[..., 3].astype(np.uint16) * crop_mask.astype(np.uint16)).astype(np.uint8)
out_img = Image.fromarray(arr, mode="RGBA")
buf = io.BytesIO()
out_img.save(buf, format="PNG", optimize=False)
return (
buf.getvalue(),
{
"x": x0 / W,
"y": y0 / H,
"w": (x1 - x0) / W,
"h": (y1 - y0) / H,
},
int(ys.size),
)
@app.post("/auto-segment")
async def auto_segment(file: UploadFile = File(...)) -> JSONResponse:
data = await file.read()
if not data:
raise HTTPException(status_code=400, detail="empty upload")
if len(data) > MAX_UPLOAD_BYTES:
raise HTTPException(
status_code=413,
detail=f"upload too large ({len(data)} bytes, max {MAX_UPLOAD_BYTES})",
)
try:
src = Image.open(io.BytesIO(data)).convert("RGBA")
except Exception as err: # noqa: BLE001
raise HTTPException(status_code=400, detail=f"bad image: {err}")
full_w, full_h = src.size
# SAM wants a 3-channel image; we pass the downscaled RGB view.
work = _resize_long_edge(src, SAM_INPUT_DIM)
work_rgb = work.convert("RGB")
# points_per_side=8 → 64 grid points. Default 32 gives 1024 which
# is ~16× slower and overkill for a drawing with maybe 5-10 parts.
try:
sam_out = _sam_pipeline(
work_rgb,
points_per_side=8,
pred_iou_thresh=0.85,
stability_score_thresh=0.85,
)
except Exception as err: # noqa: BLE001
log.exception("SAM failed")
raise HTTPException(status_code=500, detail=f"sam failed: {err}")
raw_masks = sam_out.get("masks", [])
raw_scores = sam_out.get("scores", [])
if not raw_masks:
return JSONResponse({"segments": []})
# The pipeline returns masks at the *input* (downscaled) resolution.
# Upsample them to full-res once so crops are sharp, and keep a
# small copy for NMS + client-side hit-testing.
work_w, work_h = work_rgb.size
min_area_px = int(full_w * full_h * SAM_MIN_AREA_FRAC)
max_area_px = int(full_w * full_h * SAM_MAX_AREA_FRAC)
candidates = []
for mask, score in zip(raw_masks, raw_scores):
mask_arr = np.asarray(mask, dtype=bool)
if mask_arr.shape != (work_h, work_w):
# Some pipelines return (H, W) at the *original* size; handle both.
if mask_arr.shape == (full_h, full_w):
full_mask = mask_arr
else:
continue
else:
# Nearest-neighbor upsample to full res.
ys = (np.linspace(0, work_h - 1, full_h)).astype(np.int32)
xs = (np.linspace(0, work_w - 1, full_w)).astype(np.int32)
full_mask = mask_arr[ys[:, None], xs[None, :]]
area = int(full_mask.sum())
if area < min_area_px or area > max_area_px:
continue
small = _downsample_mask(full_mask, MASK_OUT_DIM, MASK_OUT_DIM)
candidates.append(
{
"full_mask": full_mask,
"small_mask": small,
"score": float(score),
"area": area,
}
)
if not candidates:
return JSONResponse({"segments": []})
small_masks = [c["small_mask"] for c in candidates]
scores_arr = np.array([c["score"] for c in candidates], dtype=np.float32)
kept_idx = _nms(small_masks, scores_arr, SAM_NMS_IOU)
segments = []
for seg_i, idx in enumerate(kept_idx):
c = candidates[idx]
try:
png_bytes, bbox, _ = _crop_with_mask(src, c["full_mask"])
except Exception as err: # noqa: BLE001
log.warning("crop failed for seg %d: %s", seg_i, err)
continue
segments.append(
{
"id": f"seg-{seg_i}",
"score": c["score"],
"area": c["area"],
"bbox": bbox,
"maskW": MASK_OUT_DIM,
"maskH": MASK_OUT_DIM,
# Pack bool mask as 1 byte per pixel, base64 for JSON transport.
"mask": base64.b64encode(
c["small_mask"].astype(np.uint8).tobytes()
).decode("ascii"),
"croppedPng": base64.b64encode(png_bytes).decode("ascii"),
}
)
# Largest first — matches segment-review's princess-selection heuristic.
segments.sort(key=lambda s: s["area"], reverse=True)
return JSONResponse({"segments": segments})
# Mount last so API routes above take precedence. html=True makes "/"
# serve index.html automatically.
app.mount("/", StaticFiles(directory=STATIC_ROOT, html=True), name="static")