ihabooe commited on
Commit
a75fa35
·
verified ·
1 Parent(s): 7c07d5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -28
app.py CHANGED
@@ -1,68 +1,90 @@
1
- import numpy as np
2
  import torch
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
6
- from PIL import Image
7
  from briarmbg import BriaRMBG
8
- import io
 
 
 
9
 
10
 
11
- # Load the model
12
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  net.to(device)
15
- net.eval()
16
-
17
 
 
18
  def resize_image(image):
19
  image = image.convert('RGB')
20
  model_input_size = (1024, 1024)
21
- return image.resize(model_input_size, Image.BILINEAR)
22
-
23
 
 
24
  def process(image):
 
25
  orig_image = Image.fromarray(image)
26
- w, h = orig_image.size
27
  image = resize_image(orig_image)
28
  im_np = np.array(image)
29
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
30
- im_tensor = torch.unsqueeze(im_tensor, 0) / 255.0
 
31
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
 
32
  if torch.cuda.is_available():
33
  im_tensor = im_tensor.cuda()
34
 
35
- # Inference
36
  result = net(im_tensor)
 
 
37
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
38
- result_array = ((result - result.min()) / (result.max() - result.min()) * 255).cpu().data.numpy().astype(np.uint8)
 
 
39
 
40
- # Add mask to original image
 
41
  pil_mask = Image.fromarray(np.squeeze(result_array))
 
 
42
  new_im = orig_image.copy()
43
  new_im.putalpha(pil_mask)
44
 
45
- # Convert to bytes for download
46
- buffer = io.BytesIO()
47
- new_im.save(buffer, format="PNG")
48
- buffer.seek(0)
49
 
50
- return new_im, buffer
51
 
 
 
 
 
 
 
52
 
53
- def process_with_download(image):
54
- new_image, buffer = process(image)
55
- return new_image, ("background_removed.png", buffer)
56
 
 
57
 
 
58
  demo = gr.Interface(
59
- fn=process_with_download,
60
- inputs=gr.Image(type="numpy"),
61
- outputs=[gr.Image(label="Processed Image"), gr.File(label="Download")],
62
- title="Background Removal with BRIA RMBG 1.4",
63
- description="Upload an image to remove the background and download the result."
 
 
64
  )
65
 
66
  if __name__ == "__main__":
67
- demo.launch()
68
-
 
1
+ import numpy as np
2
  import torch
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
 
6
  from briarmbg import BriaRMBG
7
+ import PIL
8
+ from PIL import Image
9
+ import tempfile
10
+ import os
11
 
12
 
13
+ # Load the pre-trained model
14
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  net.to(device)
17
+ net.eval()
 
18
 
19
+ # Resize the input image for model compatibility
20
  def resize_image(image):
21
  image = image.convert('RGB')
22
  model_input_size = (1024, 1024)
23
+ image = image.resize(model_input_size, Image.BILINEAR)
24
+ return image
25
 
26
+ # Background removal process
27
  def process(image):
28
+ # Prepare the input
29
  orig_image = Image.fromarray(image)
30
+ w, h = orig_im_size = orig_image.size
31
  image = resize_image(orig_image)
32
  im_np = np.array(image)
33
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
34
+ im_tensor = torch.unsqueeze(im_tensor, 0)
35
+ im_tensor = torch.divide(im_tensor, 255.0)
36
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
37
+
38
  if torch.cuda.is_available():
39
  im_tensor = im_tensor.cuda()
40
 
41
+ # Inference with the model
42
  result = net(im_tensor)
43
+
44
+ # Post-process the result
45
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
46
+ ma = torch.max(result)
47
+ mi = torch.min(result)
48
+ result = (result - mi) / (ma - mi) # Normalize result
49
 
50
+ # Convert the result to an image
51
+ result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
  pil_mask = Image.fromarray(np.squeeze(result_array))
53
+
54
+ # Add the mask as alpha channel to the original image
55
  new_im = orig_image.copy()
56
  new_im.putalpha(pil_mask)
57
 
58
+ # Save the processed image to a temporary file
59
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
60
+ new_im.save(temp_file, format='PNG')
61
+ temp_file.close() # Ensure the file is closed before Gradio uses it
62
 
63
+ return temp_file.name # Return the path to the temporary file for downloading
64
 
65
+ # Gradio interface setup
66
+ gr.Markdown("## BRIA RMBG 1.4")
67
+ gr.HTML('''<p style="margin-bottom: 10px; font-size: 94%">
68
+ This is a demo for BRIA RMBG 1.4 that uses
69
+ <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as a backbone.
70
+ </p>''')
71
 
72
+ title = "Background Removal"
73
+ description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
74
+ For testing, upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>. To purchase a commercial license, simply click <a href='https://go.bria.ai/3ZCBTLH' target='_blank'><b>Here</b></a>. <br>"""
75
 
76
+ examples = [['./input.jpg'],]
77
 
78
+ # Modify the interface to use live updates and file download
79
  demo = gr.Interface(
80
+ fn=process, # The function to process the image
81
+ inputs=gr.Image(type="numpy"), # Input type (image)
82
+ outputs=gr.File(label="Download Processed Image"), # Output as a file (download button)
83
+ examples=examples, # Example images for users to try
84
+ title=title, # Title of the app
85
+ description=description, # Description of the app
86
+ live=True # Automatically processes when an image is uploaded
87
  )
88
 
89
  if __name__ == "__main__":
90
+ demo.launch(share=False)