import io from contextlib import asynccontextmanager from fastapi import FastAPI, Request, HTTPException, Response from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool from rembg import remove, new_session from PIL import Image ml_models = {} def warmup_model() -> bool: """Warm up the model with a small image to reduce first-request latency.""" session = ml_models.get("isnet") if session is None: return False try: # Create a small test image img = Image.new('RGB', (64, 64), color=(255, 255, 255)) buf = io.BytesIO() img.save(buf, format='PNG') remove(buf.getvalue(), session=session) return True except Exception as e: print(f"Warm-up failed: {e}") return False @asynccontextmanager async def lifespan(app: FastAPI): """ Lifespan Context Manager: Loads the AI model ONLY when the server starts. Cleans it up when the server stops. """ print("Server starting: Loading ISNet model into memory...") # Load the model once and store it in the dictionary ml_models["isnet"] = new_session("isnet-general-use") yield # Cleanup logic (runs on shutdown) ml_models.clear() print("Server shutting down: Model memory released.") app = FastAPI( title="AI Background Removal Service", description="High-precision background removal using ISNet and Alpha Matting.", lifespan=lifespan ) origins = [ "http://127.0.0.1:5500", "http://localhost:5500", "https://toolboxesai.com" ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def process_image_sync(input_bytes: bytes, params: dict) -> bytes: """ Synchronous wrapper for the heavy lifting. """ # --- CHANGE START: Retrieve model from lifespan dictionary --- session = ml_models.get("isnet") if session is None: raise RuntimeError("Model not loaded") # --- CHANGE END --- # (Existing logic kept exactly as you requested) use_alpha_matting = params.get("alpha_matting", "false").lower() == "true" af_thresh = int(params.get("af", 240)) ab_thresh = int(params.get("ab", 10)) return remove( input_bytes, session=session, # Pass the retrieved session here alpha_matting=use_alpha_matting, alpha_matting_foreground_threshold=af_thresh, alpha_matting_background_threshold=ab_thresh ) @app.post( "/api/remove-background", summary="Remove Background with Advanced Controls", responses={ 200: {"content": {"image/png": {}, "image/webp": {}}}, 400: {"description": "Invalid input."}, 500: {"description": "Processing error."}, } ) async def http_remove_background(request: Request): """ Removes background using ISNet model. Supports 'alpha_matting=true' query param for high-quality edge preservation. """ # 1. Validate Query Params query_params = request.query_params output_format = query_params.get("format", "png").lower() if output_format not in ["png", "webp"]: raise HTTPException(400, "Unsupported format. Use 'png' or 'webp'.") # 2. Validate Content-Type content_type = request.headers.get("content-type") if not content_type or not content_type.startswith("image/"): raise HTTPException(400, "Invalid Content-Type. Must be image/*") try: # 3. Read Input (Async I/O) input_bytes = await request.body() if not input_bytes: raise HTTPException(400, "No image data provided.") # 4. Process Image (CPU Bound - Offloaded to Threadpool) # This prevents the server from freezing while calculating alpha values. output_bytes = await run_in_threadpool( process_image_sync, input_bytes, query_params ) # 5. Handle Format Conversion final_bytes = output_bytes media_type = "image/png" if output_format == "webp": # WebP conversion (also CPU bound, ideally offloaded, but fast enough here) try: pil_image = Image.open(io.BytesIO(output_bytes)) webp_buffer = io.BytesIO() pil_image.save(webp_buffer, format="WEBP", quality=90) final_bytes = webp_buffer.getvalue() media_type = "image/webp" except Exception as e: # Fallback to PNG if WebP fails print(f"WebP conversion failed: {e}") final_bytes = output_bytes media_type = "image/png" output_format = "png" # 6. Return Stream return StreamingResponse( io.BytesIO(final_bytes), media_type=media_type, headers={ "Content-Disposition": f"attachment; filename=bg_removed.{output_format}" } ) except Exception as e: print(f"Error: {e}") raise HTTPException(500, f"Processing failed: {str(e)}") @app.get("/health") async def health_check(): """Health check with model warm-up - call this to keep model hot.""" model_loaded = ml_models.get("isnet") is not None if model_loaded: warmup_success = await run_in_threadpool(warmup_model) return { "status": "healthy", "model_loaded": True, "model_warmed": warmup_success } return Response( content='{"status": "unhealthy", "model_loaded": false}', media_type="application/json", status_code=503 ) @app.get("/ping") def get_ping(): return Response(content="pong", media_type="text/plain") if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)