ihabooe commited on
Commit
38b6344
·
verified ·
1 Parent(s): daae5a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -42
app.py CHANGED
@@ -4,63 +4,32 @@ import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
6
  from briarmbg import BriaRMBG
7
- import PIL
8
  from PIL import Image
9
  import tempfile
10
 
11
-
12
- # Load the pre-trained model
13
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  net.to(device)
16
- net.eval()
17
 
18
- # Resize the input image for model compatibility
19
  def resize_image(image):
20
  image = image.convert('RGB')
21
  model_input_size = (1024, 1024)
22
- image = image.resize(model_input_size, Image.BILINEAR)
23
- return image
24
 
25
- # Background removal process with progress bar
26
  def process(image, progress=gr.Progress()):
27
- progress(0.1) # Start with 10% for loading
28
- # Prepare the input
29
  orig_image = Image.fromarray(image)
30
- w, h = orig_im_size = orig_image.size
 
 
31
  image = resize_image(orig_image)
32
  im_np = np.array(image)
33
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
34
- im_tensor = torch.unsqueeze(im_tensor, 0)
35
- im_tensor = torch.divide(im_tensor, 255.0)
36
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
37
 
38
  if torch.cuda.is_available():
39
- im_tensor = im_tensor.cuda()
40
-
41
- # Inference with the model
42
- result = net(im_tensor)
43
-
44
- progress(0.5) # Progress 50% during inference
45
-
46
- # Post-process the result
47
- result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
48
- ma = torch.max(result)
49
- mi = torch.min(result)
50
- result = (result - mi) / (ma - mi) # Normalize result
51
-
52
- # Convert the result to an image
53
- result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
54
- pil_mask = Image.fromarray(np.squeeze(result_array))
55
-
56
- # Add the mask as alpha channel to the original image
57
- new_im = orig_image.copy()
58
- new_im.putalpha(pil_mask)
59
-
60
- # Save the processed image to a temporary file
61
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
62
- new_im.save(temp_file, format='PNG')
63
- temp_file.close() # Ensure the file is closed before Gradio uses it
64
-
65
- progress(1.0) # Completion of the process
66
- return new_im, temp_file.name # Return the processed image and d
 
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
6
  from briarmbg import BriaRMBG
 
7
  from PIL import Image
8
  import tempfile
9
 
10
+ # Load the model
 
11
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  net.to(device)
14
+ net.eval()
15
 
16
+ # Resize input image
17
  def resize_image(image):
18
  image = image.convert('RGB')
19
  model_input_size = (1024, 1024)
20
+ return image.resize(model_input_size, Image.BILINEAR)
 
21
 
22
+ # Process the image
23
  def process(image, progress=gr.Progress()):
24
+ progress(0.2) # 20% progress for loading
 
25
  orig_image = Image.fromarray(image)
26
+ w, h = orig_image.size
27
+
28
+ # Resize image
29
  image = resize_image(orig_image)
30
  im_np = np.array(image)
31
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
 
 
32
  im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
33
 
34
  if torch.cuda.is_available():
35
+ im_tensor