videopix commited on
Commit
a1dd89f
·
verified ·
1 Parent(s): 490a080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -232
app.py CHANGED
@@ -1,27 +1,34 @@
1
  import os
2
- from io import BytesIO
3
-
4
- import numpy as np
5
- import requests
6
  import torch
7
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Query
 
8
  from fastapi.responses import StreamingResponse, HTMLResponse
9
- from fastapi.middleware.cors import CORSMiddleware
10
  from PIL import Image
 
 
11
  from transformers import AutoModelForImageSegmentation
 
12
 
13
  # ---------------------------------------------------------
14
- # Optional HEIC Support
15
  # ---------------------------------------------------------
16
  try:
17
  import pillow_heif
18
  pillow_heif.register_heif_opener()
19
- print("HEIC/HEIF supported")
20
- except:
21
- print("Install pillow-heif for HEIC support")
22
 
23
  # ---------------------------------------------------------
24
- # Load Model
 
 
 
 
 
 
 
25
  # ---------------------------------------------------------
26
  MODEL_DIR = "models/BiRefNet"
27
  os.makedirs(MODEL_DIR, exist_ok=True)
@@ -29,7 +36,7 @@ os.makedirs(MODEL_DIR, exist_ok=True)
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
31
 
32
- print("Loading BiRefNet...")
33
  birefnet = AutoModelForImageSegmentation.from_pretrained(
34
  "ZhengPeng7/BiRefNet",
35
  cache_dir=MODEL_DIR,
@@ -37,284 +44,183 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
37
  revision="main"
38
  )
39
  birefnet.to(device, dtype=dtype).eval()
40
- print("Model Ready")
 
 
 
41
 
42
  # ---------------------------------------------------------
43
- # FastAPI App
44
  # ---------------------------------------------------------
45
  app = FastAPI(title="Background Remover API")
46
 
47
- # Allow requests from any app
48
- app.add_middleware(
49
- CORSMiddleware,
50
- allow_origins=["*"],
51
- allow_methods=["*"],
52
- allow_headers=["*"],
53
- )
54
-
55
  # ---------------------------------------------------------
56
- # Utility Functions
57
  # ---------------------------------------------------------
58
- def load_image_from_url(url: str):
59
  try:
60
- resp = requests.get(url, timeout=10)
61
- resp.raise_for_status()
62
- return Image.open(BytesIO(resp.content)).convert("RGB")
63
  except Exception as e:
64
- raise HTTPException(status_code=400, detail=f"Error loading image URL: {str(e)}")
 
65
 
66
- def transform_image(image: Image.Image, resolution: int):
67
  image = image.resize((resolution, resolution))
68
- img = np.array(image).astype("float32") / 255.0
69
- mean = np.array([0.485, 0.456, 0.406], dtype="float32")
70
- std = np.array([0.229, 0.224, 0.225], dtype="float32")
71
- img = (img - mean) / std
72
- img = img.transpose((2, 0, 1))
73
- return torch.from_numpy(img).unsqueeze(0).to(device=device, dtype=dtype)
74
 
75
- def process_image(image: Image.Image, resolution: int):
76
- original_size = image.size
77
- tensor = transform_image(image, resolution)
78
 
79
- with torch.no_grad():
80
- mask_pred = birefnet(tensor)[-1].sigmoid().cpu()[0, 0]
81
 
82
- mask = Image.fromarray((mask_pred.numpy() * 255).astype("uint8"))
83
- mask = mask.resize(original_size)
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  image = image.convert("RGBA")
86
  image.putalpha(mask)
87
  return image
88
 
89
  # ---------------------------------------------------------
90
- # GET + POST endpoint
91
  # ---------------------------------------------------------
92
- @app.api_route("/remove-background", methods=["GET", "POST"])
93
  async def remove_background(
94
  file: UploadFile = File(None),
95
  image_url: str = Form(None),
96
- resolution: int = Form(512),
97
- get_url: str = Query(None),
98
- get_res: int = Query(512),
99
  ):
100
  try:
101
- # GET mode: /remove-background?get_url=...&get_res=512
102
- if get_url:
103
- img = load_image_from_url(get_url)
104
- resolution = get_res
105
-
106
- # POST mode - file upload
107
- elif file:
108
  data = await file.read()
109
- if not data:
110
- raise HTTPException(status_code=400, detail="Empty file")
111
- img = Image.open(BytesIO(data)).convert("RGB")
112
-
113
- # POST mode - URL
114
  elif image_url:
115
- img = load_image_from_url(image_url)
116
-
117
  else:
118
- raise HTTPException(status_code=400, detail="No image provided")
119
 
