WildanJR commited on
Commit
7253dd1
·
verified ·
1 Parent(s): 7e21f17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -71
app.py CHANGED
@@ -1,94 +1,105 @@
1
- import torch
2
  import numpy as np
 
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
- from PIL import Image
8
  from briarmbg import BriaRMBG
 
 
9
  from typing import Tuple
10
 
11
- # Load the BriaRMBG model
12
- net = BriaRMBG()
13
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
14
  if torch.cuda.is_available():
15
  net.load_state_dict(torch.load(model_path))
16
- net = net.cuda()
17
  else:
18
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
19
- net.eval()
20
-
21
- def resize_image(image) -> Image.Image:
22
- """Resize image to fit model's input requirements."""
23
- pixel_number = 960 * 960
24
- granularity_val = 64
25
- ratio = image.size[0] / image.size[1]
26
- width = int((pixel_number * ratio) ** 0.5)
27
- width -= width % granularity_val
28
- height = int(pixel_number / width)
29
- height -= height % granularity_val
30
- return image.resize((width, height))
31
-
32
- def get_masked_background_image(image, image_mask) -> Tuple[np.ndarray, np.ndarray]:
33
- """Apply the segmentation mask to the original image."""
34
- image_mask = image_mask.resize(image.size)
35
- image = np.array(image.convert("RGB")).transpose(2, 0, 1).astype(np.float32) / 255.0
36
- image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
37
- image[:, image_mask < 0.5] = 0
38
- return image, image_mask
39
-
40
- def get_control_image_tensor(vae, image, mask) -> torch.Tensor:
41
- """Prepare the masked image tensor for model input."""
42
- masked_image, image_mask = get_masked_background_image(image, mask)
43
- masked_image_tensor = torch.from_numpy(masked_image)
44
- masked_image_tensor = (masked_image_tensor - 0.5) / 0.5
45
- masked_image_tensor = masked_image_tensor.unsqueeze(0).to(device="cuda:0")
46
- control_latents = vae.encode(masked_image_tensor[:, :3, :, :].to(vae.dtype)).latent_dist.sample()
47
- control_latents = control_latents * vae.config.scaling_factor
48
- mask_tensor = torch.tensor(image_mask, dtype=torch.float32)[None, None, ...].to(device="cuda:0")
49
- mask_resized = torch.nn.functional.interpolate(mask_tensor, size=(control_latents.shape[2], control_latents.shape[3]), mode='nearest')
50
- control_tensor = torch.cat([control_latents, mask_resized], dim=1)
51
- return control_tensor
52
 
53
- def remove_bg_from_image(image) -> Image.Image:
54
- """Use BriaRMBG to generate a segmentation mask."""
55
- from transformers import pipeline
56
- pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
57
- mask = pipe(image, return_mask=True)
58
- return mask
59
 
60
- def paste_fg_over_image(gen_image: Image.Image, orig_image: Image.Image, fg_mask: Image.Image) -> Image.Image:
61
- """Paste the foreground over the generated image."""
62
- fg_mask = fg_mask.convert("L").resize(orig_image.size, Image.NEAREST)
63
- gen_image = gen_image.convert("RGBA")
64
- orig_image = orig_image.convert("RGBA")
65
- gen_image.paste(orig_image, (0, 0), fg_mask)
66
- return gen_image.convert("RGB")
67
 
68
  def process(image):
69
- """Process image for background removal and pasting over new background."""
 
70
  orig_image = Image.fromarray(image)
 
71
  image = resize_image(orig_image)
72
- mask = remove_bg_from_image(image)
73
- result_image = paste_fg_over_image(image, orig_image, mask)
74
- return result_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # Set up Gradio interface
 
 
 
 
 
 
 
 
 
 
77
  title = "Human Body Segmentation"
78
- description = """
79
- Human Body Segmentation model by <a href='https://github.com/WildanJR09' target='_blank'><b>WildanJR</b></a>.
80
- Separate foreground and background for various image categories. Trained for commercial content.
81
  """
