mobisoft commited on
Commit
777f1e3
·
verified ·
1 Parent(s): 49f1016

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -70
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import os
2
- import threading
3
  import torch
4
  import numpy as np
5
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
6
- from fastapi.responses import StreamingResponse, HTMLResponse, RedirectResponse, JSONResponse
7
  from PIL import Image
8
  from io import BytesIO
9
  import requests
@@ -11,7 +10,7 @@ from transformers import AutoModelForImageSegmentation
11
  import uvicorn
12
 
13
  # ---------------------------------------------------------
14
- # Optional HEIC/HEIF
15
  # ---------------------------------------------------------
16
  try:
17
  import pillow_heif
@@ -20,41 +19,47 @@ except ImportError:
20
  pass
21
 
22
  # ---------------------------------------------------------
23
- # Performance settings for HF CPU
24
  # ---------------------------------------------------------
25
- os.environ["OMP_NUM_THREADS"] = "1"
26
- os.environ["MKL_NUM_THREADS"] = "1"
27
- torch.set_num_threads(1)
 
28
 
29
  # ---------------------------------------------------------
30
- # Constants
31
  # ---------------------------------------------------------
32
- TARGET_SIZE = (512, 512) # Faster inference
33
- MAX_SIDE = 3000 # Auto-downscale for huge uploads
34
 
35
  # ---------------------------------------------------------
36
- # Load model
37
  # ---------------------------------------------------------
38
  MODEL_DIR = "models/BiRefNet"
39
  os.makedirs(MODEL_DIR, exist_ok=True)
40
 
41
- device = "cuda" if torch.cuda.is_available() else "cpu"
42
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
43
-
44
- print("Loading BiRefNet…")
45
- birefnet = AutoModelForImageSegmentation.from_pretrained(
46
  "ZhengPeng7/BiRefNet",
47
  cache_dir=MODEL_DIR,
48
- trust_remote_code=True,
49
- revision="main",
50
  )
51
- birefnet.to(device, dtype=dtype).eval()
52
  print("Model ready.")
53
 
54
- lock = threading.Lock()
 
 
 
 
 
 
 
 
 
55
 
56
  # ---------------------------------------------------------
57
- # Helpers
58
  # ---------------------------------------------------------
59
  def load_image_from_url(url: str) -> Image.Image:
60
  try:
@@ -62,7 +67,7 @@ def load_image_from_url(url: str) -> Image.Image:
62
  r.raise_for_status()
63
  return Image.open(BytesIO(r.content)).convert("RGB")
64
  except Exception:
65
- raise HTTPException(status_code=400, detail="Invalid image URL")
66
 
67
 
68
  def auto_downscale(img: Image.Image) -> Image.Image:
@@ -71,35 +76,27 @@ def auto_downscale(img: Image.Image) -> Image.Image:
71
  return img
72
 
73
  scale = MAX_SIDE / max(w, h)
74
- new_w = int(w * scale)
75
- new_h = int(h * scale)
76
-
77
- print(f"[INFO] Downscaling {w}×{h} → {new_w}×{new_h}")
78
- return img.resize((new_w, new_h), Image.LANCZOS)
79
 
80
 
81
- def transform(img: Image.Image) -> torch.Tensor:
82
  img = img.resize(TARGET_SIZE)
83
 
84
- arr = np.array(img).astype(np.float32) / 255.0
85
- mean = np.array([0.485, 0.456, 0.406])
86
- std = np.array([0.229, 0.224, 0.225])
87
- arr = (arr - mean) / std
88
- arr = np.transpose(arr, (2, 0, 1))
89
 
90
- t = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
91
- return t
92
 
93
 
94
  def run_inference(img: Image.Image) -> Image.Image:
95
  orig_size = img.size
96
  tensor = transform(img)
97
 
98
- with lock:
99
- with torch.no_grad():
100
- pred = birefnet(tensor)[-1].sigmoid().cpu()[0, 0]
101
 
102
- mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
103
 
104
  img = img.convert("RGBA")
105
  img.putalpha(mask)
@@ -107,22 +104,22 @@ def run_inference(img: Image.Image) -> Image.Image:
107
 
108
 
