from fastapi import FastAPI, File, UploadFile, HTTPException, Header from fastapi.responses import Response from fastapi.middleware.cors import CORSMiddleware import onnxruntime as ort import numpy as np from PIL import Image import cv2 import io from datetime import datetime, timedelta from collections import defaultdict app = FastAPI(title="MODNet API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL_PATH = "modnet.onnx" # model ONNX is in root folder! MODEL_WIDTH = 512 MODEL_HEIGHT = 512 print("🔄 Loading MODNet ONNX model...") onnx_session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider']) print("✅ MODNet model loaded successfully!") user_quotas = defaultdict(lambda: {"count": 0, "date": datetime.now().date()}) MAX_DAILY_IMAGES = 5 def check_and_update_quota(user_id: str) -> bool: today = datetime.now().date() user_data = user_quotas[user_id] if user_data["date"] != today: user_data["count"] = 0 user_data["date"] = today if user_data["count"] >= MAX_DAILY_IMAGES: return False user_data["count"] += 1 return True def preprocess_image(image: Image.Image, target_size=(MODEL_WIDTH, MODEL_HEIGHT)): if image.mode != 'RGB': image = image.convert('RGB') orig_width, orig_height = image.size image_resized = image.resize(target_size, Image.LANCZOS) img_array = np.array(image_resized).astype(np.float32) / 255.0 img_array = np.transpose(img_array, (2, 0, 1)) img_array = np.expand_dims(img_array, axis=0) return img_array, (orig_width, orig_height) def postprocess_mask(mask: np.ndarray, original_size): mask = mask[0, 0] mask = (mask * 255).round().astype(np.uint8) mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_LINEAR) mask = np.where(mask > 127, 255, 0).astype(np.uint8) return mask def remove_background(image: Image.Image): input_array, original_size = preprocess_image(image) input_name = onnx_session.get_inputs()[0].name output = onnx_session.run(None, {input_name: input_array}) mask = postprocess_mask(output[0], original_size) image_array = np.array(image.convert('RGBA')) image_array[:, :, 3] = mask result_image = Image.fromarray(image_array, 'RGBA') return result_image @app.get("/") async def root(): return {"status": "healthy", "service": "MODNet API", "version": "1.0.0"} @app.get("/quota/{user_id}") async def get_quota(user_id: str): today = datetime.now().date() user_data = user_quotas[user_id] if user_data["date"] != today: user_data["count"] = 0 user_data["date"] = today remaining = MAX_DAILY_IMAGES - user_data["count"] return { "user_id": user_id, "used": user_data["count"], "remaining": max(0, remaining), "limit": MAX_DAILY_IMAGES, "resets_at": str(today + timedelta(days=1)) } @app.post("/remove-background") async def remove_background_endpoint( file: UploadFile = File(...), user_id: str = Header(..., alias="X-User-ID") ): if not user_id or len(user_id) < 10: raise HTTPException( status_code=400, detail="Invalid user ID. Please provide a valid device identifier." ) if not check_and_update_quota(user_id): raise HTTPException( status_code=429, detail=f"Daily quota exceeded. You can process {MAX_DAILY_IMAGES} images per day. Try again tomorrow!" ) if not file.content_type.startswith('image/'): raise HTTPException( status_code=400, detail="Invalid file type. Please upload an image (JPEG or PNG)." ) try: image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)) result_image = remove_background(image) output_buffer = io.BytesIO() result_image.save(output_buffer, format='PNG') output_buffer.seek(0) return Response( content=output_buffer.getvalue(), media_type="image/png", headers={ "X-Quota-Used": str(user_quotas[user_id]["count"]), "X-Quota-Remaining": str(MAX_DAILY_IMAGES - user_quotas[user_id]["count"]) } ) except Exception as e: raise HTTPException( status_code=500, detail=f"Error processing image: {str(e)}" ) finally: if 'image_bytes' in locals(): del image_bytes if 'image' in locals(): del image if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)