videopix commited on
Commit
d37a6e5
·
verified ·
1 Parent(s): 6e265be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -71
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  from io import BytesIO
3
- import uvicorn
4
 
5
  import numpy as np
6
  import requests
@@ -11,19 +10,19 @@ from fastapi.middleware.cors import CORSMiddleware
11
  from PIL import Image
12
  from transformers import AutoModelForImageSegmentation
13
 
14
- # -------------------------
15
- # Optional HEIC/HEIF Support
16
- # -------------------------
17
  try:
18
  import pillow_heif
19
  pillow_heif.register_heif_opener()
20
  print("HEIC/HEIF supported")
21
- except ImportError:
22
  print("Install pillow-heif for HEIC support")
23
 
24
- # -------------------------
25
- # Model Setup
26
- # -------------------------
27
  MODEL_DIR = "models/BiRefNet"
28
  os.makedirs(MODEL_DIR, exist_ok=True)
29
 
@@ -38,130 +37,183 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
38
  revision="main"
39
  )
40
  birefnet.to(device, dtype=dtype).eval()
41
- print("Model loaded successfully.")
42
 
43
- # -------------------------
44
  # FastAPI App
45
- # -------------------------
46
  app = FastAPI(title="Background Remover API")
47
 
48
- # Allow API calls from mobile apps, web apps, backend servers
49
  app.add_middleware(
50
  CORSMiddleware,
51
  allow_origins=["*"],
52
- allow_credentials=True,
53
  allow_methods=["*"],
54
  allow_headers=["*"],
55
  )
56
 
57
- # -------------------------
58
  # Utility Functions
59
- # -------------------------
60
  def load_image_from_url(url: str):
61
  try:
62
  resp = requests.get(url, timeout=10)
63
  resp.raise_for_status()
64
  return Image.open(BytesIO(resp.content)).convert("RGB")
65
  except Exception as e:
66
- raise HTTPException(status_code=400, detail=f"Invalid image URL: {str(e)}")
67
 
68
  def transform_image(image: Image.Image, resolution: int):
69
  image = image.resize((resolution, resolution))
70
- arr = np.array(image).astype(np.float32) / 255.0
71
- mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
72
- std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
73
- arr = (arr - mean) / std
74
- arr = arr.transpose((2, 0, 1))
75
- return torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
76
 
77
  def process_image(image: Image.Image, resolution: int):
78
- orig = image.size
79
- t = transform_image(image, resolution)
 
80
  with torch.no_grad():
81
- pred = birefnet(t)[-1].sigmoid().cpu()[0, 0]
 
 
 
82
 
83
- mask = Image.fromarray((pred.numpy() * 255).astype("uint8")).resize(orig)
84
  image = image.convert("RGBA")
85
  image.putalpha(mask)
86
  return image
87
 
88
-
89
- # -------------------------
90
- # GET + POST API SUPPORT
91
- # -------------------------
92
  @app.api_route("/remove-background", methods=["GET", "POST"])
93
  async def remove_background(
94
  file: UploadFile = File(None),
95
  image_url: str = Form(None),
96
  resolution: int = Form(512),
97
- get_url: str = Query(None, description="Use for GET request: ?get_url=https://..."),
98
- get_res: int = Query(512, description="Resolution for GET request"),
99
  ):
100
- """
101
- Supports:
102
- - POST file upload
103
- - POST image_url
104
- - GET request using: /remove-background?get_url=...&get_res=512
105
- """
106
-
107
  try:
108
- # Determine GET or POST mode
109
  if get_url:
110
- image = load_image_from_url(get_url)
111
  resolution = get_res
112
 
 
113
  elif file:
114
- content = await file.read()
115
- if not content:
116
- raise HTTPException(status_code=400, detail="Uploaded file is empty.")
117
- image = Image.open(BytesIO(content)).convert("RGB")
118
 
 
119
  elif image_url:
120
- image = load_image_from_url(image_url)
121
 
122
  else:
123
- raise HTTPException(status_code=400, detail="No image provided.")
124
 
125
- result = process_image(image, resolution)
126
- buf = BytesIO()
127
- result.save(buf, format="PNG")
128
- buf.seek(0)
 
129
 
130
  return StreamingResponse(
131
- buf,
132
  media_type="image/png",
133
- headers={"Content-Disposition": "inline; filename=result.png"}
134
  )
