shuvo108 commited on
Commit
0cf052e
·
verified ·
1 Parent(s): 4263c0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -33
app.py CHANGED
@@ -2,55 +2,46 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  from PIL import Image
 
5
  from torchvision import transforms
6
 
7
- # BiRefNet import
8
- from birefnet import BiRefNet
 
 
 
 
9
 
10
- # Device
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- # Load model
14
- model = BiRefNet.from_pretrained("ZhengPeng7/BiRefNet")
15
  model.to(device)
16
- model.eval()
17
 
18
- # Image transform
19
  transform = transforms.Compose([
20
- transforms.Resize((1024, 1024)),
21
- transforms.ToTensor(),
22
  ])
23
 
24
- def remove_bg(image):
25
- if image is None:
26
- return None
27
 
28
- original_size = image.size
29
-
30
- img = transform(image).unsqueeze(0).to(device)
31
 
32
  with torch.no_grad():
33
- pred = model(img)[-1]
34
- mask = torch.sigmoid(pred).cpu().squeeze().numpy()
35
-
36
- mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size)
37
-
38
- image = image.convert("RGBA")
39
- image_np = np.array(image)
40
- mask_np = np.array(mask)
41
 
42
- image_np[:, :, 3] = mask_np
 
 
43
 
44
- return Image.fromarray(image_np)
 
 
45
 
46
- # Gradio UI
47
  demo = gr.Interface(
48
  fn=remove_bg,
49
- inputs=gr.Image(type="pil", label="Upload Image"),
50
- outputs=gr.Image(type="pil", label="Background Removed"),
51
- title="BiRefNet AI – Background Remover",
52
- description="Powered by BiRefNet (Open Source AI Model)"
53
  )
54
 
55
- if __name__ == "__main__":
56
- demo.launch()
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
+ from transformers import AutoModelForImageSegmentation
6
  from torchvision import transforms
7
 
8
+ # Load RMBG-1.4 (BiRefNet)
9
+ model = AutoModelForImageSegmentation.from_pretrained(
10
+ "briaai/RMBG-1.4",
11
+ trust_remote_code=True
12
+ )
13
+ model.eval()
14
 
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
16
  model.to(device)
 
17
 
 
18
  transform = transforms.Compose([
19
+ transforms.ToTensor()
 
20
  ])
21
 
22
+ def remove_bg(image: Image.Image):
23
+ if image.mode != "RGB":
24
+ image = image.convert("RGB")
25
 
26
+ inp = transform(image).unsqueeze(0).to(device)
 
 
27
 
28
  with torch.no_grad():
29
+ pred = model(inp)[0][0]
 
 
 
 
 
 
 
30
 
31
+ mask = pred.sigmoid().cpu().numpy()
32
+ mask = (mask > 0.5).astype(np.uint8) * 255
33
+ mask = Image.fromarray(mask).resize(image.size)
34
 
35
+ output = image.copy()
36
+ output.putalpha(mask)
37
+ return output
38
 
 
39
  demo = gr.Interface(
40
  fn=remove_bg,
41
+ inputs=gr.Image(type="pil"),
42
+ outputs=gr.Image(type="pil"),
43
+ title="RMBG-1.4 Background Remover",
44
+ description="BiRefNet based background removal"
45
  )
46
 
47
+ demo.launch()