videopix commited on
Commit
3d7434a
·
verified ·
1 Parent(s): 9c9a36a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -37
app.py CHANGED
@@ -27,7 +27,7 @@ os.environ["MKL_NUM_THREADS"] = "1"
27
  torch.set_num_threads(1)
28
 
29
  # ---------------------------------------------------------
30
- # Model load
31
  # ---------------------------------------------------------
32
  MODEL_DIR = "models/BiRefNet"
33
  os.makedirs(MODEL_DIR, exist_ok=True)
@@ -35,6 +35,7 @@ os.makedirs(MODEL_DIR, exist_ok=True)
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
37
 
 
38
  birefnet = AutoModelForImageSegmentation.from_pretrained(
39
  "ZhengPeng7/BiRefNet",
40
  cache_dir=MODEL_DIR,
@@ -42,8 +43,9 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
42
  revision="main"
43
  )
44
  birefnet.to(device, dtype=dtype).eval()
 
45
 
46
- # Thread lock to protect inference on CPU
47
  inference_lock = threading.Lock()
48
 
49
  # ---------------------------------------------------------
@@ -59,9 +61,7 @@ def load_image_from_url(url: str) -> Image.Image:
59
 
60
 
61
  def auto_downscale(image: Image.Image, max_side: int = 3000) -> Image.Image:
62
- """Downscale very large images to speed up CPU inference."""
63
  w, h = image.size
64
-
65
  if max(w, h) <= max_side:
66
  return image
67
 
@@ -69,21 +69,19 @@ def auto_downscale(image: Image.Image, max_side: int = 3000) -> Image.Image:
69
  new_w = int(w * scale)
70
  new_h = int(h * scale)
71
 
72
- image = image.resize((new_w, new_h), Image.LANCZOS)
73
- print(f"[INFO] Downscaled large image from {w}x{h} to {new_w}x{new_h}")
74
-
75
- return image
76
 
77
 
78
  def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor:
79
  image = image.resize((resolution, resolution))
80
  arr = np.array(image).astype(np.float32) / 255.0
 
81
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
82
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
83
-
84
  arr = (arr - mean) / std
85
- arr = np.transpose(arr, (2, 0, 1))
86
 
 
87
  tensor = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
88
  return tensor
89
 
@@ -106,10 +104,14 @@ def run_inference(image: Image.Image, resolution: int = 512) -> Image.Image:
106
 
107
 
108
  # ---------------------------------------------------------
109
- # API endpoint
110
  # ---------------------------------------------------------
111
- @app = FastAPI(title="Background Remover API")
112
 
 
 
 
 
113
  @app.post("/remove-background")
