videopix commited on
Commit
bf34fae
·
verified ·
1 Parent(s): 0887e03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -88
app.py CHANGED
@@ -2,8 +2,8 @@ import os
2
  import threading
3
  import torch
4
  import numpy as np
5
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
6
- from fastapi.responses import StreamingResponse, HTMLResponse
7
  from PIL import Image
8
  from io import BytesIO
9
  import requests
@@ -20,7 +20,7 @@ except ImportError:
20
  pass
21
 
22
  # ---------------------------------------------------------
23
- # Performance settings for CPU (HF Spaces)
24
  # ---------------------------------------------------------
25
  os.environ["OMP_NUM_THREADS"] = "1"
26
  os.environ["MKL_NUM_THREADS"] = "1"
@@ -29,8 +29,8 @@ torch.set_num_threads(1)
29
  # ---------------------------------------------------------
30
  # Constants
31
  # ---------------------------------------------------------
32
- TARGET_SIZE = (512, 512) # Faster inference resolution
33
- MAX_SIDE = 3000 # Auto-downscale limit for large uploads
34
 
35
  # ---------------------------------------------------------
36
  # Load model
@@ -46,67 +46,65 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
46
  "ZhengPeng7/BiRefNet",
47
  cache_dir=MODEL_DIR,
48
  trust_remote_code=True,
49
- revision="main"
50
  )
51
  birefnet.to(device, dtype=dtype).eval()
52
  print("Model ready.")
53
 
54
- # Thread lock for CPU safety
55
- inference_lock = threading.Lock()
56
 
57
  # ---------------------------------------------------------
58
- # Helper functions
59
  # ---------------------------------------------------------
60
  def load_image_from_url(url: str) -> Image.Image:
61
  try:
62
  r = requests.get(url, timeout=10)
63
  r.raise_for_status()
64
  return Image.open(BytesIO(r.content)).convert("RGB")
65
- except Exception as e:
66
- raise HTTPException(status_code=400, detail=f"Cannot load image from URL: {str(e)}")
67
 
68
 
69
- def auto_downscale(image: Image.Image) -> Image.Image:
70
- w, h = image.size
71
  if max(w, h) <= MAX_SIDE:
72
- return image
73
 
74
  scale = MAX_SIDE / max(w, h)
75
  new_w = int(w * scale)
76
  new_h = int(h * scale)
77
 
78
  print(f"[INFO] Downscaling {w}×{h} → {new_w}×{new_h}")
79
- return image.resize((new_w, new_h), Image.LANCZOS)
80
-
81
 
82
- def transform_image(image: Image.Image) -> torch.Tensor:
83
- image = image.resize(TARGET_SIZE)
84
 
85
- arr = np.array(image).astype(np.float32) / 255.0
86
- mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
87
- std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
88
 
 
 
 
89
  arr = (arr - mean) / std
90
  arr = np.transpose(arr, (2, 0, 1))
91
 
92
- tensor = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
93
- return tensor
94
 
95
 
96
- def run_inference(image: Image.Image) -> Image.Image:
97
- orig_size = image.size
98
- input_tensor = transform_image(image)
99
 
100
- with inference_lock:
101
  with torch.no_grad():
102
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
103
 
104
- pred = preds[0, 0]
105
  mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
106
 
107
- image = image.convert("RGBA")
108
- image.putalpha(mask)
109
- return image
 
110
 
111
  # ---------------------------------------------------------
112
  # FastAPI app
@@ -114,103 +112,122 @@ def run_inference(image: Image.Image) -> Image.Image:
114
  app = FastAPI(title="Background Remover API")
115
 
116
  # ---------------------------------------------------------
117
- # POST endpoint only (no GET processing)
 
 
 
 
 
 
 
 
 
 
118
  # ---------------------------------------------------------
119
  @app.post("/remove-background")
120
- async def remove_background(
121
- file: UploadFile = File(None),
122
- image_url: str = Form(None)
123
- ):
124
  try:
125
- # load image
126
  if file:
127
  raw = await file.read()
128
- image = Image.open(BytesIO(raw)).convert("RGB")
129
  elif image_url:
130
- image = load_image_from_url(image_url)
131
  else:
132
- raise HTTPException(status_code=400, detail="Provide file or image_url.")
133
-
134
- # auto shrink large inputs
135
- image = auto_downscale(image)
136
 
137
- # remove background
138
- result = run_inference(image)
139
 
140
- # return PNG
141
  buf = BytesIO()
142
- result.save(buf, format="PNG", optimize=True)
143
  buf.seek(0)
144
 
145
  return StreamingResponse(buf, media_type="image/png")
146
 
147
- except HTTPException:
148
- raise
149
  except Exception as e:
150
- raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
 
151
 
152
  # ---------------------------------------------------------
153
- # UI for POST method testing only
154
  # ---------------------------------------------------------
155
  @app.get("/", response_class=HTMLResponse)
156
  async def ui():
