Mr7Explorer commited on
Commit
5020db1
·
verified ·
1 Parent(s): 8029c2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -38
app.py CHANGED
@@ -1,64 +1,65 @@
 
1
  import gradio as gr
2
  import torch
3
  from torchvision import transforms
4
  from PIL import Image
5
  import requests
6
- from io import BytesIO
7
 
8
- # Download BiRefNet weights (automatically, if not present)
9
- model_url = "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet.pth"
10
- model_path = "BiRefNet.pth"
11
 
12
  def download_weights():
13
- if not os.path.exists(model_path):
14
- response = requests.get(model_url)
15
- with open(model_path, "wb") as f:
16
- f.write(response.content)
17
-
18
- # Define BiRefNet model loading (use standard repo's class definition)
19
- import os
20
 
21
- class BiRefNetDummy(torch.nn.Module):
22
- def forward(self, x):
23
- # Dummy output, placeholder for real BiRefNet logic.
24
- # Will be replaced by actual model code in the next step.
25
- return torch.ones_like(x)
26
-
27
- ddef load_model():
28
  download_weights()
29
- net = BiRefNet()
30
- net.load_state_dict(torch.load("BiRefNet.pth", map_location="cpu"))
31
- net.eval()
32
- return net
 
33
 
34
  bi_ref_net = load_model()
35
 
 
 
 
 
 
 
36
  def remove_bg(input_image):
37
- # PIL image to tensor
38
- transform = transforms.Compose([
39
- transforms.ToTensor(),
40
- ])
41
- img_tensor = transform(input_image).unsqueeze(0)
 
 
42
 
43
- # Get mask from BiRefNet (replace with actual model code later)
44
- mask = bi_ref_net(img_tensor)
45
- mask = mask.squeeze().detach().numpy()
46
- mask_img = Image.fromarray((mask * 255).astype("uint8")).convert("L")
47
- mask_img = mask_img.resize(input_image.size)
48
 
49
- # Apply mask: set alpha channel
50
- result = input_image.convert("RGBA")
51
  result.putalpha(mask_img)
52
  return result
53
 
54
  demo = gr.Interface(
55
  fn=remove_bg,
56
  inputs=gr.Image(type="pil", label="Input Image"),
57
- outputs=gr.Image(type="pil", label="Output (Background Removed)"),
58
  title="Backdrop Studio - BiRefNet Background Removal",
59
- description="Upload an image to remove the background using BiRefNet AI. [v1.0 demo]"
60
  )
61
 
62
  if __name__ == "__main__":
63
- download_weights() # Download weights if needed
64
- demo.launch()
 
1
+ import os
2
  import gradio as gr
3
  import torch
4
  from torchvision import transforms
5
  from PIL import Image
6
  import requests
7
+ from BiRefNet import BiRefNet
8
 
9
+ # 1. Download BiRefNet weights if not present
10
+ MODEL_URL = "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet.pth"
11
+ MODEL_PATH = "BiRefNet.pth"
12
 
13
  def download_weights():
14
+ if not os.path.exists(MODEL_PATH):
15
+ print("Downloading BiRefNet weights...")
16
+ r = requests.get(MODEL_URL)
17
+ with open(MODEL_PATH, "wb") as f:
18
+ f.write(r.content)
19
+ print("Done downloading BiRefNet weights.")
 
20
 
21
+ # 2. Load BiRefNet model
22
+ def load_model():
 
 
 
 
 
23
  download_weights()
24
+ model = BiRefNet()
25
+ state_dict = torch.load(MODEL_PATH, map_location="cpu")
26
+ model.load_state_dict(state_dict)
27
+ model.eval()
28
+ return model
29
 
30
  bi_ref_net = load_model()
31
 
32
+ # 3. Define transforms (assuming model expects 224x224 or similar, adjust if needed)
33
+ preprocess = transforms.Compose([
34
+ transforms.Resize((224, 224)), # Adjust to BiRefNet input size if different
35
+ transforms.ToTensor()
36
+ ])
37
+
38
  def remove_bg(input_image):
39
+ # Preprocess image
40
+ image = input_image.convert("RGB")
41
+ img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
42
+
43
+ # Inference (no gradients needed)
44
+ with torch.no_grad():
45
+ mask = bi_ref_net(img_tensor)[0, 0] # Output mask from model, shape: [H, W]
46
 
47
+ # Resize mask to original image size, normalize (if needed)
48
+ mask_img = transforms.ToPILImage()(mask.cpu().clamp(0, 1))
49
+ mask_img = mask_img.resize(image.size, Image.BILINEAR)
 
 
50
 
51
+ # Create RGBA output by setting alpha to mask
52
+ result = image.convert("RGBA")
53
  result.putalpha(mask_img)
54
  return result
55
 
56
  demo = gr.Interface(
57
  fn=remove_bg,
58
  inputs=gr.Image(type="pil", label="Input Image"),
59
+ outputs=gr.Image(type="pil", label="Background Removed (PNG)"),
60
  title="Backdrop Studio - BiRefNet Background Removal",
61
+ description="Upload an image to remove the background using BiRefNet AI."
62
  )
63
 
64
  if __name__ == "__main__":
65
+ demo.launch()