109
  # ---------------------------------------------------------
110
- # FastAPI app
111
  # ---------------------------------------------------------
112
  app = FastAPI(title="Background Remover API")
113
 
114
  # ---------------------------------------------------------
115
- # Redirect GET → POST logic
116
  # ---------------------------------------------------------
117
  @app.get("/remove-background")
118
- async def redirect_to_post():
119
  return JSONResponse(
120
- {"detail": "This endpoint only supports POST. Use POST /remove-background"},
121
  status_code=405
122
  )
123
 
124
  # ---------------------------------------------------------
125
- # Main POST endpoint
126
  # ---------------------------------------------------------
127
  @app.post("/remove-background")
128
  async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
@@ -130,101 +127,115 @@ async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
130
  if file:
131
  raw = await file.read()
132
  img = Image.open(BytesIO(raw)).convert("RGB")
 
133
  elif image_url:
134
  img = load_image_from_url(image_url)
 
135
  else:
136
- raise HTTPException(status_code=400, detail="Upload file or image_url required")
137
 
138
  img = auto_downscale(img)
139
  result = run_inference(img)
140
 
141
  buf = BytesIO()
142
- result.save(buf, format="PNG")
143
  buf.seek(0)
144
 
145
  return StreamingResponse(buf, media_type="image/png")
146
 
147
  except Exception as e:
148
- raise HTTPException(status_code=500, detail=str(e))
149
 
150
 
151
  # ---------------------------------------------------------
152
- # UI: Show INPUT + OUTPUT (big preview)
153
  # ---------------------------------------------------------
154
  @app.get("/", response_class=HTMLResponse)
155
- async def ui():
156
  return """
157
  <html>
158
  <head>
159
- <title>Background Remover – Test UI</title>
160
  <link rel='stylesheet'
161
  href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
162
  </head>
163
  <body class='bg-light'>
164
  <div class='container py-4 text-center'>
165
 
166
- <h2 class='mb-4'>API Test Panel (POST Only)</h2>
167
 
168
  <div class='row'>
169
  <div class='col-md-6'>
170
- <h5>Input Image</h5>
171
  <img id='inputImg' style='max-width:100%; border-radius:10px;'>
172
  </div>
173
  <div class='col-md-6'>
174
- <h5>Output Image</h5>
175
  <img id='outputImg' style='max-width:100%; border-radius:10px;'>
176
  </div>
177
  </div>
178
 
179
  <hr>
180
 
181
- <h4>Upload Test</h4>
182
- <form id="uploadForm" enctype='multipart/form-data'>
183
  <input type='file' id='fileInput' class='form-control mb-3'>
184
- <button class='btn btn-primary'>Send POST</button>
185
  </form>
186
 
187
  <hr>
188
 
189
- <h4>URL Test</h4>
190
  <form id='urlForm'>
191
- <input id='urlInput' class='form-control mb-3' placeholder='https://example.com/image.jpg'>
192
- <button class='btn btn-success'>Send POST</button>
193
  </form>
 
194
  </div>
195
 
196
  <script>
197
  const inputImg = document.getElementById("inputImg");
198
  const outputImg = document.getElementById("outputImg");
199
 
200
- // FILE TEST
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  document.getElementById("uploadForm").addEventListener("submit", async e => {
202
  e.preventDefault();
203
  const file = document.getElementById("fileInput").files[0];
204
- if (!file) return alert("Select a file first.");
205
 
206
  inputImg.src = URL.createObjectURL(file);
207
 
208
  const fd = new FormData();
209
  fd.append("file", file);
210
 
211
- const r = await fetch("/remove-background", { method:"POST", body:fd });
212
- outputImg.src = URL.createObjectURL(await r.blob());
213
  });
214
 
215
- // URL TEST
216
  document.getElementById("urlForm").addEventListener("submit", async e => {
217
  e.preventDefault();
218
  const url = document.getElementById("urlInput").value.trim();
219
- if (!url) return alert("Enter an image URL first.");
220
 
221
  inputImg.src = url;
222
 
223
  const fd = new FormData();
224
  fd.append("image_url", url);
225
 
226
- const r = await fetch("/remove-background", { method:"POST", body:fd });
227
- outputImg.src = URL.createObjectURL(await r.blob());
228
  });
229
  </script>
230
 
@@ -232,8 +243,9 @@ async def ui():
232
  </html>
233
  """
