File size: 4,723 Bytes
6c1cc1e
 
 
 
 
 
 
 
 
 
 
5c59863
6c1cc1e
 
 
5c59863
6c1cc1e
 
 
 
 
5c59863
 
6c1cc1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c59863
 
 
6c1cc1e
 
 
5c59863
6c1cc1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c59863
6c1cc1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c59863
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from fastapi import FastAPI, File, UploadFile, HTTPException, Header
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
import onnxruntime as ort
import numpy as np
from PIL import Image
import cv2
import io
from datetime import datetime, timedelta
from collections import defaultdict

app = FastAPI(title="MODNet API", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

MODEL_PATH = "modnet.onnx"   # model ONNX is in root folder!
MODEL_WIDTH = 512
MODEL_HEIGHT = 512

print("🔄 Loading MODNet ONNX model...")
onnx_session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
print("✅ MODNet model loaded successfully!")

user_quotas = defaultdict(lambda: {"count": 0, "date": datetime.now().date()})
MAX_DAILY_IMAGES = 5

def check_and_update_quota(user_id: str) -> bool:
    today = datetime.now().date()
    user_data = user_quotas[user_id]
    if user_data["date"] != today:
        user_data["count"] = 0
        user_data["date"] = today
    if user_data["count"] >= MAX_DAILY_IMAGES:
        return False
    user_data["count"] += 1
    return True

def preprocess_image(image: Image.Image, target_size=(MODEL_WIDTH, MODEL_HEIGHT)):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    orig_width, orig_height = image.size
    image_resized = image.resize(target_size, Image.LANCZOS)
    img_array = np.array(image_resized).astype(np.float32) / 255.0
    img_array = np.transpose(img_array, (2, 0, 1))
    img_array = np.expand_dims(img_array, axis=0)
    return img_array, (orig_width, orig_height)

def postprocess_mask(mask: np.ndarray, original_size):
    mask = mask[0, 0]
    mask = (mask * 255).round().astype(np.uint8)
    mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_LINEAR)
    mask = np.where(mask > 127, 255, 0).astype(np.uint8)
    return mask

def remove_background(image: Image.Image):
    input_array, original_size = preprocess_image(image)
    input_name = onnx_session.get_inputs()[0].name
    output = onnx_session.run(None, {input_name: input_array})
    mask = postprocess_mask(output[0], original_size)
    image_array = np.array(image.convert('RGBA'))
    image_array[:, :, 3] = mask
    result_image = Image.fromarray(image_array, 'RGBA')
    return result_image

@app.get("/")
async def root():
    return {"status": "healthy", "service": "MODNet API", "version": "1.0.0"}

@app.get("/quota/{user_id}")
async def get_quota(user_id: str):
    today = datetime.now().date()
    user_data = user_quotas[user_id]
    if user_data["date"] != today:
        user_data["count"] = 0
        user_data["date"] = today
    remaining = MAX_DAILY_IMAGES - user_data["count"]
    return {
        "user_id": user_id,
        "used": user_data["count"],
        "remaining": max(0, remaining),
        "limit": MAX_DAILY_IMAGES,
        "resets_at": str(today + timedelta(days=1))
    }

@app.post("/remove-background")
async def remove_background_endpoint(
    file: UploadFile = File(...),
    user_id: str = Header(..., alias="X-User-ID")
):
    if not user_id or len(user_id) < 10:
        raise HTTPException(
            status_code=400,
            detail="Invalid user ID. Please provide a valid device identifier."
        )
    if not check_and_update_quota(user_id):
        raise HTTPException(
            status_code=429,
            detail=f"Daily quota exceeded. You can process {MAX_DAILY_IMAGES} images per day. Try again tomorrow!"
        )
    if not file.content_type.startswith('image/'):
        raise HTTPException(
            status_code=400,
            detail="Invalid file type. Please upload an image (JPEG or PNG)."
        )
    try:
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes))
        result_image = remove_background(image)
        output_buffer = io.BytesIO()
        result_image.save(output_buffer, format='PNG')
        output_buffer.seek(0)
        return Response(
            content=output_buffer.getvalue(),
            media_type="image/png",
            headers={
                "X-Quota-Used": str(user_quotas[user_id]["count"]),
                "X-Quota-Remaining": str(MAX_DAILY_IMAGES - user_quotas[user_id]["count"])
            }
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Error processing image: {str(e)}"
        )
    finally:
        if 'image_bytes' in locals(): del image_bytes
        if 'image' in locals(): del image

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)