Spaces:
Sleeping
Sleeping
File size: 4,723 Bytes
6c1cc1e 5c59863 6c1cc1e 5c59863 6c1cc1e 5c59863 6c1cc1e 5c59863 6c1cc1e 5c59863 6c1cc1e 5c59863 6c1cc1e 5c59863 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | from fastapi import FastAPI, File, UploadFile, HTTPException, Header
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
import onnxruntime as ort
import numpy as np
from PIL import Image
import cv2
import io
from datetime import datetime, timedelta
from collections import defaultdict
app = FastAPI(title="MODNet API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_PATH = "modnet.onnx" # model ONNX is in root folder!
MODEL_WIDTH = 512
MODEL_HEIGHT = 512
print("🔄 Loading MODNet ONNX model...")
onnx_session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
print("✅ MODNet model loaded successfully!")
user_quotas = defaultdict(lambda: {"count": 0, "date": datetime.now().date()})
MAX_DAILY_IMAGES = 5
def check_and_update_quota(user_id: str) -> bool:
today = datetime.now().date()
user_data = user_quotas[user_id]
if user_data["date"] != today:
user_data["count"] = 0
user_data["date"] = today
if user_data["count"] >= MAX_DAILY_IMAGES:
return False
user_data["count"] += 1
return True
def preprocess_image(image: Image.Image, target_size=(MODEL_WIDTH, MODEL_HEIGHT)):
if image.mode != 'RGB':
image = image.convert('RGB')
orig_width, orig_height = image.size
image_resized = image.resize(target_size, Image.LANCZOS)
img_array = np.array(image_resized).astype(np.float32) / 255.0
img_array = np.transpose(img_array, (2, 0, 1))
img_array = np.expand_dims(img_array, axis=0)
return img_array, (orig_width, orig_height)
def postprocess_mask(mask: np.ndarray, original_size):
mask = mask[0, 0]
mask = (mask * 255).round().astype(np.uint8)
mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_LINEAR)
mask = np.where(mask > 127, 255, 0).astype(np.uint8)
return mask
def remove_background(image: Image.Image):
input_array, original_size = preprocess_image(image)
input_name = onnx_session.get_inputs()[0].name
output = onnx_session.run(None, {input_name: input_array})
mask = postprocess_mask(output[0], original_size)
image_array = np.array(image.convert('RGBA'))
image_array[:, :, 3] = mask
result_image = Image.fromarray(image_array, 'RGBA')
return result_image
@app.get("/")
async def root():
return {"status": "healthy", "service": "MODNet API", "version": "1.0.0"}
@app.get("/quota/{user_id}")
async def get_quota(user_id: str):
today = datetime.now().date()
user_data = user_quotas[user_id]
if user_data["date"] != today:
user_data["count"] = 0
user_data["date"] = today
remaining = MAX_DAILY_IMAGES - user_data["count"]
return {
"user_id": user_id,
"used": user_data["count"],
"remaining": max(0, remaining),
"limit": MAX_DAILY_IMAGES,
"resets_at": str(today + timedelta(days=1))
}
@app.post("/remove-background")
async def remove_background_endpoint(
file: UploadFile = File(...),
user_id: str = Header(..., alias="X-User-ID")
):
if not user_id or len(user_id) < 10:
raise HTTPException(
status_code=400,
detail="Invalid user ID. Please provide a valid device identifier."
)
if not check_and_update_quota(user_id):
raise HTTPException(
status_code=429,
detail=f"Daily quota exceeded. You can process {MAX_DAILY_IMAGES} images per day. Try again tomorrow!"
)
if not file.content_type.startswith('image/'):
raise HTTPException(
status_code=400,
detail="Invalid file type. Please upload an image (JPEG or PNG)."
)
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
result_image = remove_background(image)
output_buffer = io.BytesIO()
result_image.save(output_buffer, format='PNG')
output_buffer.seek(0)
return Response(
content=output_buffer.getvalue(),
media_type="image/png",
headers={
"X-Quota-Used": str(user_quotas[user_id]["count"]),
"X-Quota-Remaining": str(MAX_DAILY_IMAGES - user_quotas[user_id]["count"])
}
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error processing image: {str(e)}"
)
finally:
if 'image_bytes' in locals(): del image_bytes
if 'image' in locals(): del image
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|