ihabooe commited on
Commit
31ef11c
·
verified ·
1 Parent(s): c718d81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -12
app.py CHANGED
@@ -8,12 +8,15 @@ import PIL
8
  from PIL import Image
9
  import tempfile
10
  import os
 
11
 
12
  # Load the pre-trained model
 
13
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  net.to(device)
16
  net.eval()
 
17
 
18
  # Resize the input image for model compatibility
19
  def resize_image(image):
@@ -23,8 +26,14 @@ def resize_image(image):
23
  return image
24
 
25
  # Background removal process
26
- def process(image):
 
 
 
 
 
27
  # Prepare the input
 
28
  orig_image = Image.fromarray(image)
29
  w, h = orig_im_size = orig_image.size
30
  image = resize_image(orig_image)
@@ -33,6 +42,8 @@ def process(image):
33
  im_tensor = torch.unsqueeze(im_tensor, 0)
34
  im_tensor = torch.divide(im_tensor, 255.0)
35
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
 
 
36
  if torch.cuda.is_available():
37
  im_tensor = im_tensor.cuda()
38
 
@@ -40,6 +51,7 @@ def process(image):
40
  with torch.no_grad():
41
  result = net(im_tensor)
42
 
 
43
  # Post-process the result
44
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
45
  ma = torch.max(result)
@@ -54,13 +66,18 @@ def process(image):
54
  new_im = orig_image.copy()
55
  new_im.putalpha(pil_mask)
56
 
 
57
  # Save the processed image to a temporary file
58
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
59
  new_im.save(temp_file.name, format='PNG')
60
  temp_file.close()
61
 
62
- # Return the path to the temporary file for downloading
63
- return temp_file.name
 
 
 
 
64
 
65
  # Gradio interface setup
66
  title = "Background Removal"
@@ -69,15 +86,37 @@ description = r"""Background removal model developed by <a href='https://BRIA.AI
69
  examples = [['./input.jpg']]
70
 
71
  # Create the Gradio interface
72
- demo = gr.Interface(
73
- fn=process, # The function to process the image
74
- inputs=gr.Image(type="numpy"), # Input type (image)
75
- outputs=gr.File(label="Download Processed Image"), # Output as a file (download button)
76
- examples=examples, # Example images for users to try
77
- title=title, # Title of the app
78
- description=description, # Description of the app
79
- live=True # Automatically processes when an image is uploaded
80
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  if __name__ == "__main__":
83
  demo.launch(share=False)
 
8
  from PIL import Image
9
  import tempfile
10
  import os
11
+ import time
12
 
13
  # Load the pre-trained model
14
+ print("Loading model...")
15
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  net.to(device)
18
  net.eval()
19
+ print(f"Model loaded on {device}")
20
 
21
  # Resize the input image for model compatibility
22
  def resize_image(image):
 
26
  return image
27
 
28
  # Background removal process
29
+ def process(image, progress=gr.Progress()):
30
+ if image is None:
31
+ return None, None
32
+
33
+ progress(0, desc="Starting processing...")
34
+
35
  # Prepare the input
36
+ progress(0.1, desc="Preparing image...")
37
  orig_image = Image.fromarray(image)
38
  w, h = orig_im_size = orig_image.size
39
  image = resize_image(orig_image)
 
42
  im_tensor = torch.unsqueeze(im_tensor, 0)
43
  im_tensor = torch.divide(im_tensor, 255.0)
44
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
45
+
46
+ progress(0.3, desc="Processing with AI model...")
47
  if torch.cuda.is_available():
48
  im_tensor = im_tensor.cuda()
49
 
 
51
  with torch.no_grad():
52
  result = net(im_tensor)
53
 
54
+ progress(0.6, desc="Post-processing...")
55
  # Post-process the result
56
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
57
  ma = torch.max(result)
 
66
  new_im = orig_image.copy()
67
  new_im.putalpha(pil_mask)
68
 
69
+ progress(0.8, desc="Preparing download...")
70
  # Save the processed image to a temporary file
71
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
72
  new_im.save(temp_file.name, format='PNG')
73
  temp_file.close()
74
 
75
+ # Convert to numpy array for display
76
+ output_array = np.array(new_im.convert("RGBA"))
77
+
78
+ progress(1.0, desc="Done!")
79
+ # Return both the image for display and the path for download
80
+ return output_array, temp_file.name
81
 
82
  # Gradio interface setup
83
  title = "Background Removal"
 
86
  examples = [['./input.jpg']]
87
 
88
  # Create the Gradio interface
89
+ with gr.Blocks() as demo:
90
+ gr.Markdown(f"# {title}")
91
+ gr.Markdown(description)
92
+
93
+ with gr.Row():
94
+ with gr.Column(scale=1):
95
+ input_image = gr.Image(type="numpy", label="Upload Image")
96
+ process_btn = gr.Button("Remove Background", variant="primary")
97
+
98
+ with gr.Column(scale=1):
99
+ output_image = gr.Image(type="numpy", label="Result")
100
+ download_btn = gr.File(label="Download Image")
101
+
102
+ # Set up example images
103
+ gr.Examples(examples, inputs=input_image)
104
+
105
+ # Set up processing logic
106
+ process_btn.click(
107
+ fn=process,
108
+ inputs=input_image,
109
+ outputs=[output_image, download_btn],
110
+ show_progress="full"
111
+ )
112
+
113
+ # Also process automatically when image is uploaded
114
+ input_image.change(
115
+ fn=process,
116
+ inputs=input_image,
117
+ outputs=[output_image, download_btn],
118
+ show_progress="full"
119
+ )
120
 
121
  if __name__ == "__main__":
122
  demo.launch(share=False)