Spaces:
Sleeping
Sleeping
File size: 4,350 Bytes
0baa056 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | """
services/image_processor.py
Central AI engine β Silueta ONNX model via rembg.
Lifecycle
---------
startup β load_model() β warm_model() β ready = True
request β remove_background() β image_to_png_buffer()
shutdown β (session GC'd automatically)
Design decisions
----------------
* The rembg session is stored as a module-level singleton so it is
created exactly once per Gunicorn worker process.
* warm_model() runs a dummy inference so the first real request
doesn't pay the ONNX JIT cost.
* We never touch disk β all I/O is in-memory BytesIO.
"""
import logging
from io import BytesIO
from PIL import Image
from rembg import new_session, remove
from config.constants import MODEL_NAME, ENGINE_TAG, PROCESSING_FAILED
from fastapi import HTTPException
logger = logging.getLogger(__name__)
# ββ Module-level state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_session = None # rembg InferenceSession
_model_ready: bool = False
# ββ Startup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_model() -> None:
"""
Instantiate the rembg Silueta session.
Must be called once during FastAPI startup before any request is served.
"""
global _session, _model_ready
logger.info("Loading %s modelβ¦", ENGINE_TAG)
try:
_session = new_session(MODEL_NAME)
_model_ready = True
logger.info("Model loaded successfully.")
except Exception as exc:
_model_ready = False
logger.error("Model load failed: %s", exc)
raise
def warm_model() -> None:
"""
Run one dummy inference to prime ONNX caches.
This prevents the first real user request from experiencing cold-start latency.
"""
if not _model_ready:
logger.warning("Skipping warmup β model not loaded.")
return
logger.info("Warming up modelβ¦")
try:
dummy = Image.new("RGBA", (64, 64), (128, 128, 128, 255))
_run_inference(dummy)
logger.info("Warmup complete.")
except Exception as exc:
logger.warning("Warmup inference failed (non-fatal): %s", exc)
# ββ Status βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def is_model_ready() -> bool:
"""Return True when the session is loaded and warmup has run."""
return _model_ready
# ββ Core inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _run_inference(image: Image.Image) -> Image.Image:
"""Internal: run rembg remove() with the preloaded session."""
return remove(image, session=_session)
def remove_background(image: Image.Image) -> Image.Image:
"""
Public API: accept a PIL Image, return a transparent RGBA PIL Image.
Raises HTTPException(500) on model failure.
"""
if not _model_ready:
raise HTTPException(status_code=503, detail="Inference engine unavailable.")
try:
logger.info("Processing image %dx%dβ¦", *image.size)
result = _run_inference(image)
logger.info("Processing complete.")
return result
except HTTPException:
raise
except Exception as exc:
logger.error("Background removal failed: %s", exc)
raise HTTPException(status_code=500, detail=PROCESSING_FAILED)
# ββ Export βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def image_to_png_buffer(image: Image.Image) -> BytesIO:
"""
Serialize a PIL Image to an in-memory PNG BytesIO stream.
The caller is responsible for streaming this buffer to the client.
"""
buf = BytesIO()
image.save(buf, format="PNG", optimize=False)
buf.seek(0)
return buf
|