mobisoft commited on
Commit
72cef1f
Β·
verified Β·
1 Parent(s): 562254a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -86
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import torch
3
  import numpy as np
4
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
5
- from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
6
  from PIL import Image
7
  from io import BytesIO
8
  import requests
@@ -10,27 +10,18 @@ from transformers import AutoModelForImageSegmentation
10
  import uvicorn
11
 
12
  # ---------------------------------------------------------
13
- # Optional HEIC support
14
- # ---------------------------------------------------------
15
- try:
16
- import pillow_heif
17
- pillow_heif.register_heif_opener()
18
- except ImportError:
19
- pass
20
-
21
- # ---------------------------------------------------------
22
- # Performance tuning (CPU)
23
  # ---------------------------------------------------------
24
  os.environ["OMP_NUM_THREADS"] = "1"
25
  os.environ["MKL_NUM_THREADS"] = "1"
26
  torch.set_num_threads(1)
27
 
28
  # ---------------------------------------------------------
29
- # Constants (optimized)
30
  # ---------------------------------------------------------
31
- TARGET_SIZE = (384, 384) # πŸ”₯ faster than 512
32
- MAX_SIDE = 3000
33
  MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
 
34
 
35
  # ---------------------------------------------------------
36
  # Load model
@@ -41,76 +32,53 @@ os.makedirs(MODEL_DIR, exist_ok=True)
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
43
 
44
- print("Loading BiRefNet...")
45
 
46
- birefnet = AutoModelForImageSegmentation.from_pretrained(
47
  "ZhengPeng7/BiRefNet",
48
  cache_dir=MODEL_DIR,
49
- trust_remote_code=True,
50
- revision="main",
51
  )
52
 
53
- birefnet.to(device, dtype=dtype).eval()
54
- print("Model ready.")
 
55
 
56
  # ---------------------------------------------------------
57
- # Helpers
58
  # ---------------------------------------------------------
59
- def load_image_from_url(url: str) -> Image.Image:
60
- try:
61
- r = requests.get(url, timeout=10)
62
- r.raise_for_status()
63
- return Image.open(BytesIO(r.content)).convert("RGB")
64
- except Exception:
65
- raise HTTPException(status_code=400, detail="Invalid image URL")
66
-
67
 
68
- def compress_if_needed(img: Image.Image, raw_bytes: bytes) -> Image.Image:
69
- size = len(raw_bytes)
70
 
71
- if size <= MAX_FILE_SIZE:
 
 
72
  return img
73
 
74
- print(f"[INFO] Compressing {size/1024/1024:.2f}MB image...")
75
 
76
  img = img.convert("RGB")
77
 
 
78
  w, h = img.size
79
- scale = min(1.0, 1600 / max(w, h)) # πŸ”₯ more aggressive
80
  img = img.resize((int(w * scale), int(h * scale)), Image.BILINEAR)
81
 
82
- quality = 85
83
-
84
- while True:
85
- buffer = BytesIO()
86
- img.save(buffer, format="JPEG", quality=quality, optimize=True)
87
- compressed_size = buffer.tell()
88
-
89
- if compressed_size <= MAX_FILE_SIZE or quality <= 40:
90
- print(f"[INFO] Final size: {compressed_size/1024/1024:.2f}MB")
91
- buffer.seek(0)
92
- return Image.open(buffer).convert("RGB")
93
-
94
- quality -= 10
95
-
96
-
97
- def auto_downscale(img: Image.Image) -> Image.Image:
98
- w, h = img.size
99
-
100
- if max(w, h) <= MAX_SIDE:
101
- return img
102
 
103
- scale = MAX_SIDE / max(w, h)
104
- new_size = (int(w * scale), int(h * scale))
105
 
106
- print(f"[INFO] Downscaling {w} β†’ {new_size}")
107
- return img.resize(new_size, Image.BILINEAR)
108
 