120
- result = process_image(img, resolution)
121
 
122
- buffer = BytesIO()
123
- result.save(buffer, format="PNG")
124
- buffer.seek(0)
125
 
126
- return StreamingResponse(
127
- buffer,
128
- media_type="image/png",
129
- headers={"Content-Disposition": "inline; filename=result.png"},
130
- )
131
 
132
  except HTTPException:
133
  raise
134
  except Exception as e:
135
- raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
136
 
137
- # ---------------------------------------------------------
138
- # Favicon (stop 404 logs)
139
- # ---------------------------------------------------------
140
- @app.get("/favicon.ico")
141
- async def favicon():
142
- return HTMLResponse("")
143
 
144
  # ---------------------------------------------------------
145
- # UI for testing (POST method)
146
  # ---------------------------------------------------------
147
  @app.get("/", response_class=HTMLResponse)
148
  async def index():
149
- return """
150
- <!DOCTYPE html>
151
- <html>
152
- <head>
153
- <title>Background Remover API</title>
154
- <meta name="viewport" content="width=device-width, initial-scale=1">
155
-
156
- <style>
157
- body {
158
- font-family: Arial, sans-serif;
159
- padding: 20px;
160
- margin: 0;
161
- background: #f5f5f5;
162
- }
163
- h2 {
164
- text-align: center;
165
- margin-bottom: 25px;
166
- }
167
-
168
- .container {
169
- max-width: 700px;
170
- margin: auto;
171
- background: white;
172
- padding: 25px;
173
- border-radius: 12px;
174
- box-shadow: 0 4px 10px rgba(0,0,0,0.1);
175
- }
176
-
177
- .section {
178
- margin-bottom: 30px;
179
- }
180
-
181
- label {
182
- font-weight: bold;
183
- }
184
-
185
- input[type="file"],
186
- input[type="text"],
187
- input[type="number"] {
188
- width: 100%;
189
- padding: 10px;
190
- margin-top: 6px;
191
- border-radius: 6px;
192
- border: 1px solid #ccc;
193
- }
194
-
195
- button {
196
- padding: 12px 18px;
197
- margin-top: 12px;
198
- width: 100%;
199
- border: none;
200
- background: #007bff;
201
- color: white;
202
- border-radius: 6px;
203
- cursor: pointer;
204
- font-size: 16px;
205
- }
206
-
207
- button:hover {
208
- background: #005dc4;
209
- }
210
-
211
- #resultWrapper {
212
- text-align: center;
213
- margin-top: 20px;
214
- }
215
-
216
- img {
217
- max-width: 100%;
218
- border-radius: 10px;
219
- margin-top: 15px;
220
- }
221
-
222
- /* Responsive fixes */
223
- @media (max-width: 600px) {
224
- .container {
225
- padding: 15px;
226
- }
227
- button {
228
- font-size: 15px;
229
- }
230
- }
231
- </style>
232
-
233
- </head>
234
- <body>
235
-
236
- <h2>Background Remover API Tester</h2>
237
-
238
- <div class="container">
239
-
240
- <!-- Upload Section -->
241
- <div class="section">
242
- <label>Upload Image</label>
243
- <form id="uploadForm" enctype="multipart/form-data">
244
- <input type="file" id="file" name="file" accept="image/*">
245
-
246
- <label>Resolution</label>
247
- <input type="number" id="resFile" value="512" min="64" max="2048">
248
-
249
- <button type="submit">Remove Background</button>
250
  </form>
251
- </div>
252
 
253
- <hr>
254
 
255
- <!-- URL Section -->
256
- <div class="section">
257
- <label>Image URL</label>
258
  <form id="urlForm">
259
- <input type="text" id="imgUrl" placeholder="https://example.com/image.jpg">
260
-
261
- <label>Resolution</label>
262
- <input type="number" id="resUrl" value="512" min="64" max="2048">
263
-
264
- <button type="submit">Remove Background</button>
 
265
  </form>
266
- </div>
267
 
268
- <div id="resultWrapper">
269
- <h3>Result</h3>
270
- <img id="result" />
271
  </div>
272
 
273
- </div>
274
-
275
- <script>
276
- const resultImg = document.getElementById("result");
277
-
278
- // Upload File
279
- document.getElementById("uploadForm").onsubmit = async (e) => {
280
- e.preventDefault();
281
- const file = document.getElementById("file").files[0];
282
- const res = document.getElementById("resFile").value;
283
-
284
- if (!file) return alert("Please select a file");
285
 
286
- const fd = new FormData();
287
- fd.append("file", file);
288
- fd.append("resolution", res);
 
289
 
290
- const r = await fetch("/remove-background", { method: "POST", body: fd });
291
- resultImg.src = URL.createObjectURL(await r.blob());
292
- };
 
293
 
294
- // URL Mode
295
- document.getElementById("urlForm").onsubmit = async (e) => {
296
- e.preventDefault();
297
- const url = document.getElementById("imgUrl").value;
298
- const res = document.getElementById("resUrl").value;
299
 
300
- if (!url.trim()) return alert("Please enter an image URL");
 
 
 
301
 
302
- const fd = new FormData();
303
- fd.append("image_url", url);
304
- fd.append("resolution", res);
 
305
 
306
- const r = await fetch("/remove-background", { method: "POST", body: fd });
307
- resultImg.src = URL.createObjectURL(await r.blob());
308
- };
309
- </script>
 
 
 
 
 
310
 
311
- </body>
312
- </html>
313
- """
314
 
