Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, Header | |
| from fastapi.responses import Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import onnxruntime as ort | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import io | |
| from datetime import datetime, timedelta | |
| from collections import defaultdict | |
| app = FastAPI(title="MODNet API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_PATH = "modnet.onnx" # model ONNX is in root folder! | |
| MODEL_WIDTH = 512 | |
| MODEL_HEIGHT = 512 | |
| print("π Loading MODNet ONNX model...") | |
| onnx_session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider']) | |
| print("β MODNet model loaded successfully!") | |
| user_quotas = defaultdict(lambda: {"count": 0, "date": datetime.now().date()}) | |
| MAX_DAILY_IMAGES = 5 | |
| def check_and_update_quota(user_id: str) -> bool: | |
| today = datetime.now().date() | |
| user_data = user_quotas[user_id] | |
| if user_data["date"] != today: | |
| user_data["count"] = 0 | |
| user_data["date"] = today | |
| if user_data["count"] >= MAX_DAILY_IMAGES: | |
| return False | |
| user_data["count"] += 1 | |
| return True | |
| def preprocess_image(image: Image.Image, target_size=(MODEL_WIDTH, MODEL_HEIGHT)): | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| orig_width, orig_height = image.size | |
| image_resized = image.resize(target_size, Image.LANCZOS) | |
| img_array = np.array(image_resized).astype(np.float32) / 255.0 | |
| img_array = np.transpose(img_array, (2, 0, 1)) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array, (orig_width, orig_height) | |
| def postprocess_mask(mask: np.ndarray, original_size): | |
| mask = mask[0, 0] | |
| mask = (mask * 255).round().astype(np.uint8) | |
| mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_LINEAR) | |
| mask = np.where(mask > 127, 255, 0).astype(np.uint8) | |
| return mask | |
| def remove_background(image: Image.Image): | |
| input_array, original_size = preprocess_image(image) | |
| input_name = onnx_session.get_inputs()[0].name | |
| output = onnx_session.run(None, {input_name: input_array}) | |
| mask = postprocess_mask(output[0], original_size) | |
| image_array = np.array(image.convert('RGBA')) | |
| image_array[:, :, 3] = mask | |
| result_image = Image.fromarray(image_array, 'RGBA') | |
| return result_image | |
| async def root(): | |
| return {"status": "healthy", "service": "MODNet API", "version": "1.0.0"} | |
| async def get_quota(user_id: str): | |
| today = datetime.now().date() | |
| user_data = user_quotas[user_id] | |
| if user_data["date"] != today: | |
| user_data["count"] = 0 | |
| user_data["date"] = today | |
| remaining = MAX_DAILY_IMAGES - user_data["count"] | |
| return { | |
| "user_id": user_id, | |
| "used": user_data["count"], | |
| "remaining": max(0, remaining), | |
| "limit": MAX_DAILY_IMAGES, | |
| "resets_at": str(today + timedelta(days=1)) | |
| } | |
| async def remove_background_endpoint( | |
| file: UploadFile = File(...), | |
| user_id: str = Header(..., alias="X-User-ID") | |
| ): | |
| if not user_id or len(user_id) < 10: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid user ID. Please provide a valid device identifier." | |
| ) | |
| if not check_and_update_quota(user_id): | |
| raise HTTPException( | |
| status_code=429, | |
| detail=f"Daily quota exceeded. You can process {MAX_DAILY_IMAGES} images per day. Try again tomorrow!" | |
| ) | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid file type. Please upload an image (JPEG or PNG)." | |
| ) | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| result_image = remove_background(image) | |
| output_buffer = io.BytesIO() | |
| result_image.save(output_buffer, format='PNG') | |
| output_buffer.seek(0) | |
| return Response( | |
| content=output_buffer.getvalue(), | |
| media_type="image/png", | |
| headers={ | |
| "X-Quota-Used": str(user_quotas[user_id]["count"]), | |
| "X-Quota-Remaining": str(MAX_DAILY_IMAGES - user_quotas[user_id]["count"]) | |
| } | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error processing image: {str(e)}" | |
| ) | |
| finally: | |
| if 'image_bytes' in locals(): del image_bytes | |
| if 'image' in locals(): del image | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |