videopix commited on
Commit
bf19655
·
verified ·
1 Parent(s): 3a722b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -64
app.py CHANGED
@@ -1,80 +1,161 @@
1
- import io
2
  import os
3
- import threading
4
- from fastapi import FastAPI, File, UploadFile, Request
5
  from fastapi.responses import StreamingResponse, HTMLResponse
6
- from fastapi.templating import Jinja2Templates
7
  from PIL import Image
8
- import torch
9
- import torchvision.transforms as transforms
10
- import onnx
11
- import onnxruntime as ort
12
  import numpy as np
 
 
 
 
 
13
 
14
- # Settings
15
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
- ONNX_PATH = os.path.join(os.path.dirname(__file__), "birefnet.onnx")
17
-
18
- # Preprocessing transform
19
- transform_image = transforms.Compose([
20
- transforms.Resize((1024, 1024)),
21
- transforms.ToTensor(),
22
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
23
- ])
24
-
25
- # Load ONNX model safely
26
- if not os.path.exists(ONNX_PATH):
27
- raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}")
28
-
29
- # Attempt to load external data automatically
30
- try:
31
- onnx_session = ort.InferenceSession(ONNX_PATH, providers=["CUDAExecutionProvider"] if DEVICE=="cuda" else ["CPUExecutionProvider"])
32
- except ort.OnnxRuntimeError:
33
- # Embed external data into memory if original session fails
34
- print("Embedding external tensor data into the ONNX model...")
35
- model = onnx.load(ONNX_PATH, load_external_data=False) # embed data
36
- embedded_path = ONNX_PATH.replace(".onnx", "_embedded.onnx")
37
- onnx.save(model, embedded_path)
38
- onnx_session = ort.InferenceSession(embedded_path, providers=["CUDAExecutionProvider"] if DEVICE=="cuda" else ["CPUExecutionProvider"])
39
 
40
- print(f"ONNX model loaded with providers: {onnx_session.get_providers()}")
41
 
42
- # Lock for thread-safe inference
43
- onnx_lock = threading.Lock()
 
 
 
 
 
 
 
44
 
45
- def run_model_onnx(input_tensor: torch.Tensor) -> torch.Tensor:
46
- with onnx_lock:
47
- ort_inputs = {onnx_session.get_inputs()[0].name: input_tensor.cpu().numpy()}
48
- ort_outs = onnx_session.run(None, ort_inputs)
49
- preds = torch.from_numpy(ort_outs[-1]).sigmoid()
50
- return preds
 
 
 
 
 
 
51
 
52
  def process_image(image: Image.Image) -> Image.Image:
53
- original_size = image.size
54
- input_tensor = transform_image(image).unsqueeze(0) # (1,C,H,W)
55
- preds = run_model_onnx(input_tensor)
56
- pred = preds[0]
57
- if pred.dim() == 3:
58
- pred = pred[0].squeeze(0)
59
- mask = transforms.ToPILImage()(pred.clamp(0, 1))
60
- mask = mask.resize(original_size, resample=Image.BILINEAR)
61
- image_rgba = image.convert("RGBA")
62
- image_rgba.putalpha(mask)
63
- return image_rgba
64
 
65
- # FastAPI app
 
 
66
  app = FastAPI(title="Background Removal API")
67
- templates = Jinja2Templates(directory="templates")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  @app.post("/remove-background")
70
- async def remove_background(file: UploadFile = File(...)):
71
- image = Image.open(file.file).convert("RGB")
72
- result_image = process_image(image)
73
- buf = io.BytesIO()
74
- result_image.save(buf, format="PNG")
75
- buf.seek(0)
76
- return StreamingResponse(buf, media_type="image/png")
 
 
 
 
77
 
 
 
 
78
  @app.get("/", response_class=HTMLResponse)
79
- async def home(request: Request):
80
- return templates.TemplateResponse("index.html", {"request": request})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Union
3
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
4
  from fastapi.responses import StreamingResponse, HTMLResponse
 
5
  from PIL import Image
 
 
 
 
6
  import numpy as np
7
+ import torch
8
+ from transformers import AutoModelForImageSegmentation
9
+ from io import BytesIO
10
+ from loadimg import load_img
11
+ import uvicorn
12
 
13
+ # -------------------------
14
+ # Model Setup (Load Once)
15
+ # -------------------------
16
+ MODEL_DIR = "models/BiRefNet"
17
+ os.makedirs(MODEL_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ print("Loading BiRefNet model (this may take a while on first run)...")
22
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
23
+ "ZhengPeng7/BiRefNet",
24
+ cache_dir=MODEL_DIR,
25
+ trust_remote_code=True
26
+ )
27
+ birefnet.to(device)
28
+ birefnet.eval()
29
+ print("Model loaded successfully.")
30
 
31
+ # -------------------------
32
+ # Image Preprocessing
33
+ # -------------------------
34
+ def transform_image(image: Image.Image) -> torch.Tensor:
35
+ image = image.resize((1024, 1024))
36
+ arr = np.array(image).astype(np.float32) / 255.0
37
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
38
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
39
+ arr = (arr - mean) / std
40
+ arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
41
+ tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
42
+ return tensor
43
 
