mobisoft commited on
Commit
36df28d
·
verified ·
1 Parent(s): 448805e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -111
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
 
2
  import torch
3
  import numpy as np
4
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
5
- from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
6
  from PIL import Image
7
  from io import BytesIO
8
  import requests
@@ -10,84 +11,58 @@ from transformers import AutoModelForImageSegmentation
10
  import uvicorn
11
 
12
  # ---------------------------------------------------------
13
- # HEIC SUPPORT
14
  # ---------------------------------------------------------
15
  try:
16
  import pillow_heif
17
  pillow_heif.register_heif_opener()
18
- except:
19
  pass
20
 
21
  # ---------------------------------------------------------
22
- # CPU OPTIMIZATION
23
  # ---------------------------------------------------------
24
- CPU_THREADS = min(4, os.cpu_count() or 2)
25
- os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
26
- os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
27
-
28
- torch.set_num_threads(CPU_THREADS)
29
- torch.set_num_interop_threads(1)
30
 
31
  # ---------------------------------------------------------
32
- # SETTINGS
33
  # ---------------------------------------------------------
34
- TARGET_SIZE = (512, 512)
35
- MAX_SIDE = 1800
36
 
37
  # ---------------------------------------------------------
38
- # LOAD MODEL
39
  # ---------------------------------------------------------
40
  MODEL_DIR = "models/BiRefNet"
41
  os.makedirs(MODEL_DIR, exist_ok=True)
42
 
43
- print("Loading model...")
 
44
 
45
- model = AutoModelForImageSegmentation.from_pretrained(
 
46
  "ZhengPeng7/BiRefNet",
47
  cache_dir=MODEL_DIR,
48
- trust_remote_code=True
 
49
  )
50
-
51
- # ✅ CRITICAL FIX
52
- model = model.float()
53
-
54
- # ✅ channels last (CPU boost)
55
- model = model.to(memory_format=torch.channels_last)
56
-
57
- model.eval()
58
-
59
- # ---------------------------------------------------------
60
- # TORCHSCRIPT (BIG BOOST)
61
- # ---------------------------------------------------------
62
- print("Compiling model (TorchScript)...")
63
-
64
- dummy = torch.randn(1, 3, 512, 512).to(memory_format=torch.channels_last)
65
-
66
- with torch.no_grad():
67
- model = torch.jit.trace(model, dummy)
68
-
69
  print("Model ready.")
70
 
71
- # ---------------------------------------------------------
72
- # WARMUP
73
- # ---------------------------------------------------------
74
- def warmup():
75
- dummy = torch.randn(1, 3, 512, 512).to(memory_format=torch.channels_last)
76
- with torch.no_grad():
77
- _ = model(dummy)
78
-
79
- warmup()
80
 
81
  # ---------------------------------------------------------
82
- # HELPERS
83
  # ---------------------------------------------------------
84
  def load_image_from_url(url: str) -> Image.Image:
85
  try:
86
  r = requests.get(url, timeout=10)
87
  r.raise_for_status()
88
  return Image.open(BytesIO(r.content)).convert("RGB")
89
- except:
90
- raise HTTPException(400, "Invalid image URL")
91
 
92
 
93
  def auto_downscale(img: Image.Image) -> Image.Image:
@@ -96,57 +71,58 @@ def auto_downscale(img: Image.Image) -> Image.Image:
96
  return img
97
 
98
  scale = MAX_SIDE / max(w, h)
99
- return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
100
-
101
 
102
- def transform(img: Image.Image):
103
- img = img.resize(TARGET_SIZE)
104
 
105
- arr = np.asarray(img, dtype=np.float32) / 255.0
106
- arr -= np.array([0.485, 0.456, 0.406], dtype=np.float32)
107
- arr /= np.array([0.229, 0.224, 0.225], dtype=np.float32)
108
 
109
- arr = arr.transpose(2, 0, 1)
 
110
 
111
- tensor = torch.from_numpy(arr).unsqueeze(0).float()
 
 
 
 
112
 
113
- # channels last
114
- return tensor.to(memory_format=torch.channels_last)
115
 
116
 
117
  def run_inference(img: Image.Image) -> Image.Image:
118
  orig_size = img.size
119
-
120
  tensor = transform(img)
121
 
