videopix commited on
Commit
5d90ae5
·
verified ·
1 Parent(s): 8a357a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -35
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import os
2
- from io import BytesIO
3
- from PIL import Image, UnidentifiedImageError
 
4
  import torch
5
  import numpy as np
6
  from transformers import AutoModelForImageSegmentation
7
- from fastapi import FastAPI, File, Form, UploadFile, HTTPException
8
- from fastapi.responses import StreamingResponse
9
  import requests
 
10
 
11
  # -------------------------
12
  # Model Setup
@@ -26,71 +27,71 @@ birefnet.to(device).eval()
26
  print("Model loaded successfully.")
27
 
28
  # -------------------------
29
- # Image Preprocessing
 
 
 
 
 
30
  # -------------------------
 
 
 
 
 
 
 
 
31
  def transform_image(image: Image.Image) -> torch.Tensor:
32
  image = image.resize((1024, 1024))
33
  arr = np.array(image).astype(np.float32) / 255.0
34
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
35
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
36
  arr = (arr - mean) / std
37
- arr = np.transpose(arr, (2, 0, 1))
38
  tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
39
  return tensor
40
 
41
  def process_image(image: Image.Image) -> Image.Image:
 
42
  input_tensor = transform_image(image)
43
  with torch.no_grad():
44
- pred = birefnet(input_tensor)[-1].sigmoid().cpu()[0, 0]
45
- mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(image.size)
 
46
  image = image.convert("RGBA")
47
  image.putalpha(mask)
48
  return image
49
 
50
  # -------------------------
51
- # FastAPI App
52
- # -------------------------
53
- app = FastAPI(title="Background Removal API")
54
-
55
- # -------------------------
56
- # API Endpoint: Return PNG
57
  # -------------------------
58
  @app.post("/remove-background")
59
  async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
60
  """
61
- Accept either an uploaded file or an image URL.
62
- Returns PNG with transparent background.
 
63
  """
64
  try:
65
  if file:
66
- img_bytes = await file.read()
67
- img = Image.open(BytesIO(img_bytes)).convert("RGB")
68
  elif image_url:
69
- resp = requests.get(image_url, timeout=10)
70
- resp.raise_for_status()
71
- img = Image.open(BytesIO(resp.content)).convert("RGB")
72
  else:
73
- raise HTTPException(status_code=400, detail="Provide file or image_url")
74
 
75
- # Process background removal
76
- result_img = process_image(img)
77
 
78
- # Convert to PNG bytes
79
  buf = BytesIO()
80
- result_img.save(buf, format="PNG")
81
  buf.seek(0)
82
  return StreamingResponse(buf, media_type="image/png")
83
-
84
- except UnidentifiedImageError:
85
- raise HTTPException(status_code=400, detail="Invalid image format")
86
- except requests.RequestException:
87
- raise HTTPException(status_code=400, detail="Failed to fetch image from URL")
88
  except Exception as e:
89
  raise HTTPException(status_code=500, detail=str(e))
90
 
91
  # -------------------------
92
- # Optional Root Info
93
  # -------------------------
94
- @app.get("/")
95
- async def index():
96
- return {"message": "POST /remove-background with 'file' or 'image_url'. Returns PNG."}
 
1
  import os
2
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
3
+ from fastapi.responses import StreamingResponse
4
+ from PIL import Image
5
  import torch
6
  import numpy as np
7
  from transformers import AutoModelForImageSegmentation
8
+ from io import BytesIO
 
9
  import requests
10
+ import uvicorn
11
 
12
  # -------------------------
13
  # Model Setup
 
27
  print("Model loaded successfully.")
28
 
29
  # -------------------------
30
+ # FastAPI App
31
+ # -------------------------
32
+ app = FastAPI(title="Background Remover API")
33
+
34
+ # -------------------------
35
+ # Utility Functions
36
  # -------------------------
37
+ def load_image_from_url(url: str) -> Image.Image:
38
+ try:
39
+ response = requests.get(url, timeout=10)
40
+ response.raise_for_status()
41
+ return Image.open(BytesIO(response.content)).convert("RGB")
42
+ except Exception as e:
43
+ raise HTTPException(status_code=400, detail=f"Error loading image from URL: {str(e)}")
44
+
45
  def transform_image(image: Image.Image) -> torch.Tensor:
46
  image = image.resize((1024, 1024))
47
  arr = np.array(image).astype(np.float32) / 255.0
48
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
49
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
50
  arr = (arr - mean) / std
51
+ arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
52
  tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
53
  return tensor
54
 
55
  def process_image(image: Image.Image) -> Image.Image:
56
+ image_size = image.size
57
  input_tensor = transform_image(image)
58
  with torch.no_grad():
59
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
60
+ pred = preds[0, 0]
61
+ mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(image_size)
62
  image = image.convert("RGBA")
63
  image.putalpha(mask)
64
  return image
65
 
66
  # -------------------------
67
+ # /remove-background Endpoint
 
 
 
 
 
68
  # -------------------------
69
  @app.post("/remove-background")
70
  async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
71
  """
72
+ Remove background from an image.
73
+ Accepts either a file upload or an image URL.
74
+ Returns a PNG with transparent background.
75
  """
76
  try:
77
  if file:
78
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
 
79
  elif image_url:
80
+ image = load_image_from_url(image_url)
 
 
81
  else:
82
+ raise HTTPException(status_code=400, detail="Provide either 'file' or 'image_url'.")
83
 
84
+ result = process_image(image)
 
85
 
 
86
  buf = BytesIO()
87
+ result.save(buf, format="PNG")
88
  buf.seek(0)
89
  return StreamingResponse(buf, media_type="image/png")
 
 
 
 
 
90
  except Exception as e:
91
  raise HTTPException(status_code=500, detail=str(e))
92
 
93
  # -------------------------
94
+ # Run App on Spaces
95
  # -------------------------
96
+ if __name__ == "__main__":
97
+ uvicorn.run(app, host="0.0.0.0", port=7860)