44
  def process_image(image: Image.Image) -> Image.Image:
45
+ image_size = image.size
46
+ input_tensor = transform_image(image)
47
+ with torch.no_grad():
48
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
49
+ pred = preds[0, 0]
50
+ mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(image_size)
51
+ image = image.copy()
52
+ image.putalpha(mask)
53
+ return image
 
 
54
 
55
+ # -------------------------
56
+ # FastAPI App
57
+ # -------------------------
58
  app = FastAPI(title="Background Removal API")
59
+
60
+ # -------------------------
61
+ # API Endpoints
62
+ # -------------------------
63
+ @app.post("/remove-background")
64
+ async def remove_bg_file(file: UploadFile = File(...)):
65
+ """Upload an image file and get transparent PNG"""
66
+ try:
67
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
68
+ transparent = process_image(image)
69
+ buf = BytesIO()
70
+ transparent.save(buf, format="PNG")
71
+ buf.seek(0)
72
+ return StreamingResponse(buf, media_type="image/png")
73
+ except Exception as e:
74
+ raise HTTPException(status_code=500, detail=str(e))
75
 
76
  @app.post("/remove-background")
77
+ async def remove_bg_url(image_url: str = Form(...)):
78
+ """Provide image URL and get transparent PNG"""
79
+ try:
80
+ image = load_img(image_url, output_type="pil").convert("RGB")
81
+ transparent = process_image(image)
82
+ buf = BytesIO()
83
+ transparent.save(buf, format="PNG")
84
+ buf.seek(0)
85
+ return StreamingResponse(buf, media_type="image/png")
86
+ except Exception as e:
87
+ raise HTTPException(status_code=500, detail=str(e))
88
 
89
+ # -------------------------
90
+ # Web Interface
91
+ # -------------------------
92
  @app.get("/", response_class=HTMLResponse)
93
+ async def index():
94
+ html_content = """
95
+ <!DOCTYPE html>
96
+ <html lang="en">
97
+ <head>
98
+ <meta charset="UTF-8">
99
+ <title>Background Removal Tool</title>
100
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
101
+ <style>
102
+ body { padding: 30px; background-color: #f8f9fa; }
103
+ .container { max-width: 600px; background: #fff; padding: 20px; border-radius: 10px; box-shadow: 0 0 10px rgba(0,0,0,0.1);}
104
+ img { max-width: 100%; margin-top: 10px; }
105
+ </style>
106
+ </head>
107
+ <body>
108
+ <div class="container">
109
+ <h2 class="mb-4">Background Removal Tool</h2>
110
+ <form id="fileForm" enctype="multipart/form-data">
111
+ <div class="mb-3">
112
+ <label for="fileInput" class="form-label">Upload Image</label>
113
+ <input class="form-control" type="file" id="fileInput" name="file">
114
+ </div>
115
+ <button class="btn btn-primary" type="submit">Remove Background</button>
116
+ </form>
117
+ <hr>
118
+ <form id="urlForm">
119
+ <div class="mb-3">
120
+ <label for="urlInput" class="form-label">Image URL</label>
121
+ <input class="form-control" type="text" id="urlInput" placeholder="Enter image URL">
122
+ </div>
123
+ <button class="btn btn-success" type="submit">Remove Background</button>
124
+ </form>
125
+ <hr>
126
+ <h5>Result:</h5>
127
+ <img id="resultImg" src="">
128
+ </div>
129
+ <script>
130
+ const fileForm = document.getElementById('fileForm');
131
+ fileForm.addEventListener('submit', async (e) => {
132
+ e.preventDefault();
133
+ const fileInput = document.getElementById('fileInput');
134
+ if(fileInput.files.length === 0) return alert("Select a file!");
135
+ const formData = new FormData();
136
+ formData.append("file", fileInput.files[0]);
137
+ const res = await fetch('/remove_bg_file', {method: 'POST', body: formData});
138
+ const blob = await res.blob();
139
+ document.getElementById('resultImg').src = URL.createObjectURL(blob);
140
+ });
141
+ const urlForm = document.getElementById('urlForm');
142
+ urlForm.addEventListener('submit', async (e) => {
143
+ e.preventDefault();
144
+ const urlInput = document.getElementById('urlInput').value;
145
+ const formData = new FormData();
146
+ formData.append("image_url", urlInput);
147
+ const res = await fetch('/remove-background', {method: 'POST', body: formData});
148
+ const blob = await res.blob();
149
+ document.getElementById('resultImg').src = URL.createObjectURL(blob);
150
+ });
151
+ </script>
152
+ </body>
153
+ </html>
154
+ """
155
+ return HTMLResponse(content=html_content)
156
+
157
+ # -------------------------
158
+ # Run the server
159
+ # -------------------------
160
+ if __name__ == "__main__":
161
+ uvicorn.run(app, host="0.0.0.0", port=7860)