122
- with torch.no_grad():
123
- pred = model(tensor)[-1].sigmoid()[0, 0].cpu().numpy()
 
124
 
125
- mask = Image.fromarray((pred * 255).astype(np.uint8)).resize(orig_size)
126
 
127
  img = img.convert("RGBA")
128
  img.putalpha(mask)
129
-
130
  return img
131
 
132
 
133
  # ---------------------------------------------------------
134
- # FASTAPI
135
  # ---------------------------------------------------------
136
- app = FastAPI(title="Fast Background Remover")
137
 
138
  # ---------------------------------------------------------
139
- # GET redirect
140
  # ---------------------------------------------------------
141
  @app.get("/remove-background")
142
- def redirect():
143
  return JSONResponse(
144
- {"detail": "Use POST /remove-background"},
145
  status_code=405
146
  )
147
 
148
  # ---------------------------------------------------------
149
- # MAIN ENDPOINT (UNCHANGED)
150
  # ---------------------------------------------------------
151
  @app.post("/remove-background")
152
  async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
@@ -154,104 +130,102 @@ async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
154
  if file:
155
  raw = await file.read()
156
  img = Image.open(BytesIO(raw)).convert("RGB")
157
-
158
  elif image_url:
159
  img = load_image_from_url(image_url)
160
-
161
  else:
162
- raise HTTPException(400, "Provide file or image_url")
163
 
164
  img = auto_downscale(img)
165
-
166
  result = run_inference(img)
167
 
168
  buf = BytesIO()
169
- result.save(buf, format="PNG", optimize=True)
170
  buf.seek(0)
171
 
172
  return StreamingResponse(buf, media_type="image/png")
173
 
174
  except Exception as e:
175
- raise HTTPException(500, str(e))
176
 
177
 
178
  # ---------------------------------------------------------
179
- # UI (UNCHANGED BUT CLEAN)
180
  # ---------------------------------------------------------
181
  @app.get("/", response_class=HTMLResponse)
182
- def ui():
183
  return """
184
  <html>
185
  <head>
186
- <title>Background Remover</title>
187
  <link rel='stylesheet'
188
- href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
189
  </head>
190
  <body class='bg-light'>
191
  <div class='container py-4 text-center'>
192
 
193
- <h2>Background Remover</h2>
194
 
195
  <div class='row'>
196
  <div class='col-md-6'>
197
- <h5>Input</h5>
198
- <img id='inputImg' style='max-width:100%'>
199
  </div>
200
  <div class='col-md-6'>
201
- <h5>Output</h5>
202
- <img id='outputImg' style='max-width:100%'>
203
  </div>
204
  </div>
205
 
206
  <hr>
207
 
208
- <form id="uploadForm">
 
209
  <input type='file' id='fileInput' class='form-control mb-3'>
210
- <button class='btn btn-primary'>Upload</button>
211
  </form>
212
 
213
  <hr>
214
 
 
215
  <form id='urlForm'>
216
- <input id='urlInput' class='form-control mb-3' placeholder='Image URL'>
217
- <button class='btn btn-success'>Send URL</button>
218
  </form>
219
-
220
  </div>
221
 
222
  <script>
223
  const inputImg = document.getElementById("inputImg");
224
  const outputImg = document.getElementById("outputImg");
225
 
226
- async function send(fd){
227
- const r = await fetch("/remove-background", {
228
- method:"POST",
229
- body:fd
230
- });
231
-
232
- const blob = await r.blob();
233
- outputImg.src = URL.createObjectURL(blob);
234
- }
235
-
236
- document.getElementById("uploadForm").onsubmit = async e=>{
237
  e.preventDefault();
238
- const file = fileInput.files[0];
 
 
239
  inputImg.src = URL.createObjectURL(file);
240
 
241
  const fd = new FormData();
242
  fd.append("file", file);
243
- send(fd);
244
- };
245
 
246
- document.getElementById("urlForm").onsubmit = async e=>{
 
 
 
 
 
247
  e.preventDefault();
248
- const url = urlInput.value;
 
 
249
  inputImg.src = url;
250
 
251
  const fd = new FormData();
252
  fd.append("image_url", url);
253
- send(fd);
254
- };
 
 
255
  </script>
256
 
257
  </body>
@@ -259,7 +233,7 @@ def ui():
259
  """
