videopix commited on
Commit
9c9a36a
·
verified ·
1 Parent(s): a1dd89f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -75
app.py CHANGED
@@ -16,19 +16,18 @@ import uvicorn
16
  try:
17
  import pillow_heif
18
  pillow_heif.register_heif_opener()
19
- print("HEIC/HEIF supported.")
20
  except ImportError:
21
- print("Install pillow-heif for HEIC support.")
22
 
23
  # ---------------------------------------------------------
24
- # Performance settings (especially for CPU on HF Spaces)
25
  # ---------------------------------------------------------
26
  os.environ["OMP_NUM_THREADS"] = "1"
27
  os.environ["MKL_NUM_THREADS"] = "1"
28
  torch.set_num_threads(1)
29
 
30
  # ---------------------------------------------------------
31
- # Load model
32
  # ---------------------------------------------------------
33
  MODEL_DIR = "models/BiRefNet"
34
  os.makedirs(MODEL_DIR, exist_ok=True)
@@ -36,7 +35,6 @@ os.makedirs(MODEL_DIR, exist_ok=True)
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
38
 
39
- print("Loading BiRefNet model...")
40
  birefnet = AutoModelForImageSegmentation.from_pretrained(
41
  "ZhengPeng7/BiRefNet",
42
  cache_dir=MODEL_DIR,
@@ -44,16 +42,10 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
44
  revision="main"
45
  )
46
  birefnet.to(device, dtype=dtype).eval()
47
- print("Model loaded.")
48
 
49
- # Global lock to protect the model during inference
50
  inference_lock = threading.Lock()
51
 
52
- # ---------------------------------------------------------
53
- # FastAPI app
54
- # ---------------------------------------------------------
55
- app = FastAPI(title="Background Remover API")
56
-
57
  # ---------------------------------------------------------
58
  # Helper functions
59
  # ---------------------------------------------------------
@@ -66,6 +58,23 @@ def load_image_from_url(url: str) -> Image.Image:
66
  raise HTTPException(status_code=400, detail=f"Cannot load image from URL: {str(e)}")
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor:
70
  image = image.resize((resolution, resolution))
71
  arr = np.array(image).astype(np.float32) / 255.0
@@ -81,14 +90,12 @@ def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor:
81
 
82
  def run_inference(image: Image.Image, resolution: int = 512) -> Image.Image:
83
  orig_size = image.size
 
84
  input_tensor = transform_image(image, resolution)
85
 
86
- with inference_lock: # prevents CPU thread crashes
87
  with torch.no_grad():
88
- try:
89
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
90
- except Exception as e:
91
- raise RuntimeError(f"Inference error: {str(e)}")
92
 
93
  pred = preds[0, 0]
94
  mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
@@ -97,9 +104,12 @@ def run_inference(image: Image.Image, resolution: int = 512) -> Image.Image:
97
  image.putalpha(mask)
98
  return image
99
 
 
100
  # ---------------------------------------------------------
101
- # /remove-background endpoint
102
  # ---------------------------------------------------------
 
 
103
  @app.post("/remove-background")
