ihabooe commited on
Commit
c718d81
·
verified ·
1 Parent(s): a75fa35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -24
app.py CHANGED
@@ -9,12 +9,11 @@ from PIL import Image
9
  import tempfile
10
  import os
11
 
12
-
13
  # Load the pre-trained model
14
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  net.to(device)
17
- net.eval()
18
 
19
  # Resize the input image for model compatibility
20
  def resize_image(image):
@@ -34,48 +33,42 @@ def process(image):
34
  im_tensor = torch.unsqueeze(im_tensor, 0)
35
  im_tensor = torch.divide(im_tensor, 255.0)
36
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
37
-
38
  if torch.cuda.is_available():
39
  im_tensor = im_tensor.cuda()
40
-
41
  # Inference with the model
42
- result = net(im_tensor)
 
43
 
44
  # Post-process the result
45
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
46
  ma = torch.max(result)
47
  mi = torch.min(result)
48
- result = (result - mi) / (ma - mi) # Normalize result
49
-
50
  # Convert the result to an image
51
  result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
  pil_mask = Image.fromarray(np.squeeze(result_array))
53
-
54
  # Add the mask as alpha channel to the original image
55
  new_im = orig_image.copy()
56
  new_im.putalpha(pil_mask)
57
-
58
  # Save the processed image to a temporary file
59
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
60
- new_im.save(temp_file, format='PNG')
61
- temp_file.close() # Ensure the file is closed before Gradio uses it
62
-
63
- return temp_file.name # Return the path to the temporary file for downloading
 
64
 
65
  # Gradio interface setup
66
- gr.Markdown("## BRIA RMBG 1.4")
67
- gr.HTML('''<p style="margin-bottom: 10px; font-size: 94%">
68
- This is a demo for BRIA RMBG 1.4 that uses
69
- <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as a backbone.
70
- </p>''')
71
-
72
  title = "Background Removal"
73
- description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
74
- For testing, upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>. To purchase a commercial license, simply click <a href='https://go.bria.ai/3ZCBTLH' target='_blank'><b>Here</b></a>. <br>"""
75
 
76
- examples = [['./input.jpg'],]
77
 
78
- # Modify the interface to use live updates and file download
79
  demo = gr.Interface(
80
  fn=process, # The function to process the image
81
  inputs=gr.Image(type="numpy"), # Input type (image)
@@ -87,4 +80,4 @@ demo = gr.Interface(
87
  )
88
 
89
  if __name__ == "__main__":
90
- demo.launch(share=False)
 
9
  import tempfile
10
  import os
11
 
 
12
  # Load the pre-trained model
13
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  net.to(device)
16
+ net.eval()
17
 
18
  # Resize the input image for model compatibility
19
  def resize_image(image):
 
33
  im_tensor = torch.unsqueeze(im_tensor, 0)
34
  im_tensor = torch.divide(im_tensor, 255.0)
35
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
 
36
  if torch.cuda.is_available():
37
  im_tensor = im_tensor.cuda()
38
+
39
  # Inference with the model
40
+ with torch.no_grad():
41
+ result = net(im_tensor)
42
 
43
  # Post-process the result
44
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
45
  ma = torch.max(result)
46
  mi = torch.min(result)
47
+ result = (result - mi) / (ma - mi)
48
+
49
  # Convert the result to an image
50
  result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
51
  pil_mask = Image.fromarray(np.squeeze(result_array))
52
+
53
  # Add the mask as alpha channel to the original image
54
  new_im = orig_image.copy()
55
  new_im.putalpha(pil_mask)
56
+
57
  # Save the processed image to a temporary file
58
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
59
+ new_im.save(temp_file.name, format='PNG')
60
+ temp_file.close()
61
+
62
+ # Return the path to the temporary file for downloading
63
+ return temp_file.name
64
 
65
  # Gradio interface setup
 
 
 
 
 
 
66
  title = "Background Removal"
67
+ description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br> For testing, upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>. To purchase a commercial license, simply click <a href='https://go.bria.ai/3ZCBTLH' target='_blank'><b>Here</b></a>. <br>"""
 
68
 
69
+ examples = [['./input.jpg']]
70
 
71
+ # Create the Gradio interface
72
  demo = gr.Interface(
73
  fn=process, # The function to process the image
74
  inputs=gr.Image(type="numpy"), # Input type (image)
 
80
  )
81
 
82
  if __name__ == "__main__":
83
+ demo.launch(share=False)