Daimler commited on
Commit
af0f898
·
verified ·
1 Parent(s): 35a854c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -127
app.py CHANGED
@@ -1,128 +1,130 @@
1
- from typing import Dict
2
- import gradio as gr
3
- import json
4
- import PIL.Image, PIL.ImageOps
5
- import torch
6
- import torchvision.transforms.functional as F
7
- from matplotlib import cm
8
- from matplotlib.colors import to_hex
9
- import numpy as np
10
-
11
- from src.models.dino import DINOSegmentationModel
12
- from src.models.vit import ViTSegmentation
13
- from src.models.unet import UNet
14
- from src.utils import get_transform
15
-
16
-
17
- device = torch.device("cpu")
18
- model_weight1 = "weights/dino.pth"
19
- model_weight2 = "weights/vit.pth"
20
- model_weight3 = "weights/unet.pth"
21
-
22
- model1 = DINOSegmentationModel()
23
- model1.segmentation_head.load_state_dict(torch.load(model_weight1, map_location=device))
24
- model1.eval()
25
- model2 = ViTSegmentation()
26
- model2.segmentation_head.load_state_dict(torch.load(model_weight2, map_location=device))
27
- model2.eval()
28
- model3 = UNet()
29
- model3.load_state_dict(torch.load(model_weight3, map_location=device))
30
- model3.eval()
31
-
32
- mask_labels = {
33
- "0": "Background", "1": "Hat", "2": "Hair", "3": "Sunglasses", "4": "Upper-clothes",
34
- "5": "Skirt", "6": "Pants", "7": "Dress", "8": "Belt", "9": "Right-shoe",
35
- "10": "Left-shoe", "11": "Face", "12": "Right-leg", "13": "Left-leg",
36
- "14": "Right-arm", "15": "Left-arm", "16": "Bag", "17": "Scarf"
37
- }
38
-
39
- color_map = cm.get_cmap('tab20', 18)
40
- label_colors = {label: to_hex(color_map(idx / len(mask_labels))[:3]) for idx, label in enumerate(mask_labels)}
41
- fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255
42
-
43
-
44
- def mask_to_color(mask: np.ndarray) -> np.ndarray:
45
- h, w = mask.shape
46
- color_mask = np.zeros((h, w, 3), dtype=np.uint8)
47
- for class_idx in range(18):
48
- color_mask[mask == class_idx] = fixed_colors[class_idx]
49
- return color_mask
50
-
51
-
52
- def segment_image(image, model_name: str) -> PIL.Image:
53
- if model_name == "DINO":
54
- model = model1
55
- elif model_name == "ViT":
56
- model = model2
57
- else:
58
- model = model3
59
-
60
- original_width, original_height = image.size
61
- transform = get_transform(model.mean, model.std)
62
- input_tensor = transform(image).unsqueeze(0)
63
-
64
- with torch.no_grad():
65
- mask = model(input_tensor)
66
- mask = torch.argmax(mask.squeeze(), dim=0).cpu().numpy()
67
-
68
- mask_image = mask_to_color(mask)
69
-
70
- mask_image = PIL.Image.fromarray(mask_image)
71
- mask_aspect_ratio = mask_image.width / mask_image.height
72
-
73
- new_height = original_height
74
- new_width = int(new_height * mask_aspect_ratio)
75
- mask_image = mask_image.resize((new_width, new_height), PIL.Image.Resampling.NEAREST)
76
-
77
- final_mask = PIL.Image.new("RGB", (original_width, original_height))
78
- offset = ((original_width - new_width) // 2, 0)
79
- final_mask.paste(mask_image, offset)
80
-
81
- return final_mask
82
-
83
-
84
- def generate_legend_html_compact() -> str:
85
- legend_html = """
86
- <div style='display: flex; flex-wrap: wrap; gap: 10px; justify-content: center;'>
87
- """
88
- for idx, (label, color) in enumerate(label_colors.items()):
89
- legend_html += f"""
90
- <div style='display: flex; align-items: center; justify-content: center;
91
- padding: 5px 10px; border: 1px solid {color};
92
- background-color: {color}; border-radius: 5px;
93
- color: white; font-size: 12px; text-align: center;'>
94
- {mask_labels[label]}
95
- </div>
96
- """
97
- legend_html += "</div>"
98
- return legend_html
99
-
100
-
101
- examples = [
102
- ["assets/images_examples/image1.jpg"],
103
- ["assets/images_examples/image2.jpg"],
104
- ["assets/images_examples/image3.jpg"]
105
- ]
106
-
107
-
108
- with gr.Blocks() as demo:
109
- gr.Markdown("## Clothes Segmentation")
110
- with gr.Row():
111
- with gr.Column():
112
- pic = gr.Image(label="Upload Human Image", type="pil", height=300, width=300)
113
- model_choice = gr.Dropdown(choices=["DINO", "ViT", "UNet"], label="Select Model", value="DINO")
114
- with gr.Row():
115
- with gr.Column(scale=1):
116
- predict_btn = gr.Button("Predict")
117
- with gr.Column(scale=1):
118
- clear_btn = gr.Button("Clear")
119
-
120
- with gr.Column():
121
- output = gr.Image(label="Mask", type="pil", height=300, width=300)
122
- legend = gr.HTML(label="Legend", value=generate_legend_html_compact())
123
-
124
- predict_btn.click(fn=segment_image, inputs=[pic, model_choice], outputs=output, api_name="predict")
125
- clear_btn.click(lambda: (None, None), outputs=[pic, output])
126
- gr.Examples(examples=examples, inputs=[pic])
127
-
 
 
128
  demo.launch()
 
1
+ from typing import Dict
2
+ import gradio as gr
3
+ import json
4
+ import PIL.Image, PIL.ImageOps
5
+ import torch
6
+ import torchvision.transforms.functional as F
7
+ from matplotlib import cm
8
+ from matplotlib.colors import to_hex
9
+ import numpy as np
10
+
11
+ from src.models.dino import DINOSegmentationModel
12
+ from src.models.vit import ViTSegmentation
13
+ from src.models.unet import UNet
14
+ from src.utils import get_transform
15
+
16
+
17
+ device = torch.device("cpu")
18
+ model_weight1 = "weights/dino.pth"
19
+ model_weight2 = "weights/vit.pth"
20
+ model_weight3 = "weights/unet.pth"
21
+
22
+ model1 = DINOSegmentationModel()
23
+ model1.segmentation_head.load_state_dict(torch.load(model_weight1, map_location=device))
24
+ model1.eval()
25
+ model2 = ViTSegmentation()
26
+ model2.segmentation_head.load_state_dict(torch.load(model_weight2, map_location=device))
27
+ model2.eval()
28
+ model3 = UNet()
29
+ model3.load_state_dict(torch.load(model_weight3, map_location=device))
30
+ model3.eval()
31
+
32
+ mask_labels = {
33
+ "0": "Background", "1": "Hat", "2": "Hair", "3": "Sunglasses", "4": "Upper-clothes",
34
+ "5": "Skirt", "6": "Pants", "7": "Dress", "8": "Belt", "9": "Right-shoe",
35
+ "10": "Left-shoe", "11": "Face", "12": "Right-leg", "13": "Left-leg",
36
+ "14": "Right-arm", "15": "Left-arm", "16": "Bag", "17": "Scarf"
37
+ }
38
+
39
+ color_map = cm.get_cmap('tab20', 18)
40
+ label_colors = {label: to_hex(color_map(idx / len(mask_labels))[:3]) for idx, label in enumerate(mask_labels)}
41
+ fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255
42
+
43
+
44
+ def mask_to_color(mask: np.ndarray) -> np.ndarray:
45
+ h, w = mask.shape
46
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
47
+ for class_idx in range(18):
48
+ color_mask[mask == class_idx] = fixed_colors[class_idx]
49
+ return color_mask
50
+
51
+
52
+ def segment_image(image, model_name: str) -> PIL.Image:
53
+ if model_name == "DINO":
54
+ model = model1
55
+ elif model_name == "ViT":
56
+ model = model2
57
+ else:
58
+ model = model3
59
+
60
+ original_width, original_height = image.size
61
+ transform = get_transform(model.mean, model.std)
62
+ input_tensor = transform(image).unsqueeze(0)
63
+
64
+ with torch.no_grad():
65
+ mask = model(input_tensor)
66
+ mask = torch.argmax(mask.squeeze(), dim=0).cpu().numpy()
67
+
68
+ mask_image = mask_to_color(mask)
69
+
70
+ mask_image = PIL.Image.fromarray(mask_image)
71
+ mask_aspect_ratio = mask_image.width / mask_image.height
72
+
73
+ new_height = original_height
74
+ new_width = int(new_height * mask_aspect_ratio)
75
+ mask_image = mask_image.resize((new_width, new_height), PIL.Image.Resampling.NEAREST)
76
+
77
+ final_mask = PIL.Image.new("RGB", (original_width, original_height))
78
+ offset = ((original_width - new_width) // 2, 0)
79
+ final_mask.paste(mask_image, offset)
80
+
81
+ return final_mask
82
+
83
+
84
+ def generate_legend_html_compact() -> str:
85
+ legend_html = """
86
+ <div style='display: flex; flex-wrap: wrap; gap: 10px; justify-content: center;'>
87
+ """
88
+ for idx, (label, color) in enumerate(label_colors.items()):
89
+ legend_html += f"""
90
+ <div style='display: flex; align-items: center; justify-content: center;
91
+ padding: 5px 10px; border: 1px solid {color};
92
+ background-color: {color}; border-radius: 5px;
93
+ color: white; font-size: 12px; text-align: center;'>
94
+ {mask_labels[label]}
95
+ </div>
96
+ """
97
+ legend_html += "</div>"
98
+ return legend_html
99
+
100
+
101
+ examples = [
102
+ ["assets/images_examples/image1.jpg"],
103
+ ["assets/images_examples/image2.jpg"],
104
+ ["assets/images_examples/image3.jpg"]
105
+ ]
106
+
107
+
108
+ with gr.Blocks() as demo:
109
+ gr.Markdown("## Clothes Segmentation")
110
+ with gr.Row():
111
+ with gr.Column():
112
+ pic = gr.Image(label="Upload Human Image", type="pil", height=300, width=300)
113
+ model_choice = gr.Dropdown(choices=["DINO", "ViT", "UNet"], label="Select Model", value="DINO")
114
+ with gr.Row():
115
+ with gr.Column(scale=1):
116
+ predict_btn = gr.Button("Predict")
117
+ with gr.Column(scale=1):
118
+ clear_btn = gr.Button("Clear")
119
+
120
+ with gr.Column():
121
+ output = gr.Image(label="Mask", type="pil", height=300, width=300)
122
+ legend = gr.HTML(label="Legend", value=generate_legend_html_compact())
123
+
124
+ #predict_btn.click(fn=segment_image, inputs=[pic, model_choice], outputs=output, api_name="predict")
125
+
126
+ predict_btn.click(fn=segment_image, inputs=[pic, model_choice], outputs=output)
127
+ clear_btn.click(lambda: (None, None), outputs=[pic, output])
128
+ gr.Examples(examples=examples, inputs=[pic])
129
+
130
  demo.launch()