Spaces:
Sleeping
Sleeping
File size: 4,524 Bytes
af0f898 61123b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import Dict
import gradio as gr
import json
import PIL.Image, PIL.ImageOps
import torch
import torchvision.transforms.functional as F
from matplotlib import cm
from matplotlib.colors import to_hex
import numpy as np
from src.models.dino import DINOSegmentationModel
from src.models.vit import ViTSegmentation
from src.models.unet import UNet
from src.utils import get_transform
device = torch.device("cpu")
model_weight1 = "weights/dino.pth"
model_weight2 = "weights/vit.pth"
model_weight3 = "weights/unet.pth"
model1 = DINOSegmentationModel()
model1.segmentation_head.load_state_dict(torch.load(model_weight1, map_location=device))
model1.eval()
model2 = ViTSegmentation()
model2.segmentation_head.load_state_dict(torch.load(model_weight2, map_location=device))
model2.eval()
model3 = UNet()
model3.load_state_dict(torch.load(model_weight3, map_location=device))
model3.eval()
mask_labels = {
"0": "Background", "1": "Hat", "2": "Hair", "3": "Sunglasses", "4": "Upper-clothes",
"5": "Skirt", "6": "Pants", "7": "Dress", "8": "Belt", "9": "Right-shoe",
"10": "Left-shoe", "11": "Face", "12": "Right-leg", "13": "Left-leg",
"14": "Right-arm", "15": "Left-arm", "16": "Bag", "17": "Scarf"
}
color_map = cm.get_cmap('tab20', 18)
label_colors = {label: to_hex(color_map(idx / len(mask_labels))[:3]) for idx, label in enumerate(mask_labels)}
fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255
def mask_to_color(mask: np.ndarray) -> np.ndarray:
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for class_idx in range(18):
color_mask[mask == class_idx] = fixed_colors[class_idx]
return color_mask
def segment_image(image, model_name: str) -> PIL.Image:
if model_name == "DINO":
model = model1
elif model_name == "ViT":
model = model2
else:
model = model3
original_width, original_height = image.size
transform = get_transform(model.mean, model.std)
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
mask = model(input_tensor)
mask = torch.argmax(mask.squeeze(), dim=0).cpu().numpy()
mask_image = mask_to_color(mask)
mask_image = PIL.Image.fromarray(mask_image)
mask_aspect_ratio = mask_image.width / mask_image.height
new_height = original_height
new_width = int(new_height * mask_aspect_ratio)
mask_image = mask_image.resize((new_width, new_height), PIL.Image.Resampling.NEAREST)
final_mask = PIL.Image.new("RGB", (original_width, original_height))
offset = ((original_width - new_width) // 2, 0)
final_mask.paste(mask_image, offset)
return final_mask
def generate_legend_html_compact() -> str:
legend_html = """
<div style='display: flex; flex-wrap: wrap; gap: 10px; justify-content: center;'>
"""
for idx, (label, color) in enumerate(label_colors.items()):
legend_html += f"""
<div style='display: flex; align-items: center; justify-content: center;
padding: 5px 10px; border: 1px solid {color};
background-color: {color}; border-radius: 5px;
color: white; font-size: 12px; text-align: center;'>
{mask_labels[label]}
</div>
"""
legend_html += "</div>"
return legend_html
examples = [
["assets/images_examples/image1.jpg"],
["assets/images_examples/image2.jpg"],
["assets/images_examples/image3.jpg"]
]
with gr.Blocks() as demo:
gr.Markdown("## Clothes Segmentation")
with gr.Row():
with gr.Column():
pic = gr.Image(label="Upload Human Image", type="pil", height=300, width=300)
model_choice = gr.Dropdown(choices=["DINO", "ViT", "UNet"], label="Select Model", value="DINO")
with gr.Row():
with gr.Column(scale=1):
predict_btn = gr.Button("Predict")
with gr.Column(scale=1):
clear_btn = gr.Button("Clear")
with gr.Column():
output = gr.Image(label="Mask", type="pil", height=300, width=300)
legend = gr.HTML(label="Legend", value=generate_legend_html_compact())
#predict_btn.click(fn=segment_image, inputs=[pic, model_choice], outputs=output, api_name="predict")
predict_btn.click(fn=segment_image, inputs=[pic, model_choice], outputs=output)
clear_btn.click(lambda: (None, None), outputs=[pic, output])
gr.Examples(examples=examples, inputs=[pic])
demo.launch() |