mobisoft commited on
Commit
49f1016
·
verified ·
1 Parent(s): e876aaa

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +239 -0
  3. butterfly.jpg +3 -0
  4. requirements.txt +12 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ butterfly.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import torch
4
+ import numpy as np
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
6
+ from fastapi.responses import StreamingResponse, HTMLResponse, RedirectResponse, JSONResponse
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import requests
10
+ from transformers import AutoModelForImageSegmentation
11
+ import uvicorn
12
+
13
+ # ---------------------------------------------------------
14
+ # Optional HEIC/HEIF
15
+ # ---------------------------------------------------------
16
+ try:
17
+ import pillow_heif
18
+ pillow_heif.register_heif_opener()
19
+ except ImportError:
20
+ pass
21
+
22
+ # ---------------------------------------------------------
23
+ # Performance settings for HF CPU
24
+ # ---------------------------------------------------------
25
+ os.environ["OMP_NUM_THREADS"] = "1"
26
+ os.environ["MKL_NUM_THREADS"] = "1"
27
+ torch.set_num_threads(1)
28
+
29
+ # ---------------------------------------------------------
30
+ # Constants
31
+ # ---------------------------------------------------------
32
+ TARGET_SIZE = (512, 512) # Faster inference
33
+ MAX_SIDE = 3000 # Auto-downscale for huge uploads
34
+
35
+ # ---------------------------------------------------------
36
+ # Load model
37
+ # ---------------------------------------------------------
38
+ MODEL_DIR = "models/BiRefNet"
39
+ os.makedirs(MODEL_DIR, exist_ok=True)
40
+
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
43
+
44
+ print("Loading BiRefNet…")
45
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
46
+ "ZhengPeng7/BiRefNet",
47
+ cache_dir=MODEL_DIR,
48
+ trust_remote_code=True,
49
+ revision="main",
50
+ )
51
+ birefnet.to(device, dtype=dtype).eval()
52
+ print("Model ready.")
53
+
54
+ lock = threading.Lock()
55
+
56
+ # ---------------------------------------------------------
57
+ # Helpers
58
+ # ---------------------------------------------------------
59
+ def load_image_from_url(url: str) -> Image.Image:
60
+ try:
61
+ r = requests.get(url, timeout=10)
62
+ r.raise_for_status()
63
+ return Image.open(BytesIO(r.content)).convert("RGB")
64
+ except Exception:
65
+ raise HTTPException(status_code=400, detail="Invalid image URL")
66
+
67
+
68
+ def auto_downscale(img: Image.Image) -> Image.Image:
69
+ w, h = img.size
70
+ if max(w, h) <= MAX_SIDE:
71
+ return img
72
+
73
+ scale = MAX_SIDE / max(w, h)
74
+ new_w = int(w * scale)
75
+ new_h = int(h * scale)
76
+
77
+ print(f"[INFO] Downscaling {w}×{h} → {new_w}×{new_h}")
78
+ return img.resize((new_w, new_h), Image.LANCZOS)
79
+
80
+
81
+ def transform(img: Image.Image) -> torch.Tensor:
82
+ img = img.resize(TARGET_SIZE)
83
+
84
+ arr = np.array(img).astype(np.float32) / 255.0
85
+ mean = np.array([0.485, 0.456, 0.406])
86
+ std = np.array([0.229, 0.224, 0.225])
87
+ arr = (arr - mean) / std
88
+ arr = np.transpose(arr, (2, 0, 1))
89
+
90
+ t = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
91
+ return t
92
+
93
+
94
+ def run_inference(img: Image.Image) -> Image.Image:
95
+ orig_size = img.size
96
+ tensor = transform(img)
97
+
98
+ with lock:
99
+ with torch.no_grad():
100
+ pred = birefnet(tensor)[-1].sigmoid().cpu()[0, 0]
101
+
102
+ mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
103
+
104
+ img = img.convert("RGBA")
105
+ img.putalpha(mask)
106
+ return img
107
+
108
+
109
+ # ---------------------------------------------------------
110
+ # FastAPI app
111
+ # ---------------------------------------------------------
112
+ app = FastAPI(title="Background Remover API")
113
+
114
+ # ---------------------------------------------------------
115
+ # Redirect GET → POST logic
116
+ # ---------------------------------------------------------
117
+ @app.get("/remove-background")
118
+ async def redirect_to_post():
119
+ return JSONResponse(
120
+ {"detail": "This endpoint only supports POST. Use POST /remove-background"},
121
+ status_code=405
122
+ )
123
+
124
+ # ---------------------------------------------------------
125
+ # Main POST endpoint
126
+ # ---------------------------------------------------------
127
+ @app.post("/remove-background")
128
+ async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
129
+ try:
130
+ if file:
131
+ raw = await file.read()
132
+ img = Image.open(BytesIO(raw)).convert("RGB")
133
+ elif image_url:
134
+ img = load_image_from_url(image_url)
135
+ else:
136
+ raise HTTPException(status_code=400, detail="Upload file or image_url required")
137
+
138
+ img = auto_downscale(img)
139
+ result = run_inference(img)
140
+
141
+ buf = BytesIO()
142
+ result.save(buf, format="PNG")
143
+ buf.seek(0)
144
+
145
+ return StreamingResponse(buf, media_type="image/png")
146
+
147
+ except Exception as e:
148
+ raise HTTPException(status_code=500, detail=str(e))
149
+
150
+
151
+ # ---------------------------------------------------------
152
+ # UI: Show INPUT + OUTPUT (big preview)
153
+ # ---------------------------------------------------------
154
+ @app.get("/", response_class=HTMLResponse)
155
+ async def ui():
156
+ return """
157
+ <html>
158
+ <head>
159
+ <title>Background Remover – Test UI</title>
160
+ <link rel='stylesheet'
161
+ href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
162
+ </head>
163
+ <body class='bg-light'>
164
+ <div class='container py-4 text-center'>
165
+
166
+ <h2 class='mb-4'>API Test Panel (POST Only)</h2>
167
+
168
+ <div class='row'>
169
+ <div class='col-md-6'>
170
+ <h5>Input Image</h5>
171
+ <img id='inputImg' style='max-width:100%; border-radius:10px;'>
172
+ </div>
173
+ <div class='col-md-6'>
174
+ <h5>Output Image</h5>
175
+ <img id='outputImg' style='max-width:100%; border-radius:10px;'>
176
+ </div>
177
+ </div>
178
+
179
+ <hr>
180
+
181
+ <h4>Upload Test</h4>
182
+ <form id="uploadForm" enctype='multipart/form-data'>
183
+ <input type='file' id='fileInput' class='form-control mb-3'>
184
+ <button class='btn btn-primary'>Send POST</button>
185
+ </form>
186
+
187
+ <hr>
188
+
189
+ <h4>URL Test</h4>
190
+ <form id='urlForm'>
191
+ <input id='urlInput' class='form-control mb-3' placeholder='https://example.com/image.jpg'>
192
+ <button class='btn btn-success'>Send POST</button>
193
+ </form>
194
+ </div>
195
+
196
+ <script>
197
+ const inputImg = document.getElementById("inputImg");
198
+ const outputImg = document.getElementById("outputImg");
199
+
200
+ // FILE TEST
201
+ document.getElementById("uploadForm").addEventListener("submit", async e => {
202
+ e.preventDefault();
203
+ const file = document.getElementById("fileInput").files[0];
204
+ if (!file) return alert("Select a file first.");
205
+
206
+ inputImg.src = URL.createObjectURL(file);
207
+
208
+ const fd = new FormData();
209
+ fd.append("file", file);
210
+
211
+ const r = await fetch("/remove-background", { method:"POST", body:fd });
212
+ outputImg.src = URL.createObjectURL(await r.blob());
213
+ });
214
+
215
+ // URL TEST
216
+ document.getElementById("urlForm").addEventListener("submit", async e => {
217
+ e.preventDefault();
218
+ const url = document.getElementById("urlInput").value.trim();
219
+ if (!url) return alert("Enter an image URL first.");
220
+
221
+ inputImg.src = url;
222
+
223
+ const fd = new FormData();
224
+ fd.append("image_url", url);
225
+
226
+ const r = await fetch("/remove-background", { method:"POST", body:fd });
227
+ outputImg.src = URL.createObjectURL(await r.blob());
228
+ });
229
+ </script>
230
+
231
+ </body>
232
+ </html>
233
+ """
234
+
235
+ # ---------------------------------------------------------
236
+ # Run app
237
+ # ---------------------------------------------------------
238
+ if __name__ == "__main__":
239
+ uvicorn.run(app, host="0.0.0.0", port=7860)
butterfly.jpg ADDED

Git LFS Details

  • SHA256: a90552572374e49e2f198a8d7a11eeee6e733013fe884f2dda268670e6c788e7
  • Pointer size: 131 Bytes
  • Size of remote file: 196 kB
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ transformers>=4.39.1
3
+ pillow>=10.0.0
4
+ pillow-heif>=0.15.0
5
+ numpy>=1.25.0
6
+ uvicorn>=0.23.0
7
+ fastapi>=0.102.0
8
+ loadimg>=0.1.1
9
+ timm>=0.9.2
10
+ kornia>=0.7.0
11
+ einops>=0.6.1
12
+ requests>=2.31.0