videopix commited on
Commit
81cb8c6
·
verified ·
1 Parent(s): baf2587

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -44
app.py CHANGED
@@ -1,17 +1,16 @@
1
  import os
2
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
3
  from fastapi.responses import StreamingResponse, HTMLResponse
4
- from PIL import Image, ImageSequence
5
  import torch
6
  import numpy as np
7
  from transformers import AutoModelForImageSegmentation
8
  from io import BytesIO
9
  import requests
10
  import uvicorn
11
- from concurrent.futures import ThreadPoolExecutor
12
 
13
  # -------------------------
14
- # Optional HEIC Support
15
  # -------------------------
16
  try:
17
  import pillow_heif
@@ -26,6 +25,7 @@ except ImportError:
26
  MODEL_DIR = "models/BiRefNet"
27
  os.makedirs(MODEL_DIR, exist_ok=True)
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
29
 
30
  print("Loading BiRefNet model...")
31
  birefnet = AutoModelForImageSegmentation.from_pretrained(
@@ -34,17 +34,13 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
34
  trust_remote_code=True,
35
  revision="main"
36
  )
37
- birefnet.to(device)
38
- if device == "cuda":
39
- birefnet = birefnet.half() # FP16 for faster GPU inference
40
- birefnet.eval()
41
  print("Model loaded successfully.")
42
 
43
  # -------------------------
44
  # FastAPI App
45
  # -------------------------
46
  app = FastAPI(title="Background Remover API")
47
- executor = ThreadPoolExecutor(max_workers=4)
48
 
49
  # -------------------------
50
  # Utility Functions
@@ -53,27 +49,27 @@ def load_image_from_url(url: str) -> Image.Image:
53
  try:
54
  response = requests.get(url, timeout=10)
55
  response.raise_for_status()
56
- return Image.open(BytesIO(response.content))
57
  except Exception as e:
58
  raise HTTPException(status_code=400, detail=f"Error loading image from URL: {str(e)}")
59
 
60
- def transform_image(image: Image.Image, target_size=(512, 512)) -> torch.Tensor:
61
- image = image.resize(target_size)
62
  arr = np.array(image).astype(np.float32) / 255.0
63
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
64
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
65
  arr = (arr - mean) / std
66
  arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
67
- tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float16 if device=="cuda" else torch.float32).to(device)
68
  return tensor
69
 
70
- def process_image(image: Image.Image) -> Image.Image:
71
- image_size = image.size
72
- input_tensor = transform_image(image)
73
  with torch.no_grad():
74
  preds = birefnet(input_tensor)[-1].sigmoid().cpu()
75
  pred = preds[0, 0]
76
- mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(image_size)
77
  image = image.convert("RGBA")
78
  image.putalpha(mask)
79
  return image
@@ -82,30 +78,28 @@ def process_image(image: Image.Image) -> Image.Image:
82
  # /remove-background Endpoint
83
  # -------------------------
84
  @app.post("/remove-background")
85
- async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
 
 
 
 
86
  """
87
  Remove background from an image.
88
- Accepts file upload or image URL.
 
89
  Returns PNG with transparent background.
90
  """
91
  try:
92
  if file:
93
- image = Image.open(BytesIO(await file.read()))
94
  elif image_url:
95
  image = load_image_from_url(image_url)
96
  else:
97
- raise HTTPException(status_code=400, detail="Provide 'file' or 'image_url'.")
98
-
99
- # Handle multi-frame images (GIF, PDF)
100
- if getattr(image, "is_animated", False):
101
- frames = [process_image(frame.convert("RGBA")) for frame in ImageSequence.Iterator(image)]
102
- buf = BytesIO()
103
- frames[0].save(buf, format="PNG", save_all=True, append_images=frames[1:])
104
- else:
105
- result = process_image(image.convert("RGBA"))
106
- buf = BytesIO()
107
- result.save(buf, format="PNG")
108
 
 
 
 
109
  buf.seek(0)
110
  return StreamingResponse(buf, media_type="image/png")
111
  except Exception as e:
@@ -135,55 +129,61 @@ async def index():
135
  <h2 class="mb-4">Background Remover API Tester</h2>
136
  <form id="uploadForm" class="mb-4" enctype="multipart/form-data">
137
  <div class="mb-3">
138
- <label class="form-label">Upload Image (any format):</label>
139
  <input class="form-control" type="file" id="fileInput" name="file" accept="image/*">
140
  </div>
 
 
 
 