135
 
 
 
136
  except Exception as e:
137
  raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
138
 
139
-
140
- # -------------------------
141
- # Favicon handler (fix 404)
142
- # -------------------------
143
  @app.get("/favicon.ico")
144
  async def favicon():
145
  return HTMLResponse("")
146
 
147
-
148
- # -------------------------
149
- # Test Page
150
- # -------------------------
151
  @app.get("/", response_class=HTMLResponse)
152
  async def index():
153
  return """
154
- <h2>Background Remover API Live</h2>
155
- <p>POST endpoint: <code>/remove-background</code></p>
156
- <p>GET example:</p>
157
- <pre>/remove-background?get_url=https://example.com/img.jpg&get_res=512</pre>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  """
159
 
160
-
161
-
162
- # -------------------------
163
  # Run App
164
- # -------------------------
165
  if __name__ == "__main__":
166
- uvicorn.run(app, host="0.0.0.0", port=7860)
167
-
 
1
  import os
2
  from io import BytesIO
 
3
 
4
  import numpy as np
5
  import requests
 
10
  from PIL import Image
11
  from transformers import AutoModelForImageSegmentation
12
 
13
+ # ---------------------------------------------------------
14
+ # Optional HEIC Support
15
+ # ---------------------------------------------------------
16
  try:
17
  import pillow_heif
18
  pillow_heif.register_heif_opener()
19
  print("HEIC/HEIF supported")
20
+ except:
21
  print("Install pillow-heif for HEIC support")
22
 
23
+ # ---------------------------------------------------------
24
+ # Load Model
25
+ # ---------------------------------------------------------
26
  MODEL_DIR = "models/BiRefNet"
27
  os.makedirs(MODEL_DIR, exist_ok=True)
28
 
 
37
  revision="main"
38
  )
39
  birefnet.to(device, dtype=dtype).eval()
40
+ print("Model Ready")
41
 
42
+ # ---------------------------------------------------------
43
  # FastAPI App
44
+ # ---------------------------------------------------------
45
  app = FastAPI(title="Background Remover API")
46
 
47
+ # Allow requests from any app
48
  app.add_middleware(
49
  CORSMiddleware,
50
  allow_origins=["*"],
 
51
  allow_methods=["*"],
52
  allow_headers=["*"],
53
  )
54
 
55
+ # ---------------------------------------------------------
56
  # Utility Functions
57
+ # ---------------------------------------------------------
58
  def load_image_from_url(url: str):
59
  try:
60
  resp = requests.get(url, timeout=10)
61
  resp.raise_for_status()
62
  return Image.open(BytesIO(resp.content)).convert("RGB")
63
  except Exception as e:
64
+ raise HTTPException(status_code=400, detail=f"Error loading image URL: {str(e)}")
65
 
66
  def transform_image(image: Image.Image, resolution: int):
67
  image = image.resize((resolution, resolution))
68
+ img = np.array(image).astype("float32") / 255.0
69
+ mean = np.array([0.485, 0.456, 0.406], dtype="float32")
70
+ std = np.array([0.229, 0.224, 0.225], dtype="float32")
71
+ img = (img - mean) / std
72
+ img = img.transpose((2, 0, 1))
73
+ return torch.from_numpy(img).unsqueeze(0).to(device=device, dtype=dtype)
74
 
75
  def process_image(image: Image.Image, resolution: int):
76
+ original_size = image.size
77
+ tensor = transform_image(image, resolution)
78
+
79
  with torch.no_grad():
80
+ mask_pred = birefnet(tensor)[-1].sigmoid().cpu()[0, 0]
81
+
82
+ mask = Image.fromarray((mask_pred.numpy() * 255).astype("uint8"))
83
+ mask = mask.resize(original_size)
84
 
 
85
  image = image.convert("RGBA")
86
  image.putalpha(mask)
87
  return image
88
 
89
+ # ---------------------------------------------------------
90
+ # GET + POST endpoint
91
+ # ---------------------------------------------------------
 
92
  @app.api_route("/remove-background", methods=["GET", "POST"])
93
  async def remove_background(
94
  file: UploadFile = File(None),
95
  image_url: str = Form(None),
96
  resolution: int = Form(512),
97
+ get_url: str = Query(None),
98
+ get_res: int = Query(512),
99
  ):
 
 
 
 
 
 
 