157
  return """
158
  <html>
159
  <head>
160
- <title>Background Remover – Test Tool</title>
161
- <link rel='stylesheet' href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
 
162
  </head>
163
-
164
- <body style='background:#f8f9fa;padding-top:40px;'>
165
- <div class='container text-center'>
166
- <h2>POST Method Test Panel</h2>
167
- <p>This UI only sends POST requests to <code>/remove-background</code>.</p>
168
-
169
- <h5>Test with File Upload:</h5>
170
- <form id='uploadForm' enctype='multipart/form-data'>
171
- <input class='form-control mb-2' type='file' id='fileInput'>
 
 
 
 
 
 
 
 
 
 
 
 
172
  <button class='btn btn-primary'>Send POST</button>
173
  </form>
174
 
175
- <div class='my-4'>OR</div>
176
 
177
- <h5>Test with Image URL:</h5>
178
  <form id='urlForm'>
179
- <input class='form-control mb-2' id='urlInput' placeholder='https://example.com/image.jpg'>
180
  <button class='btn btn-success'>Send POST</button>
181
  </form>
182
-
183
- <h4 class='mt-4'>Output:</h4>
184
- <img id='outputImg' style='max-width:90%;border-radius:10px;'/>
185
  </div>
186
 
187
  <script>
188
- const outputImg = document.getElementById("outputImg");
 
189
 
190
- document.getElementById("uploadForm").addEventListener("submit", async e => {
191
- e.preventDefault();
192
- const file = document.getElementById("fileInput").files[0];
193
- if (!file) return alert("Select a file first");
 
194
 
195
- const fd = new FormData();
196
- fd.append("file", file);
197
 
198
- const r = await fetch("/remove-background", {method:"POST", body:fd});
199
- outputImg.src = URL.createObjectURL(await r.blob());
200
- });
201
 
202
- document.getElementById("urlForm").addEventListener("submit", async e => {
203
- e.preventDefault();
204
- const url = document.getElementById("urlInput").value.trim();
205
- if (!url) return alert("Enter an image URL");
206
 
207
- const fd = new FormData();
208
- fd.append("image_url", url);
 
 
 
209
 
210
- const r = await fetch("/remove-background", {method:"POST", body:fd});
211
- outputImg.src = URL.createObjectURL(await r.blob());
212
- });
 
 
 
 
 
213
  </script>
 
214
  </body>
215
  </html>
216
  """
 
2
  import threading
3
  import torch
4
  import numpy as np
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
6
+ from fastapi.responses import StreamingResponse, HTMLResponse, RedirectResponse, JSONResponse
7
  from PIL import Image
8
  from io import BytesIO
9
  import requests
 
20
  pass
21
 
22
  # ---------------------------------------------------------
23
+ # Performance settings for HF CPU
24
  # ---------------------------------------------------------
25
  os.environ["OMP_NUM_THREADS"] = "1"
26
  os.environ["MKL_NUM_THREADS"] = "1"
 
29
  # ---------------------------------------------------------
30
  # Constants
31
  # ---------------------------------------------------------
32
+ TARGET_SIZE = (512, 512) # Faster inference
33
+ MAX_SIDE = 3000 # Auto-downscale for huge uploads
34
 
35
  # ---------------------------------------------------------
36
  # Load model
 
46
  "ZhengPeng7/BiRefNet",
47
  cache_dir=MODEL_DIR,
48
  trust_remote_code=True,
49
+ revision="main",
50
  )
51
  birefnet.to(device, dtype=dtype).eval()
52
  print("Model ready.")
53
 
54
+ lock = threading.Lock()
 
55
 
56
  # ---------------------------------------------------------
57
+ # Helpers
58
  # ---------------------------------------------------------
59
  def load_image_from_url(url: str) -> Image.Image:
60
  try:
61
  r = requests.get(url, timeout=10)
62
  r.raise_for_status()
63
  return Image.open(BytesIO(r.content)).convert("RGB")
64
+ except Exception:
65
+ raise HTTPException(status_code=400, detail="Invalid image URL")
66
 
67
 
68
+ def auto_downscale(img: Image.Image) -> Image.Image:
69
+ w, h = img.size
70
  if max(w, h) <= MAX_SIDE:
71
+ return img
72
 
73
  scale = MAX_SIDE / max(w, h)
74
  new_w = int(w * scale)
75
  new_h = int(h * scale)
76
 
77
  print(f"[INFO] Downscaling {w}×{h} → {new_w}×{new_h}")
78
+ return img.resize((new_w, new_h), Image.LANCZOS)
 
79
 
 
 
80
 
81
+ def transform(img: Image.Image) -> torch.Tensor:
82
+ img = img.resize(TARGET_SIZE)
 
83
 
84
+ arr = np.array(img).astype(np.float32) / 255.0
85
+ mean = np.array([0.485, 0.456, 0.406])
86
+ std = np.array([0.229, 0.224, 0.225])
87
  arr = (arr - mean) / std
88
  arr = np.transpose(arr, (2, 0, 1))
89
 
