ihabooe commited on
Commit
e730218
·
verified ·
1 Parent(s): b82a86a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -40
app.py CHANGED
@@ -5,69 +5,71 @@ from torchvision.transforms.functional import normalize
5
  import gradio as gr
6
  from briarmbg import BriaRMBG
7
  from PIL import Image
8
- import tempfile
9
 
10
- # Load the model
11
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  net.to(device)
14
  net.eval()
15
 
16
- # Resize input image
17
  def resize_image(image):
18
  image = image.convert('RGB')
19
  model_input_size = (1024, 1024)
20
- return image.resize(model_input_size, Image.BILINEAR)
 
21
 
22
- # Process the image
23
- def process(image, progress=gr.Progress()):
24
- progress(0.2) # 20% progress for loading
25
- orig_image = Image.fromarray(image)
26
- w, h = orig_image.size
27
 
28
- # Resize image
 
 
29
  image = resize_image(orig_image)
30
  im_np = np.array(image)
31
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
 
 
32
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
33
 
34
  if torch.cuda.is_available():
35
  im_tensor = im_tensor.cuda()
36
 
37
- # Inference
38
  result = net(im_tensor)
39
- progress(0.7) # 70% progress during inference
40
-
41
- # Post-process
42
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
43
- result = (result - result.min()) / (result.max() - result.min())
 
 
44
 
45
- # Convert to PIL image with alpha mask
46
  result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
47
- pil_mask = Image.fromarray(result_array)
 
48
  new_im = orig_image.copy()
49
  new_im.putalpha(pil_mask)
50
 
51
- # Save the image for download
52
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
53
- new_im.save(temp_file.name, format='PNG')
54
-
55
- progress(1.0) # 100% complete
56
- return new_im, temp_file.name
57
-
58
- # Gradio interface
59
- with gr.Blocks() as demo:
60
- gr.Markdown("## BRIA RMBG 1.4 - Background Remover")
61
- gr.HTML("<p>Upload your image to remove the background.</p>")
62
-
63
- with gr.Row():
64
- input_image = gr.Image(type="numpy", label="Upload Image")
65
- output_image = gr.Image(type="pil", label="Processed Image")
66
-
67
- download_button = gr.File(label="Download Processed Image")
68
-
69
- input_image.change(fn=process, inputs=input_image, outputs=[output_image, download_button])
70
-
71
- # Run the app on port 7860
72
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
73
 
 
 
 
5
  import gradio as gr
6
  from briarmbg import BriaRMBG
7
  from PIL import Image
8
+ import io
9
 
 
10
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  net.to(device)
13
  net.eval()
14
 
15
+
16
  def resize_image(image):
17
  image = image.convert('RGB')
18
  model_input_size = (1024, 1024)
19
+ image = image.resize(model_input_size, Image.BILINEAR)
20
+ return image
21
 
 
 
 
 
 
22
 
23
+ def process(image):
24
+ orig_image = Image.fromarray(image)
25
+ w, h = orig_im_size = orig_image.size
26
  image = resize_image(orig_image)
27
  im_np = np.array(image)
28
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
29
+ im_tensor = torch.unsqueeze(im_tensor, 0)
30
+ im_tensor = torch.divide(im_tensor, 255.0)
31
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
32
 
33
  if torch.cuda.is_available():
34
  im_tensor = im_tensor.cuda()
35
 
 
36
  result = net(im_tensor)
 
 
 
37
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
38
+ ma = torch.max(result)
39
+ mi = torch.min(result)
40
+ result = (result - mi) / (ma - mi)
41
 
 
42
  result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
43
+ pil_mask = Image.fromarray(result_array, mode="L")
44
+
45
  new_im = orig_image.copy()
46
  new_im.putalpha(pil_mask)
47
 
48
+ output_buffer = io.BytesIO()
49
+ new_im.save(output_buffer, format="PNG")
50
+ output_buffer.seek(0)
51
+
52
+ return new_im, output_buffer
53
+
54
+
55
+ def process_with_download(image):
56
+ new_im, output_buffer = process(image)
57
+ return new_im, ("output.png", output_buffer, "image/png")
58
+
59
+
60
+ gr.Markdown("## BRIA RMBG 1.4 - Background Remover")
61
+
62
+ demo = gr.Interface(
63
+ fn=process_with_download,
64
+ inputs=gr.Image(type="numpy"),
65
+ outputs=[
66
+ gr.Image(type="pil", label="Result"),
67
+ gr.File(label="Download Image")
68
+ ],
69
+ live=True,
70
+ title="BRIA RMBG Background Remover",
71
+ description="Upload an image to remove the background automatically.",
72
+ )
73
 
74
+ if __name__ == "__main__":
75
+ demo.launch()