100
  try:
101
+ # GET mode: /remove-background?get_url=...&get_res=512
102
  if get_url:
103
+ img = load_image_from_url(get_url)
104
  resolution = get_res
105
 
106
+ # POST mode - file upload
107
  elif file:
108
+ data = await file.read()
109
+ if not data:
110
+ raise HTTPException(status_code=400, detail="Empty file")
111
+ img = Image.open(BytesIO(data)).convert("RGB")
112
 
113
+ # POST mode - URL
114
  elif image_url:
115
+ img = load_image_from_url(image_url)
116
 
117
  else:
118
+ raise HTTPException(status_code=400, detail="No image provided")
119
 
120
+ result = process_image(img, resolution)
121
+
122
+ buffer = BytesIO()
123
+ result.save(buffer, format="PNG")
124
+ buffer.seek(0)
125
 
126
  return StreamingResponse(
127
+ buffer,
128
  media_type="image/png",
129
+ headers={"Content-Disposition": "inline; filename=result.png"},
130
  )
131
 
132
+ except HTTPException:
133
+ raise
134
  except Exception as e:
135
  raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
136
 
137
+ # ---------------------------------------------------------
138
+ # Favicon (stop 404 logs)
139
+ # ---------------------------------------------------------
 
140
  @app.get("/favicon.ico")
141
  async def favicon():
142
  return HTMLResponse("")
143
 
144
+ # ---------------------------------------------------------
145
+ # UI for testing (POST method)
146
+ # ---------------------------------------------------------
 
147
  @app.get("/", response_class=HTMLResponse)
148
  async def index():
149
  return """
150
+ <!DOCTYPE html>
151
+ <html>
152
+ <head>
153
+ <title>Background Remover API Tester</title>
154
+ <style>
155
+ body { font-family: Arial; padding: 20px; }
156
+ img { max-width: 100%; margin-top: 20px; border-radius: 10px; }
157
+ </style>
158
+ </head>
159
+ <body>
160
+ <h2>Background Remover API Test (POST)</h2>
161
+
162
+ <h4>Upload Image</h4>
163
+ <form id="uploadForm" enctype="multipart/form-data">
164
+ <input type="file" id="file" name="file"><br><br>
165
+ <label>Resolution:</label>
166
+ <input type="number" id="resFile" value="512"><br><br>
167
+ <button type="submit">Remove Background</button>
168
+ </form>
169
+
170
+ <h4>Or use Image URL</h4>
171
+ <form id="urlForm">
172
+ <input type="text" id="imgUrl" placeholder="https://example.com/image.jpg" size="50"><br><br>
173
+ <label>Resolution:</label>
174
+ <input type="number" id="resUrl" value="512"><br><br>
175
+ <button type="submit">Remove Background</button>
176
+ </form>
177
+
178
+ <h3>Result:</h3>
179
+ <img id="result" />
180
+
181
+ <script>
182
+ const resultImg = document.getElementById("result");
183
+
184
+ document.getElementById("uploadForm").onsubmit = async (e) => {
185
+ e.preventDefault();
186
+ const file = document.getElementById("file").files[0];
187
+ const res = document.getElementById("resFile").value;
188
+
189
+ const fd = new FormData();
190
+ fd.append("file", file);
191
+ fd.append("resolution", res);
192
+
193
+ const r = await fetch("/remove-background", { method: "POST", body: fd });
194
+ resultImg.src = URL.createObjectURL(await r.blob());
195
+ };
196
+
197
+ document.getElementById("urlForm").onsubmit = async (e) => {
198
+ e.preventDefault();
199
+ const url = document.getElementById("imgUrl").value;
200
+ const res = document.getElementById("resUrl").value;
201
+
202
+ const fd = new FormData();
203
+ fd.append("image_url", url);
204
+ fd.append("resolution", res);
205
+
206
+ const r = await fetch("/remove-background", { method: "POST", body: fd });
207
+ resultImg.src = URL.createObjectURL(await r.blob());
208
+ };
209
+ </script>
210
+ </body>
211
+ </html>
212
  """
213
 
214
+ # ---------------------------------------------------------
 
 
215
  # Run App
216
+ # ---------------------------------------------------------
217
  if __name__ == "__main__":
218
+ import uvicorn
219
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)