90
+ t = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
91
+ return t
92
 
93
 
94
+ def run_inference(img: Image.Image) -> Image.Image:
95
+ orig_size = img.size
96
+ tensor = transform(img)
97
 
98
+ with lock:
99
  with torch.no_grad():
100
+ pred = birefnet(tensor)[-1].sigmoid().cpu()[0, 0]
101
 
 
102
  mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
103
 
104
+ img = img.convert("RGBA")
105
+ img.putalpha(mask)
106
+ return img
107
+
108
 
109
  # ---------------------------------------------------------
110
  # FastAPI app
 
112
  app = FastAPI(title="Background Remover API")
113
 
114
  # ---------------------------------------------------------
115
+ # Redirect GET POST logic
116
+ # ---------------------------------------------------------
117
+ @app.get("/remove-background")
118
+ async def redirect_to_post():
119
+ return JSONResponse(
120
+ {"detail": "This endpoint only supports POST. Use POST /remove-background"},
121
+ status_code=405
122
+ )
123
+
124
+ # ---------------------------------------------------------
125
+ # Main POST endpoint
126
  # ---------------------------------------------------------
127
  @app.post("/remove-background")
128
+ async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
 
 
 
129
  try:
 
130
  if file:
131
  raw = await file.read()
132
+ img = Image.open(BytesIO(raw)).convert("RGB")
133
  elif image_url:
134
+ img = load_image_from_url(image_url)
135
  else:
136
+ raise HTTPException(status_code=400, detail="Upload file or image_url required")
 
 
 
137
 
138
+ img = auto_downscale(img)
139
+ result = run_inference(img)
140
 
 
141
  buf = BytesIO()
142
+ result.save(buf, format="PNG")
143
  buf.seek(0)
144
 
145
  return StreamingResponse(buf, media_type="image/png")
146
 
 
 
147
  except Exception as e:
148
+ raise HTTPException(status_code=500, detail=str(e))
149
+
150
 
151
  # ---------------------------------------------------------
152
+ # UI: Show INPUT + OUTPUT (big preview)
153
  # ---------------------------------------------------------
154
  @app.get("/", response_class=HTMLResponse)
155
  async def ui():
156
  return """
157
  <html>
158
  <head>
159
+ <title>Background Remover – Test UI</title>
160
+ <link rel='stylesheet'
161
+ href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
162
  </head>
163
+ <body class='bg-light'>
164
+ <div class='container py-4 text-center'>
165
+
166
+ <h2 class='mb-4'>API Test Panel (POST Only)</h2>
167
+
168
+ <div class='row'>
169
+ <div class='col-md-6'>
170
+ <h5>Input Image</h5>
171
+ <img id='inputImg' style='max-width:100%; border-radius:10px;'>
172
+ </div>
173
+ <div class='col-md-6'>
174
+ <h5>Output Image</h5>
175
+ <img id='outputImg' style='max-width:100%; border-radius:10px;'>
176
+ </div>
177
+ </div>
178
+
179
+ <hr>
180
+
181
+ <h4>Upload Test</h4>
182
+ <form id="uploadForm" enctype='multipart/form-data'>
183
+ <input type='file' id='fileInput' class='form-control mb-3'>
184
  <button class='btn btn-primary'>Send POST</button>
185
  </form>
186
 
187
+ <hr>
188
 
189
+ <h4>URL Test</h4>
190
  <form id='urlForm'>
191
+ <input id='urlInput' class='form-control mb-3' placeholder='https://example.com/image.jpg'>
192
  <button class='btn btn-success'>Send POST</button>
193
  </form>
 
 
 
194
  </div>
195
 
196
  <script>
197
+ const inputImg = document.getElementById("inputImg");
198
+ const outputImg = document.getElementById("outputImg");
199
 
200
+ // FILE TEST
201
+ document.getElementById("uploadForm").addEventListener("submit", async e => {
202
+ e.preventDefault();
203
+ const file = document.getElementById("fileInput").files[0];
204
+ if (!file) return alert("Select a file first.");
205
 
206
+ inputImg.src = URL.createObjectURL(file);
 
207
 
208
+ const fd = new FormData();
209
+ fd.append("file", file);
 
210
 
211
+ const r = await fetch("/remove-background", { method:"POST", body:fd });
212
+ outputImg.src = URL.createObjectURL(await r.blob());
213
+ });
 
214
 
215
+ // URL TEST
216
+ document.getElementById("urlForm").addEventListener("submit", async e => {
217
+ e.preventDefault();
218
+ const url = document.getElementById("urlInput").value.trim();
219
+ if (!url) return alert("Enter an image URL first.");
220
 
221
+ inputImg.src = url;
222
+
223
+ const fd = new FormData();
224
+ fd.append("image_url", url);
225
+
226
+ const r = await fetch("/remove-background", { method:"POST", body:fd });
227
+ outputImg.src = URL.createObjectURL(await r.blob());
228
+ });
229
  </script>
230
+
231
  </body>
232
  </html>
233
  """