Update app.py
Browse files
app.py
CHANGED
|
@@ -1,94 +1,105 @@
|
|
| 1 |
-
import torch
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torchvision.transforms.functional import normalize
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
import gradio as gr
|
| 7 |
-
from
|
| 8 |
from briarmbg import BriaRMBG
|
|
|
|
|
|
|
| 9 |
from typing import Tuple
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
| 14 |
if torch.cuda.is_available():
|
| 15 |
net.load_state_dict(torch.load(model_path))
|
| 16 |
-
net
|
| 17 |
else:
|
| 18 |
-
net.load_state_dict(torch.load(model_path,
|
| 19 |
-
net.eval()
|
| 20 |
-
|
| 21 |
-
def resize_image(image) -> Image.Image:
|
| 22 |
-
"""Resize image to fit model's input requirements."""
|
| 23 |
-
pixel_number = 960 * 960
|
| 24 |
-
granularity_val = 64
|
| 25 |
-
ratio = image.size[0] / image.size[1]
|
| 26 |
-
width = int((pixel_number * ratio) ** 0.5)
|
| 27 |
-
width -= width % granularity_val
|
| 28 |
-
height = int(pixel_number / width)
|
| 29 |
-
height -= height % granularity_val
|
| 30 |
-
return image.resize((width, height))
|
| 31 |
-
|
| 32 |
-
def get_masked_background_image(image, image_mask) -> Tuple[np.ndarray, np.ndarray]:
|
| 33 |
-
"""Apply the segmentation mask to the original image."""
|
| 34 |
-
image_mask = image_mask.resize(image.size)
|
| 35 |
-
image = np.array(image.convert("RGB")).transpose(2, 0, 1).astype(np.float32) / 255.0
|
| 36 |
-
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
|
| 37 |
-
image[:, image_mask < 0.5] = 0
|
| 38 |
-
return image, image_mask
|
| 39 |
-
|
| 40 |
-
def get_control_image_tensor(vae, image, mask) -> torch.Tensor:
|
| 41 |
-
"""Prepare the masked image tensor for model input."""
|
| 42 |
-
masked_image, image_mask = get_masked_background_image(image, mask)
|
| 43 |
-
masked_image_tensor = torch.from_numpy(masked_image)
|
| 44 |
-
masked_image_tensor = (masked_image_tensor - 0.5) / 0.5
|
| 45 |
-
masked_image_tensor = masked_image_tensor.unsqueeze(0).to(device="cuda:0")
|
| 46 |
-
control_latents = vae.encode(masked_image_tensor[:, :3, :, :].to(vae.dtype)).latent_dist.sample()
|
| 47 |
-
control_latents = control_latents * vae.config.scaling_factor
|
| 48 |
-
mask_tensor = torch.tensor(image_mask, dtype=torch.float32)[None, None, ...].to(device="cuda:0")
|
| 49 |
-
mask_resized = torch.nn.functional.interpolate(mask_tensor, size=(control_latents.shape[2], control_latents.shape[3]), mode='nearest')
|
| 50 |
-
control_tensor = torch.cat([control_latents, mask_resized], dim=1)
|
| 51 |
-
return control_tensor
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
return
|
| 59 |
|
| 60 |
-
def paste_fg_over_image(gen_image: Image.Image, orig_image: Image.Image, fg_mask: Image.Image) -> Image.Image:
|
| 61 |
-
"""Paste the foreground over the generated image."""
|
| 62 |
-
fg_mask = fg_mask.convert("L").resize(orig_image.size, Image.NEAREST)
|
| 63 |
-
gen_image = gen_image.convert("RGBA")
|
| 64 |
-
orig_image = orig_image.convert("RGBA")
|
| 65 |
-
gen_image.paste(orig_image, (0, 0), fg_mask)
|
| 66 |
-
return gen_image.convert("RGB")
|
| 67 |
|
| 68 |
def process(image):
|
| 69 |
-
|
|
|
|
| 70 |
orig_image = Image.fromarray(image)
|
|
|
|
| 71 |
image = resize_image(orig_image)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
title = "Human Body Segmentation"
|
| 78 |
-
description = """
|
| 79 |
-
|
| 80 |
-
Separate foreground and background for various image categories. Trained for commercial content.
|
| 81 |
"""
|
| 82 |
-
examples = [['./jisoo.jpg']]
|
| 83 |
-
|
| 84 |
-
demo = gr.Interface(
|
| 85 |
-
|
| 86 |
-
inputs="image",
|
| 87 |
-
outputs="image",
|
| 88 |
-
title=title,
|
| 89 |
-
description=description,
|
| 90 |
-
examples=examples
|
| 91 |
-
)
|
| 92 |
|
| 93 |
if __name__ == "__main__":
|
| 94 |
-
demo.launch(share=False)
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torchvision.transforms.functional import normalize
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
import gradio as gr
|
| 7 |
+
from gradio_imageslider import ImageSlider
|
| 8 |
from briarmbg import BriaRMBG
|
| 9 |
+
import PIL
|
| 10 |
+
from PIL import Image
|
| 11 |
from typing import Tuple
|
| 12 |
|
| 13 |
+
net=BriaRMBG()
|
| 14 |
+
# model_path = "./model1.pth"
|
| 15 |
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
| 16 |
if torch.cuda.is_available():
|
| 17 |
net.load_state_dict(torch.load(model_path))
|
| 18 |
+
net=net.cuda()
|
| 19 |
else:
|
| 20 |
+
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
| 21 |
+
net.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
|
| 24 |
+
def resize_image(image):
|
| 25 |
+
image = image.convert('RGB')
|
| 26 |
+
model_input_size = (1024, 1024)
|
| 27 |
+
image = image.resize(model_input_size, Image.BILINEAR)
|
| 28 |
+
return image
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def process(image):
|
| 32 |
+
|
| 33 |
+
# prepare input
|
| 34 |
orig_image = Image.fromarray(image)
|
| 35 |
+
w,h = orig_im_size = orig_image.size
|
| 36 |
image = resize_image(orig_image)
|
| 37 |
+
im_np = np.array(image)
|
| 38 |
+
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
|
| 39 |
+
im_tensor = torch.unsqueeze(im_tensor,0)
|
| 40 |
+
im_tensor = torch.divide(im_tensor,255.0)
|
| 41 |
+
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
im_tensor=im_tensor.cuda()
|
| 44 |
+
|
| 45 |
+
#inference
|
| 46 |
+
result=net(im_tensor)
|
| 47 |
+
# post process
|
| 48 |
+
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
|
| 49 |
+
ma = torch.max(result)
|
| 50 |
+
mi = torch.min(result)
|
| 51 |
+
result = (result-mi)/(ma-mi)
|
| 52 |
+
# image to pil
|
| 53 |
+
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
|
| 54 |
+
pil_im = Image.fromarray(np.squeeze(im_array))
|
| 55 |
+
# paste the mask on the original image
|
| 56 |
+
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
| 57 |
+
new_im.paste(orig_image, mask=pil_im)
|
| 58 |
+
# new_orig_image = orig_image.convert('RGBA')
|
| 59 |
+
|
| 60 |
+
return new_im
|
| 61 |
+
# return [new_orig_image, new_im]]
|
| 62 |
+
|
| 63 |
+
# block = gr.Blocks().queue()
|
| 64 |
+
|
| 65 |
+
# with block:
|
| 66 |
+
# gr.Markdown("## HBS_V1")
|
| 67 |
+
# gr.HTML('''
|
| 68 |
+
# <p style="margin-bottom: 10px; font-size: 94%">
|
| 69 |
+
# This is a demo for Human Body Segmentation that using
|
| 70 |
+
# YoloV8 image instance model as backbone.
|
| 71 |
+
# </p>
|
| 72 |
+
# ''')
|
| 73 |
+
# with gr.Row():
|
| 74 |
+
# with gr.Column():
|
| 75 |
+
# input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
|
| 76 |
+
# # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
|
| 77 |
+
# run_button = gr.Button(value="Run")
|
| 78 |
+
|
| 79 |
+
# with gr.Column():
|
| 80 |
+
# result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
|
| 81 |
+
# ips = [input_image]
|
| 82 |
+
# run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 83 |
|
| 84 |
+
# block.launch(debug = True)
|
| 85 |
+
|
| 86 |
+
# block = gr.Blocks().queue()
|
| 87 |
+
|
| 88 |
+
gr.Markdown("## HBS_V1")
|
| 89 |
+
gr.HTML('''
|
| 90 |
+
<p style="margin-bottom: 10px; font-size: 94%">
|
| 91 |
+
This is a demo for Human Body Segmentation that using
|
| 92 |
+
YoloV8 image instance model as backbone.
|
| 93 |
+
</p>
|
| 94 |
+
''')
|
| 95 |
title = "Human Body Segmentation"
|
| 96 |
+
description = r"""Human Body Segmentation model developed by <a href='https://github.com/WildanJR09' target='_blank'><b>WildanJR</b></a>, Designed to effectively separate foreground from background in a range of categories and image types.<br>
|
| 97 |
+
This model has been trained on a carefully selected dataset, which includes: general stock images, e-commerce, gaming, and advertising content, making it suitable for commercial use cases powering enterprise content creation at scale. The accuracy, efficiency, and versatility currently rival leading source-available models. It is ideal where content safety, legally licensed datasets, and bias mitigation are paramount. For test upload your image and wait. </a>.<br>
|
|
|
|
| 98 |
"""
|
| 99 |
+
examples = [['./jisoo.jpg'],]
|
| 100 |
+
# output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
|
| 101 |
+
# demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
|
| 102 |
+
demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
if __name__ == "__main__":
|
| 105 |
+
demo.launch(share=False)
|