141
  <button class="btn btn-primary" type="submit">Remove Background</button>
142
  </form>
143
  <div class="mb-4">OR</div>
144
  <form id="urlForm" class="mb-4">
145
  <div class="mb-3">
146
- <label class="form-label">Enter Image URL:</label>
147
  <input class="form-control" type="text" id="urlInput" placeholder="https://example.com/image.jpg">
148
  </div>
 
 
 
 
149
  <button class="btn btn-success" type="submit">Remove Background</button>
150
  </form>
151
  <div id="resultContainer" class="mt-4">
152
  <h5>Result:</h5>
153
  <img id="resultImg" src="" alt="">
154
- <a id="downloadLink" class="btn btn-info mt-2" download="result.png" style="display:none;">Download PNG</a>
155
  </div>
156
  </div>
157
  <script>
158
  const uploadForm = document.getElementById("uploadForm");
159
  const urlForm = document.getElementById("urlForm");
160
  const resultImg = document.getElementById("resultImg");
161
- const downloadLink = document.getElementById("downloadLink");
162
 
163
  uploadForm.addEventListener("submit", async e => {
164
  e.preventDefault();
165
  const fileInput = document.getElementById("fileInput");
 
166
  if (!fileInput.files.length) return alert("Please select a file!");
167
  const formData = new FormData();
168
  formData.append("file", fileInput.files[0]);
169
- const res = await fetch("/remove-background", { method: "POST", body: formData });
170
- const blob = await res.blob();
 
171
  resultImg.src = URL.createObjectURL(blob);
172
- downloadLink.href = resultImg.src;
173
- downloadLink.style.display = "inline-block";
174
  });
175
 
176
  urlForm.addEventListener("submit", async e => {
177
  e.preventDefault();
178
  const url = document.getElementById("urlInput").value.trim();
 
179
  if (!url) return alert("Please enter an image URL!");
180
  const formData = new FormData();
181
  formData.append("image_url", url);
182
- const res = await fetch("/remove-background", { method: "POST", body: formData });
183
- const blob = await res.blob();
 
184
  resultImg.src = URL.createObjectURL(blob);
185
- downloadLink.href = resultImg.src;
186
- downloadLink.style.display = "inline-block";
187
  });
188
  </script>
189
  </body>
@@ -192,7 +192,7 @@ async def index():
192
  return HTMLResponse(html)
193
 
194
  # -------------------------
195
- # Run App on Spaces
196
  # -------------------------
197
  if __name__ == "__main__":
198
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Query
3
  from fastapi.responses import StreamingResponse, HTMLResponse
4
+ from PIL import Image
5
  import torch
6
  import numpy as np
7
  from transformers import AutoModelForImageSegmentation
8
  from io import BytesIO
9
  import requests
10
  import uvicorn
 
11
 
12
  # -------------------------
13
+ # Optional HEIC/HEIF Support
14
  # -------------------------
15
  try:
16
  import pillow_heif
 
25
  MODEL_DIR = "models/BiRefNet"
26
  os.makedirs(MODEL_DIR, exist_ok=True)
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
29
 
30
  print("Loading BiRefNet model...")
31
  birefnet = AutoModelForImageSegmentation.from_pretrained(
 
34
  trust_remote_code=True,
35
  revision="main"
36
  )
37
+ birefnet.to(device, dtype=dtype).eval()
 
 
 
38
  print("Model loaded successfully.")
39
 
40
  # -------------------------
41
  # FastAPI App
42
  # -------------------------
43
  app = FastAPI(title="Background Remover API")
 
44
 
45
  # -------------------------
46
  # Utility Functions
 
49
  try:
50
  response = requests.get(url, timeout=10)
51
  response.raise_for_status()
52
+ return Image.open(BytesIO(response.content)).convert("RGB")
53
  except Exception as e:
54
  raise HTTPException(status_code=400, detail=f"Error loading image from URL: {str(e)}")
55
 
56
+ def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor:
57
+ image = image.resize((resolution, resolution))
58
  arr = np.array(image).astype(np.float32) / 255.0
59
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
60
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
61
  arr = (arr - mean) / std
62
  arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
63
+ tensor = torch.from_numpy(arr).unsqueeze(0).to(dtype).to(device)
64
  return tensor
65
 
66
+ def process_image(image: Image.Image, resolution: int = 512) -> Image.Image:
67
+ orig_size = image.size
68
+ input_tensor = transform_image(image, resolution)
69
  with torch.no_grad():