260
 
261
  # ---------------------------------------------------------
262
- # RUN
263
  # ---------------------------------------------------------
264
  if __name__ == "__main__":
265
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
+ import threading
3
  import torch
4
  import numpy as np
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
6
+ from fastapi.responses import StreamingResponse, HTMLResponse, RedirectResponse, JSONResponse
7
  from PIL import Image
8
  from io import BytesIO
9
  import requests
 
11
  import uvicorn
12
 
13
  # ---------------------------------------------------------
14
+ # Optional HEIC/HEIF
15
  # ---------------------------------------------------------
16
  try:
17
  import pillow_heif
18
  pillow_heif.register_heif_opener()
19
+ except ImportError:
20
  pass
21
 
22
  # ---------------------------------------------------------
23
+ # Performance settings for HF CPU
24
  # ---------------------------------------------------------
25
+ os.environ["OMP_NUM_THREADS"] = "1"
26
+ os.environ["MKL_NUM_THREADS"] = "1"
27
+ torch.set_num_threads(1)
 
 
 
28
 
29
  # ---------------------------------------------------------
30
+ # Constants
31
  # ---------------------------------------------------------
32
+ TARGET_SIZE = (512, 512) # Faster inference
33
+ MAX_SIDE = 3000 # Auto-downscale for huge uploads
34
 
35
  # ---------------------------------------------------------
36
+ # Load model
37
  # ---------------------------------------------------------
38
  MODEL_DIR = "models/BiRefNet"
39
  os.makedirs(MODEL_DIR, exist_ok=True)
40
 
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
43
 
44
+ print("Loading BiRefNet…")
45
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
46
  "ZhengPeng7/BiRefNet",
47
  cache_dir=MODEL_DIR,
48
+ trust_remote_code=True,
49
+ revision="main",
50
  )
51
+ birefnet.to(device, dtype=dtype).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  print("Model ready.")
53
 
54
+ lock = threading.Lock()
 
 
 
 
 
 
 
 
55
 
56
  # ---------------------------------------------------------
57
+ # Helpers
58
  # ---------------------------------------------------------
59
  def load_image_from_url(url: str) -> Image.Image:
60
  try:
61
  r = requests.get(url, timeout=10)
62
  r.raise_for_status()
63
  return Image.open(BytesIO(r.content)).convert("RGB")
64
+ except Exception:
65
+ raise HTTPException(status_code=400, detail="Invalid image URL")
66
 
67
 
68
  def auto_downscale(img: Image.Image) -> Image.Image:
 
71
  return img
72
 
73
  scale = MAX_SIDE / max(w, h)
74
+ new_w = int(w * scale)
75
+ new_h = int(h * scale)
76
 
77
+ print(f"[INFO] Downscaling {w}×{h} → {new_w}×{new_h}")
78
+ return img.resize((new_w, new_h), Image.LANCZOS)
79
 
 
 
 
80
 
81
+ def transform(img: Image.Image) -> torch.Tensor:
82
+ img = img.resize(TARGET_SIZE)
83
 
84
+ arr = np.array(img).astype(np.float32) / 255.0
85
+ mean = np.array([0.485, 0.456, 0.406])
86
+ std = np.array([0.229, 0.224, 0.225])
87
+ arr = (arr - mean) / std
88
+ arr = np.transpose(arr, (2, 0, 1))
89
 
90
+ t = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
91
+ return t
92
 
93
 
94
  def run_inference(img: Image.Image) -> Image.Image:
95
  orig_size = img.size
 
96
  tensor = transform(img)
97
 
98
+ with lock:
99
+ with torch.no_grad():
100
+ pred = birefnet(tensor)[-1].sigmoid().cpu()[0, 0]
101
 
102
+ mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
103
 
104
  img = img.convert("RGBA")
105
  img.putalpha(mask)
 
106
  return img
107
 
108
 
109
  # ---------------------------------------------------------
110
+ # FastAPI app
111
  # ---------------------------------------------------------
112
+ app = FastAPI(title="Background Remover API")
113
 
114
  # ---------------------------------------------------------
115
+ # Redirect GET → POST logic
116
  # ---------------------------------------------------------
117
  @app.get("/remove-background")
