videopix commited on
Commit
80cc4e7
·
verified ·
1 Parent(s): e62a63f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -115
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import os
 
 
 
 
 
2
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Query
3
  from fastapi.responses import StreamingResponse, HTMLResponse
 
4
  from PIL import Image
5
- import torch
6
- import numpy as np
7
  from transformers import AutoModelForImageSegmentation
8
- from io import BytesIO
9
- import requests
10
- import uvicorn
11
 
12
  # -------------------------
13
  # Optional HEIC/HEIF Support
@@ -15,19 +16,20 @@ import uvicorn
15
  try:
16
  import pillow_heif
17
  pillow_heif.register_heif_opener()
18
- print("HEIC/HEIF format supported.")
19
  except ImportError:
20
- print("⚠️ Install pillow-heif for HEIC support: pip install pillow-heif")
21
 
22
  # -------------------------
23
  # Model Setup
24
  # -------------------------
25
  MODEL_DIR = "models/BiRefNet"
26
  os.makedirs(MODEL_DIR, exist_ok=True)
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
29
 
30
- print("Loading BiRefNet model...")
31
  birefnet = AutoModelForImageSegmentation.from_pretrained(
32
  "ZhengPeng7/BiRefNet",
33
  cache_dir=MODEL_DIR,
@@ -42,154 +44,118 @@ print("Model loaded successfully.")
42
  # -------------------------
43
  app = FastAPI(title="Background Remover API")
44
 
 
 
 
 
 
 
 
 
 
45
  # -------------------------
46
  # Utility Functions
47
  # -------------------------
48
- def load_image_from_url(url: str) -> Image.Image:
49
  try:
50
- response = requests.get(url, timeout=10)
51
- response.raise_for_status()
52
- return Image.open(BytesIO(response.content)).convert("RGB")
53
  except Exception as e:
54
- raise HTTPException(status_code=400, detail=f"Error loading image from URL: {str(e)}")
55
 
56
- def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor:
57
  image = image.resize((resolution, resolution))
58
  arr = np.array(image).astype(np.float32) / 255.0
59
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
60
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
61
  arr = (arr - mean) / std
62
- arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
63
- tensor = torch.from_numpy(arr).unsqueeze(0).to(dtype).to(device)
64
- return tensor
65
 
66
- def process_image(image: Image.Image, resolution: int = 512) -> Image.Image:
67
- orig_size = image.size
68
- input_tensor = transform_image(image, resolution)
69
  with torch.no_grad():
70
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
71
- pred = preds[0, 0]
72
- mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
73
  image = image.convert("RGBA")
74
  image.putalpha(mask)
75
  return image
76
 
 
77
  # -------------------------
78
- # /remove-background Endpoint
79
  # -------------------------
80
- @app.post("/remove-background")
81
  async def remove_background(
82
  file: UploadFile = File(None),
83
  image_url: str = Form(None),
84
- resolution: int = Form(512)
 
 
85
  ):
86
  """
87
- Remove background from an image.
88
- Accepts a file upload or image URL.
89
- Optional resolution (default 512) for faster inference.
90
- Returns PNG with transparent background.
91
  """
 
92
  try:
93
- if file:
94
- image = Image.open(BytesIO(await file.read())).convert("RGB")
 
 
 
 
 
 
 
 
 
95
  elif image_url:
96
  image = load_image_from_url(image_url)
 
97
  else:
98
- raise HTTPException(status_code=400, detail="Provide either 'file' or 'image_url'.")
99
 
100
  result = process_image(image, resolution)
101
  buf = BytesIO()
102
  result.save(buf, format="PNG")
103
  buf.seek(0)
104
- return StreamingResponse(buf, media_type="image/png")
 
 
 
 
 
 
105
  except Exception as e:
106
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
107
 
108
  # -------------------------
109
- # Developer Test Page (Bootstrap)
110
  # -------------------------
111
  @app.get("/", response_class=HTMLResponse)
112
  async def index():
113
- html = """
114
- <!DOCTYPE html>
115
- <html lang="en">
116
- <head>
117
- <meta charset="UTF-8">
118
- <meta name="viewport" content="width=device-width, initial-scale=1">
119
- <title>Background Remover API Test</title>
120
- <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet">
121
- <style>
122
- body { background-color: #f8f9fa; padding-top: 40px; }
123
- .container { max-width: 700px; }
124
- img { max-width: 100%; margin-top: 20px; border-radius: 10px; }
125
- </style>
126
- </head>
127
- <body>
128
- <div class="container text-center">
129
- <h2 class="mb-4">Background Remover API Tester</h2>
130
- <form id="uploadForm" class="mb-4" enctype="multipart/form-data">
131
- <div class="mb-3">
132
- <label for="fileInput" class="form-label">Upload Image (any format, e.g. JPG, PNG, HEIC):</label>
133
- <input class="form-control" type="file" id="fileInput" name="file" accept="image/*">
134
- </div>
135
- <div class="mb-3">
136
- <label for="resInput" class="form-label">Resolution (default 512):</label>
137
- <input class="form-control" type="number" id="resInput" name="resolution" value="512" min="64" max="2048">
138
- </div>
139
- <button class="btn btn-primary" type="submit">Remove Background</button>
140
- </form>
141
- <div class="mb-4">OR</div>
142
- <form id="urlForm" class="mb-4">
143
- <div class="mb-3">
144
- <label for="urlInput" class="form-label">Enter Image URL:</label>
145
- <input class="form-control" type="text" id="urlInput" placeholder="https://example.com/image.jpg">
146
- </div>
147
- <div class="mb-3">
148
- <label for="urlResInput" class="form-label">Resolution (default 512):</label>
149
- <input class="form-control" type="number" id="urlResInput" name="resolution" value="512" min="64" max="2048">
150
- </div>
151
- <button class="btn btn-success" type="submit">Remove Background</button>
152
- </form>
153
- <div id="resultContainer" class="mt-4">
154
- <h5>Result:</h5>
155
- <img id="resultImg" src="" alt="">
156
- </div>
157
- </div>
158
- <script>
159
- const uploadForm = document.getElementById("uploadForm");
160
- const urlForm = document.getElementById("urlForm");
161
- const resultImg = document.getElementById("resultImg");
162
-
163
- uploadForm.addEventListener("submit", async e => {
164
- e.preventDefault();
165
- const fileInput = document.getElementById("fileInput");
166
- const res = document.getElementById("resInput").value || 512;
167
- if (!fileInput.files.length) return alert("Please select a file!");
168
- const formData = new FormData();
169
- formData.append("file", fileInput.files[0]);
170
- formData.append("resolution", res);
171
- const response = await fetch("/remove-background", { method: "POST", body: formData });
172
- const blob = await response.blob();
173
- resultImg.src = URL.createObjectURL(blob);
174
- });
175
-
176
- urlForm.addEventListener("submit", async e => {
177
- e.preventDefault();
178
- const url = document.getElementById("urlInput").value.trim();
179
- const res = document.getElementById("urlResInput").value || 512;
180
- if (!url) return alert("Please enter an image URL!");
181
- const formData = new FormData();
182
- formData.append("image_url", url);
183
- formData.append("resolution", res);
184
- const response = await fetch("/remove-background", { method: "POST", body: formData });
185
- const blob = await response.blob();
186
- resultImg.src = URL.createObjectURL(blob);
187
- });
188
- </script>
189
- </body>
190
- </html>
191
  """
192
- return HTMLResponse(html)
193
 
194
  # -------------------------
195
  # Run App
 
1
  import os
2
+ from io import BytesIO
3
+
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Query
8
  from fastapi.responses import StreamingResponse, HTMLResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
  from PIL import Image
 
 
11
  from transformers import AutoModelForImageSegmentation
 
 
 
12
 
13
  # -------------------------
14
  # Optional HEIC/HEIF Support
 
16
  try:
17
  import pillow_heif
18
  pillow_heif.register_heif_opener()
19
+ print("HEIC/HEIF supported")
20
  except ImportError:
21
+ print("Install pillow-heif for HEIC support")
22
 
23
  # -------------------------
24
  # Model Setup
25
  # -------------------------
26
  MODEL_DIR = "models/BiRefNet"
27
  os.makedirs(MODEL_DIR, exist_ok=True)
28
+
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
31
 
32
+ print("Loading BiRefNet...")
33
  birefnet = AutoModelForImageSegmentation.from_pretrained(
34
  "ZhengPeng7/BiRefNet",
35
  cache_dir=MODEL_DIR,
 
44
  # -------------------------
45
  app = FastAPI(title="Background Remover API")
46
 
47
+ # Allow API calls from mobile apps, web apps, backend servers
48
+ app.add_middleware(
49
+ CORSMiddleware,
50
+ allow_origins=["*"],
51
+ allow_credentials=True,
52
+ allow_methods=["*"],
53
+ allow_headers=["*"],
54
+ )
55
+
56
  # -------------------------
57
  # Utility Functions
58
  # -------------------------
59
+ def load_image_from_url(url: str):
60
  try:
61
+ resp = requests.get(url, timeout=10)
62
+ resp.raise_for_status()
63
+ return Image.open(BytesIO(resp.content)).convert("RGB")
64
  except Exception as e:
65
+ raise HTTPException(status_code=400, detail=f"Invalid image URL: {str(e)}")
66
 
67
+ def transform_image(image: Image.Image, resolution: int):
68
  image = image.resize((resolution, resolution))
69
  arr = np.array(image).astype(np.float32) / 255.0
70
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
71
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
72
  arr = (arr - mean) / std
73
+ arr = arr.transpose((2, 0, 1))
74
+ return torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
 
75
 
76
+ def process_image(image: Image.Image, resolution: int):
77
+ orig = image.size
78
+ t = transform_image(image, resolution)
79
  with torch.no_grad():
80
+ pred = birefnet(t)[-1].sigmoid().cpu()[0, 0]
81
+
82
+ mask = Image.fromarray((pred.numpy() * 255).astype("uint8")).resize(orig)
83
  image = image.convert("RGBA")
84
  image.putalpha(mask)
85
  return image
86
 
87
+
88
  # -------------------------
89
+ # GET + POST API SUPPORT
90
  # -------------------------
91
+ @app.api_route("/remove-background", methods=["GET", "POST"])
92
  async def remove_background(
93
  file: UploadFile = File(None),
94
  image_url: str = Form(None),
95
+ resolution: int = Form(512),
96
+ get_url: str = Query(None, description="Use for GET request: ?get_url=https://..."),
97
+ get_res: int = Query(512, description="Resolution for GET request"),
98
  ):
99
  """
100
+ Supports:
101
+ - POST file upload
102
+ - POST image_url
103
+ - GET request using: /remove-background?get_url=...&get_res=512
104
  """
105
+
106
  try:
107
+ # Determine GET or POST mode
108
+ if get_url:
109
+ image = load_image_from_url(get_url)
110
+ resolution = get_res
111
+
112
+ elif file:
113
+ content = await file.read()
114
+ if not content:
115
+ raise HTTPException(status_code=400, detail="Uploaded file is empty.")
116
+ image = Image.open(BytesIO(content)).convert("RGB")
117
+
118
  elif image_url:
119
  image = load_image_from_url(image_url)
120
+
121
  else:
122
+ raise HTTPException(status_code=400, detail="No image provided.")
123
 
124
  result = process_image(image, resolution)
125
  buf = BytesIO()
126
  result.save(buf, format="PNG")
127
  buf.seek(0)
128
+
129
+ return StreamingResponse(
130
+ buf,
131
+ media_type="image/png",
132
+ headers={"Content-Disposition": "inline; filename=result.png"}
133
+ )
134
+
135
  except Exception as e:
136
+ raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
137
+
138
+
139
+ # -------------------------
140
+ # Favicon handler (fix 404)
141
+ # -------------------------
142
+ @app.get("/favicon.ico")
143
+ async def favicon():
144
+ return HTMLResponse("")
145
+
146
 
147
  # -------------------------
148
+ # Test Page
149
  # -------------------------
150
  @app.get("/", response_class=HTMLResponse)
151
  async def index():
152
+ return """
153
+ <h2>Background Remover API Live</h2>
154
+ <p>POST endpoint: <code>/remove-background</code></p>
155
+ <p>GET example:</p>
156
+ <pre>/remove-background?get_url=https://example.com/img.jpg&get_res=512</pre>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  """
158
+
159
 
160
  # -------------------------
161
  # Run App