315
  # ---------------------------------------------------------
316
- # Run App
317
  # ---------------------------------------------------------
318
  if __name__ == "__main__":
319
- import uvicorn
320
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
1
  import os
2
+ import threading
 
 
 
3
  import torch
4
+ import numpy as np
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
6
  from fastapi.responses import StreamingResponse, HTMLResponse
 
7
  from PIL import Image
8
+ from io import BytesIO
9
+ import requests
10
  from transformers import AutoModelForImageSegmentation
11
+ import uvicorn
12
 
13
  # ---------------------------------------------------------
14
+ # Optional HEIC/HEIF support
15
  # ---------------------------------------------------------
16
  try:
17
  import pillow_heif
18
  pillow_heif.register_heif_opener()
19
+ print("HEIC/HEIF supported.")
20
+ except ImportError:
21
+ print("Install pillow-heif for HEIC support.")
22
 
23
  # ---------------------------------------------------------
24
+ # Performance settings (especially for CPU on HF Spaces)
25
+ # ---------------------------------------------------------
26
+ os.environ["OMP_NUM_THREADS"] = "1"
27
+ os.environ["MKL_NUM_THREADS"] = "1"
28
+ torch.set_num_threads(1)
29
+
30
+ # ---------------------------------------------------------
31
+ # Load model
32
  # ---------------------------------------------------------
33
  MODEL_DIR = "models/BiRefNet"
34
  os.makedirs(MODEL_DIR, exist_ok=True)
 
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
38
 
39
+ print("Loading BiRefNet model...")
40
  birefnet = AutoModelForImageSegmentation.from_pretrained(
41
  "ZhengPeng7/BiRefNet",
42
  cache_dir=MODEL_DIR,
 
44
  revision="main"
45
  )
46
  birefnet.to(device, dtype=dtype).eval()
47
+ print("Model loaded.")
48
+
49
+ # Global lock to protect the model during inference
50
+ inference_lock = threading.Lock()
51
 
52
  # ---------------------------------------------------------
53
+ # FastAPI app
54
  # ---------------------------------------------------------
55
  app = FastAPI(title="Background Remover API")
56
 
 
 
 
 
 
 
 
 
57
  # ---------------------------------------------------------
58
+ # Helper functions
59
  # ---------------------------------------------------------
60
+ def load_image_from_url(url: str) -> Image.Image:
61
  try:
62
+ r = requests.get(url, timeout=10)
63
+ r.raise_for_status()
64
+ return Image.open(BytesIO(r.content)).convert("RGB")
65
  except Exception as e:
66
+ raise HTTPException(status_code=400, detail=f"Cannot load image from URL: {str(e)}")
67
+
68
 
69
+ def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor:
70
  image = image.resize((resolution, resolution))
71
+ arr = np.array(image).astype(np.float32) / 255.0
72
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
73
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
74
+
75
+ arr = (arr - mean) / std
76
+ arr = np.transpose(arr, (2, 0, 1))
77
 
78
+ tensor = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
79
+ return tensor
 
80
 
 
 
81
 
82
+ def run_inference(image: Image.Image, resolution: int = 512) -> Image.Image:
83
+ orig_size = image.size
84
+ input_tensor = transform_image(image, resolution)
85
+
86
+ with inference_lock: # prevents CPU thread crashes
87
+ with torch.no_grad():
88
+ try:
89
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
90
+ except Exception as e:
91
+ raise RuntimeError(f"Inference error: {str(e)}")
92
+
93
+ pred = preds[0, 0]
94
+ mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
95
 
96
  image = image.convert("RGBA")
97
  image.putalpha(mask)
98
  return image
99
 
100
  # ---------------------------------------------------------
101
+ # /remove-background endpoint
102
  # ---------------------------------------------------------
