videopix commited on
Commit
c77bffa
·
verified ·
1 Parent(s): 347af97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -39
app.py CHANGED
@@ -1,77 +1,162 @@
1
  import io
2
  import os
3
- from fastapi import FastAPI, File, UploadFile, HTTPException
 
4
  from fastapi.responses import StreamingResponse, HTMLResponse
5
- from PIL import Image, UnidentifiedImageError
 
6
  import torch
7
  import torchvision.transforms as transforms
8
  import onnxruntime as ort
9
 
10
- # -----------------------------
11
  # Settings
12
- # -----------------------------
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
- ONNX_PATH = "BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx"
15
 
 
 
 
 
 
 
 
 
16
  if not os.path.exists(ONNX_PATH):
17
  raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}")
18
 
19
- # Use only one model, load it once
20
  providers = ["CUDAExecutionProvider"] if DEVICE == "cuda" else ["CPUExecutionProvider"]
21
  onnx_session = ort.InferenceSession(ONNX_PATH, providers=providers)
 
22
 
23
- # Preprocessing
24
- transform_image = transforms.Compose([
25
- transforms.Resize((1024, 1024)),
26
- transforms.ToTensor(),
27
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
28
- ])
 
 
 
 
29
 
30
- # -----------------------------
31
- # Helper
32
- # -----------------------------
33
- def run_model(image: Image.Image):
34
- input_tensor = transform_image(image).unsqueeze(0)
35
- ort_inputs = {onnx_session.get_inputs()[0].name: input_tensor.cpu().numpy()}
36
- ort_outs = onnx_session.run(None, ort_inputs)
37
- pred = torch.from_numpy(ort_outs[-1])[0]
38
  if pred.dim() == 3:
39
  pred = pred[0].squeeze(0)
40
  mask = transforms.ToPILImage()(pred.clamp(0, 1))
41
- mask = mask.resize(image.size)
42
- result = image.convert("RGBA")
43
- result.putalpha(mask)
44
- return result
45
 
46
- # -----------------------------
47
  # FastAPI app
48
- # -----------------------------
49
  app = FastAPI(title="Background Removal API")
50
 
51
  @app.post("/remove-background")
52
  async def remove_background(file: UploadFile = File(...)):
53
- try:
54
- contents = await file.read()
55
- image = Image.open(io.BytesIO(contents)).convert("RGB")
56
- except UnidentifiedImageError:
57
- raise HTTPException(status_code=400, detail="Unsupported or corrupted image file.")
58
 
59
- result_image = run_model(image)
60
  buf = io.BytesIO()
61
  result_image.save(buf, format="PNG")
62
  buf.seek(0)
 
63
  return StreamingResponse(buf, media_type="image/png")
64
 
 
65
  @app.get("/", response_class=HTMLResponse)
66
- async def home():
67
  return """
 
68
  <html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  <body>
70
- <h1>Background Remover</h1>
71
- <form method="post" enctype="multipart/form-data" action="/remove-background">
72
- <input type="file" name="file" accept="image/*" required>
73
- <button type="submit">Remove Background</button>
74
- </form>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  </body>
76
  </html>
77
- """
 
1
  import io
2
  import os
3
+ import threading
4
+ from fastapi import FastAPI, File, UploadFile, Request
5
  from fastapi.responses import StreamingResponse, HTMLResponse
6
+ from fastapi.staticfiles import StaticFiles
7
+ from PIL import Image
8
  import torch
9
  import torchvision.transforms as transforms
10
  import onnxruntime as ort
11
 
 
12
  # Settings
 
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ ONNX_PATH = os.path.join(os.path.dirname(__file__), "BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx")
15
 
16
+ # Preprocessing transform
17
+ transform_image = transforms.Compose([
18
+ transforms.Resize((1024, 1024)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
21
+ ])
22
+
23
+ # Load ONNX model
24
  if not os.path.exists(ONNX_PATH):
25
  raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}")
26
 
 
27
  providers = ["CUDAExecutionProvider"] if DEVICE == "cuda" else ["CPUExecutionProvider"]
28
  onnx_session = ort.InferenceSession(ONNX_PATH, providers=providers)
29
+ print(f"ONNX model loaded with providers: {providers}")
30
 
31
+ # Lock for thread-safe ONNX inference
32
+ onnx_lock = threading.Lock()
33
+
34
+ # Helper functions
35
+ def run_model_onnx(input_tensor: torch.Tensor) -> torch.Tensor:
36
+ with onnx_lock: # ensure thread safety
37
+ ort_inputs = {onnx_session.get_inputs()[0].name: input_tensor.cpu().numpy()}
38
+ ort_outs = onnx_session.run(None, ort_inputs)
39
+ preds = torch.from_numpy(ort_outs[-1]).sigmoid()
40
+ return preds
41
 
