Janeka commited on
Commit
d554758
·
verified ·
1 Parent(s): 729a829

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -93
app.py CHANGED
@@ -1,100 +1,47 @@
1
- import cv2
2
  import numpy as np
3
  import torch
 
4
  from PIL import Image
5
- from fastapi import FastAPI, UploadFile, File, HTTPException
6
- from fastapi.responses import HTMLResponse, Response
7
- import time
8
- from io import BytesIO
9
- import os
10
- from huggingface_hub import hf_hub_download
11
-
12
- # Manual implementation of BiRefNet components
13
- class BiRefNet(torch.nn.Module):
14
- def __init__(self):
15
- super().__init__()
16
- # Simplified architecture for CPU
17
- self.conv1 = torch.nn.Conv2d(4, 64, kernel_size=3, padding=1)
18
- self.conv2 = torch.nn.Conv2d(64, 1, kernel_size=3, padding=1)
19
-
20
- def forward(self, x):
21
- x = torch.relu(self.conv1(x))
22
- return torch.sigmoid(self.conv2(x))
23
-
24
- def inference_image(model, image, device="cpu"):
25
- """Simplified inference for CPU"""
26
- input_tensor = torch.from_numpy(image).permute(2,0,1).unsqueeze(0).float()/255.0
27
- with torch.no_grad():
28
- output = model(input_tensor.to(device))
29
- return output.squeeze().cpu().numpy()
30
-
31
- app = FastAPI(title="BiRefNet Background Remover")
32
-
33
- # Configuration
34
- MAX_SIZE = 1024
35
- MODEL_PATH = "birefnet.pth"
36
- DEVICE = "cpu"
37
 
38
- # Initialize model
39
  model = BiRefNet()
40
- if not os.path.exists(MODEL_PATH):
41
- hf_hub_download(repo_id="ZhengPeng7/BiRefNet",
42
- filename="birefnet.pth",
43
- local_dir=".",
44
- force_filename="birefnet.pth")
45
- model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
46
- model.to(DEVICE)
47
  model.eval()
48
 
49
- def preprocess_image(image):
50
- if image.shape[2] == 3:
51
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
52
- h, w = image.shape[:2]
53
- if max(h, w) > MAX_SIZE:
54
- ratio = MAX_SIZE / max(h, w)
55
- image = cv2.resize(image, (int(w*ratio), int(h*ratio)),
56
- interpolation=cv2.INTER_AREA)
57
- return image
58
-
59
- @app.post("/remove_bg")
60
- async def remove_bg(file: UploadFile = File(...)):
61
- try:
62
- start_time = time.time()
63
- contents = await file.read()
64
- img = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_UNCHANGED)
65
- img = preprocess_image(img)
66
-
67
- with torch.no_grad():
68
- alpha = inference_image(model, img, DEVICE)
69
-
70
- if img.shape[2] == 3:
71
- img = cv2.cvtColor(img, cv2.COLOR_RGB2RGBA)
72
- img[:, :, 3] = (alpha * 255).astype(np.uint8)
73
-
74
- img_bytes = BytesIO()
75
- Image.fromarray(img).save(img_bytes, format="PNG")
76
- return Response(
77
- content=img_bytes.getvalue(),
78
- media_type="image/png",
79
- headers={"X-Processing-Time": f"{time.time()-start_time:.2f}s"}
80
- )
81
- except Exception as e:
82
- raise HTTPException(500, f"Error: {str(e)}")
83
-
84
- @app.get("/", response_class=HTMLResponse)
85
- async def home():
86
- return HTMLResponse("""
87
- <html>
88
- <body>
89
- <h1>BiRefNet Background Remover</h1>
90
- <form action="/remove_bg" method="post" enctype="multipart/form-data">
91
- <input type="file" name="file" accept="image/*" required>
92
- <button>Remove Background</button>
93
- </form>
94
- </body>
95
- </html>
96
- """)
97
-
98
- if __name__ == "__main__":
99
- import uvicorn
100
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import gradio as gr
2
  import numpy as np
3
  import torch
4
+ import cv2
5
  from PIL import Image
6
+ from BiRefNet.models.BiRefNet import BiRefNet
7
+ from BiRefNet.utils.dataloader import test_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Initialize model (smaller version for CPU)
10
  model = BiRefNet()
11
+ model.load_state_dict(torch.load('BiRefNet.pth', map_location=torch.device('cpu')))
 
 
 
 
 
 
12
  model.eval()
13
 
14
+ def remove_background(input_image):
15
+ # Preprocess image
16
+ image = np.array(input_image)
17
+ image = cv2.resize(image, (320, 320)) # Smaller size for CPU
18
+ image = test_dataset.preprocess(image)
19
+ image = torch.from_numpy(image).unsqueeze(0)
20
+
21
+ # Inference
22
+ with torch.no_grad():
23
+ pred = model(image)
24
+
25
+ # Post-process
26
+ pred = pred.squeeze().cpu().numpy()
27
+ mask = (pred > 0.5).astype(np.uint8) * 255
28
+
29
+ # Apply mask to original image
30
+ original = cv2.resize(np.array(input_image), (320, 320))
31
+ result = cv2.bitwise_and(original, original, mask=mask)
32
+
33
+ return Image.fromarray(result), Image.fromarray(mask)
34
+
35
+ # Gradio interface
36
+ interface = gr.Interface(
37
+ fn=remove_background,
38
+ inputs=gr.Image(type="pil", label="Input Image"),
39
+ outputs=[
40
+ gr.Image(type="pil", label="Result"),
41
+ gr.Image(type="pil", label="Mask")
42
+ ],
43
+ title="BiRefNet Background Remover (CPU)",
44
+ description="Upload an image to remove the background. Works on CPU but may be slow."
45
+ )
46
+
47
+ interface.launch()