118
+ async def redirect_to_post():
119
  return JSONResponse(
120
+ {"detail": "This endpoint only supports POST. Use POST /remove-background"},
121
  status_code=405
122
  )
123
 
124
  # ---------------------------------------------------------
125
+ # Main POST endpoint
126
  # ---------------------------------------------------------
127
  @app.post("/remove-background")
128
  async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
 
130
  if file:
131
  raw = await file.read()
132
  img = Image.open(BytesIO(raw)).convert("RGB")
 
133
  elif image_url:
134
  img = load_image_from_url(image_url)
 
135
  else:
136
+ raise HTTPException(status_code=400, detail="Upload file or image_url required")
137
 
138
  img = auto_downscale(img)
 
139
  result = run_inference(img)
140
 
141
  buf = BytesIO()
142
+ result.save(buf, format="PNG")
143
  buf.seek(0)
144
 
145
  return StreamingResponse(buf, media_type="image/png")
146
 
147
  except Exception as e:
148
+ raise HTTPException(status_code=500, detail=str(e))
149
 
150
 
151
  # ---------------------------------------------------------
152
+ # UI: Show INPUT + OUTPUT (big preview)
153
  # ---------------------------------------------------------
154
  @app.get("/", response_class=HTMLResponse)
155
+ async def ui():
156
  return """
157
  <html>
158
  <head>
159
+ <title>Background Remover – Test UI</title>
160
  <link rel='stylesheet'
161
+ href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
162
  </head>
163
  <body class='bg-light'>
164
  <div class='container py-4 text-center'>
165
 
166
+ <h2 class='mb-4'>API Test Panel (POST Only)</h2>
167
 
168
  <div class='row'>
169
  <div class='col-md-6'>
170
+ <h5>Input Image</h5>
171
+ <img id='inputImg' style='max-width:100%; border-radius:10px;'>
172
  </div>
173
  <div class='col-md-6'>
174
+ <h5>Output Image</h5>
175
+ <img id='outputImg' style='max-width:100%; border-radius:10px;'>
176
  </div>
177
  </div>
178
 
179
  <hr>
180
 
181
+ <h4>Upload Test</h4>
182
+ <form id="uploadForm" enctype='multipart/form-data'>
183
  <input type='file' id='fileInput' class='form-control mb-3'>
184
+ <button class='btn btn-primary'>Send POST</button>
185
  </form>
186
 
187
  <hr>
188
 
189
+ <h4>URL Test</h4>
190
  <form id='urlForm'>
191
+ <input id='urlInput' class='form-control mb-3' placeholder='https://example.com/image.jpg'>
192
+ <button class='btn btn-success'>Send POST</button>
193
  </form>
 
194
  </div>
195
 
196
  <script>
197
  const inputImg = document.getElementById("inputImg");
198
  const outputImg = document.getElementById("outputImg");
199
 
200
+ // FILE TEST
201
+ document.getElementById("uploadForm").addEventListener("submit", async e => {
 
 
 
 
 
 
 
 
 
202
  e.preventDefault();
203
+ const file = document.getElementById("fileInput").files[0];
204
+ if (!file) return alert("Select a file first.");
205
+
206
  inputImg.src = URL.createObjectURL(file);
207
 
208
  const fd = new FormData();
209
  fd.append("file", file);
 
 
210
 
211
+ const r = await fetch("/remove-background", { method:"POST", body:fd });
212
+ outputImg.src = URL.createObjectURL(await r.blob());
213
+ });
214
+
215
+ // URL TEST
216
+ document.getElementById("urlForm").addEventListener("submit", async e => {
217
  e.preventDefault();
218
+ const url = document.getElementById("urlInput").value.trim();
219
+ if (!url) return alert("Enter an image URL first.");
220
+
221
  inputImg.src = url;
222
 
223
  const fd = new FormData();
224
  fd.append("image_url", url);
225
+
226
+ const r = await fetch("/remove-background", { method:"POST", body:fd });
227
+ outputImg.src = URL.createObjectURL(await r.blob());
228
+ });
229
  </script>
230
 
231
  </body>
 
233
  """
234
 
235
  # ---------------------------------------------------------
236
+ # Run app
237
  # ---------------------------------------------------------
238
  if __name__ == "__main__":
239
+ uvicorn.run(app, host="0.0.0.0", port=7860)