109
-
110
- def transform(img: Image.Image) -> torch.Tensor:
111
  img = img.resize(TARGET_SIZE, Image.BILINEAR)
112
 
113
- arr = np.array(img).astype(np.float32) / 255.0
114
 
115
  mean = np.array([0.485, 0.456, 0.406])
116
  std = np.array([0.229, 0.224, 0.225])
@@ -121,12 +89,13 @@ def transform(img: Image.Image) -> torch.Tensor:
121
  return torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
122
 
123
 
124
- def run_inference(img: Image.Image) -> Image.Image:
 
125
  orig_size = img.size
126
  tensor = transform(img)
127
 
128
- with torch.inference_mode(): # πŸ”₯ faster
129
- pred = birefnet(tensor)
130
  pred = pred[-1] if isinstance(pred, (list, tuple)) else pred
131
  pred = pred.sigmoid()[0, 0].cpu()
132
 
@@ -139,39 +108,28 @@ def run_inference(img: Image.Image) -> Image.Image:
139
 
140
 
141
  # ---------------------------------------------------------
142
- # FastAPI app
143
  # ---------------------------------------------------------
144
- app = FastAPI(title="Fast Background Remover API")
145
 
146
- # ---------------------------------------------------------
147
- # Redirect GET
148
- # ---------------------------------------------------------
149
- @app.get("/remove-background")
150
- async def redirect_to_post():
151
- return JSONResponse(
152
- {"detail": "Use POST /remove-background"},
153
- status_code=405
154
- )
155
-
156
- # ---------------------------------------------------------
157
- # Main endpoint
158
- # ---------------------------------------------------------
159
  @app.post("/remove-background")
160
  async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
161
  try:
162
  if file:
163
  raw = await file.read()
164
  img = Image.open(BytesIO(raw)).convert("RGB")
 
 
165
  img = compress_if_needed(img, raw)
166
 
167
  elif image_url:
168
  img = load_image_from_url(image_url)
169
 
170
  else:
171
- raise HTTPException(status_code=400, detail="Provide file or image_url")
172
 
173
- img = auto_downscale(img)
174
- result = run_inference(img)
175
 
176
  buf = BytesIO()
177
  result.save(buf, format="PNG")
@@ -180,16 +138,16 @@ async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
180
  return StreamingResponse(buf, media_type="image/png")
181
 
182
  except Exception as e:
183
- raise HTTPException(status_code=500, detail=str(e))
184
 
185
 
186
  # ---------------------------------------------------------
187
- # UI
188
  # ---------------------------------------------------------
189
  @app.get("/", response_class=HTMLResponse)
190
- async def ui():
191
  return """
192
- <html>
193
  <head>
194
  <title>Fast Background Remover</title>
195
  <link rel='stylesheet'
@@ -265,6 +223,7 @@ async def ui():
265
  """
266
 
267
 
 
268
  # ---------------------------------------------------------
269
  # Run
270
  # ---------------------------------------------------------
 
2
  import torch
3
  import numpy as np
4
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
5
+ from fastapi.responses import StreamingResponse, HTMLResponse
6
  from PIL import Image
7
  from io import BytesIO
8
  import requests
 
10
  import uvicorn
11
 
12
  # ---------------------------------------------------------
13
+ # CPU optimization (important for HF Spaces)
 
 
 
 
 
 
 
 
 
14
  # ---------------------------------------------------------
15
  os.environ["OMP_NUM_THREADS"] = "1"
16
  os.environ["MKL_NUM_THREADS"] = "1"
17
  torch.set_num_threads(1)
18
 
19
  # ---------------------------------------------------------
20
+ # Config (speed focused)
21
  # ---------------------------------------------------------
22
+ TARGET_SIZE = (320, 320) # πŸ”₯ faster inference
 
23
  MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
24
+ MAX_COMPRESS_DIM = 1400 # aggressive resize
25
 
26
  # ---------------------------------------------------------
