mobisoft commited on
Commit
7160c3c
·
verified ·
1 Parent(s): 777f1e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -45
app.py CHANGED
@@ -10,27 +10,29 @@ from transformers import AutoModelForImageSegmentation
10
  import uvicorn
11
 
12
  # ---------------------------------------------------------
13
- # HEIC/HEIF SUPPORT
14
  # ---------------------------------------------------------
15
  try:
16
  import pillow_heif
17
  pillow_heif.register_heif_opener()
18
- except ImportError:
19
  pass
20
 
21
  # ---------------------------------------------------------
22
- # CPU PERFORMANCE (FIXED)
23
  # ---------------------------------------------------------
24
  CPU_THREADS = min(4, os.cpu_count() or 2)
25
  os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
26
  os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
 
27
  torch.set_num_threads(CPU_THREADS)
 
28
 
29
  # ---------------------------------------------------------
30
  # SETTINGS
31
  # ---------------------------------------------------------
32
  TARGET_SIZE = (512, 512)
33
- MAX_SIDE = 2000
34
 
35
  # ---------------------------------------------------------
36
  # LOAD MODEL
@@ -39,24 +41,42 @@ MODEL_DIR = "models/BiRefNet"
39
  os.makedirs(MODEL_DIR, exist_ok=True)
40
 
41
  print("Loading model...")
 
42
  model = AutoModelForImageSegmentation.from_pretrained(
43
  "ZhengPeng7/BiRefNet",
44
  cache_dir=MODEL_DIR,
45
  trust_remote_code=True
46
  )
 
 
 
 
 
 
 
47
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
48
  print("Model ready.")
49
 
50
  # ---------------------------------------------------------
51
  # WARMUP
52
  # ---------------------------------------------------------
53
  def warmup():
54
- dummy = torch.randn(1, 3, 512, 512)
55
  with torch.no_grad():
56
  _ = model(dummy)
57
 
58
  warmup()
59
- print("Warmup done.")
60
 
61
  # ---------------------------------------------------------
62
  # HELPERS
@@ -66,7 +86,7 @@ def load_image_from_url(url: str) -> Image.Image:
66
  r = requests.get(url, timeout=10)
67
  r.raise_for_status()
68
  return Image.open(BytesIO(r.content)).convert("RGB")
69
- except Exception:
70
  raise HTTPException(400, "Invalid image URL")
71
 
72
 
@@ -83,14 +103,20 @@ def transform(img: Image.Image):
83
  img = img.resize(TARGET_SIZE)
84
 
85
  arr = np.asarray(img, dtype=np.float32) / 255.0
86
- arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
 
 
87
  arr = arr.transpose(2, 0, 1)
88
 
89
- return torch.from_numpy(arr).unsqueeze(0)
 
 
 
90
 
91
 
92
  def run_inference(img: Image.Image) -> Image.Image:
93
  orig_size = img.size
 
94
  tensor = transform(img)
95
 
96
  with torch.no_grad():
@@ -100,13 +126,14 @@ def run_inference(img: Image.Image) -> Image.Image:
100
 
101
  img = img.convert("RGBA")
102
  img.putalpha(mask)
 
103
  return img
104
 
105
 
106
  # ---------------------------------------------------------
107
  # FASTAPI
108
  # ---------------------------------------------------------
109
- app = FastAPI(title="Background Remover API")
110
 
111
  # ---------------------------------------------------------
112
  # GET redirect
@@ -135,6 +162,7 @@ async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
135
  raise HTTPException(400, "Provide file or image_url")
136
 
137
  img = auto_downscale(img)
 
138
  result = run_inference(img)
139
 
140
  buf = BytesIO()
@@ -148,7 +176,7 @@ async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
148
 
149
 
150
  # ---------------------------------------------------------
151
- # UI (IMPROVED BUT SAME LOGIC)
152
  # ---------------------------------------------------------
153
  @app.get("/", response_class=HTMLResponse)
154
  def ui():
@@ -157,38 +185,36 @@ def ui():
157
  <head>
158
  <title>Background Remover</title>
159
  <link rel='stylesheet'
160
- href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
161
  </head>
162
  <body class='bg-light'>
163
  <div class='container py-4 text-center'>
164
 
165
- <h2 class='mb-4'>Background Remover</h2>
166
 
167
  <div class='row'>
168
  <div class='col-md-6'>
169
  <h5>Input</h5>
170
- <img id='inputImg' style='max-width:100%; border-radius:10px;'>
171
  </div>
