ihabooe commited on
Commit
7c07d5f
·
verified ·
1 Parent(s): ac34b16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -25
app.py CHANGED
@@ -3,59 +3,66 @@ import torch
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
6
- from briarmbg import BriaRMBG
7
  from PIL import Image
 
 
8
 
9
 
 
10
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  net.to(device)
13
- net.eval()
 
14
 
15
  def resize_image(image):
16
  image = image.convert('RGB')
17
  model_input_size = (1024, 1024)
18
- image = image.resize(model_input_size, Image.BILINEAR)
19
- return image
20
 
21
  def process(image):
22
  orig_image = Image.fromarray(image)
23
  w, h = orig_image.size
24
  image = resize_image(orig_image)
25
  im_np = np.array(image)
26
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
 
27
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
28
  if torch.cuda.is_available():
29
  im_tensor = im_tensor.cuda()
30
 
 
31
  result = net(im_tensor)
32
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
33
- result = (result - result.min()) / (result.max() - result.min())
34
- result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
35
-
36
- # Fix: Convert to grayscale image
37
- pil_mask = Image.fromarray(result_array.squeeze(), mode="L")
38
 
 
 
39
  new_im = orig_image.copy()
40
  new_im.putalpha(pil_mask)
41
- return new_im
42
 
43
- gr.Markdown("## BRIA RMBG 1.4")
44
- gr.HTML('''
45
- <p style="margin-bottom: 10px; font-size: 94%">
46
- This is a demo for BRIA RMBG 1.4 that uses
47
- <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
48
- </p>
49
- ''')
50
 
51
- title = "Background Removal"
52
- description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
53
- For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>. To purchase a commercial license, simply click <a href='https://go.bria.ai/3ZCBTLH' target='_blank'><b>Here</b></a>. <br>
54
- """
55
 
56
- examples = [['./input.jpg']]
 
 
57
 
58
- demo = gr.Interface(fn=process, inputs="image", outputs="image", examples=examples, title=title, description=description)
 
 
 
 
 
 
 
59
 
60
  if __name__ == "__main__":
61
- demo.launch(share=False)
 
 
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
 
6
  from PIL import Image
7
+ from briarmbg import BriaRMBG
8
+ import io
9
 
10
 
11
+ # Load the model
12
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  net.to(device)
15
+ net.eval()
16
+
17
 
18
  def resize_image(image):
19
  image = image.convert('RGB')
20
  model_input_size = (1024, 1024)
21
+ return image.resize(model_input_size, Image.BILINEAR)
22
+
23
 
24
  def process(image):
25
  orig_image = Image.fromarray(image)
26
  w, h = orig_image.size
27
  image = resize_image(orig_image)
28
  im_np = np.array(image)
29
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
30
+ im_tensor = torch.unsqueeze(im_tensor, 0) / 255.0
31
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
32
  if torch.cuda.is_available():
33
  im_tensor = im_tensor.cuda()
34
 
35
+ # Inference
36
  result = net(im_tensor)
37
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
38
+ result_array = ((result - result.min()) / (result.max() - result.min()) * 255).cpu().data.numpy().astype(np.uint8)
 
 
 
 
39
 
40
+ # Add mask to original image
41
+ pil_mask = Image.fromarray(np.squeeze(result_array))
42
  new_im = orig_image.copy()
43
  new_im.putalpha(pil_mask)
 
44
 
45
+ # Convert to bytes for download
46
+ buffer = io.BytesIO()
47
+ new_im.save(buffer, format="PNG")
48
+ buffer.seek(0)
49
+
50
+ return new_im, buffer
 
51
 
 
 
 
 
52
 
53
+ def process_with_download(image):
54
+ new_image, buffer = process(image)
55
+ return new_image, ("background_removed.png", buffer)
56
 
57
+
58
+ demo = gr.Interface(
59
+ fn=process_with_download,
60
+ inputs=gr.Image(type="numpy"),
61
+ outputs=[gr.Image(label="Processed Image"), gr.File(label="Download")],
62
+ title="Background Removal with BRIA RMBG 1.4",
63
+ description="Upload an image to remove the background and download the result."
64
+ )
65
 
66
  if __name__ == "__main__":
67
+ demo.launch()
68
+