ihabooe commited on
Commit
74499c7
·
verified ·
1 Parent(s): 9fb66bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -32
app.py CHANGED
@@ -3,70 +3,88 @@ import torch
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
6
- from gradio_imageslider import ImageSlider
7
  from briarmbg import BriaRMBG
8
  import PIL
9
  from PIL import Image
10
- from typing import Tuple
 
11
 
12
 
 
13
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  net.to(device)
16
  net.eval()
17
 
18
-
19
  def resize_image(image):
20
  image = image.convert('RGB')
21
  model_input_size = (1024, 1024)
22
  image = image.resize(model_input_size, Image.BILINEAR)
23
  return image
24
 
25
-
26
  def process(image):
27
-
28
- # prepare input
29
  orig_image = Image.fromarray(image)
30
- w,h = orig_im_size = orig_image.size
31
  image = resize_image(orig_image)
32
  im_np = np.array(image)
33
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
34
- im_tensor = torch.unsqueeze(im_tensor,0)
35
- im_tensor = torch.divide(im_tensor,255.0)
36
- im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
 
37
  if torch.cuda.is_available():
38
- im_tensor=im_tensor.cuda()
39
 
40
- #inference
41
- result=net(im_tensor)
42
- # post process
43
- result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
 
44
  ma = torch.max(result)
45
  mi = torch.min(result)
46
- result = (result-mi)/(ma-mi)
47
- # image to pil
48
- result_array = (result*255).cpu().data.numpy().astype(np.uint8)
 
49
  pil_mask = Image.fromarray(np.squeeze(result_array))
50
- # add the mask on the original image as alpha channel
 
51
  new_im = orig_image.copy()
52
  new_im.putalpha(pil_mask)
53
- return new_im
54
- # return [new_orig_image, new_im]
55
 
 
 
 
 
 
 
56
 
 
57
  gr.Markdown("## BRIA RMBG 1.4")
58
- gr.HTML('''
59
- <p style="margin-bottom: 10px; font-size: 94%">
60
- This is a demo for BRIA RMBG 1.4 that using
61
- <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
62
- </p>
63
- ''')
64
  title = "Background Removal"
65
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
66
- For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>. To purchase a commercial license, simply click <a href='https://go.bria.ai/3ZCBTLH' target='_blank'><b>Here</b></a>. <br>
67
- """
68
  examples = [['./input.jpg'],]
69
- demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
 
 
 
 
 
 
 
 
 
 
70
 
71
  if __name__ == "__main__":
72
- demo.launch(share=False)
 
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
 
6
  from briarmbg import BriaRMBG
7
  import PIL
8
  from PIL import Image
9
+ import tempfile
10
+ import os
11
 
12
 
13
+ # Load the pre-trained model
14
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  net.to(device)
17
  net.eval()
18
 
19
+ # Resize the input image for model compatibility
20
  def resize_image(image):
21
  image = image.convert('RGB')
22
  model_input_size = (1024, 1024)
23
  image = image.resize(model_input_size, Image.BILINEAR)
24
  return image
25
 
26
+ # Background removal process
27
  def process(image):
28
+ # Prepare the input
 
29
  orig_image = Image.fromarray(image)
30
+ w, h = orig_im_size = orig_image.size
31
  image = resize_image(orig_image)
32
  im_np = np.array(image)
33
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
34
+ im_tensor = torch.unsqueeze(im_tensor, 0)
35
+ im_tensor = torch.divide(im_tensor, 255.0)
36
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
37
+
38
  if torch.cuda.is_available():
39
+ im_tensor = im_tensor.cuda()
40
 
41
+ # Inference with the model
42
+ result = net(im_tensor)
43
+
44
+ # Post-process the result
45
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
46
  ma = torch.max(result)
47
  mi = torch.min(result)
48
+ result = (result - mi) / (ma - mi) # Normalize result
49
+
50
+ # Convert the result to an image
51
+ result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
  pil_mask = Image.fromarray(np.squeeze(result_array))
53
+
54
+ # Add the mask as alpha channel to the original image
55
  new_im = orig_image.copy()
56
  new_im.putalpha(pil_mask)
 
 
57
 
58
+ # Save the processed image to a temporary file
59
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
60
+ new_im.save(temp_file, format='PNG')
61
+ temp_file.close() # Ensure the file is closed before Gradio uses it
62
+
63
+ return temp_file.name # Return the path to the temporary file for downloading
64
 
65
+ # Gradio interface setup
66
  gr.Markdown("## BRIA RMBG 1.4")
67
+ gr.HTML('''<p style="margin-bottom: 10px; font-size: 94%">
68
+ This is a demo for BRIA RMBG 1.4 that uses
69
+ <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as a backbone.
70
+ </p>''')
71
+
 
72
  title = "Background Removal"
73
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
74
+ For testing, upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>. To purchase a commercial license, simply click <a href='https://go.bria.ai/3ZCBTLH' target='_blank'><b>Here</b></a>. <br>"""
75
+
76
  examples = [['./input.jpg'],]
77
+
78
+ # Modify the interface to use live updates and file download
79
+ demo = gr.Interface(
80
+ fn=process, # The function to process the image
81
+ inputs=gr.inputs.Image(type="numpy"), # Input type (image)
82
+ outputs=gr.File(label="Download Processed Image"), # Output as a file (download button)
83
+ examples=examples, # Example images for users to try
84
+ title=title, # Title of the app
85
+ description=description, # Description of the app
86
+ live=True # Automatically processes when an image is uploaded
87
+ )
88
 
89
  if __name__ == "__main__":
90
+ demo.launch(share=False)