70
  preds = birefnet(input_tensor)[-1].sigmoid().cpu()
71
  pred = preds[0, 0]
72
+ mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
73
  image = image.convert("RGBA")
74
  image.putalpha(mask)
75
  return image
 
78
  # /remove-background Endpoint
79
  # -------------------------
80
  @app.post("/remove-background")
81
+ async def remove_background(
82
+ file: UploadFile = File(None),
83
+ image_url: str = Form(None),
84
+ resolution: int = Form(512)
85
+ ):
86
  """
87
  Remove background from an image.
88
+ Accepts a file upload or image URL.
89
+ Optional resolution (default 512) for faster inference.
90
  Returns PNG with transparent background.
91
  """
92
  try:
93
  if file:
94
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
95
  elif image_url:
96
  image = load_image_from_url(image_url)
97
  else:
98
+ raise HTTPException(status_code=400, detail="Provide either 'file' or 'image_url'.")
 
 
 
 
 
 
 
 
 
 
99
 
100
+ result = process_image(image, resolution)
101
+ buf = BytesIO()
102
+ result.save(buf, format="PNG")
103
  buf.seek(0)
104
  return StreamingResponse(buf, media_type="image/png")
105
  except Exception as e:
 
129
  <h2 class="mb-4">Background Remover API Tester</h2>
130
  <form id="uploadForm" class="mb-4" enctype="multipart/form-data">
131
  <div class="mb-3">
132
+ <label for="fileInput" class="form-label">Upload Image (any format, e.g. JPG, PNG, HEIC):</label>
133
  <input class="form-control" type="file" id="fileInput" name="file" accept="image/*">
134
  </div>
135
+ <div class="mb-3">
136
+ <label for="resInput" class="form-label">Resolution (default 512):</label>
137
+ <input class="form-control" type="number" id="resInput" name="resolution" value="512" min="64" max="2048">
138
+ </div>
139
  <button class="btn btn-primary" type="submit">Remove Background</button>
140
  </form>
141
  <div class="mb-4">OR</div>
142
  <form id="urlForm" class="mb-4">
143
  <div class="mb-3">
144
+ <label for="urlInput" class="form-label">Enter Image URL:</label>
145
  <input class="form-control" type="text" id="urlInput" placeholder="https://example.com/image.jpg">
146
  </div>
147
+ <div class="mb-3">
148
+ <label for="urlResInput" class="form-label">Resolution (default 512):</label>
149
+ <input class="form-control" type="number" id="urlResInput" name="resolution" value="512" min="64" max="2048">
150
+ </div>
151
  <button class="btn btn-success" type="submit">Remove Background</button>
152
  </form>
153
  <div id="resultContainer" class="mt-4">
154
  <h5>Result:</h5>
155
  <img id="resultImg" src="" alt="">
 
156
  </div>
157
  </div>
158
  <script>
159
  const uploadForm = document.getElementById("uploadForm");
160
  const urlForm = document.getElementById("urlForm");
161
  const resultImg = document.getElementById("resultImg");
 
162
 
163
  uploadForm.addEventListener("submit", async e => {
164
  e.preventDefault();
165
  const fileInput = document.getElementById("fileInput");
166
+ const res = document.getElementById("resInput").value || 512;
167
  if (!fileInput.files.length) return alert("Please select a file!");
168
  const formData = new FormData();
169
  formData.append("file", fileInput.files[0]);
170
+ formData.append("resolution", res);
171
+ const response = await fetch("/remove-background", { method: "POST", body: formData });
172
+ const blob = await response.blob();
173
  resultImg.src = URL.createObjectURL(blob);
 
 
174
  });
175
 
176
  urlForm.addEventListener("submit", async e => {
177
  e.preventDefault();
178
  const url = document.getElementById("urlInput").value.trim();
179
+ const res = document.getElementById("urlResInput").value || 512;
180
  if (!url) return alert("Please enter an image URL!");
181
  const formData = new FormData();
182
  formData.append("image_url", url);
183
+ formData.append("resolution", res);
184
+ const response = await fetch("/remove-background", { method: "POST", body: formData });
185
+ const blob = await response.blob();
186
  resultImg.src = URL.createObjectURL(blob);
 
 
187
  });
188
  </script>
189
  </body>
 
192
  return HTMLResponse(html)
193
 
194
  # -------------------------
195
+ # Run App
196
  # -------------------------
197
  if __name__ == "__main__":
198
  uvicorn.run(app, host="0.0.0.0", port=7860)