background-remover-api / services /image_processor.py
Keramo's picture
Upload 6 files
0baa056 verified
Raw
History Blame Contribute Delete
4.35 kB
"""
services/image_processor.py
Central AI engine β€” Silueta ONNX model via rembg.
Lifecycle
---------
startup β†’ load_model() β†’ warm_model() β†’ ready = True
request β†’ remove_background() β†’ image_to_png_buffer()
shutdown β†’ (session GC'd automatically)
Design decisions
----------------
* The rembg session is stored as a module-level singleton so it is
created exactly once per Gunicorn worker process.
* warm_model() runs a dummy inference so the first real request
doesn't pay the ONNX JIT cost.
* We never touch disk β€” all I/O is in-memory BytesIO.
"""
import logging
from io import BytesIO
from PIL import Image
from rembg import new_session, remove
from config.constants import MODEL_NAME, ENGINE_TAG, PROCESSING_FAILED
from fastapi import HTTPException
logger = logging.getLogger(__name__)
# ── Module-level state ─────────────────────────────────────────────────────────
_session = None # rembg InferenceSession
_model_ready: bool = False
# ── Startup ────────────────────────────────────────────────────────────────────
def load_model() -> None:
"""
Instantiate the rembg Silueta session.
Must be called once during FastAPI startup before any request is served.
"""
global _session, _model_ready
logger.info("Loading %s model…", ENGINE_TAG)
try:
_session = new_session(MODEL_NAME)
_model_ready = True
logger.info("Model loaded successfully.")
except Exception as exc:
_model_ready = False
logger.error("Model load failed: %s", exc)
raise
def warm_model() -> None:
"""
Run one dummy inference to prime ONNX caches.
This prevents the first real user request from experiencing cold-start latency.
"""
if not _model_ready:
logger.warning("Skipping warmup β€” model not loaded.")
return
logger.info("Warming up model…")
try:
dummy = Image.new("RGBA", (64, 64), (128, 128, 128, 255))
_run_inference(dummy)
logger.info("Warmup complete.")
except Exception as exc:
logger.warning("Warmup inference failed (non-fatal): %s", exc)
# ── Status ─────────────────────────────────────────────────────────────────────
def is_model_ready() -> bool:
"""Return True when the session is loaded and warmup has run."""
return _model_ready
# ── Core inference ─────────────────────────────────────────────────────────────
def _run_inference(image: Image.Image) -> Image.Image:
"""Internal: run rembg remove() with the preloaded session."""
return remove(image, session=_session)
def remove_background(image: Image.Image) -> Image.Image:
"""
Public API: accept a PIL Image, return a transparent RGBA PIL Image.
Raises HTTPException(500) on model failure.
"""
if not _model_ready:
raise HTTPException(status_code=503, detail="Inference engine unavailable.")
try:
logger.info("Processing image %dx%d…", *image.size)
result = _run_inference(image)
logger.info("Processing complete.")
return result
except HTTPException:
raise
except Exception as exc:
logger.error("Background removal failed: %s", exc)
raise HTTPException(status_code=500, detail=PROCESSING_FAILED)
# ── Export ─────────────────────────────────────────────────────────────────────
def image_to_png_buffer(image: Image.Image) -> BytesIO:
"""
Serialize a PIL Image to an in-memory PNG BytesIO stream.
The caller is responsible for streaming this buffer to the client.
"""
buf = BytesIO()
image.save(buf, format="PNG", optimize=False)
buf.seek(0)
return buf