whitepeacock commited on
Commit
7437b84
·
verified ·
1 Parent(s): d13bf05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -23,11 +23,13 @@ pillow_heif.register_heif_opener()
23
  executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4)
24
 
25
  # -------------------------
26
- # Model Setup (Loaded Once)
27
  # -------------------------
28
  MODEL_DIR = "models/BiRefNet"
29
  os.makedirs(MODEL_DIR, exist_ok=True)
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
31
 
32
  print("Loading BiRefNet model (first run may take a while)...")
33
  birefnet = AutoModelForImageSegmentation.from_pretrained(
@@ -37,13 +39,15 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
37
  )
38
  birefnet.to(device)
39
  birefnet.eval()
40
- print("Model loaded successfully.")
41
 
42
  # -------------------------
43
  # Image Preprocessing
44
  # -------------------------
 
 
45
  def transform_image(image: Image.Image) -> torch.Tensor:
46
- image = image.resize((1024, 1024))
47
  arr = np.array(image).astype(np.float32) / 255.0
48
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
49
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
@@ -53,16 +57,22 @@ def transform_image(image: Image.Image) -> torch.Tensor:
53
  return tensor
54
 
55
  def process_image_sync(image: Image.Image) -> BytesIO:
56
- """Process image synchronously and return PNG bytes (no files written)."""
57
  image_size = image.size
58
  input_tensor = transform_image(image)
 
59
  with torch.no_grad():
60
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
 
 
 
 
 
 
61
 
62
  pred = preds[0, 0].numpy()
63
  mask = Image.fromarray((pred * 255).astype(np.uint8)).resize(image_size)
64
 
65
- # Apply alpha mask and keep only in-memory result
66
  image = image.copy()
67
  image.putalpha(mask)
68
 
@@ -85,24 +95,24 @@ def open_image_safely(file_bytes: bytes) -> Image.Image:
85
  img = Image.open(BytesIO(file_bytes))
86
  fmt = (img.format or "").lower()
87
 
88
- # Handle PDF: first page only
89
  if fmt == "pdf":
90
  from pdf2image import convert_from_bytes
91
  pdf_images = convert_from_bytes(file_bytes, first_page=1, last_page=1)
92
  return pdf_images[0].convert("RGB")
93
 
94
- # Handle animated GIF: first frame only
95
  if fmt == "gif" and getattr(img, "is_animated", False):
96
  img.seek(0)
97
  return img.convert("RGB")
98
 
99
- # Handle SVG: rasterize to PNG
100
  if fmt == "svg":
101
  import cairosvg
102
  png_bytes = cairosvg.svg2png(bytestring=file_bytes)
103
  return Image.open(BytesIO(png_bytes)).convert("RGB")
104
 
105
- # Other formats (HEIC, HEIF, JPG, PNG, etc.)
106
  return img.convert("RGB")
107
 
108
  except Exception as e:
