videopix commited on
Commit
9f86510
·
verified ·
1 Parent(s): 36c54d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -22
app.py CHANGED
@@ -1,13 +1,14 @@
1
- import gradio as gr
2
- from fastapi import FastAPI, Request
3
- from fastapi.responses import JSONResponse
4
- from PIL import Image
5
  from io import BytesIO
 
6
  import torch
7
  import numpy as np
8
  from transformers import AutoModelForImageSegmentation
9
- import os
10
- import requests
 
11
 
12
  # -------------------------
13
  # Model Setup
@@ -16,6 +17,7 @@ MODEL_DIR = "models/BiRefNet"
16
  os.makedirs(MODEL_DIR, exist_ok=True)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
 
19
  birefnet = AutoModelForImageSegmentation.from_pretrained(
20
  "ZhengPeng7/BiRefNet",
21
  cache_dir=MODEL_DIR,
@@ -23,47 +25,61 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
23
  revision="main"
24
  )
25
  birefnet.to(device).eval()
 
26
 
27
- def transform_image(image: Image.Image):
 
 
 
28
  image = image.resize((1024, 1024))
29
  arr = np.array(image).astype(np.float32) / 255.0
30
- mean = np.array([0.485, 0.456, 0.406])
31
- std = np.array([0.229, 0.224, 0.225])
32
  arr = (arr - mean) / std
33
  arr = np.transpose(arr, (2, 0, 1))
34
- return torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
 
35
 
36
- def process_image(image: Image.Image):
37
  input_tensor = transform_image(image)
38
  with torch.no_grad():
39
- pred = birefnet(input_tensor)[-1].sigmoid().cpu()[0,0]
40
- mask = Image.fromarray((pred.numpy()*255).astype(np.uint8)).resize(image.size)
 
41
  image.putalpha(mask)
42
  return image
43
 
 
 
 
44
  def remove_background_gradio(img):
45
  return process_image(img.convert("RGB"))
46
 
47
  # -------------------------
48
- # Gradio Interface
49
  # -------------------------
50
  demo = gr.Interface(
51
  fn=remove_background_gradio,
52
  inputs=gr.Image(type="pil"),
53
  outputs=gr.Image(type="pil"),
 
 
54
  )
55
 
56
  # -------------------------
57
  # FastAPI App
58
  # -------------------------
59
- app = gr.routes.FastAPI.create_app(demo) # Wraps Gradio app
60
 
61
- # Custom route for `/remove-background`
62
  @app.post("/remove-background")
63
  async def remove_background(request: Request):
64
- data = await request.form()
65
- file = data.get("file")
66
- image_url = data.get("image_url")
 
 
 
 
67
 
68
  if file:
69
  img = Image.open(file.file).convert("RGB")
@@ -71,10 +87,17 @@ async def remove_background(request: Request):
71
  resp = requests.get(image_url)
72
  img = Image.open(BytesIO(resp.content)).convert("RGB")
73
  else:
74
- return JSONResponse({"error": "Provide file or image_url"}, status_code=400)
75
 
76
- result = remove_background_gradio(img)
77
  buf = BytesIO()
78
  result.save(buf, format="PNG")
79
  buf.seek(0)
80
- return JSONResponse({"image": "data:image/png;base64," + base64.b64encode(buf.read()).decode()})
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import requests
 
4
  from io import BytesIO
5
+ from PIL import Image
6
  import torch
7
  import numpy as np
8
  from transformers import AutoModelForImageSegmentation
9
+ import gradio as gr
10
+ from fastapi import FastAPI, Request
11
+ from fastapi.responses import StreamingResponse, HTMLResponse
12
 
13
  # -------------------------
14
  # Model Setup
 
17
  os.makedirs(MODEL_DIR, exist_ok=True)
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
+ print("Loading BiRefNet model...")
21
  birefnet = AutoModelForImageSegmentation.from_pretrained(
22
  "ZhengPeng7/BiRefNet",
23
  cache_dir=MODEL_DIR,
 
25
  revision="main"
26
  )
27
  birefnet.to(device).eval()
28
+ print("Model loaded successfully.")
29
 
30
+ # -------------------------
31
+ # Image Preprocessing
32
+ # -------------------------
33
+ def transform_image(image: Image.Image) -> torch.Tensor:
34
  image = image.resize((1024, 1024))
35
  arr = np.array(image).astype(np.float32) / 255.0
36
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
37
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
38
  arr = (arr - mean) / std
39
  arr = np.transpose(arr, (2, 0, 1))
40
+ tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
41
+ return tensor
42
 
43
+ def process_image(image: Image.Image) -> Image.Image:
44
  input_tensor = transform_image(image)
45
  with torch.no_grad():
46
+ pred = birefnet(input_tensor)[-1].sigmoid().cpu()[0, 0]
47
+ mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(image.size)
48
+ image = image.convert("RGBA")
49
  image.putalpha(mask)
50
  return image
51
 
52
+ # -------------------------
53
+ # Gradio Function
54
+ # -------------------------
55
  def remove_background_gradio(img):
56
  return process_image(img.convert("RGB"))
57
 
58
  # -------------------------
59
+ # Gradio Interface for UI
60
  # -------------------------
61
  demo = gr.Interface(
62
  fn=remove_background_gradio,
63
  inputs=gr.Image(type="pil"),
64
  outputs=gr.Image(type="pil"),
65
+ title="Background Removal Tool",
66
+ description="Upload an image and get a transparent background."
67
  )
68
 
69
  # -------------------------
70
  # FastAPI App
71
  # -------------------------
72
+ app = gr.routes.FastAPI.create_app(demo) # Wrap Gradio
73
 
 
74
  @app.post("/remove-background")
75
  async def remove_background(request: Request):
76
+ """
77
+ Custom endpoint: accepts 'file' upload or 'image_url' form.
78
+ Returns PNG bytes.
79
+ """
80
+ form = await request.form()
81
+ file = form.get("file")
82
+ image_url = form.get("image_url")
83
 
84
  if file:
85
  img = Image.open(file.file).convert("RGB")
 
87
  resp = requests.get(image_url)
88
  img = Image.open(BytesIO(resp.content)).convert("RGB")
89
  else:
90
+ return {"error": "Provide file or image_url"}
91
 
92
+ result = process_image(img)
93
  buf = BytesIO()
94
  result.save(buf, format="PNG")
95
  buf.seek(0)
96
+ return StreamingResponse(buf, media_type="image/png")
97
+
98
+ # -------------------------
99
+ # Optional: Root UI
100
+ # -------------------------
101
+ @app.get("/", response_class=HTMLResponse)
102
+ async def index():
103
+ return demo.launch(share=False, inline=True)[0] # Embed Gradio UI