PawanratRung commited on
Commit
8bfc38f
·
verified ·
1 Parent(s): 8091dc1

Create __init__.py

Browse files
Files changed (1) hide show
  1. 3rdparty/SCHP/__init__.py +242 -0
3rdparty/SCHP/__init__.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from SCHP import networks
8
+ from SCHP.utils.transforms import get_affine_transform, transform_logits
9
+ from torchvision import transforms
10
+
11
+
12
+ def get_palette(num_cls):
13
+ """Returns the color map for visualizing the segmentation mask.
14
+ Args:
15
+ num_cls: Number of classes
16
+ Returns:
17
+ The color map
18
+ """
19
+ n = num_cls
20
+ palette = [0] * (n * 3)
21
+ for j in range(0, n):
22
+ lab = j
23
+ palette[j * 3 + 0] = 0
24
+ palette[j * 3 + 1] = 0
25
+ palette[j * 3 + 2] = 0
26
+ i = 0
27
+ while lab:
28
+ palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
29
+ palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
30
+ palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
31
+ i += 1
32
+ lab >>= 3
33
+ return palette
34
+
35
+
36
+ dataset_settings = {
37
+ "lip": {
38
+ "input_size": [473, 473],
39
+ "num_classes": 20,
40
+ "label": [
41
+ "Background",
42
+ "Hat",
43
+ "Hair",
44
+ "Glove",
45
+ "Sunglasses",
46
+ "Upper-clothes",
47
+ "Dress",
48
+ "Coat",
49
+ "Socks",
50
+ "Pants",
51
+ "Jumpsuits",
52
+ "Scarf",
53
+ "Skirt",
54
+ "Face",
55
+ "Left-arm",
56
+ "Right-arm",
57
+ "Left-leg",
58
+ "Right-leg",
59
+ "Left-shoe",
60
+ "Right-shoe",
61
+ ],
62
+ },
63
+ "atr": {
64
+ "input_size": [512, 512],
65
+ "num_classes": 18,
66
+ "label": [
67
+ "Background",
68
+ "Hat",
69
+ "Hair",
70
+ "Sunglasses",
71
+ "Upper-clothes",
72
+ "Skirt",
73
+ "Pants",
74
+ "Dress",
75
+ "Belt",
76
+ "Left-shoe",
77
+ "Right-shoe",
78
+ "Face",
79
+ "Left-leg",
80
+ "Right-leg",
81
+ "Left-arm",
82
+ "Right-arm",
83
+ "Bag",
84
+ "Scarf",
85
+ ],
86
+ },
87
+ "pascal": {
88
+ "input_size": [512, 512],
89
+ "num_classes": 7,
90
+ "label": [
91
+ "Background",
92
+ "Head",
93
+ "Torso",
94
+ "Upper Arms",
95
+ "Lower Arms",
96
+ "Upper Legs",
97
+ "Lower Legs",
98
+ ],
99
+ },
100
+ }
101
+
102
+
103
+ class SCHP:
104
+ def __init__(self, ckpt_path, device):
105
+ dataset_type = None
106
+ if "lip" in ckpt_path:
107
+ dataset_type = "lip"
108
+ elif "atr" in ckpt_path:
109
+ dataset_type = "atr"
110
+ elif "pascal" in ckpt_path:
111
+ dataset_type = "pascal"
112
+ assert dataset_type is not None, "Dataset type not found in checkpoint path"
113
+ self.device = device
114
+ self.num_classes = dataset_settings[dataset_type]["num_classes"]
115
+ self.input_size = dataset_settings[dataset_type]["input_size"]
116
+ self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0]
117
+ self.palette = get_palette(self.num_classes)
118
+
119
+ self.label = dataset_settings[dataset_type]["label"]
120
+ self.model = networks.init_model(
121
+ "resnet101", num_classes=self.num_classes, pretrained=None
122
+ ).to(device)
123
+ self.load_ckpt(ckpt_path)
124
+ self.model.eval()
125
+
126
+ self.transform = transforms.Compose(
127
+ [
128
+ transforms.ToTensor(),
129
+ transforms.Normalize(
130
+ mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]
131
+ ),
132
+ ]
133
+ )
134
+ self.upsample = torch.nn.Upsample(
135
+ size=self.input_size, mode="bilinear", align_corners=True
136
+ )
137
+
138
+ def load_ckpt(self, ckpt_path):
139
+ rename_map = {
140
+ "decoder.conv3.2.weight": "decoder.conv3.3.weight",
141
+ "decoder.conv3.3.weight": "decoder.conv3.4.weight",
142
+ "decoder.conv3.3.bias": "decoder.conv3.4.bias",
143
+ "decoder.conv3.3.running_mean": "decoder.conv3.4.running_mean",
144
+ "decoder.conv3.3.running_var": "decoder.conv3.4.running_var",
145
+ "fushion.3.weight": "fushion.4.weight",
146
+ "fushion.3.bias": "fushion.4.bias",
147
+ }
148
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
149
+ new_state_dict = OrderedDict()
150
+ for k, v in state_dict.items():
151
+ name = k[7:] # remove `module.`
152
+ new_state_dict[name] = v
153
+ new_state_dict_ = OrderedDict()
154
+ for k, v in list(new_state_dict.items()):
155
+ if k in rename_map:
156
+ new_state_dict_[rename_map[k]] = v
157
+ else:
158
+ new_state_dict_[k] = v
159
+ self.model.load_state_dict(new_state_dict_, strict=False)
160
+
161
+ def _box2cs(self, box):
162
+ x, y, w, h = box[:4]
163
+ return self._xywh2cs(x, y, w, h)
164
+
165
+ def _xywh2cs(self, x, y, w, h):
166
+ center = np.zeros((2), dtype=np.float32)
167
+ center[0] = x + w * 0.5
168
+ center[1] = y + h * 0.5
169
+ if w > self.aspect_ratio * h:
170
+ h = w * 1.0 / self.aspect_ratio
171
+ elif w < self.aspect_ratio * h:
172
+ w = h * self.aspect_ratio
173
+ scale = np.array([w, h], dtype=np.float32)
174
+ return center, scale
175
+
176
+ def preprocess(self, image):
177
+ if isinstance(image, str):
178
+ img = cv2.imread(image, cv2.IMREAD_COLOR)
179
+ elif isinstance(image, Image.Image):
180
+ # to cv2 format
181
+ img = np.array(image)
182
+
183
+ h, w, _ = img.shape
184
+ # Get person center and scale
185
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
186
+ r = 0
187
+ trans = get_affine_transform(person_center, s, r, self.input_size)
188
+ input = cv2.warpAffine(
189
+ img,
190
+ trans,
191
+ (int(self.input_size[1]), int(self.input_size[0])),
192
+ flags=cv2.INTER_LINEAR,
193
+ borderMode=cv2.BORDER_CONSTANT,
194
+ borderValue=(0, 0, 0),
195
+ )
196
+
197
+ input = self.transform(input).to(self.device).unsqueeze(0)
198
+ meta = {
199
+ "center": person_center,
200
+ "height": h,
201
+ "width": w,
202
+ "scale": s,
203
+ "rotation": r,
204
+ }
205
+ return input, meta
206
+
207
+ def __call__(self, image_or_path):
208
+ if isinstance(image_or_path, list):
209
+ image_list = []
210
+ meta_list = []
211
+ for image in image_or_path:
212
+ image, meta = self.preprocess(image)
213
+ image_list.append(image)
214
+ meta_list.append(meta)
215
+ image = torch.cat(image_list, dim=0)
216
+ else:
217
+ image, meta = self.preprocess(image_or_path)
218
+ meta_list = [meta]
219
+
220
+ output = self.model(image)
221
+ # upsample_outputs = self.upsample(output[0][-1])
222
+ upsample_outputs = self.upsample(output)
223
+ upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
224
+
225
+ output_img_list = []
226
+ for upsample_output, meta in zip(upsample_outputs, meta_list):
227
+ c, s, w, h = meta["center"], meta["scale"], meta["width"], meta["height"]
228
+ logits_result = transform_logits(
229
+ upsample_output.data.cpu().numpy(),
230
+ c,
231
+ s,
232
+ w,
233
+ h,
234
+ input_size=self.input_size,
235
+ )
236
+ parsing_result = np.argmax(logits_result, axis=2)
237
+ output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
238
+ output_img.putpalette(self.palette)
239
+ output_img_list.append(output_img)
240
+
241
+ return output_img_list[0] if len(output_img_list) == 1 else output_img_list
242
+