videopix commited on
Commit
36c54d6
·
verified ·
1 Parent(s): 36b7654

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -121
app.py CHANGED
@@ -1,13 +1,13 @@
1
- import os
2
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
3
- from fastapi.responses import StreamingResponse, HTMLResponse
4
  from PIL import Image
 
5
  import torch
6
  import numpy as np
7
  from transformers import AutoModelForImageSegmentation
8
- from io import BytesIO
9
- from loadimg import load_img
10
- from contextlib import asynccontextmanager
11
 
12
  # -------------------------
13
  # Model Setup
@@ -15,134 +15,66 @@ from contextlib import asynccontextmanager
15
  MODEL_DIR = "models/BiRefNet"
16
  os.makedirs(MODEL_DIR, exist_ok=True)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
- birefnet = None # will initialize on startup
19
 
20
- # -------------------------
21
- # Lifespan (Startup/Shutdown)
22
- # -------------------------
23
- @asynccontextmanager
24
- async def lifespan(app: FastAPI):
25
- global birefnet
26
- if birefnet is None:
27
- print("Loading BiRefNet model...")
28
- birefnet = AutoModelForImageSegmentation.from_pretrained(
29
- "ZhengPeng7/BiRefNet",
30
- cache_dir=MODEL_DIR,
31
- trust_remote_code=True,
32
- revision="main"
33
- )
34
- birefnet.to(device).eval()
35
- print("Model loaded successfully.")
36
- yield
37
- # shutdown logic (optional)
38
-
39
- # -------------------------
40
- # FastAPI App
41
- # -------------------------
42
- app = FastAPI(title="Background Removal API", lifespan=lifespan)
43
 
44
- # -------------------------
45
- # Image Processing
46
- # -------------------------
47
- def transform_image(image: Image.Image) -> torch.Tensor:
48
  image = image.resize((1024, 1024))
49
  arr = np.array(image).astype(np.float32) / 255.0
50
- mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
51
- std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
52
  arr = (arr - mean) / std
53
- arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
54
- tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
55
- return tensor
56
 
57
- def process_image(image: Image.Image) -> Image.Image:
58
- image_size = image.size
59
  input_tensor = transform_image(image)
60
  with torch.no_grad():
61
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
62
- pred = preds[0, 0]
63
- mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(image_size)
64
- image = image.copy()
65
  image.putalpha(mask)
66
  return image
67
 
 
 
 
68
  # -------------------------
69
- # API Endpoint
70
  # -------------------------
71
- @app.post("/remove-background")
72
- async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
73
- """
74
- Accept either an uploaded file or image URL.
75
- Returns PNG with transparent background.
76
- """
77
- try:
78
- if file:
79
- image = Image.open(BytesIO(await file.read())).convert("RGB")
80
- elif image_url:
81
- image = load_img(image_url, output_type="pil").convert("RGB")
82
- else:
83
- raise HTTPException(status_code=400, detail="Provide file or image_url")
84
-
85
- result = process_image(image)
86
- buf = BytesIO()
87
- result.save(buf, format="PNG")
88
- buf.seek(0)
89
- return StreamingResponse(buf, media_type="image/png")
90
- except Exception as e:
91
- raise HTTPException(status_code=500, detail=str(e))
92
 
93
  # -------------------------
94
- # Web Interface
95
  # -------------------------