234
 
 
235
  # ---------------------------------------------------------
236
- # Run app
237
  # ---------------------------------------------------------
238
  if __name__ == "__main__":
239
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
 
2
  import torch
3
  import numpy as np
4
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
5
+ from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
6
  from PIL import Image
7
  from io import BytesIO
8
  import requests
 
10
  import uvicorn
11
 
12
  # ---------------------------------------------------------
13
+ # HEIC/HEIF SUPPORT
14
  # ---------------------------------------------------------
15
  try:
16
  import pillow_heif
 
19
  pass
20
 
21
  # ---------------------------------------------------------
22
+ # CPU PERFORMANCE (FIXED)
23
  # ---------------------------------------------------------
24
+ CPU_THREADS = min(4, os.cpu_count() or 2)
25
+ os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
26
+ os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
27
+ torch.set_num_threads(CPU_THREADS)
28
 
29
  # ---------------------------------------------------------
30
+ # SETTINGS
31
  # ---------------------------------------------------------
32
+ TARGET_SIZE = (512, 512)
33
+ MAX_SIDE = 2000
34
 
35
  # ---------------------------------------------------------
36
+ # LOAD MODEL
37
  # ---------------------------------------------------------
38
  MODEL_DIR = "models/BiRefNet"
39
  os.makedirs(MODEL_DIR, exist_ok=True)
40
 
41
+ print("Loading model...")
42
+ model = AutoModelForImageSegmentation.from_pretrained(
 
 
 
43
  "ZhengPeng7/BiRefNet",
44
  cache_dir=MODEL_DIR,
45
+ trust_remote_code=True
 
46
  )
47
+ model.eval()
48
  print("Model ready.")
49
 
50
+ # ---------------------------------------------------------
51
+ # WARMUP
52
+ # ---------------------------------------------------------
53
+ def warmup():
54
+ dummy = torch.randn(1, 3, 512, 512)
55
+ with torch.no_grad():
56
+ _ = model(dummy)
57
+
58
+ warmup()
59
+ print("Warmup done.")
60
 
61
  # ---------------------------------------------------------
62
+ # HELPERS
63
  # ---------------------------------------------------------
64
  def load_image_from_url(url: str) -> Image.Image:
65
  try:
 
67
  r.raise_for_status()
68
  return Image.open(BytesIO(r.content)).convert("RGB")
69
  except Exception:
70
+ raise HTTPException(400, "Invalid image URL")
71
 
72
 
73
  def auto_downscale(img: Image.Image) -> Image.Image:
 
76
  return img
77
 
78
  scale = MAX_SIDE / max(w, h)
79
+ return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
 
 
 
 
80
 
81
 
82
+ def transform(img: Image.Image):
83
  img = img.resize(TARGET_SIZE)
84
 
85
+ arr = np.asarray(img, dtype=np.float32) / 255.0
86
+ arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
87
+ arr = arr.transpose(2, 0, 1)
 
 
88
 
89
+ return torch.from_numpy(arr).unsqueeze(0)
 
90
 
91
 
92
  def run_inference(img: Image.Image) -> Image.Image:
93
  orig_size = img.size
94
  tensor = transform(img)
95
 
96
+ with torch.no_grad():
97
+ pred = model(tensor)[-1].sigmoid()[0, 0].cpu().numpy()
 
98
 
99
+ mask = Image.fromarray((pred * 255).astype(np.uint8)).resize(orig_size)
100
 
101
  img = img.convert("RGBA")
102
  img.putalpha(mask)
 
104
 
105
 
106
  # ---------------------------------------------------------
107
+ # FASTAPI
108
  # ---------------------------------------------------------
109
  app = FastAPI(title="Background Remover API")
110
 
111
  # ---------------------------------------------------------
112
+ # GET redirect
113
  # ---------------------------------------------------------
114
  @app.get("/remove-background")
115
+ def redirect():
116
  return JSONResponse(
117
+ {"detail": "Use POST /remove-background"},
118
  status_code=405
119
  )