172
  <div class='col-md-6'>
173
  <h5>Output</h5>
174
- <img id='outputImg' style='max-width:100%; border-radius:10px;'>
175
  </div>
176
  </div>
177
 
178
  <hr>
179
 
180
- <h4>Upload Image</h4>
181
  <form id="uploadForm">
182
  <input type='file' id='fileInput' class='form-control mb-3'>
183
- <button class='btn btn-primary'>Remove Background</button>
184
  </form>
185
 
186
  <hr>
187
 
188
- <h4>Image URL</h4>
189
  <form id='urlForm'>
190
- <input id='urlInput' class='form-control mb-3' placeholder='https://image.jpg'>
191
- <button class='btn btn-success'>Remove Background</button>
192
  </form>
193
 
194
  </div>
@@ -197,53 +223,41 @@ def ui():
197
  const inputImg = document.getElementById("inputImg");
198
  const outputImg = document.getElementById("outputImg");
199
 
200
- async function sendRequest(formData) {
201
- const res = await fetch("/remove-background", {
202
- method: "POST",
203
- body: formData
204
  });
205
 
206
- if (!res.ok) {
207
- alert("Error processing image");
208
- return;
209
- }
210
-
211
- const blob = await res.blob();
212
  outputImg.src = URL.createObjectURL(blob);
213
  }
214
 
215
- document.getElementById("uploadForm").addEventListener("submit", async e => {
216
  e.preventDefault();
217
- const file = document.getElementById("fileInput").files[0];
218
- if (!file) return alert("Select file");
219
-
220
  inputImg.src = URL.createObjectURL(file);
221
 
222
  const fd = new FormData();
223
  fd.append("file", file);
 
 
224
 
225
- sendRequest(fd);
226
- });
227
-
228
- document.getElementById("urlForm").addEventListener("submit", async e => {
229
  e.preventDefault();
230
- const url = document.getElementById("urlInput").value.trim();
231
- if (!url) return alert("Enter URL");
232
-
233
  inputImg.src = url;
234
 
235
  const fd = new FormData();
236
  fd.append("image_url", url);
237
-
238
- sendRequest(fd);
239
- });
240
  </script>
241
 
242
  </body>
243
  </html>
