bgr / main.py
Rajhuggingface4253's picture
Update main.py
738ee5e verified
import io
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from rembg import remove, new_session
from PIL import Image
ml_models = {}
def warmup_model() -> bool:
"""Warm up the model with a small image to reduce first-request latency."""
session = ml_models.get("isnet")
if session is None:
return False
try:
# Create a small test image
img = Image.new('RGB', (64, 64), color=(255, 255, 255))
buf = io.BytesIO()
img.save(buf, format='PNG')
remove(buf.getvalue(), session=session)
return True
except Exception as e:
print(f"Warm-up failed: {e}")
return False
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan Context Manager:
Loads the AI model ONLY when the server starts.
Cleans it up when the server stops.
"""
print("Server starting: Loading ISNet model into memory...")
# Load the model once and store it in the dictionary
ml_models["isnet"] = new_session("isnet-general-use")
yield
# Cleanup logic (runs on shutdown)
ml_models.clear()
print("Server shutting down: Model memory released.")
app = FastAPI(
title="AI Background Removal Service",
description="High-precision background removal using ISNet and Alpha Matting.",
lifespan=lifespan
)
origins = [
"http://127.0.0.1:5500",
"http://localhost:5500",
"https://toolboxesai.com"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def process_image_sync(input_bytes: bytes, params: dict) -> bytes:
"""
Synchronous wrapper for the heavy lifting.
"""
# --- CHANGE START: Retrieve model from lifespan dictionary ---
session = ml_models.get("isnet")
if session is None:
raise RuntimeError("Model not loaded")
# --- CHANGE END ---
# (Existing logic kept exactly as you requested)
use_alpha_matting = params.get("alpha_matting", "false").lower() == "true"
af_thresh = int(params.get("af", 240))
ab_thresh = int(params.get("ab", 10))
return remove(
input_bytes,
session=session, # Pass the retrieved session here
alpha_matting=use_alpha_matting,
alpha_matting_foreground_threshold=af_thresh,
alpha_matting_background_threshold=ab_thresh
)
@app.post(
"/api/remove-background",
summary="Remove Background with Advanced Controls",
responses={
200: {"content": {"image/png": {}, "image/webp": {}}},
400: {"description": "Invalid input."},
500: {"description": "Processing error."},
}
)
async def http_remove_background(request: Request):
"""
Removes background using ISNet model.
Supports 'alpha_matting=true' query param for high-quality edge preservation.
"""
# 1. Validate Query Params
query_params = request.query_params
output_format = query_params.get("format", "png").lower()
if output_format not in ["png", "webp"]:
raise HTTPException(400, "Unsupported format. Use 'png' or 'webp'.")
# 2. Validate Content-Type
content_type = request.headers.get("content-type")
if not content_type or not content_type.startswith("image/"):
raise HTTPException(400, "Invalid Content-Type. Must be image/*")
try:
# 3. Read Input (Async I/O)
input_bytes = await request.body()
if not input_bytes:
raise HTTPException(400, "No image data provided.")
# 4. Process Image (CPU Bound - Offloaded to Threadpool)
# This prevents the server from freezing while calculating alpha values.
output_bytes = await run_in_threadpool(
process_image_sync,
input_bytes,
query_params
)
# 5. Handle Format Conversion
final_bytes = output_bytes
media_type = "image/png"
if output_format == "webp":
# WebP conversion (also CPU bound, ideally offloaded, but fast enough here)
try:
pil_image = Image.open(io.BytesIO(output_bytes))
webp_buffer = io.BytesIO()
pil_image.save(webp_buffer, format="WEBP", quality=90)
final_bytes = webp_buffer.getvalue()
media_type = "image/webp"
except Exception as e:
# Fallback to PNG if WebP fails
print(f"WebP conversion failed: {e}")
final_bytes = output_bytes
media_type = "image/png"
output_format = "png"
# 6. Return Stream
return StreamingResponse(
io.BytesIO(final_bytes),
media_type=media_type,
headers={
"Content-Disposition": f"attachment; filename=bg_removed.{output_format}"
}
)
except Exception as e:
print(f"Error: {e}")
raise HTTPException(500, f"Processing failed: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check with model warm-up - call this to keep model hot."""
model_loaded = ml_models.get("isnet") is not None
if model_loaded:
warmup_success = await run_in_threadpool(warmup_model)
return {
"status": "healthy",
"model_loaded": True,
"model_warmed": warmup_success
}
return Response(
content='{"status": "unhealthy", "model_loaded": false}',
media_type="application/json",
status_code=503
)
@app.get("/ping")
def get_ping():
return Response(content="pong", media_type="text/plain")
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)