@@ -118,15 +128,11 @@ app = FastAPI(title="Background Removal API", description="Removes image backgro
118
  # -------------------------
119
  @app.post("/remove_bg_file")
120
  async def remove_bg_file(file: UploadFile = File(...)):
121
- """Upload an image and get transparent PNG."""
122
  try:
123
  contents = await file.read()
124
  image = open_image_safely(contents)
125
  output_buffer = await process_image_async(image)
126
-
127
- # Return directly from memory
128
  return StreamingResponse(output_buffer, media_type="image/png")
129
-
130
  except HTTPException as e:
131
  raise e
132
  except Exception as e:
@@ -134,7 +140,6 @@ async def remove_bg_file(file: UploadFile = File(...)):
134
 
135
  @app.post("/remove_bg_url")
136
  async def remove_bg_url(image_url: str = Form(...)):
137
- """Provide image URL and get transparent PNG."""
138
  try:
139
  image = load_img(image_url, output_type="pil").convert("RGB")
140
  output_buffer = await process_image_async(image)
@@ -143,7 +148,7 @@ async def remove_bg_url(image_url: str = Form(...)):
143
  raise HTTPException(status_code=500, detail=f"Error processing URL: {e}")
144
 
145
  # -------------------------
146
- # Web Interface (no saving)
147
  # -------------------------
148
  @app.get("/", response_class=HTMLResponse)
149
  async def index():
@@ -211,7 +216,7 @@ async def index():
211
  fileForm.addEventListener('submit', async (e) => {
212
  e.preventDefault();
213
  const fileInput = document.getElementById('fileInput');
214
- if (fileInput.files.length === 0) return alert("Select a file!");
215
  const file = fileInput.files[0];
216
  beforeImg.src = URL.createObjectURL(file);
217
  const formData = new FormData();
@@ -249,16 +254,8 @@ async def index():
249
  return HTMLResponse(content=html)
250
 
251
  # -------------------------
252
- # Run Server
253
  # -------------------------
254
  if __name__ == "__main__":
255
- import sys
256
- import os
257
- import uvicorn
258
-
259
- # Get current filename without .py
260
  module_name = os.path.splitext(os.path.basename(__file__))[0]
261
-
262
- # Run uvicorn using the detected module name
263
  uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, workers=2)
264
-
 
23
  executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4)
24
 
25
  # -------------------------
26
+ # Model Setup (Load Once)
27
  # -------------------------
28
  MODEL_DIR = "models/BiRefNet"
29
  os.makedirs(MODEL_DIR, exist_ok=True)
30
+
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ print(f"Using device: {device}")
33
 
34
  print("Loading BiRefNet model (first run may take a while)...")
35
  birefnet = AutoModelForImageSegmentation.from_pretrained(
 
39
  )
40
  birefnet.to(device)
41
  birefnet.eval()
42
+ print(f"Model loaded successfully on {device}.")
43
 
44
  # -------------------------
45
  # Image Preprocessing
46
  # -------------------------
47
+ TARGET_SIZE = (512, 512) # Lower resolution for faster inference
48
+
49
  def transform_image(image: Image.Image) -> torch.Tensor:
50
+ image = image.resize(TARGET_SIZE)
51
  arr = np.array(image).astype(np.float32) / 255.0
52
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
53
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
 
57
  return tensor
58
 
59
  def process_image_sync(image: Image.Image) -> BytesIO:
60
+ """Process image synchronously and return PNG bytes (in-memory)."""
61
  image_size = image.size
62
  input_tensor = transform_image(image)
63
+
64
  with torch.no_grad():
65
+ if device == "cuda":
66
+ # Mixed precision for GPU
67
+ with torch.cuda.amp.autocast():
68
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
69
+ else:
70
+ # CPU fallback
71
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
72
 
73
  pred = preds[0, 0].numpy()
74
  mask = Image.fromarray((pred * 255).astype(np.uint8)).resize(image_size)
75
 
 
76
  image = image.copy()
77
  image.putalpha(mask)
78
 
 
95
  img = Image.open(BytesIO(file_bytes))
96
  fmt = (img.format or "").lower()
97
 
98
+ # Handle PDF: first page
99
  if fmt == "pdf":
100
  from pdf2image import convert_from_bytes
101
  pdf_images = convert_from_bytes(file_bytes, first_page=1, last_page=1)
102
  return pdf_images[0].convert("RGB")
103
 
104
+ # Handle GIF: first frame
105
  if fmt == "gif" and getattr(img, "is_animated", False):
106
  img.seek(0)
107
  return img.convert("RGB")
108
 
109
+ # Handle SVG
110
  if fmt == "svg":
111
  import cairosvg
112
  png_bytes = cairosvg.svg2png(bytestring=file_bytes)
113
  return Image.open(BytesIO(png_bytes)).convert("RGB")
114
 
115
+ # Other formats (HEIC, HEIF, JPG, PNG)
116
  return img.convert("RGB")
117
 
118
  except Exception as e:
 
128
  # -------------------------
129
  @app.post("/remove_bg_file")
130
  async def remove_bg_file(file: UploadFile = File(...)):
 
131
  try:
132
  contents = await file.read()
133
  image = open_image_safely(contents)
134
  output_buffer = await process_image_async(image)
 
 
135
  return StreamingResponse(output_buffer, media_type="image/png")
 
136
  except HTTPException as e:
137
  raise e
138
  except Exception as e:
 
140
 
141
  @app.post("/remove_bg_url")
142
  async def remove_bg_url(image_url: str = Form(...)):
 
143
  try:
144
  image = load_img(image_url, output_type="pil").convert("RGB")
145
  output_buffer = await process_image_async(image)
 
148
  raise HTTPException(status_code=500, detail=f"Error processing URL: {e}")
149
 
150
  # -------------------------
151
+ # Web Interface
152
  # -------------------------
153
  @app.get("/", response_class=HTMLResponse)
154
  async def index():
 
216
  fileForm.addEventListener('submit', async (e) => {
217
  e.preventDefault();
218
  const fileInput = document.getElementById('fileInput');
219
+ if (!fileInput.files.length) return alert("Select a file!");
220
  const file = fileInput.files[0];
221
  beforeImg.src = URL.createObjectURL(file);
222
  const formData = new FormData();
 
254
  return HTMLResponse(content=html)
255
 
256
  # -------------------------
257
+ # Run Server (Auto-detect filename)
258
  # -------------------------
259
  if __name__ == "__main__":
 
 
 
 
 
260
  module_name = os.path.splitext(os.path.basename(__file__))[0]
 
 
261
  uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, workers=2)