File size: 4,350 Bytes
0baa056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
services/image_processor.py
Central AI engine β€” Silueta ONNX model via rembg.

Lifecycle
---------
  startup  β†’ load_model()  β†’ warm_model()  β†’ ready = True
  request  β†’ remove_background()  β†’ image_to_png_buffer()
  shutdown β†’ (session GC'd automatically)

Design decisions
----------------
  * The rembg session is stored as a module-level singleton so it is
    created exactly once per Gunicorn worker process.
  * warm_model() runs a dummy inference so the first real request
    doesn't pay the ONNX JIT cost.
  * We never touch disk β€” all I/O is in-memory BytesIO.
"""

import logging
from io import BytesIO

from PIL import Image
from rembg import new_session, remove

from config.constants import MODEL_NAME, ENGINE_TAG, PROCESSING_FAILED
from fastapi import HTTPException

logger = logging.getLogger(__name__)

# ── Module-level state ─────────────────────────────────────────────────────────
_session = None          # rembg InferenceSession
_model_ready: bool = False


# ── Startup ────────────────────────────────────────────────────────────────────

def load_model() -> None:
    """
    Instantiate the rembg Silueta session.
    Must be called once during FastAPI startup before any request is served.
    """
    global _session, _model_ready
    logger.info("Loading %s model…", ENGINE_TAG)
    try:
        _session = new_session(MODEL_NAME)
        _model_ready = True
        logger.info("Model loaded successfully.")
    except Exception as exc:
        _model_ready = False
        logger.error("Model load failed: %s", exc)
        raise


def warm_model() -> None:
    """
    Run one dummy inference to prime ONNX caches.
    This prevents the first real user request from experiencing cold-start latency.
    """
    if not _model_ready:
        logger.warning("Skipping warmup β€” model not loaded.")
        return
    logger.info("Warming up model…")
    try:
        dummy = Image.new("RGBA", (64, 64), (128, 128, 128, 255))
        _run_inference(dummy)
        logger.info("Warmup complete.")
    except Exception as exc:
        logger.warning("Warmup inference failed (non-fatal): %s", exc)


# ── Status ─────────────────────────────────────────────────────────────────────

def is_model_ready() -> bool:
    """Return True when the session is loaded and warmup has run."""
    return _model_ready


# ── Core inference ─────────────────────────────────────────────────────────────

def _run_inference(image: Image.Image) -> Image.Image:
    """Internal: run rembg remove() with the preloaded session."""
    return remove(image, session=_session)


def remove_background(image: Image.Image) -> Image.Image:
    """
    Public API: accept a PIL Image, return a transparent RGBA PIL Image.
    Raises HTTPException(500) on model failure.
    """
    if not _model_ready:
        raise HTTPException(status_code=503, detail="Inference engine unavailable.")
    try:
        logger.info("Processing image %dx%d…", *image.size)
        result = _run_inference(image)
        logger.info("Processing complete.")
        return result
    except HTTPException:
        raise
    except Exception as exc:
        logger.error("Background removal failed: %s", exc)
        raise HTTPException(status_code=500, detail=PROCESSING_FAILED)


# ── Export ─────────────────────────────────────────────────────────────────────

def image_to_png_buffer(image: Image.Image) -> BytesIO:
    """
    Serialize a PIL Image to an in-memory PNG BytesIO stream.
    The caller is responsible for streaming this buffer to the client.
    """
    buf = BytesIO()
    image.save(buf, format="PNG", optimize=False)
    buf.seek(0)
    return buf