Mr7Explorer commited on
Commit
6c1cc1e
·
verified ·
1 Parent(s): ace8d73

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Header
2
+ from fastapi.responses import Response
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import onnxruntime as ort
5
+ import numpy as np
6
+ from PIL import Image
7
+ import cv2
8
+ import io
9
+ from datetime import datetime, timedelta
10
+ from collections import defaultdict
11
+ import os
12
+
13
+ app = FastAPI(title="Backdrop Studio API", version="2.0.0")
14
+
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"], # Restrict for production!
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+ MODEL_PATH = "models/modnet.onnx"
24
+ MODEL_WIDTH = 512 # Official MODNet ONNX input size (for best speed, use 512; 768 or 1024 for higher res if model supports)
25
+ MODEL_HEIGHT = 512
26
+
27
+ print("🔄 Loading MODNet ONNX model...")
28
+ onnx_session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
29
+ print("✅ MODNet model loaded successfully!")
30
+
31
+ user_quotas = defaultdict(lambda: {"count": 0, "date": datetime.now().date()})
32
+ MAX_DAILY_IMAGES = 5
33
+
34
+ def check_and_update_quota(user_id: str) -> bool:
35
+ today = datetime.now().date()
36
+ user_data = user_quotas[user_id]
37
+ if user_data["date"] != today:
38
+ user_data["count"] = 0
39
+ user_data["date"] = today
40
+ if user_data["count"] >= MAX_DAILY_IMAGES:
41
+ return False
42
+ user_data["count"] += 1
43
+ return True
44
+
45
+ def preprocess_image(image: Image.Image, target_size=(MODEL_WIDTH, MODEL_HEIGHT)):
46
+ if image.mode != 'RGB':
47
+ image = image.convert('RGB')
48
+ orig_width, orig_height = image.size
49
+ image_resized = image.resize(target_size, Image.LANCZOS)
50
+ img_array = np.array(image_resized).astype(np.float32) / 255.0 # shape (512, 512, 3)
51
+ img_array = np.transpose(img_array, (2, 0, 1)) # (3, 512, 512)
52
+ img_array = np.expand_dims(img_array, axis=0) # (1, 3, 512, 512)
53
+ return img_array, (orig_width, orig_height)
54
+
55
+ def postprocess_mask(mask: np.ndarray, original_size):
56
+ # MODNet returns (1,1,H,W) float in [0,1]
57
+ mask = mask[0, 0] # (H,W)
58
+ mask = (mask * 255).round().astype(np.uint8)
59
+ mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_LINEAR)
60
+ # Optional: Apply threshold to get crisp mask
61
+ mask = np.where(mask > 127, 255, 0).astype(np.uint8)
62
+ return mask
63
+
64
+ def remove_background(image: Image.Image):
65
+ input_array, original_size = preprocess_image(image)
66
+ input_name = onnx_session.get_inputs()[0].name
67
+ output = onnx_session.run(None, {input_name: input_array})
68
+ mask = postprocess_mask(output[0], original_size)
69
+ image_array = np.array(image.convert('RGBA'))
70
+ image_array[:, :, 3] = mask
71
+ result_image = Image.fromarray(image_array, 'RGBA')
72
+ return result_image
73
+
74
+ @app.get("/")
75
+ async def root():
76
+ return {"status": "healthy", "service": "Backdrop Studio MODNet API", "version": "2.0.0"}
77
+
78
+ @app.get("/quota/{user_id}")
79
+ async def get_quota(user_id: str):
80
+ today = datetime.now().date()
81
+ user_data = user_quotas[user_id]
82
+ if user_data["date"] != today:
83
+ user_data["count"] = 0
84
+ user_data["date"] = today
85
+ remaining = MAX_DAILY_IMAGES - user_data["count"]
86
+ return {
87
+ "user_id": user_id,
88
+ "used": user_data["count"],
89
+ "remaining": max(0, remaining),
90
+ "limit": MAX_DAILY_IMAGES,
91
+ "resets_at": str(today + timedelta(days=1))
92
+ }
93
+
94
+ @app.post("/remove-background")
95
+ async def remove_background_endpoint(
96
+ file: UploadFile = File(...),
97
+ user_id: str = Header(..., alias="X-User-ID")
98
+ ):
99
+ if not user_id or len(user_id) < 10:
100
+ raise HTTPException(
101
+ status_code=400,
102
+ detail="Invalid user ID. Please provide a valid device identifier."
103
+ )
104
+ if not check_and_update_quota(user_id):
105
+ raise HTTPException(
106
+ status_code=429,
107
+ detail=f"Daily quota exceeded. You can process {MAX_DAILY_IMAGES} images per day. Try again tomorrow!"
108
+ )
109
+ if not file.content_type.startswith('image/'):
110
+ raise HTTPException(
111
+ status_code=400,
112
+ detail="Invalid file type. Please upload an image (JPEG or PNG)."
113
+ )
114
+ try:
115
+ image_bytes = await file.read()
116
+ image = Image.open(io.BytesIO(image_bytes))
117
+ result_image = remove_background(image)
118
+ output_buffer = io.BytesIO()
119
+ result_image.save(output_buffer, format='PNG')
120
+ output_buffer.seek(0)
121
+ return Response(
122
+ content=output_buffer.getvalue(),
123
+ media_type="image/png",
124
+ headers={
125
+ "X-Quota-Used": str(user_quotas[user_id]["count"]),
126
+ "X-Quota-Remaining": str(MAX_DAILY_IMAGES - user_quotas[user_id]["count"])
127
+ }
128
+ )
129
+ except Exception as e:
130
+ raise HTTPException(
131
+ status_code=500,
132
+ detail=f"Error processing image: {str(e)}"
133
+ )
134
+ finally:
135
+ if 'image_bytes' in locals(): del image_bytes
136
+ if 'image' in locals(): del image
137
+
138
+ if __name__ == "__main__":
139
+ import uvicorn
140
+ port = int(os.environ.get("PORT", 8080))
141
+ uvicorn.run(app, host="0.0.0.0", port=port)