114
  async def remove_background(
115
  file: UploadFile = File(None),
@@ -117,17 +119,19 @@ async def remove_background(
117
  resolution: int = Form(512)
118
  ):
119
  try:
 
120
  if file:
121
- data = await file.read()
122
- image = Image.open(BytesIO(data)).convert("RGB")
123
  elif image_url:
124
  image = load_image_from_url(image_url)
125
  else:
126
  raise HTTPException(status_code=400, detail="Provide file or image_url.")
127
 
128
- # Auto-downscale for large images → much faster
129
  image = auto_downscale(image)
130
 
 
131
  result = run_inference(image, resolution)
132
 
133
  buf = BytesIO()
@@ -143,14 +147,13 @@ async def remove_background(
143
 
144
 
145
  # ---------------------------------------------------------
146
- # Developer test UI
147
  # ---------------------------------------------------------
148
  @app.get("/", response_class=HTMLResponse)
149
- async def index():
150
  return """
151
  <html>
152
  <head>
153
- <meta charset='utf-8' />
154
  <title>Background Remover</title>
155
  <link rel='stylesheet' href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
156
  </head>
@@ -158,7 +161,7 @@ async def index():
158
  <div class='container text-center'>
159
  <h2>Background Remover API</h2>
160
 
161
- <form id='uploadForm' class='mb-4' enctype='multipart/form-data'>
162
  <input class='form-control mb-2' type='file' id='fileInput' name='file'>
163
  <input class='form-control mb-2' type='number' id='resInput' name='resolution' value='512'>
164
  <button class='btn btn-primary'>Upload</button>
@@ -166,41 +169,45 @@ async def index():
166
 
167
  <div class='mb-3'>OR</div>
168
 
169
- <form id='urlForm'>
170
  <input class='form-control mb-2' id='urlInput' placeholder='Image URL'>
171
  <input class='form-control mb-2' id='urlResInput' type='number' value='512'>
172
  <button class='btn btn-success'>Use URL</button>
173
  </form>
174
 
175
  <h5 class='mt-4'>Result:</h5>
176
- <img id='resultImg' style='max-width:100%;border-radius:10px;'/>
177
  </div>
178
 
179
  <script>
180
- const resultImg = document.getElementById("resultImg");
181
 
182
- document.getElementById("uploadForm").addEventListener("submit", async e =>{
183
  e.preventDefault();
184
  const file = document.getElementById("fileInput").files[0];
185
- if (!file) return alert("Choose an image");
186
  const res = document.getElementById("resInput").value;
187
- const f = new FormData();
188
- f.append("file", file);
189
- f.append("resolution", res);
190
- const r = await fetch("/remove-background", { method:"POST", body:f });
191
- resultImg.src = URL.createObjectURL(await r.blob());
 
 
192
  });
193
 
194
- document.getElementById("urlForm").addEventListener("submit", async e =>{
195
  e.preventDefault();
196
  const url = document.getElementById("urlInput").value.trim();
197
- if (!url) return alert("Enter URL");
198
  const res = document.getElementById("urlResInput").value;
199
- const f = new FormData();
200
- f.append("image_url", url);
201
- f.append("resolution", res);
202
- const r = await fetch("/remove-background", { method:"POST", body:f });
203
- resultImg.src = URL.createObjectURL(await r.blob());
 
 
204
  });
205
  </script>
206
  </body>
@@ -209,7 +216,7 @@ async def index():
209
 
210
 
211
  # ---------------------------------------------------------
212
- # Run server
213
  # ---------------------------------------------------------
214
  if __name__ == "__main__":
215
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
27
  torch.set_num_threads(1)
28
 
29
  # ---------------------------------------------------------
30
+ # Load model
31
  # ---------------------------------------------------------
32
  MODEL_DIR = "models/BiRefNet"
33
  os.makedirs(MODEL_DIR, exist_ok=True)
 
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
37
 
38
+ print("Loading model...")
39
  birefnet = AutoModelForImageSegmentation.from_pretrained(
40
  "ZhengPeng7/BiRefNet",
41
  cache_dir=MODEL_DIR,
 
43
  revision="main"
44
  )
45
  birefnet.to(device, dtype=dtype).eval()
46
+ print("Model ready.")
47
 
48
+ # Thread lock for safe inference on CPU
49
  inference_lock = threading.Lock()
50
 
51
  # ---------------------------------------------------------
 
61
 
62
 
63
  def auto_downscale(image: Image.Image, max_side: int = 3000) -> Image.Image:
 
64
  w, h = image.size
 
65
  if max(w, h) <= max_side:
66
  return image
67
 
 
69
  new_w = int(w * scale)
70
  new_h = int(h * scale)
71
 
72
+ print(f"[INFO] Downscaling large image {w}x{h} → {new_w}x{new_h}")
73
+ return image.resize((new_w, new_h), Image.LANCZOS)
 
 
74
 
75
 
76
  def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor:
77
  image = image.resize((resolution, resolution))
78
  arr = np.array(image).astype(np.float32) / 255.0
79
+
80
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
81
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
 
82
  arr = (arr - mean) / std
 
83
 
84
+ arr = np.transpose(arr, (2, 0, 1))
85
  tensor = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
86
  return tensor
87
 
 
104
 
105
 
106
  # ---------------------------------------------------------
107
+ # FastAPI app
108
  # ---------------------------------------------------------
109
+ app = FastAPI(title="Background Remover API")
110
 
111
+
112
+ # ---------------------------------------------------------
113
+ # remove-background endpoint
114
+ # ---------------------------------------------------------
115
  @app.post("/remove-background")
116
  async def remove_background(
117
  file: UploadFile = File(None),
 
119
  resolution: int = Form(512)
120
  ):
121
  try:
122
+ # Load input
123
  if file:
124
+ raw = await file.read()
125
+ image = Image.open(BytesIO(raw)).convert("RGB")
126
  elif image_url:
127
  image = load_image_from_url(image_url)
128
  else:
129
  raise HTTPException(status_code=400, detail="Provide file or image_url.")
130
 
131
+ # Automatically compress very large images
132
  image = auto_downscale(image)
133
 
134
+ # Process image
135
  result = run_inference(image, resolution)
136
 
137
  buf = BytesIO()
 
147
 
148
 
149
  # ---------------------------------------------------------
150
+ # Web test UI
151
  # ---------------------------------------------------------
152
  @app.get("/", response_class=HTMLResponse)
153
+ async def ui():
154
  return """
155
  <html>
156
  <head>
 
157
  <title>Background Remover</title>
158
  <link rel='stylesheet' href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
159
  </head>
 
161
  <div class='container text-center'>
162
  <h2>Background Remover API</h2>
163
 
164
+ <form id='f1' class='mb-4' enctype='multipart/form-data'>
165
  <input class='form-control mb-2' type='file' id='fileInput' name='file'>
166
  <input class='form-control mb-2' type='number' id='resInput' name='resolution' value='512'>
167
  <button class='btn btn-primary'>Upload</button>
 
169
 
170
  <div class='mb-3'>OR</div>
171
 
172
+ <form id='f2'>
173
  <input class='form-control mb-2' id='urlInput' placeholder='Image URL'>
174
  <input class='form-control mb-2' id='urlResInput' type='number' value='512'>
175
  <button class='btn btn-success'>Use URL</button>
176
  </form>
177
 
178
  <h5 class='mt-4'>Result:</h5>
179
+ <img id='out' style='max-width:100%;border-radius:10px;'/>
180
  </div>
181
 
182
  <script>
183
+ const out = document.getElementById("out");
184
 
185
+ document.getElementById("f1").addEventListener("submit", async e => {
186
  e.preventDefault();
187
  const file = document.getElementById("fileInput").files[0];
188
+ if (!file) return alert("Select an image");
189
  const res = document.getElementById("resInput").value;
190
+
191
+ const fd = new FormData();
192
+ fd.append("file", file);
193
+ fd.append("resolution", res);
194
+
195
+ const r = await fetch("/remove-background", { method:"POST", body:fd });
196
+ out.src = URL.createObjectURL(await r.blob());
197
  });
198
 
199
+ document.getElementById("f2").addEventListener("submit", async e => {
200
  e.preventDefault();
201
  const url = document.getElementById("urlInput").value.trim();
202
+ if (!url) return alert("Enter an image URL");
203
  const res = document.getElementById("urlResInput").value;
204
+
205
+ const fd = new FormData();
206
+ fd.append("image_url", url);
207
+ fd.append("resolution", res);
208
+
209
+ const r = await fetch("/remove-background", { method:"POST", body:fd });
210
+ out.src = URL.createObjectURL(await r.blob());
211
  });
212
  </script>
213
  </body>
 
216
 
217
 
218
  # ---------------------------------------------------------
219
+ # Start server
220
  # ---------------------------------------------------------
221
  if __name__ == "__main__":
222
  uvicorn.run(app, host="0.0.0.0", port=7860)