96
- @app.get("/", response_class=HTMLResponse)
97
- async def index():
98
- html_content = """
99
- <!DOCTYPE html>
100
- <html>
101
- <head>
102
- <title>Background Removal</title>
103
- <style>
104
- body { font-family: Arial; padding: 20px; }
105
- .container { max-width: 600px; margin: auto; background: #f9f9f9; padding: 20px; border-radius: 10px; }
106
- img { max-width: 100%; margin-top: 20px; }
107
- </style>
108
- </head>
109
- <body>
110
- <div class="container">
111
- <h2>Background Removal Tool</h2>
112
- <form id="fileForm" enctype="multipart/form-data">
113
- <input type="file" name="file" id="fileInput">
114
- <button type="submit">Remove Background</button>
115
- </form>
116
- <hr>
117
- <form id="urlForm">
118
- <input type="text" id="urlInput" placeholder="Image URL">
119
- <button type="submit">Remove Background</button>
120
- </form>
121
- <img id="resultImg" src="">
122
- </div>
123
- <script>
124
- const fileForm = document.getElementById('fileForm');
125
- fileForm.addEventListener('submit', async e => {
126
- e.preventDefault();
127
- const fileInput = document.getElementById('fileInput');
128
- if(fileInput.files.length === 0) return alert("Select a file!");
129
- const formData = new FormData();
130
- formData.append("file", fileInput.files[0]);
131
- const res = await fetch('/remove-background', {method:'POST', body:formData});
132
- const blob = await res.blob();
133
- document.getElementById('resultImg').src = URL.createObjectURL(blob);
134
- });
135
- const urlForm = document.getElementById('urlForm');
136
- urlForm.addEventListener('submit', async e => {
137
- e.preventDefault();
138
- const formData = new FormData();
139
- formData.append("image_url", document.getElementById('urlInput').value);
140
- const res = await fetch('/remove-background', {method:'POST', body:formData});
141
- const blob = await res.blob();
142
- document.getElementById('resultImg').src = URL.createObjectURL(blob);
143
- });
144
- </script>
145
- </body>
146
- </html>
147
- """
148
- return HTMLResponse(html_content)
 
1
+ import gradio as gr
2
+ from fastapi import FastAPI, Request
3
+ from fastapi.responses import JSONResponse
4
  from PIL import Image
5
+ from io import BytesIO
6
  import torch
7
  import numpy as np
8
  from transformers import AutoModelForImageSegmentation
9
+ import os
10
+ import requests
 
11
 
12
  # -------------------------
13
  # Model Setup
 
15
  MODEL_DIR = "models/BiRefNet"
16
  os.makedirs(MODEL_DIR, exist_ok=True)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
18
 
19
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
20
+ "ZhengPeng7/BiRefNet",
21
+ cache_dir=MODEL_DIR,
22
+ trust_remote_code=True,
23
+ revision="main"
24
+ )
25
+ birefnet.to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def transform_image(image: Image.Image):
 
 
 
28
  image = image.resize((1024, 1024))
29
  arr = np.array(image).astype(np.float32) / 255.0
30
+ mean = np.array([0.485, 0.456, 0.406])
31
+ std = np.array([0.229, 0.224, 0.225])
32
  arr = (arr - mean) / std
33
+ arr = np.transpose(arr, (2, 0, 1))
34
+ return torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
 
35
 
36
+ def process_image(image: Image.Image):
 
37
  input_tensor = transform_image(image)
38
  with torch.no_grad():
39
+ pred = birefnet(input_tensor)[-1].sigmoid().cpu()[0,0]
40
+ mask = Image.fromarray((pred.numpy()*255).astype(np.uint8)).resize(image.size)
 
 
41
  image.putalpha(mask)
42
  return image
43
 
44
+ def remove_background_gradio(img):
45
+ return process_image(img.convert("RGB"))
46
+
47
  # -------------------------
48
+ # Gradio Interface
49
  # -------------------------
50
+ demo = gr.Interface(
51
+ fn=remove_background_gradio,
52
+ inputs=gr.Image(type="pil"),
53
+ outputs=gr.Image(type="pil"),
54
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # -------------------------
57
+ # FastAPI App
58
  # -------------------------
59
+ app = gr.routes.FastAPI.create_app(demo) # Wraps Gradio app
60
+
61
+ # Custom route for `/remove-background`
62
+ @app.post("/remove-background")
63
+ async def remove_background(request: Request):
64
+ data = await request.form()
65
+ file = data.get("file")
66
+ image_url = data.get("image_url")
67
+
68
+ if file:
69
+ img = Image.open(file.file).convert("RGB")
70
+ elif image_url:
71
+ resp = requests.get(image_url)
72
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
73
+ else:
74
+ return JSONResponse({"error": "Provide file or image_url"}, status_code=400)
75
+
76
+ result = remove_background_gradio(img)
77
+ buf = BytesIO()
78
+ result.save(buf, format="PNG")
79
+ buf.seek(0)
80
+ return JSONResponse({"image": "data:image/png;base64," + base64.b64encode(buf.read()).decode()})