27
  # Load model
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
34
 
35
+ print("Loading model...")
36
 
37
+ model = AutoModelForImageSegmentation.from_pretrained(
38
  "ZhengPeng7/BiRefNet",
39
  cache_dir=MODEL_DIR,
40
+ trust_remote_code=True
 
41
  )
42
 
43
+ model.to(device, dtype=dtype).eval()
44
+
45
+ print("Model ready")
46
 
47
  # ---------------------------------------------------------
48
+ # Image helpers
49
  # ---------------------------------------------------------
50
+ def load_image_from_url(url: str):
51
+ r = requests.get(url, timeout=10)
52
+ r.raise_for_status()
53
+ return Image.open(BytesIO(r.content)).convert("RGB")
 
 
 
 
54
 
 
 
55
 
56
+ # πŸ”₯ FAST compression (key part)
57
+ def compress_if_needed(img: Image.Image, raw_bytes: bytes):
58
+ if len(raw_bytes) <= MAX_FILE_SIZE:
59
  return img
60
 
61
+ print("[INFO] Compressing image >5MB")
62
 
63
  img = img.convert("RGB")
64
 
65
+ # Resize aggressively
66
  w, h = img.size
67
+ scale = min(1.0, MAX_COMPRESS_DIM / max(w, h))
68
  img = img.resize((int(w * scale), int(h * scale)), Image.BILINEAR)
69
 
70
+ # Reduce quality quickly (no loop β†’ faster)
71
+ buffer = BytesIO()
72
+ img.save(buffer, format="JPEG", quality=70, optimize=True)
73
+ buffer.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ return Image.open(buffer).convert("RGB")
 
76
 
 
 
77
 
78
+ def transform(img):
 
79
  img = img.resize(TARGET_SIZE, Image.BILINEAR)
80
 
81
+ arr = np.asarray(img, dtype=np.float32) / 255.0
82
 
83
  mean = np.array([0.485, 0.456, 0.406])
84
  std = np.array([0.229, 0.224, 0.225])
 
89
  return torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
90
 
91
 
92
+ # πŸ”₯ FAST inference
93
+ def remove_background(img: Image.Image):
94
  orig_size = img.size
95
  tensor = transform(img)
96
 
97
+ with torch.inference_mode():
98
+ pred = model(tensor)
99
  pred = pred[-1] if isinstance(pred, (list, tuple)) else pred
100
  pred = pred.sigmoid()[0, 0].cpu()
101
 
 
108
 
109
 
110
  # ---------------------------------------------------------
111
+ # FastAPI
112
  # ---------------------------------------------------------
113
+ app = FastAPI()
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  @app.post("/remove-background")
116
  async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
117
  try:
118
  if file:
119
  raw = await file.read()
120
  img = Image.open(BytesIO(raw)).convert("RGB")
121
+
122
+ # βœ… Step 1: compress if >5MB
123
  img = compress_if_needed(img, raw)
124
 
125
  elif image_url:
126
  img = load_image_from_url(image_url)
127
 
128
  else:
129
+ raise HTTPException(400, "Provide file or URL")
130
 
131
+ # βœ… Step 2: remove background
132
+ result = remove_background(img)
133
 
134
  buf = BytesIO()
135
  result.save(buf, format="PNG")
 
138
  return StreamingResponse(buf, media_type="image/png")
139
 
140
  except Exception as e:
141
+ raise HTTPException(500, str(e))
142
 
143
 
144
  # ---------------------------------------------------------
145
+ # Simple UI
146
  # ---------------------------------------------------------
147
  @app.get("/", response_class=HTMLResponse)
148
+ async def home():
149
  return """
150
+ <html>
151
  <head>
152
  <title>Fast Background Remover</title>
153
  <link rel='stylesheet'
 
223
  """
224
 
225
 
226
+
227
  # ---------------------------------------------------------
228
  # Run
229
  # ---------------------------------------------------------