244
  """
245
 
246
-
247
  # ---------------------------------------------------------
248
  # RUN
249
  # ---------------------------------------------------------
 
10
  import uvicorn
11
 
12
  # ---------------------------------------------------------
13
+ # HEIC SUPPORT
14
  # ---------------------------------------------------------
15
  try:
16
  import pillow_heif
17
  pillow_heif.register_heif_opener()
18
+ except:
19
  pass
20
 
21
  # ---------------------------------------------------------
22
+ # CPU OPTIMIZATION
23
  # ---------------------------------------------------------
24
  CPU_THREADS = min(4, os.cpu_count() or 2)
25
  os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
26
  os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
27
+
28
  torch.set_num_threads(CPU_THREADS)
29
+ torch.set_num_interop_threads(1)
30
 
31
  # ---------------------------------------------------------
32
  # SETTINGS
33
  # ---------------------------------------------------------
34
  TARGET_SIZE = (512, 512)
35
+ MAX_SIDE = 1800
36
 
37
  # ---------------------------------------------------------
38
  # LOAD MODEL
 
41
  os.makedirs(MODEL_DIR, exist_ok=True)
42
 
43
  print("Loading model...")
44
+
45
  model = AutoModelForImageSegmentation.from_pretrained(
46
  "ZhengPeng7/BiRefNet",
47
  cache_dir=MODEL_DIR,
48
  trust_remote_code=True
49
  )
50
+
51
+ # ✅ CRITICAL FIX
52
+ model = model.float()
53
+
54
+ # ✅ channels last (CPU boost)
55
+ model = model.to(memory_format=torch.channels_last)
56
+
57
  model.eval()
58
+
59
+ # ---------------------------------------------------------
60
+ # TORCHSCRIPT (BIG BOOST)
61
+ # ---------------------------------------------------------
62
+ print("Compiling model (TorchScript)...")
63
+
64
+ dummy = torch.randn(1, 3, 512, 512).to(memory_format=torch.channels_last)
65
+
66
+ with torch.no_grad():
67
+ model = torch.jit.trace(model, dummy)
68
+
69
  print("Model ready.")
70
 
71
  # ---------------------------------------------------------
72
  # WARMUP
73
  # ---------------------------------------------------------
74
  def warmup():
75
+ dummy = torch.randn(1, 3, 512, 512).to(memory_format=torch.channels_last)
76
  with torch.no_grad():
77
  _ = model(dummy)
78
 
79
  warmup()
 
80
 
81
  # ---------------------------------------------------------
82
  # HELPERS
 
86
  r = requests.get(url, timeout=10)
87
  r.raise_for_status()
88
  return Image.open(BytesIO(r.content)).convert("RGB")
89
+ except:
90
  raise HTTPException(400, "Invalid image URL")
91
 
92
 
 
103
  img = img.resize(TARGET_SIZE)
104
 
105
  arr = np.asarray(img, dtype=np.float32) / 255.0
106
+ arr -= np.array([0.485, 0.456, 0.406], dtype=np.float32)
107
+ arr /= np.array([0.229, 0.224, 0.225], dtype=np.float32)
108
+
109
  arr = arr.transpose(2, 0, 1)
110
 
111
+ tensor = torch.from_numpy(arr).unsqueeze(0).float()
112
+
113
+ # ✅ channels last
114
+ return tensor.to(memory_format=torch.channels_last)
115
 
116
 
117
  def run_inference(img: Image.Image) -> Image.Image:
118
  orig_size = img.size
119
+
120
  tensor = transform(img)
121
 
122
  with torch.no_grad():
 
126
 
127
  img = img.convert("RGBA")
128
  img.putalpha(mask)
129
+
130
  return img
131
 
132
 
133
  # ---------------------------------------------------------
134
  # FASTAPI
135
  # ---------------------------------------------------------
136
+ app = FastAPI(title="Fast Background Remover")
137
 
138
  # ---------------------------------------------------------
139
  # GET redirect
 
162
  raise HTTPException(400, "Provide file or image_url")
163
 
164
  img = auto_downscale(img)
165
+
166
  result = run_inference(img)
167
 
168
  buf = BytesIO()
 
176
 
177
 
178
  # ---------------------------------------------------------
179
+ # UI (UNCHANGED BUT CLEAN)
180
  # ---------------------------------------------------------
181
  @app.get("/", response_class=HTMLResponse)
182
  def ui():
 
185
  <head>
186
  <title>Background Remover</title>
187
  <link rel='stylesheet'
188
+ href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
189
  </head>
190
  <body class='bg-light'>
191
  <div class='container py-4 text-center'>
192
 
193
+ <h2>Background Remover</h2>
194
 
195
  <div class='row'>
196
  <div class='col-md-6'>
197
  <h5>Input</h5>
198
+ <img id='inputImg' style='max-width:100%'>
199
  </div>
200
  <div class='col-md-6'>
201
  <h5>Output</h5>
202
+ <img id='outputImg' style='max-width:100%'>
203
  </div>
204
  </div>
205
 
206
  <hr>
207
 
 
208
  <form id="uploadForm">
209
  <input type='file' id='fileInput' class='form-control mb-3'>
210
+ <button class='btn btn-primary'>Upload</button>
211
  </form>
212
 
213
  <hr>
214
 
 
215
  <form id='urlForm'>
216
+ <input id='urlInput' class='form-control mb-3' placeholder='Image URL'>
217
+ <button class='btn btn-success'>Send URL</button>
218
  </form>
219
 
220
  </div>
 
223
  const inputImg = document.getElementById("inputImg");
224
  const outputImg = document.getElementById("outputImg");
225
 
226
+ async function send(fd){
227
+ const r = await fetch("/remove-background", {
228
+ method:"POST",
229
+ body:fd
230
  });
231
 
232
+ const blob = await r.blob();
 
 
 
 
 
233
  outputImg.src = URL.createObjectURL(blob);
234
  }
235
 
236
+ document.getElementById("uploadForm").onsubmit = async e=>{
237
  e.preventDefault();
238
+ const file = fileInput.files[0];
 
 
239
  inputImg.src = URL.createObjectURL(file);
240
 
241
  const fd = new FormData();
242
  fd.append("file", file);
243
+ send(fd);
244
+ };
245
 
246
+ document.getElementById("urlForm").onsubmit = async e=>{
 
 
 
247
  e.preventDefault();
248
+ const url = urlInput.value;
 
 
249
  inputImg.src = url;
250
 
251
  const fd = new FormData();
252
  fd.append("image_url", url);
253
+ send(fd);
254
+ };
 
255
  </script>
256
 
257
  </body>
258
  </html>
259
  """
260
 
 
261
  # ---------------------------------------------------------
262
  # RUN
263
  # ---------------------------------------------------------