42
+ def process_image(image: Image.Image) -> Image.Image:
43
+ original_size = image.size
44
+ input_tensor = transform_image(image).unsqueeze(0) # (1,C,H,W)
45
+ preds = run_model_onnx(input_tensor)
46
+ pred = preds[0]
 
 
 
47
  if pred.dim() == 3:
48
  pred = pred[0].squeeze(0)
49
  mask = transforms.ToPILImage()(pred.clamp(0, 1))
50
+ mask = mask.resize(original_size, resample=Image.BILINEAR)
51
+ image_rgba = image.convert("RGBA")
52
+ image_rgba.putalpha(mask)
53
+ return image_rgba
54
 
 
55
  # FastAPI app
 
56
  app = FastAPI(title="Background Removal API")
57
 
58
  @app.post("/remove-background")
59
  async def remove_background(file: UploadFile = File(...)):
60
+ image = Image.open(file.file).convert("RGB")
61
+ result_image = process_image(image)
 
 
 
62
 
 
63
  buf = io.BytesIO()
64
  result_image.save(buf, format="PNG")
65
  buf.seek(0)
66
+
67
  return StreamingResponse(buf, media_type="image/png")
68
 
69
+ # Serve a simple HTML frontend for testing
70
  @app.get("/", response_class=HTMLResponse)
71
+ async def home(request: Request):
72
  return """
73
+ <!DOCTYPE html>
74
  <html>
75
+ <head>
76
+ <title>Background Remover</title>
77
+ <style>
78
+ body {
79
+ font-family: Arial, sans-serif;
80
+ display: flex;
81
+ flex-direction: column;
82
+ align-items: center;
83
+ justify-content: center;
84
+ min-height: 100vh;
85
+ margin: 0;
86
+ padding: 20px;
87
+ background: #f5f5f5;
88
+ }
89
+ h1 {
90
+ color: #333;
91
+ }
92
+ .container {
93
+ background: white;
94
+ padding: 20px;
95
+ border-radius: 12px;
96
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
97
+ max-width: 500px;
98
+ width: 100%;
99
+ text-align: center;
100
+ }
101
+ input[type=file] {
102
+ margin: 15px 0;
103
+ }
104
+ button {
105
+ background: #4CAF50;
106
+ color: white;
107
+ padding: 10px 20px;
108
+ border: none;
109
+ border-radius: 6px;
110
+ cursor: pointer;
111
+ font-size: 16px;
112
+ }
113
+ button:hover {
114
+ background: #45a049;
115
+ }
116
+ img {
117
+ margin-top: 20px;
118
+ max-width: 100%;
119
+ border-radius: 8px;
120
+ }
121
+ </style>
122
+ </head>
123
  <body>
124
+ <div class="container">
125
+ <h1>Background Remover</h1>
126
+ <form id="upload-form">
127
+ <input type="file" id="file-input" name="file" accept="image/*" required />
128
+ <br/>
129
+ <button type="submit">Remove Background</button>
130
+ </form>
131
+ <div id="result"></div>
132
+ </div>
133
+ <script>
134
+ const form = document.getElementById('upload-form');
135
+ const resultDiv = document.getElementById('result');
136
+ form.addEventListener('submit', async (e) => {
137
+ e.preventDefault();
138
+ const fileInput = document.getElementById('file-input');
139
+ if (!fileInput.files.length) return;
140
+ const formData = new FormData();
141
+ formData.append('file', fileInput.files[0]);
142
+ resultDiv.innerHTML = "<p>Processing...</p>";
143
+ try {
144
+ const response = await fetch('/remove-background', {
145
+ method: 'POST',
146
+ body: formData
147
+ });
148
+ if (!response.ok) {
149
+ resultDiv.innerHTML = "<p style='color:red;'>Error processing image</p>";
150
+ return;
151
+ }
152
+ const blob = await response.blob();
153
+ const url = URL.createObjectURL(blob);
154
+ resultDiv.innerHTML = `<h3>Result:</h3><img src="${url}" alt="Processed Image"/>`;
155
+ } catch (err) {
156
+ resultDiv.innerHTML = "<p style='color:red;'>Request failed</p>";
157
+ }
158
+ });
159
+ </script>
160
  </body>
161
  </html>
162
+ """