videopix commited on
Commit
0887e03
·
verified ·
1 Parent(s): 3d7434a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -52
app.py CHANGED
@@ -11,7 +11,7 @@ from transformers import AutoModelForImageSegmentation
11
  import uvicorn
12
 
13
  # ---------------------------------------------------------
14
- # Optional HEIC/HEIF support
15
  # ---------------------------------------------------------
16
  try:
17
  import pillow_heif
@@ -26,6 +26,12 @@ os.environ["OMP_NUM_THREADS"] = "1"
26
  os.environ["MKL_NUM_THREADS"] = "1"
27
  torch.set_num_threads(1)
28
 
 
 
 
 
 
 
29
  # ---------------------------------------------------------
30
  # Load model
31
  # ---------------------------------------------------------
@@ -35,7 +41,7 @@ os.makedirs(MODEL_DIR, exist_ok=True)
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
37
 
38
- print("Loading model...")
39
  birefnet = AutoModelForImageSegmentation.from_pretrained(
40
  "ZhengPeng7/BiRefNet",
41
  cache_dir=MODEL_DIR,
@@ -45,7 +51,7 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
45
  birefnet.to(device, dtype=dtype).eval()
46
  print("Model ready.")
47
 
48
- # Thread lock for safe inference on CPU
49
  inference_lock = threading.Lock()
50
 
51
  # ---------------------------------------------------------
@@ -60,36 +66,36 @@ def load_image_from_url(url: str) -> Image.Image:
60
  raise HTTPException(status_code=400, detail=f"Cannot load image from URL: {str(e)}")
61
 
62
 
63
- def auto_downscale(image: Image.Image, max_side: int = 3000) -> Image.Image:
64
  w, h = image.size
65
- if max(w, h) <= max_side:
66
  return image
67
 
68
- scale = max_side / max(w, h)
69
  new_w = int(w * scale)
70
  new_h = int(h * scale)
71
 
72
- print(f"[INFO] Downscaling large image {w}x{h} → {new_w}x{new_h}")
73
  return image.resize((new_w, new_h), Image.LANCZOS)
74
 
75
 
76
- def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor:
77
- image = image.resize((resolution, resolution))
78
- arr = np.array(image).astype(np.float32) / 255.0
79
 
 
80
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
81
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
82
- arr = (arr - mean) / std
83
 
 
84
  arr = np.transpose(arr, (2, 0, 1))
 
85
  tensor = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
86
  return tensor
87
 
88
 
89
- def run_inference(image: Image.Image, resolution: int = 512) -> Image.Image:
90
  orig_size = image.size
91
-
92
- input_tensor = transform_image(image, resolution)
93
 
94
  with inference_lock:
95
  with torch.no_grad():
@@ -102,24 +108,21 @@ def run_inference(image: Image.Image, resolution: int = 512) -> Image.Image:
102
  image.putalpha(mask)
103
  return image
104
 
105
-
106
  # ---------------------------------------------------------
107
  # FastAPI app
108
  # ---------------------------------------------------------
109
  app = FastAPI(title="Background Remover API")
110
 
111
-
112
  # ---------------------------------------------------------
113
- # remove-background endpoint
114
  # ---------------------------------------------------------
115
  @app.post("/remove-background")
116
  async def remove_background(
117
  file: UploadFile = File(None),
118
- image_url: str = Form(None),
119
- resolution: int = Form(512)
120
  ):
121
  try:
122
- # Load input
123
  if file:
124
  raw = await file.read()
125
  image = Image.open(BytesIO(raw)).convert("RGB")
@@ -128,12 +131,13 @@ async def remove_background(
128
  else:
129
  raise HTTPException(status_code=400, detail="Provide file or image_url.")
130
 
131
- # Automatically compress very large images
132
  image = auto_downscale(image)
133
 
134
- # Process image
135
- result = run_inference(image, resolution)
136
 
 
137
  buf = BytesIO()
138
  result.save(buf, format="PNG", optimize=True)
139
  buf.seek(0)
@@ -145,78 +149,74 @@ async def remove_background(
145
  except Exception as e:
146
  raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
147
 
148
-
149
  # ---------------------------------------------------------
150
- # Web test UI
151
  # ---------------------------------------------------------
152
  @app.get("/", response_class=HTMLResponse)
153
  async def ui():
154
  return """
155
  <html>
156
  <head>
157
- <title>Background Remover</title>
158
  <link rel='stylesheet' href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
159
  </head>
 
160
  <body style='background:#f8f9fa;padding-top:40px;'>
161
  <div class='container text-center'>
162
- <h2>Background Remover API</h2>
 
163
 
164
- <form id='f1' class='mb-4' enctype='multipart/form-data'>
165
- <input class='form-control mb-2' type='file' id='fileInput' name='file'>
166
- <input class='form-control mb-2' type='number' id='resInput' name='resolution' value='512'>
167
- <button class='btn btn-primary'>Upload</button>
168
  </form>
169
 
170
- <div class='mb-3'>OR</div>
171
 
172
- <form id='f2'>
173
- <input class='form-control mb-2' id='urlInput' placeholder='Image URL'>
174
- <input class='form-control mb-2' id='urlResInput' type='number' value='512'>
175
- <button class='btn btn-success'>Use URL</button>
176
  </form>
177
 
178
- <h5 class='mt-4'>Result:</h5>
179
- <img id='out' style='max-width:100%;border-radius:10px;'/>
180
  </div>
181
 
182
  <script>
183
- const out = document.getElementById("out");
184
 
185
- document.getElementById("f1").addEventListener("submit", async e => {
186
  e.preventDefault();
187
  const file = document.getElementById("fileInput").files[0];
188
- if (!file) return alert("Select an image");
189
- const res = document.getElementById("resInput").value;
190
 
191
  const fd = new FormData();
192
  fd.append("file", file);
193
- fd.append("resolution", res);
194
 
195
- const r = await fetch("/remove-background", { method:"POST", body:fd });
196
- out.src = URL.createObjectURL(await r.blob());
197
  });
198
 
199
- document.getElementById("f2").addEventListener("submit", async e => {
200
  e.preventDefault();
201
  const url = document.getElementById("urlInput").value.trim();
202
  if (!url) return alert("Enter an image URL");
203
- const res = document.getElementById("urlResInput").value;
204
 
205
  const fd = new FormData();
206
  fd.append("image_url", url);
207
- fd.append("resolution", res);
208
 
209
- const r = await fetch("/remove-background", { method:"POST", body:fd });
210
- out.src = URL.createObjectURL(await r.blob());
211
  });
212
  </script>
213
  </body>
214
  </html>
215
  """
216
 
217
-
218
  # ---------------------------------------------------------
219
- # Start server
220
  # ---------------------------------------------------------
221
  if __name__ == "__main__":
222
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
11
  import uvicorn
12
 
13
  # ---------------------------------------------------------
14
+ # Optional HEIC/HEIF
15
  # ---------------------------------------------------------
16
  try:
17
  import pillow_heif
 
26
  os.environ["MKL_NUM_THREADS"] = "1"
27
  torch.set_num_threads(1)
28
 
29
+ # ---------------------------------------------------------
30
+ # Constants
31
+ # ---------------------------------------------------------
32
+ TARGET_SIZE = (512, 512) # Faster inference resolution
33
+ MAX_SIDE = 3000 # Auto-downscale limit for large uploads
34
+
35
  # ---------------------------------------------------------
36
  # Load model
37
  # ---------------------------------------------------------
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
43
 
44
+ print("Loading BiRefNet…")
45
  birefnet = AutoModelForImageSegmentation.from_pretrained(
46
  "ZhengPeng7/BiRefNet",
47
  cache_dir=MODEL_DIR,
 
51
  birefnet.to(device, dtype=dtype).eval()
52
  print("Model ready.")
53
 
54
+ # Thread lock for CPU safety
55
  inference_lock = threading.Lock()
56
 
57
  # ---------------------------------------------------------
 
66
  raise HTTPException(status_code=400, detail=f"Cannot load image from URL: {str(e)}")
67
 
68
 
69
+ def auto_downscale(image: Image.Image) -> Image.Image:
70
  w, h = image.size
71
+ if max(w, h) <= MAX_SIDE:
72
  return image
73
 
74
+ scale = MAX_SIDE / max(w, h)
75
  new_w = int(w * scale)
76
  new_h = int(h * scale)
77
 
78
+ print(f"[INFO] Downscaling {w}×{h} → {new_w}×{new_h}")
79
  return image.resize((new_w, new_h), Image.LANCZOS)
80
 
81
 
82
+ def transform_image(image: Image.Image) -> torch.Tensor:
83
+ image = image.resize(TARGET_SIZE)
 
84
 
85
+ arr = np.array(image).astype(np.float32) / 255.0
86
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
87
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
 
88
 
89
+ arr = (arr - mean) / std
90
  arr = np.transpose(arr, (2, 0, 1))
91
+
92
  tensor = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
93
  return tensor
94
 
95
 
96
+ def run_inference(image: Image.Image) -> Image.Image:
97
  orig_size = image.size
98
+ input_tensor = transform_image(image)
 
99
 
100
  with inference_lock:
101
  with torch.no_grad():
 
108
  image.putalpha(mask)
109
  return image
110
 
 
111
  # ---------------------------------------------------------
112
  # FastAPI app
113
  # ---------------------------------------------------------
114
  app = FastAPI(title="Background Remover API")
115
 
 
116
  # ---------------------------------------------------------
117
+ # POST endpoint only (no GET processing)
118
  # ---------------------------------------------------------
119
  @app.post("/remove-background")
120
  async def remove_background(
121
  file: UploadFile = File(None),
122
+ image_url: str = Form(None)
 
123
  ):
124
  try:
125
+ # load image
126
  if file:
127
  raw = await file.read()
128
  image = Image.open(BytesIO(raw)).convert("RGB")
 
131
  else:
132
  raise HTTPException(status_code=400, detail="Provide file or image_url.")
133
 
134
+ # auto shrink large inputs
135
  image = auto_downscale(image)
136
 
137
+ # remove background
138
+ result = run_inference(image)
139
 
140
+ # return PNG
141
  buf = BytesIO()
142
  result.save(buf, format="PNG", optimize=True)
143
  buf.seek(0)
 
149
  except Exception as e:
150
  raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
151
 
 
152
  # ---------------------------------------------------------
153
+ # UI for POST method testing only
154
  # ---------------------------------------------------------
155
  @app.get("/", response_class=HTMLResponse)
156
  async def ui():
157
  return """
158
  <html>
159
  <head>
160
+ <title>Background Remover – Test Tool</title>
161
  <link rel='stylesheet' href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
162
  </head>
163
+
164
  <body style='background:#f8f9fa;padding-top:40px;'>
165
  <div class='container text-center'>
166
+ <h2>POST Method Test Panel</h2>
167
+ <p>This UI only sends POST requests to <code>/remove-background</code>.</p>
168
 
169
+ <h5>Test with File Upload:</h5>
170
+ <form id='uploadForm' enctype='multipart/form-data'>
171
+ <input class='form-control mb-2' type='file' id='fileInput'>
172
+ <button class='btn btn-primary'>Send POST</button>
173
  </form>
174
 
175
+ <div class='my-4'>OR</div>
176
 
177
+ <h5>Test with Image URL:</h5>
178
+ <form id='urlForm'>
179
+ <input class='form-control mb-2' id='urlInput' placeholder='https://example.com/image.jpg'>
180
+ <button class='btn btn-success'>Send POST</button>
181
  </form>
182
 
183
+ <h4 class='mt-4'>Output:</h4>
184
+ <img id='outputImg' style='max-width:90%;border-radius:10px;'/>
185
  </div>
186
 
187
  <script>
188
+ const outputImg = document.getElementById("outputImg");
189
 
190
+ document.getElementById("uploadForm").addEventListener("submit", async e => {
191
  e.preventDefault();
192
  const file = document.getElementById("fileInput").files[0];
193
+ if (!file) return alert("Select a file first");
 
194
 
195
  const fd = new FormData();
196
  fd.append("file", file);
 
197
 
198
+ const r = await fetch("/remove-background", {method:"POST", body:fd});
199
+ outputImg.src = URL.createObjectURL(await r.blob());
200
  });
201
 
202
+ document.getElementById("urlForm").addEventListener("submit", async e => {
203
  e.preventDefault();
204
  const url = document.getElementById("urlInput").value.trim();
205
  if (!url) return alert("Enter an image URL");
 
206
 
207
  const fd = new FormData();
208
  fd.append("image_url", url);
 
209
 
210
+ const r = await fetch("/remove-background", {method:"POST", body:fd});
211
+ outputImg.src = URL.createObjectURL(await r.blob());
212
  });
213
  </script>
214
  </body>
215
  </html>
216
  """
217
 
 
218
  # ---------------------------------------------------------
219
+ # Run app
220
  # ---------------------------------------------------------
221
  if __name__ == "__main__":
222
  uvicorn.run(app, host="0.0.0.0", port=7860)