82
- examples = [['./jisoo.jpg']]
83
-
84
- demo = gr.Interface(
85
- fn=process,
86
- inputs="image",
87
- outputs="image",
88
- title=title,
89
- description=description,
90
- examples=examples
91
- )
92
 
93
  if __name__ == "__main__":
94
- demo.launch(share=False)
 
 
1
  import numpy as np
2
+ import torch
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
+ from gradio_imageslider import ImageSlider
8
  from briarmbg import BriaRMBG
9
+ import PIL
10
+ from PIL import Image
11
  from typing import Tuple
12
 
13
+ net=BriaRMBG()
14
+ # model_path = "./model1.pth"
15
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
  if torch.cuda.is_available():
17
  net.load_state_dict(torch.load(model_path))
18
+ net=net.cuda()
19
  else:
20
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
+ net.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+
24
+ def resize_image(image):
25
+ image = image.convert('RGB')
26
+ model_input_size = (1024, 1024)
27
+ image = image.resize(model_input_size, Image.BILINEAR)
28
+ return image
29
 
 
 
 
 
 
 
 
30
 
31
  def process(image):
32
+
33
+ # prepare input
34
  orig_image = Image.fromarray(image)
35
+ w,h = orig_im_size = orig_image.size
36
  image = resize_image(orig_image)
37
+ im_np = np.array(image)
38
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
39
+ im_tensor = torch.unsqueeze(im_tensor,0)
40
+ im_tensor = torch.divide(im_tensor,255.0)
41
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
42
+ if torch.cuda.is_available():
43
+ im_tensor=im_tensor.cuda()
44
+
45
+ #inference
46
+ result=net(im_tensor)
47
+ # post process
48
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
49
+ ma = torch.max(result)
50
+ mi = torch.min(result)
51
+ result = (result-mi)/(ma-mi)
52
+ # image to pil
53
+ im_array = (result*255).cpu().data.numpy().astype(np.uint8)
54
+ pil_im = Image.fromarray(np.squeeze(im_array))
55
+ # paste the mask on the original image
56
+ new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
57
+ new_im.paste(orig_image, mask=pil_im)
58
+ # new_orig_image = orig_image.convert('RGBA')
59
+
60
+ return new_im
61
+ # return [new_orig_image, new_im]]
62
+
63
+ # block = gr.Blocks().queue()
64
+
65
+ # with block:
66
+ # gr.Markdown("## HBS_V1")
67
+ # gr.HTML('''
68
+ # <p style="margin-bottom: 10px; font-size: 94%">
69
+ # This is a demo for Human Body Segmentation that using
70
+ # YoloV8 image instance model as backbone.
71
+ # </p>
72
+ # ''')
73
+ # with gr.Row():
74
+ # with gr.Column():
75
+ # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
76
+ # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
77
+ # run_button = gr.Button(value="Run")
78
+
79
+ # with gr.Column():
80
+ # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
81
+ # ips = [input_image]
82
+ # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
83
 
84
+ # block.launch(debug = True)
85
+
86
+ # block = gr.Blocks().queue()
87
+
88
+ gr.Markdown("## HBS_V1")
89
+ gr.HTML('''
90
+ <p style="margin-bottom: 10px; font-size: 94%">
91
+ This is a demo for Human Body Segmentation that using
92
+ YoloV8 image instance model as backbone.
93
+ </p>
94
+ ''')
95
  title = "Human Body Segmentation"
96
+ description = r"""Human Body Segmentation model developed by <a href='https://github.com/WildanJR09' target='_blank'><b>WildanJR</b></a>, Designed to effectively separate foreground from background in a range of categories and image types.<br>
97
+ This model has been trained on a carefully selected dataset, which includes: general stock images, e-commerce, gaming, and advertising content, making it suitable for commercial use cases powering enterprise content creation at scale. The accuracy, efficiency, and versatility currently rival leading source-available models. It is ideal where content safety, legally licensed datasets, and bias mitigation are paramount. For test upload your image and wait. </a>.<br>
 
98
  """
99
+ examples = [['./jisoo.jpg'],]
100
+ # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
101
+ # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
102
+ demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
 
 
 
 
 
 
103
 
104
  if __name__ == "__main__":
105
+ demo.launch(share=False)