videopix commited on
Commit
e0c20c2
·
verified ·
1 Parent(s): ebac79a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from fastapi import FastAPI, File, UploadFile
4
+ from fastapi.responses import StreamingResponse
5
+ from PIL import Image
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ import onnxruntime as ort
9
+
10
+ # Settings
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+ ONNX_PATH = os.path.join(os.path.dirname(__file__), "birefnet.onnx")
13
+
14
+ # Preprocessing transform
15
+ transform_image = transforms.Compose([
16
+ transforms.Resize((1024, 1024)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
19
+ ])
20
+
21
+ # Load ONNX model
22
+ if not os.path.exists(ONNX_PATH):
23
+ raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}")
24
+
25
+ providers = ["CUDAExecutionProvider"] if DEVICE == "cuda" else ["CPUExecutionProvider"]
26
+ onnx_session = ort.InferenceSession(ONNX_PATH, providers=providers)
27
+ print(f"ONNX model loaded with providers: {providers}")
28
+
29
+ # Helper functions
30
+ def run_model_onnx(input_tensor: torch.Tensor) -> torch.Tensor:
31
+ ort_inputs = {onnx_session.get_inputs()[0].name: input_tensor.cpu().numpy()}
32
+ ort_outs = onnx_session.run(None, ort_inputs)
33
+ preds = torch.from_numpy(ort_outs[-1]).sigmoid()
34
+ return preds
35
+
36
+ def process_image(image: Image.Image) -> Image.Image:
37
+ original_size = image.size
38
+ input_tensor = transform_image(image).unsqueeze(0) # (1,C,H,W)
39
+ preds = run_model_onnx(input_tensor)
40
+ pred = preds[0]
41
+ if pred.dim() == 3:
42
+ pred = pred[0].squeeze(0)
43
+ mask = transforms.ToPILImage()(pred.clamp(0, 1))
44
+ mask = mask.resize(original_size, resample=Image.BILINEAR)
45
+ image_rgba = image.convert("RGBA")
46
+ image_rgba.putalpha(mask)
47
+ return image_rgba
48
+
49
+ # FastAPI app
50
+ app = FastAPI(title="Background Removal API")
51
+
52
+ @app.post("/remove-background")
53
+ async def remove_background(file: UploadFile = File(...)):
54
+ image = Image.open(file.file).convert("RGB")
55
+ result_image = process_image(image)
56
+ buf = io.BytesIO()
57
+ result_image.save(buf, format="PNG")
58
+ buf.seek(0)
59
+ return StreamingResponse(buf, media_type="image/png")