104
  async def remove_background(
105
  file: UploadFile = File(None),
@@ -113,12 +123,15 @@ async def remove_background(
113
  elif image_url:
114
  image = load_image_from_url(image_url)
115
  else:
116
- raise HTTPException(status_code=400, detail="Provide either file or image_url.")
 
 
 
117
 
118
  result = run_inference(image, resolution)
119
 
120
  buf = BytesIO()
121
- result.save(buf, format="PNG")
122
  buf.seek(0)
123
 
124
  return StreamingResponse(buf, media_type="image/png")
@@ -130,93 +143,69 @@ async def remove_background(
130
 
131
 
132
  # ---------------------------------------------------------
133
- # Test UI
134
  # ---------------------------------------------------------
135
  @app.get("/", response_class=HTMLResponse)
136
  async def index():
137
- html = """
138
- <!DOCTYPE html>
139
  <html>
140
  <head>
141
- <meta charset="UTF-8">
142
  <title>Background Remover</title>
143
- <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet">
144
- <style>
145
- body { background: #f8f9fa; padding-top: 40px; }
146
- .container { max-width: 700px; }
147
- img { max-width: 100%; margin-top: 20px; border-radius: 10px; }
148
- </style>
149
  </head>
150
-
151
- <body>
152
- <div class="container text-center">
153
- <h2 class="mb-4">Background Remover API</h2>
154
-
155
- <form id="uploadForm" class="mb-4" enctype="multipart/form-data">
156
- <div class="mb-3">
157
- <input class="form-control" type="file" id="fileInput" name="file" accept="image/*">
158
- </div>
159
- <div class="mb-3">
160
- <input class="form-control" type="number" id="resInput" name="resolution" value="512" min="64" max="2048">
161
- </div>
162
- <button class="btn btn-primary">Upload</button>
163
  </form>
164
 
165
- <div class="mb-4">OR</div>
166
 
167
- <form id="urlForm">
168
- <div class="mb-3">
169
- <input class="form-control" type="text" id="urlInput" placeholder="https://example.com/img.jpg">
170
- </div>
171
- <div class="mb-3">
172
- <input class="form-control" type="number" id="urlResInput" value="512" min="64" max="2048">
173
- </div>
174
- <button class="btn btn-success">Use URL</button>
175
  </form>
176
 
177
- <h5 class="mt-4">Result:</h5>
178
- <img id="resultImg" />
179
  </div>
180
 
181
  <script>
182
- const uploadForm = document.getElementById("uploadForm");
183
- const urlForm = document.getElementById("urlForm");
184
  const resultImg = document.getElementById("resultImg");
185
 
186
- uploadForm.addEventListener("submit", async e => {
187
  e.preventDefault();
188
  const file = document.getElementById("fileInput").files[0];
189
  if (!file) return alert("Choose an image");
190
-
191
  const res = document.getElementById("resInput").value;
192
- const form = new FormData();
193
- form.append("file", file);
194
- form.append("resolution", res);
195
-
196
- const r = await fetch("/remove-background", { method: "POST", body: form });
197
- const blob = await r.blob();
198
- resultImg.src = URL.createObjectURL(blob);
199
  });
200
 
201
- urlForm.addEventListener("submit", async e => {
202
  e.preventDefault();
203
  const url = document.getElementById("urlInput").value.trim();
204
  if (!url) return alert("Enter URL");
205
-
206
  const res = document.getElementById("urlResInput").value;
207
- const form = new FormData();
208
- form.append("image_url", url);
209
- form.append("resolution", res);
210
-
211
- const r = await fetch("/remove-background", { method: "POST", body: form });
212
- const blob = await r.blob();
213
- resultImg.src = URL.createObjectURL(blob);
214
  });
215
  </script>
216
  </body>
217
  </html>
218
  """
219
- return HTMLResponse(html)
220
 
221
 
222
  # ---------------------------------------------------------
 
16
  try:
17
  import pillow_heif
18
  pillow_heif.register_heif_opener()
 
19
  except ImportError:
20
+ pass
21
 
22
  # ---------------------------------------------------------
23
+ # Performance settings for CPU (HF Spaces)
24
  # ---------------------------------------------------------
25
  os.environ["OMP_NUM_THREADS"] = "1"
26
  os.environ["MKL_NUM_THREADS"] = "1"
27
  torch.set_num_threads(1)
28
 
29
  # ---------------------------------------------------------
30
+ # Model load
31
  # ---------------------------------------------------------
32
  MODEL_DIR = "models/BiRefNet"
33
  os.makedirs(MODEL_DIR, exist_ok=True)
 
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
37
 
 
38
  birefnet = AutoModelForImageSegmentation.from_pretrained(
39
  "ZhengPeng7/BiRefNet",
40
  cache_dir=MODEL_DIR,
 
42
  revision="main"
43
  )
44
  birefnet.to(device, dtype=dtype).eval()
 
45
 
46
+ # Thread lock to protect inference on CPU
47
  inference_lock = threading.Lock()
48
 
 
 
 
 
 
49
  # ---------------------------------------------------------
50
  # Helper functions
51
  # ---------------------------------------------------------
 
58
  raise HTTPException(status_code=400, detail=f"Cannot load image from URL: {str(e)}")
59
 
60
 
61
+ def auto_downscale(image: Image.Image, max_side: int = 3000) -> Image.Image:
62
+ """Downscale very large images to speed up CPU inference."""
63
+ w, h = image.size
64
+
65
+ if max(w, h) <= max_side:
66
+ return image
67
+
68
+ scale = max_side / max(w, h)
69
+ new_w = int(w * scale)
70
+ new_h = int(h * scale)
71
+
72
+ image = image.resize((new_w, new_h), Image.LANCZOS)
73
+ print(f"[INFO] Downscaled large image from {w}x{h} to {new_w}x{new_h}")
74
+
75
+ return image
76
+
77
+
78
  def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor:
79
  image = image.resize((resolution, resolution))
80
  arr = np.array(image).astype(np.float32) / 255.0
 
90
 
91
  def run_inference(image: Image.Image, resolution: int = 512) -> Image.Image:
92
  orig_size = image.size
93
+
94
  input_tensor = transform_image(image, resolution)
95
 
96
+ with inference_lock:
97
  with torch.no_grad():
98
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
 
 
 
99
 
100
  pred = preds[0, 0]
101
  mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
 
104
  image.putalpha(mask)
105
  return image
106
 
107
+
108
  # ---------------------------------------------------------
109
+ # API endpoint
110
  # ---------------------------------------------------------
111
+ @app = FastAPI(title="Background Remover API")
112
+
113
  @app.post("/remove-background")
114
  async def remove_background(
115
  file: UploadFile = File(None),
 
123
  elif image_url:
124
  image = load_image_from_url(image_url)
125
  else:
126
+ raise HTTPException(status_code=400, detail="Provide file or image_url.")
127
+
128
+ # Auto-downscale for large images → much faster
129
+ image = auto_downscale(image)
130
 
131
  result = run_inference(image, resolution)
132
 
133
  buf = BytesIO()
134
+ result.save(buf, format="PNG", optimize=True)
135
  buf.seek(0)
136
 
137
  return StreamingResponse(buf, media_type="image/png")
 
143
 
144
 
145
  # ---------------------------------------------------------
146
+ # Developer test UI
147
  # ---------------------------------------------------------
148
  @app.get("/", response_class=HTMLResponse)
149
  async def index():
150
+ return """
 
151
  <html>
152
  <head>
153
+ <meta charset='utf-8' />
154
  <title>Background Remover</title>
155
+ <link rel='stylesheet' href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
 
 
 
 
 
156
  </head>
157
+ <body style='background:#f8f9fa;padding-top:40px;'>
158
+ <div class='container text-center'>
159
+ <h2>Background Remover API</h2>
160
+
161
+ <form id='uploadForm' class='mb-4' enctype='multipart/form-data'>
162
+ <input class='form-control mb-2' type='file' id='fileInput' name='file'>
163
+ <input class='form-control mb-2' type='number' id='resInput' name='resolution' value='512'>
164
+ <button class='btn btn-primary'>Upload</button>
 
 
 
 
 
165
  </form>
166
 
167
+ <div class='mb-3'>OR</div>
168
 
169
+ <form id='urlForm'>
170
+ <input class='form-control mb-2' id='urlInput' placeholder='Image URL'>
171
+ <input class='form-control mb-2' id='urlResInput' type='number' value='512'>
172
+ <button class='btn btn-success'>Use URL</button>
 
 
 
 
173
  </form>
174
 
175
+ <h5 class='mt-4'>Result:</h5>
176
+ <img id='resultImg' style='max-width:100%;border-radius:10px;'/>
177
  </div>
178
 
179
  <script>
 
 
180
  const resultImg = document.getElementById("resultImg");
181
 
182
+ document.getElementById("uploadForm").addEventListener("submit", async e =>{
183
  e.preventDefault();
184
  const file = document.getElementById("fileInput").files[0];
185
  if (!file) return alert("Choose an image");
 
186
  const res = document.getElementById("resInput").value;
187
+ const f = new FormData();
188
+ f.append("file", file);
189
+ f.append("resolution", res);
190
+ const r = await fetch("/remove-background", { method:"POST", body:f });
191
+ resultImg.src = URL.createObjectURL(await r.blob());
 
 
192
  });
193
 
194
+ document.getElementById("urlForm").addEventListener("submit", async e =>{
195
  e.preventDefault();
196
  const url = document.getElementById("urlInput").value.trim();
197
  if (!url) return alert("Enter URL");
 
198
  const res = document.getElementById("urlResInput").value;
199
+ const f = new FormData();
200
+ f.append("image_url", url);
201
+ f.append("resolution", res);
202
+ const r = await fetch("/remove-background", { method:"POST", body:f });
203
+ resultImg.src = URL.createObjectURL(await r.blob());
 
 
204
  });
205
  </script>
206
  </body>
207
  </html>
208
  """
 
209
 
210
 
211
  # ---------------------------------------------------------