120
 
121
  # ---------------------------------------------------------
122
+ # MAIN ENDPOINT (UNCHANGED)
123
  # ---------------------------------------------------------
124
  @app.post("/remove-background")
125
  async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
 
127
  if file:
128
  raw = await file.read()
129
  img = Image.open(BytesIO(raw)).convert("RGB")
130
+
131
  elif image_url:
132
  img = load_image_from_url(image_url)
133
+
134
  else:
135
+ raise HTTPException(400, "Provide file or image_url")
136
 
137
  img = auto_downscale(img)
138
  result = run_inference(img)
139
 
140
  buf = BytesIO()
141
+ result.save(buf, format="PNG", optimize=True)
142
  buf.seek(0)
143
 
144
  return StreamingResponse(buf, media_type="image/png")
145
 
146
  except Exception as e:
147
+ raise HTTPException(500, str(e))
148
 
149
 
150
  # ---------------------------------------------------------
151
+ # UI (IMPROVED BUT SAME LOGIC)
152
  # ---------------------------------------------------------
153
  @app.get("/", response_class=HTMLResponse)
154
+ def ui():
155
  return """
156
  <html>
157
  <head>
158
+ <title>Background Remover</title>
159
  <link rel='stylesheet'
160
  href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
161
  </head>
162
  <body class='bg-light'>
163
  <div class='container py-4 text-center'>
164
 
165
+ <h2 class='mb-4'>Background Remover</h2>
166
 
167
  <div class='row'>
168
  <div class='col-md-6'>
169
+ <h5>Input</h5>
170
  <img id='inputImg' style='max-width:100%; border-radius:10px;'>
171
  </div>
172
  <div class='col-md-6'>
173
+ <h5>Output</h5>
174
  <img id='outputImg' style='max-width:100%; border-radius:10px;'>
175
  </div>
176
  </div>
177
 
178
  <hr>
179
 
180
+ <h4>Upload Image</h4>
181
+ <form id="uploadForm">
182
  <input type='file' id='fileInput' class='form-control mb-3'>
183
+ <button class='btn btn-primary'>Remove Background</button>
184
  </form>
185
 
186
  <hr>
187
 
188
+ <h4>Image URL</h4>
189
  <form id='urlForm'>
190
+ <input id='urlInput' class='form-control mb-3' placeholder='https://image.jpg'>
191
+ <button class='btn btn-success'>Remove Background</button>
192
  </form>
193
+
194
  </div>
195
 
196
  <script>
197
  const inputImg = document.getElementById("inputImg");
198
  const outputImg = document.getElementById("outputImg");
199
 
200
+ async function sendRequest(formData) {
201
+ const res = await fetch("/remove-background", {
202
+ method: "POST",
203
+ body: formData
204
+ });
205
+
206
+ if (!res.ok) {
207
+ alert("Error processing image");
208
+ return;
209
+ }
210
+
211
+ const blob = await res.blob();
212
+ outputImg.src = URL.createObjectURL(blob);
213
+ }
214
+
215
  document.getElementById("uploadForm").addEventListener("submit", async e => {
216
  e.preventDefault();
217
  const file = document.getElementById("fileInput").files[0];
218
+ if (!file) return alert("Select file");
219
 
220
  inputImg.src = URL.createObjectURL(file);
221
 
222
  const fd = new FormData();
223
  fd.append("file", file);
224
 
225
+ sendRequest(fd);
 
226
  });
227
 
 
228
  document.getElementById("urlForm").addEventListener("submit", async e => {
229
  e.preventDefault();
230
  const url = document.getElementById("urlInput").value.trim();
231
+ if (!url) return alert("Enter URL");
232
 
233
  inputImg.src = url;
234
 
235
  const fd = new FormData();
236
  fd.append("image_url", url);
237
 
238
+ sendRequest(fd);
 
239
  });
240
  </script>
241
 
 
243
  </html>
244
  """
245
 
246
+
247
  # ---------------------------------------------------------
248
+ # RUN
249
  # ---------------------------------------------------------
250
  if __name__ == "__main__":
251
+ uvicorn.run(app, host="0.0.0.0", port=7860)