Spaces:
Running
Running
| import io | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, Request, HTTPException, Response | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.concurrency import run_in_threadpool | |
| from rembg import remove, new_session | |
| from PIL import Image | |
| ml_models = {} | |
| def warmup_model() -> bool: | |
| """Warm up the model with a small image to reduce first-request latency.""" | |
| session = ml_models.get("isnet") | |
| if session is None: | |
| return False | |
| try: | |
| # Create a small test image | |
| img = Image.new('RGB', (64, 64), color=(255, 255, 255)) | |
| buf = io.BytesIO() | |
| img.save(buf, format='PNG') | |
| remove(buf.getvalue(), session=session) | |
| return True | |
| except Exception as e: | |
| print(f"Warm-up failed: {e}") | |
| return False | |
| async def lifespan(app: FastAPI): | |
| """ | |
| Lifespan Context Manager: | |
| Loads the AI model ONLY when the server starts. | |
| Cleans it up when the server stops. | |
| """ | |
| print("Server starting: Loading ISNet model into memory...") | |
| # Load the model once and store it in the dictionary | |
| ml_models["isnet"] = new_session("isnet-general-use") | |
| yield | |
| # Cleanup logic (runs on shutdown) | |
| ml_models.clear() | |
| print("Server shutting down: Model memory released.") | |
| app = FastAPI( | |
| title="AI Background Removal Service", | |
| description="High-precision background removal using ISNet and Alpha Matting.", | |
| lifespan=lifespan | |
| ) | |
| origins = [ | |
| "http://127.0.0.1:5500", | |
| "http://localhost:5500", | |
| "https://toolboxesai.com" | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def process_image_sync(input_bytes: bytes, params: dict) -> bytes: | |
| """ | |
| Synchronous wrapper for the heavy lifting. | |
| """ | |
| # --- CHANGE START: Retrieve model from lifespan dictionary --- | |
| session = ml_models.get("isnet") | |
| if session is None: | |
| raise RuntimeError("Model not loaded") | |
| # --- CHANGE END --- | |
| # (Existing logic kept exactly as you requested) | |
| use_alpha_matting = params.get("alpha_matting", "false").lower() == "true" | |
| af_thresh = int(params.get("af", 240)) | |
| ab_thresh = int(params.get("ab", 10)) | |
| return remove( | |
| input_bytes, | |
| session=session, # Pass the retrieved session here | |
| alpha_matting=use_alpha_matting, | |
| alpha_matting_foreground_threshold=af_thresh, | |
| alpha_matting_background_threshold=ab_thresh | |
| ) | |
| async def http_remove_background(request: Request): | |
| """ | |
| Removes background using ISNet model. | |
| Supports 'alpha_matting=true' query param for high-quality edge preservation. | |
| """ | |
| # 1. Validate Query Params | |
| query_params = request.query_params | |
| output_format = query_params.get("format", "png").lower() | |
| if output_format not in ["png", "webp"]: | |
| raise HTTPException(400, "Unsupported format. Use 'png' or 'webp'.") | |
| # 2. Validate Content-Type | |
| content_type = request.headers.get("content-type") | |
| if not content_type or not content_type.startswith("image/"): | |
| raise HTTPException(400, "Invalid Content-Type. Must be image/*") | |
| try: | |
| # 3. Read Input (Async I/O) | |
| input_bytes = await request.body() | |
| if not input_bytes: | |
| raise HTTPException(400, "No image data provided.") | |
| # 4. Process Image (CPU Bound - Offloaded to Threadpool) | |
| # This prevents the server from freezing while calculating alpha values. | |
| output_bytes = await run_in_threadpool( | |
| process_image_sync, | |
| input_bytes, | |
| query_params | |
| ) | |
| # 5. Handle Format Conversion | |
| final_bytes = output_bytes | |
| media_type = "image/png" | |
| if output_format == "webp": | |
| # WebP conversion (also CPU bound, ideally offloaded, but fast enough here) | |
| try: | |
| pil_image = Image.open(io.BytesIO(output_bytes)) | |
| webp_buffer = io.BytesIO() | |
| pil_image.save(webp_buffer, format="WEBP", quality=90) | |
| final_bytes = webp_buffer.getvalue() | |
| media_type = "image/webp" | |
| except Exception as e: | |
| # Fallback to PNG if WebP fails | |
| print(f"WebP conversion failed: {e}") | |
| final_bytes = output_bytes | |
| media_type = "image/png" | |
| output_format = "png" | |
| # 6. Return Stream | |
| return StreamingResponse( | |
| io.BytesIO(final_bytes), | |
| media_type=media_type, | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=bg_removed.{output_format}" | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| raise HTTPException(500, f"Processing failed: {str(e)}") | |
| async def health_check(): | |
| """Health check with model warm-up - call this to keep model hot.""" | |
| model_loaded = ml_models.get("isnet") is not None | |
| if model_loaded: | |
| warmup_success = await run_in_threadpool(warmup_model) | |
| return { | |
| "status": "healthy", | |
| "model_loaded": True, | |
| "model_warmed": warmup_success | |
| } | |
| return Response( | |
| content='{"status": "unhealthy", "model_loaded": false}', | |
| media_type="application/json", | |
| status_code=503 | |
| ) | |
| def get_ping(): | |
| return Response(content="pong", media_type="text/plain") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True) |