103
+ @app.post("/remove-background")
104
  async def remove_background(
105
  file: UploadFile = File(None),
106
  image_url: str = Form(None),
107
+ resolution: int = Form(512)
 
 
108
  ):
109
  try:
110
+ if file:
 
 
 
 
 
 
111
  data = await file.read()
112
+ image = Image.open(BytesIO(data)).convert("RGB")
 
 
 
 
113
  elif image_url:
114
+ image = load_image_from_url(image_url)
 
115
  else:
116
+ raise HTTPException(status_code=400, detail="Provide either file or image_url.")
117
 
118
+ result = run_inference(image, resolution)
119
 
120
+ buf = BytesIO()
121
+ result.save(buf, format="PNG")
122
+ buf.seek(0)
123
 
124
+ return StreamingResponse(buf, media_type="image/png")
 
 
 
 
125
 
126
  except HTTPException:
127
  raise
128
  except Exception as e:
129
+ raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
130
 
 
 
 
 
 
 
131
 
132
  # ---------------------------------------------------------
133
+ # Test UI
134
  # ---------------------------------------------------------
135
  @app.get("/", response_class=HTMLResponse)
136
  async def index():
137
+ html = """
138
+ <!DOCTYPE html>
139
+ <html>
140
+ <head>
141
+ <meta charset="UTF-8">
142
+ <title>Background Remover</title>
143
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet">
144
+ <style>
145
+ body { background: #f8f9fa; padding-top: 40px; }
146
+ .container { max-width: 700px; }
147
+ img { max-width: 100%; margin-top: 20px; border-radius: 10px; }
148
+ </style>
149
+ </head>
150
+
151
+ <body>
152
+ <div class="container text-center">
153
+ <h2 class="mb-4">Background Remover API</h2>
154
+
155
+ <form id="uploadForm" class="mb-4" enctype="multipart/form-data">
156
+ <div class="mb-3">
157
+ <input class="form-control" type="file" id="fileInput" name="file" accept="image/*">
158
+ </div>
159
+ <div class="mb-3">
160
+ <input class="form-control" type="number" id="resInput" name="resolution" value="512" min="64" max="2048">
161
+ </div>
162
+ <button class="btn btn-primary">Upload</button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  </form>
 
164
 
165
+ <div class="mb-4">OR</div>
166
 
 
 
 
167
  <form id="urlForm">
168
+ <div class="mb-3">
169
+ <input class="form-control" type="text" id="urlInput" placeholder="https://example.com/img.jpg">
170
+ </div>
171
+ <div class="mb-3">
172
+ <input class="form-control" type="number" id="urlResInput" value="512" min="64" max="2048">
173
+ </div>
174
+ <button class="btn btn-success">Use URL</button>
175
  </form>
 
176
 
177
+ <h5 class="mt-4">Result:</h5>
178
+ <img id="resultImg" />
 
179
  </div>
180
 
181
+ <script>
182
+ const uploadForm = document.getElementById("uploadForm");
183
+ const urlForm = document.getElementById("urlForm");
184
+ const resultImg = document.getElementById("resultImg");
 
 
 
 
 
 
 
 
185
 
186
+ uploadForm.addEventListener("submit", async e => {
187
+ e.preventDefault();
188
+ const file = document.getElementById("fileInput").files[0];
189
+ if (!file) return alert("Choose an image");
190
 
191
+ const res = document.getElementById("resInput").value;
192
+ const form = new FormData();
193
+ form.append("file", file);
194
+ form.append("resolution", res);
195
 
196
+ const r = await fetch("/remove-background", { method: "POST", body: form });
197
+ const blob = await r.blob();
198
+ resultImg.src = URL.createObjectURL(blob);
199
+ });
 
200
 
201
+ urlForm.addEventListener("submit", async e => {
202
+ e.preventDefault();
203
+ const url = document.getElementById("urlInput").value.trim();
204
+ if (!url) return alert("Enter URL");
205
 
206
+ const res = document.getElementById("urlResInput").value;
207
+ const form = new FormData();
208
+ form.append("image_url", url);
209
+ form.append("resolution", res);
210
 
211
+ const r = await fetch("/remove-background", { method: "POST", body: form });
212
+ const blob = await r.blob();
213
+ resultImg.src = URL.createObjectURL(blob);
214
+ });
215
+ </script>
216
+ </body>
217
+ </html>
218
+ """
219
+ return HTMLResponse(html)
220
 
 
 
 
221
 
222
  # ---------------------------------------------------------
223
+ # Run server
224
  # ---------------------------------------------------------
225
  if __name__ == "__main__":
226
+ uvicorn.run(app, host="0.0.0.0", port=7860)