b3h-young123 commited on
Commit
4d80aec
·
verified ·
1 Parent(s): 5c8f6e7

Add files using upload-large-folder tool

Browse files
Leffa/3rdparty/SCHP/__init__.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from SCHP import networks
8
+ from SCHP.utils.transforms import get_affine_transform, transform_logits
9
+ from torchvision import transforms
10
+
11
+
12
+ def get_palette(num_cls):
13
+ """Returns the color map for visualizing the segmentation mask.
14
+ Args:
15
+ num_cls: Number of classes
16
+ Returns:
17
+ The color map
18
+ """
19
+ n = num_cls
20
+ palette = [0] * (n * 3)
21
+ for j in range(0, n):
22
+ lab = j
23
+ palette[j * 3 + 0] = 0
24
+ palette[j * 3 + 1] = 0
25
+ palette[j * 3 + 2] = 0
26
+ i = 0
27
+ while lab:
28
+ palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
29
+ palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
30
+ palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
31
+ i += 1
32
+ lab >>= 3
33
+ return palette
34
+
35
+
36
+ dataset_settings = {
37
+ "lip": {
38
+ "input_size": [473, 473],
39
+ "num_classes": 20,
40
+ "label": [
41
+ "Background",
42
+ "Hat",
43
+ "Hair",
44
+ "Glove",
45
+ "Sunglasses",
46
+ "Upper-clothes",
47
+ "Dress",
48
+ "Coat",
49
+ "Socks",
50
+ "Pants",
51
+ "Jumpsuits",
52
+ "Scarf",
53
+ "Skirt",
54
+ "Face",
55
+ "Left-arm",
56
+ "Right-arm",
57
+ "Left-leg",
58
+ "Right-leg",
59
+ "Left-shoe",
60
+ "Right-shoe",
61
+ ],
62
+ },
63
+ "atr": {
64
+ "input_size": [512, 512],
65
+ "num_classes": 18,
66
+ "label": [
67
+ "Background",
68
+ "Hat",
69
+ "Hair",
70
+ "Sunglasses",
71
+ "Upper-clothes",
72
+ "Skirt",
73
+ "Pants",
74
+ "Dress",
75
+ "Belt",
76
+ "Left-shoe",
77
+ "Right-shoe",
78
+ "Face",
79
+ "Left-leg",
80
+ "Right-leg",
81
+ "Left-arm",
82
+ "Right-arm",
83
+ "Bag",
84
+ "Scarf",
85
+ ],
86
+ },
87
+ "pascal": {
88
+ "input_size": [512, 512],
89
+ "num_classes": 7,
90
+ "label": [
91
+ "Background",
92
+ "Head",
93
+ "Torso",
94
+ "Upper Arms",
95
+ "Lower Arms",
96
+ "Upper Legs",
97
+ "Lower Legs",
98
+ ],
99
+ },
100
+ }
101
+
102
+
103
+ class SCHP:
104
+ def __init__(self, ckpt_path, device):
105
+ dataset_type = None
106
+ if "lip" in ckpt_path:
107
+ dataset_type = "lip"
108
+ elif "atr" in ckpt_path:
109
+ dataset_type = "atr"
110
+ elif "pascal" in ckpt_path:
111
+ dataset_type = "pascal"
112
+ assert dataset_type is not None, "Dataset type not found in checkpoint path"
113
+ self.device = device
114
+ self.num_classes = dataset_settings[dataset_type]["num_classes"]
115
+ self.input_size = dataset_settings[dataset_type]["input_size"]
116
+ self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0]
117
+ self.palette = get_palette(self.num_classes)
118
+
119
+ self.label = dataset_settings[dataset_type]["label"]
120
+ self.model = networks.init_model(
121
+ "resnet101", num_classes=self.num_classes, pretrained=None
122
+ ).to(device)
123
+ self.load_ckpt(ckpt_path)
124
+ self.model.eval()
125
+
126
+ self.transform = transforms.Compose(
127
+ [
128
+ transforms.ToTensor(),
129
+ transforms.Normalize(
130
+ mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]
131
+ ),
132
+ ]
133
+ )
134
+ self.upsample = torch.nn.Upsample(
135
+ size=self.input_size, mode="bilinear", align_corners=True
136
+ )
137
+
138
+ def load_ckpt(self, ckpt_path):
139
+ rename_map = {
140
+ "decoder.conv3.2.weight": "decoder.conv3.3.weight",
141
+ "decoder.conv3.3.weight": "decoder.conv3.4.weight",
142
+ "decoder.conv3.3.bias": "decoder.conv3.4.bias",
143
+ "decoder.conv3.3.running_mean": "decoder.conv3.4.running_mean",
144
+ "decoder.conv3.3.running_var": "decoder.conv3.4.running_var",
145
+ "fushion.3.weight": "fushion.4.weight",
146
+ "fushion.3.bias": "fushion.4.bias",
147
+ }
148
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
149
+ new_state_dict = OrderedDict()
150
+ for k, v in state_dict.items():
151
+ name = k[7:] # remove `module.`
152
+ new_state_dict[name] = v
153
+ new_state_dict_ = OrderedDict()
154
+ for k, v in list(new_state_dict.items()):
155
+ if k in rename_map:
156
+ new_state_dict_[rename_map[k]] = v
157
+ else:
158
+ new_state_dict_[k] = v
159
+ self.model.load_state_dict(new_state_dict_, strict=False)
160
+
161
+ def _box2cs(self, box):
162
+ x, y, w, h = box[:4]
163
+ return self._xywh2cs(x, y, w, h)
164
+
165
+ def _xywh2cs(self, x, y, w, h):
166
+ center = np.zeros((2), dtype=np.float32)
167
+ center[0] = x + w * 0.5
168
+ center[1] = y + h * 0.5
169
+ if w > self.aspect_ratio * h:
170
+ h = w * 1.0 / self.aspect_ratio
171
+ elif w < self.aspect_ratio * h:
172
+ w = h * self.aspect_ratio
173
+ scale = np.array([w, h], dtype=np.float32)
174
+ return center, scale
175
+
176
+ def preprocess(self, image):
177
+ if isinstance(image, str):
178
+ img = cv2.imread(image, cv2.IMREAD_COLOR)
179
+ elif isinstance(image, Image.Image):
180
+ # to cv2 format
181
+ img = np.array(image)
182
+
183
+ h, w, _ = img.shape
184
+ # Get person center and scale
185
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
186
+ r = 0
187
+ trans = get_affine_transform(person_center, s, r, self.input_size)
188
+ input = cv2.warpAffine(
189
+ img,
190
+ trans,
191
+ (int(self.input_size[1]), int(self.input_size[0])),
192
+ flags=cv2.INTER_LINEAR,
193
+ borderMode=cv2.BORDER_CONSTANT,
194
+ borderValue=(0, 0, 0),
195
+ )
196
+
197
+ input = self.transform(input).to(self.device).unsqueeze(0)
198
+ meta = {
199
+ "center": person_center,
200
+ "height": h,
201
+ "width": w,
202
+ "scale": s,
203
+ "rotation": r,
204
+ }
205
+ return input, meta
206
+
207
+ def __call__(self, image_or_path):
208
+ if isinstance(image_or_path, list):
209
+ image_list = []
210
+ meta_list = []
211
+ for image in image_or_path:
212
+ image, meta = self.preprocess(image)
213
+ image_list.append(image)
214
+ meta_list.append(meta)
215
+ image = torch.cat(image_list, dim=0)
216
+ else:
217
+ image, meta = self.preprocess(image_or_path)
218
+ meta_list = [meta]
219
+
220
+ output = self.model(image)
221
+ # upsample_outputs = self.upsample(output[0][-1])
222
+ upsample_outputs = self.upsample(output)
223
+ upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
224
+
225
+ output_img_list = []
226
+ for upsample_output, meta in zip(upsample_outputs, meta_list):
227
+ c, s, w, h = meta["center"], meta["scale"], meta["width"], meta["height"]
228
+ logits_result = transform_logits(
229
+ upsample_output.data.cpu().numpy(),
230
+ c,
231
+ s,
232
+ w,
233
+ h,
234
+ input_size=self.input_size,
235
+ )
236
+ parsing_result = np.argmax(logits_result, axis=2)
237
+ output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
238
+ output_img.putpalette(self.palette)
239
+ output_img_list.append(output_img)
240
+
241
+ return output_img_list[0] if len(output_img_list) == 1 else output_img_list
Leffa/3rdparty/SCHP/networks/AugmentCE2P.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : AugmentCE2P.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from torch.nn import BatchNorm2d, functional as F, LeakyReLU
18
+
19
+ affine_par = True
20
+ pretrained_settings = {
21
+ "resnet101": {
22
+ "imagenet": {
23
+ "input_space": "BGR",
24
+ "input_size": [3, 224, 224],
25
+ "input_range": [0, 1],
26
+ "mean": [0.406, 0.456, 0.485],
27
+ "std": [0.225, 0.224, 0.229],
28
+ "num_classes": 1000,
29
+ }
30
+ },
31
+ }
32
+
33
+
34
+ def conv3x3(in_planes, out_planes, stride=1):
35
+ "3x3 convolution with padding"
36
+ return nn.Conv2d(
37
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
38
+ )
39
+
40
+
41
+ class Bottleneck(nn.Module):
42
+ expansion = 4
43
+
44
+ def __init__(
45
+ self,
46
+ inplanes,
47
+ planes,
48
+ stride=1,
49
+ dilation=1,
50
+ downsample=None,
51
+ fist_dilation=1,
52
+ multi_grid=1,
53
+ ):
54
+ super(Bottleneck, self).__init__()
55
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
56
+ self.bn1 = BatchNorm2d(planes)
57
+ self.conv2 = nn.Conv2d(
58
+ planes,
59
+ planes,
60
+ kernel_size=3,
61
+ stride=stride,
62
+ padding=dilation * multi_grid,
63
+ dilation=dilation * multi_grid,
64
+ bias=False,
65
+ )
66
+ self.bn2 = BatchNorm2d(planes)
67
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68
+ self.bn3 = BatchNorm2d(planes * 4)
69
+ self.relu = nn.ReLU(inplace=False)
70
+ self.relu_inplace = nn.ReLU(inplace=True)
71
+ self.downsample = downsample
72
+ self.dilation = dilation
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out = out + residual
93
+ out = self.relu_inplace(out)
94
+
95
+ return out
96
+
97
+
98
+ class PSPModule(nn.Module):
99
+ """
100
+ Reference:
101
+ Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
102
+ """
103
+
104
+ def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
105
+ super(PSPModule, self).__init__()
106
+
107
+ self.stages = []
108
+ self.stages = nn.ModuleList(
109
+ [self._make_stage(features, out_features, size) for size in sizes]
110
+ )
111
+ self.bottleneck = nn.Sequential(
112
+ nn.Conv2d(
113
+ features + len(sizes) * out_features,
114
+ out_features,
115
+ kernel_size=3,
116
+ padding=1,
117
+ dilation=1,
118
+ bias=False,
119
+ ),
120
+ BatchNorm2d(out_features),
121
+ LeakyReLU(),
122
+ )
123
+
124
+ def _make_stage(self, features, out_features, size):
125
+ prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
126
+ conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
127
+ return nn.Sequential(
128
+ prior,
129
+ conv,
130
+ # bn
131
+ BatchNorm2d(out_features),
132
+ LeakyReLU(),
133
+ )
134
+
135
+ def forward(self, feats):
136
+ h, w = feats.size(2), feats.size(3)
137
+ priors = [
138
+ F.interpolate(
139
+ input=stage(feats), size=(h, w), mode="bilinear", align_corners=True
140
+ )
141
+ for stage in self.stages
142
+ ] + [feats]
143
+ bottle = self.bottleneck(torch.cat(priors, 1))
144
+ return bottle
145
+
146
+
147
+ class ASPPModule(nn.Module):
148
+ """
149
+ Reference:
150
+ Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
151
+ """
152
+
153
+ def __init__(
154
+ self, features, inner_features=256, out_features=512, dilations=(12, 24, 36)
155
+ ):
156
+ super(ASPPModule, self).__init__()
157
+
158
+ self.conv1 = nn.Sequential(
159
+ nn.AdaptiveAvgPool2d((1, 1)),
160
+ nn.Conv2d(
161
+ features,
162
+ inner_features,
163
+ kernel_size=1,
164
+ padding=0,
165
+ dilation=1,
166
+ bias=False,
167
+ ),
168
+ # InPlaceABNSync(inner_features)
169
+ BatchNorm2d(inner_features),
170
+ LeakyReLU(),
171
+ )
172
+ self.conv2 = nn.Sequential(
173
+ nn.Conv2d(
174
+ features,
175
+ inner_features,
176
+ kernel_size=1,
177
+ padding=0,
178
+ dilation=1,
179
+ bias=False,
180
+ ),
181
+ BatchNorm2d(inner_features),
182
+ LeakyReLU(),
183
+ )
184
+ self.conv3 = nn.Sequential(
185
+ nn.Conv2d(
186
+ features,
187
+ inner_features,
188
+ kernel_size=3,
189
+ padding=dilations[0],
190
+ dilation=dilations[0],
191
+ bias=False,
192
+ ),
193
+ BatchNorm2d(inner_features),
194
+ LeakyReLU(),
195
+ )
196
+ self.conv4 = nn.Sequential(
197
+ nn.Conv2d(
198
+ features,
199
+ inner_features,
200
+ kernel_size=3,
201
+ padding=dilations[1],
202
+ dilation=dilations[1],
203
+ bias=False,
204
+ ),
205
+ BatchNorm2d(inner_features),
206
+ LeakyReLU(),
207
+ )
208
+ self.conv5 = nn.Sequential(
209
+ nn.Conv2d(
210
+ features,
211
+ inner_features,
212
+ kernel_size=3,
213
+ padding=dilations[2],
214
+ dilation=dilations[2],
215
+ bias=False,
216
+ ),
217
+ BatchNorm2d(inner_features),
218
+ LeakyReLU(),
219
+ )
220
+
221
+ self.bottleneck = nn.Sequential(
222
+ nn.Conv2d(
223
+ inner_features * 5,
224
+ out_features,
225
+ kernel_size=1,
226
+ padding=0,
227
+ dilation=1,
228
+ bias=False,
229
+ ),
230
+ BatchNorm2d(inner_features),
231
+ LeakyReLU(),
232
+ nn.Dropout2d(0.1),
233
+ )
234
+
235
+ def forward(self, x):
236
+ _, _, h, w = x.size()
237
+
238
+ feat1 = F.interpolate(
239
+ self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
240
+ )
241
+
242
+ feat2 = self.conv2(x)
243
+ feat3 = self.conv3(x)
244
+ feat4 = self.conv4(x)
245
+ feat5 = self.conv5(x)
246
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
247
+
248
+ bottle = self.bottleneck(out)
249
+ return bottle
250
+
251
+
252
+ class Edge_Module(nn.Module):
253
+ """
254
+ Edge Learning Branch
255
+ """
256
+
257
+ def __init__(self, in_fea=[256, 512, 1024], mid_fea=256, out_fea=2):
258
+ super(Edge_Module, self).__init__()
259
+
260
+ self.conv1 = nn.Sequential(
261
+ nn.Conv2d(
262
+ in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
263
+ ),
264
+ BatchNorm2d(mid_fea),
265
+ LeakyReLU(),
266
+ )
267
+ self.conv2 = nn.Sequential(
268
+ nn.Conv2d(
269
+ in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
270
+ ),
271
+ BatchNorm2d(mid_fea),
272
+ LeakyReLU(),
273
+ )
274
+ self.conv3 = nn.Sequential(
275
+ nn.Conv2d(
276
+ in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
277
+ ),
278
+ BatchNorm2d(mid_fea),
279
+ LeakyReLU(),
280
+ )
281
+ self.conv4 = nn.Conv2d(
282
+ mid_fea, out_fea, kernel_size=3, padding=1, dilation=1, bias=True
283
+ )
284
+ # self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
285
+
286
+ def forward(self, x1, x2, x3):
287
+ _, _, h, w = x1.size()
288
+
289
+ edge1_fea = self.conv1(x1)
290
+ # edge1 = self.conv4(edge1_fea)
291
+ edge2_fea = self.conv2(x2)
292
+ edge2 = self.conv4(edge2_fea)
293
+ edge3_fea = self.conv3(x3)
294
+ edge3 = self.conv4(edge3_fea)
295
+
296
+ edge2_fea = F.interpolate(
297
+ edge2_fea, size=(h, w), mode="bilinear", align_corners=True
298
+ )
299
+ edge3_fea = F.interpolate(
300
+ edge3_fea, size=(h, w), mode="bilinear", align_corners=True
301
+ )
302
+ edge2 = F.interpolate(edge2, size=(h, w), mode="bilinear", align_corners=True)
303
+ edge3 = F.interpolate(edge3, size=(h, w), mode="bilinear", align_corners=True)
304
+
305
+ # edge = torch.cat([edge1, edge2, edge3], dim=1)
306
+ edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
307
+ # edge = self.conv5(edge)
308
+
309
+ # return edge, edge_fea
310
+ return edge_fea
311
+
312
+
313
+ class Decoder_Module(nn.Module):
314
+ """
315
+ Parsing Branch Decoder Module.
316
+ """
317
+
318
+ def __init__(self, num_classes):
319
+ super(Decoder_Module, self).__init__()
320
+ self.conv1 = nn.Sequential(
321
+ nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False),
322
+ BatchNorm2d(256),
323
+ LeakyReLU(),
324
+ )
325
+ self.conv2 = nn.Sequential(
326
+ nn.Conv2d(
327
+ 256, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False
328
+ ),
329
+ BatchNorm2d(48),
330
+ LeakyReLU(),
331
+ )
332
+ self.conv3 = nn.Sequential(
333
+ nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
334
+ BatchNorm2d(256),
335
+ LeakyReLU(),
336
+ nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
337
+ BatchNorm2d(256),
338
+ LeakyReLU(),
339
+ )
340
+
341
+ # self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
342
+
343
+ def forward(self, xt, xl):
344
+ _, _, h, w = xl.size()
345
+ xt = F.interpolate(
346
+ self.conv1(xt), size=(h, w), mode="bilinear", align_corners=True
347
+ )
348
+ xl = self.conv2(xl)
349
+ x = torch.cat([xt, xl], dim=1)
350
+ x = self.conv3(x)
351
+ # seg = self.conv4(x)
352
+ # return seg, x
353
+ return x
354
+
355
+
356
+ class ResNet(nn.Module):
357
+ def __init__(self, block, layers, num_classes):
358
+ self.inplanes = 128
359
+ super(ResNet, self).__init__()
360
+ self.conv1 = conv3x3(3, 64, stride=2)
361
+ self.bn1 = BatchNorm2d(64)
362
+ self.relu1 = nn.ReLU(inplace=False)
363
+ self.conv2 = conv3x3(64, 64)
364
+ self.bn2 = BatchNorm2d(64)
365
+ self.relu2 = nn.ReLU(inplace=False)
366
+ self.conv3 = conv3x3(64, 128)
367
+ self.bn3 = BatchNorm2d(128)
368
+ self.relu3 = nn.ReLU(inplace=False)
369
+
370
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
371
+
372
+ self.layer1 = self._make_layer(block, 64, layers[0])
373
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
374
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
375
+ self.layer4 = self._make_layer(
376
+ block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 1, 1)
377
+ )
378
+
379
+ self.context_encoding = PSPModule(2048, 512)
380
+
381
+ self.edge = Edge_Module()
382
+ self.decoder = Decoder_Module(num_classes)
383
+
384
+ self.fushion = nn.Sequential(
385
+ nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False),
386
+ BatchNorm2d(256),
387
+ LeakyReLU(),
388
+ nn.Dropout2d(0.1),
389
+ nn.Conv2d(
390
+ 256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True
391
+ ),
392
+ )
393
+
394
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
395
+ downsample = None
396
+ if stride != 1 or self.inplanes != planes * block.expansion:
397
+ downsample = nn.Sequential(
398
+ nn.Conv2d(
399
+ self.inplanes,
400
+ planes * block.expansion,
401
+ kernel_size=1,
402
+ stride=stride,
403
+ bias=False,
404
+ ),
405
+ BatchNorm2d(planes * block.expansion, affine=affine_par),
406
+ )
407
+
408
+ layers = []
409
+ generate_multi_grid = lambda index, grids: (
410
+ grids[index % len(grids)] if isinstance(grids, tuple) else 1
411
+ )
412
+ layers.append(
413
+ block(
414
+ self.inplanes,
415
+ planes,
416
+ stride,
417
+ dilation=dilation,
418
+ downsample=downsample,
419
+ multi_grid=generate_multi_grid(0, multi_grid),
420
+ )
421
+ )
422
+ self.inplanes = planes * block.expansion
423
+ for i in range(1, blocks):
424
+ layers.append(
425
+ block(
426
+ self.inplanes,
427
+ planes,
428
+ dilation=dilation,
429
+ multi_grid=generate_multi_grid(i, multi_grid),
430
+ )
431
+ )
432
+
433
+ return nn.Sequential(*layers)
434
+
435
+ def forward(self, x):
436
+ x = self.relu1(self.bn1(self.conv1(x)))
437
+ x = self.relu2(self.bn2(self.conv2(x)))
438
+ x = self.relu3(self.bn3(self.conv3(x)))
439
+ x = self.maxpool(x)
440
+ x2 = self.layer1(x)
441
+ x3 = self.layer2(x2)
442
+ x4 = self.layer3(x3)
443
+ x5 = self.layer4(x4)
444
+ x = self.context_encoding(x5)
445
+ # parsing_result, parsing_fea = self.decoder(x, x2)
446
+ parsing_fea = self.decoder(x, x2)
447
+ # Edge Branch
448
+ # edge_result, edge_fea = self.edge(x2, x3, x4)
449
+ edge_fea = self.edge(x2, x3, x4)
450
+ # Fusion Branch
451
+ x = torch.cat([parsing_fea, edge_fea], dim=1)
452
+ fusion_result = self.fushion(x)
453
+ # return [[parsing_result, fusion_result], [edge_result]]
454
+ return fusion_result
455
+
456
+
457
+ def initialize_pretrained_model(
458
+ model, settings, pretrained="./models/resnet101-imagenet.pth"
459
+ ):
460
+ model.input_space = settings["input_space"]
461
+ model.input_size = settings["input_size"]
462
+ model.input_range = settings["input_range"]
463
+ model.mean = settings["mean"]
464
+ model.std = settings["std"]
465
+
466
+ if pretrained is not None:
467
+ saved_state_dict = torch.load(pretrained)
468
+ new_params = model.state_dict().copy()
469
+ for i in saved_state_dict:
470
+ i_parts = i.split(".")
471
+ if not i_parts[0] == "fc":
472
+ new_params[".".join(i_parts[0:])] = saved_state_dict[i]
473
+ model.load_state_dict(new_params)
474
+
475
+
476
+ def resnet101(num_classes=20, pretrained="./models/resnet101-imagenet.pth"):
477
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
478
+ settings = pretrained_settings["resnet101"]["imagenet"]
479
+ initialize_pretrained_model(model, settings, pretrained)
480
+ return model
Leffa/3rdparty/SCHP/networks/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ from SCHP.networks.AugmentCE2P import resnet101
4
+
5
+ __factory = {
6
+ "resnet101": resnet101,
7
+ }
8
+
9
+
10
+ def init_model(name, *args, **kwargs):
11
+ if name not in __factory.keys():
12
+ raise KeyError("Unknown model arch: {}".format(name))
13
+ return __factory[name](*args, **kwargs)
Leffa/3rdparty/SCHP/utils/transforms.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import cv2
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ class BRG2Tensor_transform(object):
16
+ def __call__(self, pic):
17
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
18
+ if isinstance(img, torch.ByteTensor):
19
+ return img.float()
20
+ else:
21
+ return img
22
+
23
+
24
+ class BGR2RGB_transform(object):
25
+ def __call__(self, tensor):
26
+ return tensor[[2, 1, 0], :, :]
27
+
28
+
29
+ def flip_back(output_flipped, matched_parts):
30
+ """
31
+ ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
32
+ """
33
+ assert (
34
+ output_flipped.ndim == 4
35
+ ), "output_flipped should be [batch_size, num_joints, height, width]"
36
+
37
+ output_flipped = output_flipped[:, :, :, ::-1]
38
+
39
+ for pair in matched_parts:
40
+ tmp = output_flipped[:, pair[0], :, :].copy()
41
+ output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
42
+ output_flipped[:, pair[1], :, :] = tmp
43
+
44
+ return output_flipped
45
+
46
+
47
+ def fliplr_joints(joints, joints_vis, width, matched_parts):
48
+ """
49
+ flip coords
50
+ """
51
+ # Flip horizontal
52
+ joints[:, 0] = width - joints[:, 0] - 1
53
+
54
+ # Change left-right parts
55
+ for pair in matched_parts:
56
+ joints[pair[0], :], joints[pair[1], :] = (
57
+ joints[pair[1], :],
58
+ joints[pair[0], :].copy(),
59
+ )
60
+ joints_vis[pair[0], :], joints_vis[pair[1], :] = (
61
+ joints_vis[pair[1], :],
62
+ joints_vis[pair[0], :].copy(),
63
+ )
64
+
65
+ return joints * joints_vis, joints_vis
66
+
67
+
68
+ def transform_preds(coords, center, scale, input_size):
69
+ target_coords = np.zeros(coords.shape)
70
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
71
+ for p in range(coords.shape[0]):
72
+ target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
73
+ return target_coords
74
+
75
+
76
+ def transform_parsing(pred, center, scale, width, height, input_size):
77
+
78
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
79
+ target_pred = cv2.warpAffine(
80
+ pred,
81
+ trans,
82
+ (int(width), int(height)), # (int(width), int(height)),
83
+ flags=cv2.INTER_NEAREST,
84
+ borderMode=cv2.BORDER_CONSTANT,
85
+ borderValue=(0),
86
+ )
87
+
88
+ return target_pred
89
+
90
+
91
+ def transform_logits(logits, center, scale, width, height, input_size):
92
+
93
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
94
+ channel = logits.shape[2]
95
+ target_logits = []
96
+ for i in range(channel):
97
+ target_logit = cv2.warpAffine(
98
+ logits[:, :, i],
99
+ trans,
100
+ (int(width), int(height)), # (int(width), int(height)),
101
+ flags=cv2.INTER_LINEAR,
102
+ borderMode=cv2.BORDER_CONSTANT,
103
+ borderValue=(0),
104
+ )
105
+ target_logits.append(target_logit)
106
+ target_logits = np.stack(target_logits, axis=2)
107
+
108
+ return target_logits
109
+
110
+
111
+ def get_affine_transform(
112
+ center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0
113
+ ):
114
+ if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
115
+ print(scale)
116
+ scale = np.array([scale, scale])
117
+
118
+ scale_tmp = scale
119
+
120
+ src_w = scale_tmp[0]
121
+ dst_w = output_size[1]
122
+ dst_h = output_size[0]
123
+
124
+ rot_rad = np.pi * rot / 180
125
+ src_dir = get_dir([0, src_w * -0.5], rot_rad)
126
+ dst_dir = np.array([0, (dst_w - 1) * -0.5], np.float32)
127
+
128
+ src = np.zeros((3, 2), dtype=np.float32)
129
+ dst = np.zeros((3, 2), dtype=np.float32)
130
+ src[0, :] = center + scale_tmp * shift
131
+ src[1, :] = center + src_dir + scale_tmp * shift
132
+ dst[0, :] = [(dst_w - 1) * 0.5, (dst_h - 1) * 0.5]
133
+ dst[1, :] = np.array([(dst_w - 1) * 0.5, (dst_h - 1) * 0.5]) + dst_dir
134
+
135
+ src[2:, :] = get_3rd_point(src[0, :], src[1, :])
136
+ dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
137
+
138
+ if inv:
139
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
140
+ else:
141
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
142
+
143
+ return trans
144
+
145
+
146
+ def affine_transform(pt, t):
147
+ new_pt = np.array([pt[0], pt[1], 1.0]).T
148
+ new_pt = np.dot(t, new_pt)
149
+ return new_pt[:2]
150
+
151
+
152
+ def get_3rd_point(a, b):
153
+ direct = a - b
154
+ return b + np.array([-direct[1], direct[0]], dtype=np.float32)
155
+
156
+
157
+ def get_dir(src_point, rot_rad):
158
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
159
+
160
+ src_result = [0, 0]
161
+ src_result[0] = src_point[0] * cs - src_point[1] * sn
162
+ src_result[1] = src_point[0] * sn + src_point[1] * cs
163
+
164
+ return src_result
165
+
166
+
167
+ def crop(img, center, scale, output_size, rot=0):
168
+ trans = get_affine_transform(center, scale, rot, output_size)
169
+
170
+ dst_img = cv2.warpAffine(
171
+ img, trans, (int(output_size[1]), int(output_size[0])), flags=cv2.INTER_LINEAR
172
+ )
173
+
174
+ return dst_img
Leffa/3rdparty/detectron2/data/transforms/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from fvcore.transforms.transform import Transform, TransformList # order them first
3
+ from fvcore.transforms.transform import *
4
+ from .transform import *
5
+ from .augmentation import *
6
+ from .augmentation_impl import *
7
+
8
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
9
+
10
+
11
+ from detectron2.utils.env import fixup_module_metadata
12
+
13
+ fixup_module_metadata(__name__, globals(), __all__)
14
+ del fixup_module_metadata
Leffa/3rdparty/detectron2/export/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ This directory contains code to prepare a detectron2 model for deployment.
3
+ Currently it supports exporting a detectron2 model to TorchScript, ONNX, or (deprecated) Caffe2 format.
4
+
5
+ Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage.
6
+
7
+
8
+ ### Acknowledgements
9
+
10
+ Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion tools.
11
+
12
+ Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who
13
+ help export Detectron2 models to TorchScript.
14
+
15
+ Thanks to ONNX Converter team at Microsoft who help export Detectron2 models to ONNX.
Leffa/3rdparty/detectron2/export/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import warnings
4
+
5
+ from .flatten import TracingAdapter
6
+ from .torchscript import dump_torchscript_IR, scripting_with_instances
7
+
8
+ try:
9
+ from caffe2.proto import caffe2_pb2 as _tmp
10
+ from caffe2.python import core
11
+
12
+ # caffe2 is optional
13
+ except ImportError:
14
+ pass
15
+ else:
16
+ from .api import *
17
+
18
+
19
+ # TODO: Update ONNX Opset version and run tests when a newer PyTorch is supported
20
+ STABLE_ONNX_OPSET_VERSION = 11
21
+
22
+
23
+ def add_export_config(cfg):
24
+ warnings.warn(
25
+ "add_export_config has been deprecated and behaves as no-op function.", DeprecationWarning
26
+ )
27
+ return cfg
28
+
29
+
30
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
Leffa/3rdparty/detectron2/export/api.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import os
5
+ import torch
6
+ from caffe2.proto import caffe2_pb2
7
+ from torch import nn
8
+
9
+ from detectron2.config import CfgNode
10
+ from detectron2.utils.file_io import PathManager
11
+
12
+ from .caffe2_inference import ProtobufDetectionModel
13
+ from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
14
+ from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph
15
+
16
+ __all__ = [
17
+ "Caffe2Model",
18
+ "Caffe2Tracer",
19
+ ]
20
+
21
+
22
+ class Caffe2Tracer:
23
+ """
24
+ Make a detectron2 model traceable with Caffe2 operators.
25
+ This class creates a traceable version of a detectron2 model which:
26
+
27
+ 1. Rewrite parts of the model using ops in Caffe2. Note that some ops do
28
+ not have GPU implementation in Caffe2.
29
+ 2. Remove post-processing and only produce raw layer outputs
30
+
31
+ After making a traceable model, the class provide methods to export such a
32
+ model to different deployment formats.
33
+ Exported graph produced by this class take two input tensors:
34
+
35
+ 1. (1, C, H, W) float "data" which is an image (usually in [0, 255]).
36
+ (H, W) often has to be padded to multiple of 32 (depend on the model
37
+ architecture).
38
+ 2. 1x3 float "im_info", each row of which is (height, width, 1.0).
39
+ Height and width are true image shapes before padding.
40
+
41
+ The class currently only supports models using builtin meta architectures.
42
+ Batch inference is not supported, and contributions are welcome.
43
+ """
44
+
45
+ def __init__(self, cfg: CfgNode, model: nn.Module, inputs):
46
+ """
47
+ Args:
48
+ cfg (CfgNode): a detectron2 config used to construct caffe2-compatible model.
49
+ model (nn.Module): An original pytorch model. Must be among a few official models
50
+ in detectron2 that can be converted to become caffe2-compatible automatically.
51
+ Weights have to be already loaded to this model.
52
+ inputs: sample inputs that the given model takes for inference.
53
+ Will be used to trace the model. For most models, random inputs with
54
+ no detected objects will not work as they lead to wrong traces.
55
+ """
56
+ assert isinstance(cfg, CfgNode), cfg
57
+ assert isinstance(model, torch.nn.Module), type(model)
58
+
59
+ # TODO make it support custom models, by passing in c2 model directly
60
+ C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
61
+ self.traceable_model = C2MetaArch(cfg, copy.deepcopy(model))
62
+ self.inputs = inputs
63
+ self.traceable_inputs = self.traceable_model.get_caffe2_inputs(inputs)
64
+
65
+ def export_caffe2(self):
66
+ """
67
+ Export the model to Caffe2's protobuf format.
68
+ The returned object can be saved with its :meth:`.save_protobuf()` method.
69
+ The result can be loaded and executed using Caffe2 runtime.
70
+
71
+ Returns:
72
+ :class:`Caffe2Model`
73
+ """
74
+ from .caffe2_export import export_caffe2_detection_model
75
+
76
+ predict_net, init_net = export_caffe2_detection_model(
77
+ self.traceable_model, self.traceable_inputs
78
+ )
79
+ return Caffe2Model(predict_net, init_net)
80
+
81
+ def export_onnx(self):
82
+ """
83
+ Export the model to ONNX format.
84
+ Note that the exported model contains custom ops only available in caffe2, therefore it
85
+ cannot be directly executed by other runtime (such as onnxruntime or TensorRT).
86
+ Post-processing or transformation passes may be applied on the model to accommodate
87
+ different runtimes, but we currently do not provide support for them.
88
+
89
+ Returns:
90
+ onnx.ModelProto: an onnx model.
91
+ """
92
+ from .caffe2_export import export_onnx_model as export_onnx_model_impl
93
+
94
+ return export_onnx_model_impl(self.traceable_model, (self.traceable_inputs,))
95
+
96
+ def export_torchscript(self):
97
+ """
98
+ Export the model to a ``torch.jit.TracedModule`` by tracing.
99
+ The returned object can be saved to a file by ``.save()``.
100
+
101
+ Returns:
102
+ torch.jit.TracedModule: a torch TracedModule
103
+ """
104
+ logger = logging.getLogger(__name__)
105
+ logger.info("Tracing the model with torch.jit.trace ...")
106
+ with torch.no_grad():
107
+ return torch.jit.trace(self.traceable_model, (self.traceable_inputs,))
108
+
109
+
110
+ class Caffe2Model(nn.Module):
111
+ """
112
+ A wrapper around the traced model in Caffe2's protobuf format.
113
+ The exported graph has different inputs/outputs from the original Pytorch
114
+ model, as explained in :class:`Caffe2Tracer`. This class wraps around the
115
+ exported graph to simulate the same interface as the original Pytorch model.
116
+ It also provides functions to save/load models in Caffe2's format.'
117
+
118
+ Examples:
119
+ ::
120
+ c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2()
121
+ inputs = [{"image": img_tensor_CHW}]
122
+ outputs = c2_model(inputs)
123
+ orig_outputs = torch_model(inputs)
124
+ """
125
+
126
+ def __init__(self, predict_net, init_net):
127
+ super().__init__()
128
+ self.eval() # always in eval mode
129
+ self._predict_net = predict_net
130
+ self._init_net = init_net
131
+ self._predictor = None
132
+
133
+ __init__.__HIDE_SPHINX_DOC__ = True
134
+
135
+ @property
136
+ def predict_net(self):
137
+ """
138
+ caffe2.core.Net: the underlying caffe2 predict net
139
+ """
140
+ return self._predict_net
141
+
142
+ @property
143
+ def init_net(self):
144
+ """
145
+ caffe2.core.Net: the underlying caffe2 init net
146
+ """
147
+ return self._init_net
148
+
149
+ def save_protobuf(self, output_dir):
150
+ """
151
+ Save the model as caffe2's protobuf format.
152
+ It saves the following files:
153
+
154
+ * "model.pb": definition of the graph. Can be visualized with
155
+ tools like `netron <https://github.com/lutzroeder/netron>`_.
156
+ * "model_init.pb": model parameters
157
+ * "model.pbtxt": human-readable definition of the graph. Not
158
+ needed for deployment.
159
+
160
+ Args:
161
+ output_dir (str): the output directory to save protobuf files.
162
+ """
163
+ logger = logging.getLogger(__name__)
164
+ logger.info("Saving model to {} ...".format(output_dir))
165
+ if not PathManager.exists(output_dir):
166
+ PathManager.mkdirs(output_dir)
167
+
168
+ with PathManager.open(os.path.join(output_dir, "model.pb"), "wb") as f:
169
+ f.write(self._predict_net.SerializeToString())
170
+ with PathManager.open(os.path.join(output_dir, "model.pbtxt"), "w") as f:
171
+ f.write(str(self._predict_net))
172
+ with PathManager.open(os.path.join(output_dir, "model_init.pb"), "wb") as f:
173
+ f.write(self._init_net.SerializeToString())
174
+
175
+ def save_graph(self, output_file, inputs=None):
176
+ """
177
+ Save the graph as SVG format.
178
+
179
+ Args:
180
+ output_file (str): a SVG file
181
+ inputs: optional inputs given to the model.
182
+ If given, the inputs will be used to run the graph to record
183
+ shape of every tensor. The shape information will be
184
+ saved together with the graph.
185
+ """
186
+ from .caffe2_export import run_and_save_graph
187
+
188
+ if inputs is None:
189
+ save_graph(self._predict_net, output_file, op_only=False)
190
+ else:
191
+ size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0)
192
+ device = get_pb_arg_vals(self._predict_net, "device", b"cpu").decode("ascii")
193
+ inputs = convert_batched_inputs_to_c2_format(inputs, size_divisibility, device)
194
+ inputs = [x.cpu().numpy() for x in inputs]
195
+ run_and_save_graph(self._predict_net, self._init_net, inputs, output_file)
196
+
197
+ @staticmethod
198
+ def load_protobuf(dir):
199
+ """
200
+ Args:
201
+ dir (str): a directory used to save Caffe2Model with
202
+ :meth:`save_protobuf`.
203
+ The files "model.pb" and "model_init.pb" are needed.
204
+
205
+ Returns:
206
+ Caffe2Model: the caffe2 model loaded from this directory.
207
+ """
208
+ predict_net = caffe2_pb2.NetDef()
209
+ with PathManager.open(os.path.join(dir, "model.pb"), "rb") as f:
210
+ predict_net.ParseFromString(f.read())
211
+
212
+ init_net = caffe2_pb2.NetDef()
213
+ with PathManager.open(os.path.join(dir, "model_init.pb"), "rb") as f:
214
+ init_net.ParseFromString(f.read())
215
+
216
+ return Caffe2Model(predict_net, init_net)
217
+
218
+ def __call__(self, inputs):
219
+ """
220
+ An interface that wraps around a Caffe2 model and mimics detectron2's models'
221
+ input/output format. See details about the format at :doc:`/tutorials/models`.
222
+ This is used to compare the outputs of caffe2 model with its original torch model.
223
+
224
+ Due to the extra conversion between Pytorch/Caffe2, this method is not meant for
225
+ benchmark. Because of the conversion, this method also has dependency
226
+ on detectron2 in order to convert to detectron2's output format.
227
+ """
228
+ if self._predictor is None:
229
+ self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net)
230
+ return self._predictor(inputs)
Leffa/3rdparty/detectron2/export/c10.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import math
4
+ from typing import Dict
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from detectron2.layers import ShapeSpec, cat
9
+ from detectron2.layers.roi_align_rotated import ROIAlignRotated
10
+ from detectron2.modeling import poolers
11
+ from detectron2.modeling.proposal_generator import rpn
12
+ from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference
13
+ from detectron2.structures import Boxes, ImageList, Instances, Keypoints, RotatedBoxes
14
+
15
+ from .shared import alias, to_device
16
+
17
+
18
+ """
19
+ This file contains caffe2-compatible implementation of several detectron2 components.
20
+ """
21
+
22
+
23
+ class Caffe2Boxes(Boxes):
24
+ """
25
+ Representing a list of detectron2.structures.Boxes from minibatch, each box
26
+ is represented by a 5d vector (batch index + 4 coordinates), or a 6d vector
27
+ (batch index + 5 coordinates) for RotatedBoxes.
28
+ """
29
+
30
+ def __init__(self, tensor):
31
+ assert isinstance(tensor, torch.Tensor)
32
+ assert tensor.dim() == 2 and tensor.size(-1) in [4, 5, 6], tensor.size()
33
+ # TODO: make tensor immutable when dim is Nx5 for Boxes,
34
+ # and Nx6 for RotatedBoxes?
35
+ self.tensor = tensor
36
+
37
+
38
+ # TODO clean up this class, maybe just extend Instances
39
+ class InstancesList:
40
+ """
41
+ Tensor representation of a list of Instances object for a batch of images.
42
+
43
+ When dealing with a batch of images with Caffe2 ops, a list of bboxes
44
+ (instances) are usually represented by single Tensor with size
45
+ (sigma(Ni), 5) or (sigma(Ni), 4) plus a batch split Tensor. This class is
46
+ for providing common functions to convert between these two representations.
47
+ """
48
+
49
+ def __init__(self, im_info, indices, extra_fields=None):
50
+ # [N, 3] -> (H, W, Scale)
51
+ self.im_info = im_info
52
+ # [N,] -> indice of batch to which the instance belongs
53
+ self.indices = indices
54
+ # [N, ...]
55
+ self.batch_extra_fields = extra_fields or {}
56
+
57
+ self.image_size = self.im_info
58
+
59
+ def get_fields(self):
60
+ """like `get_fields` in the Instances object,
61
+ but return each field in tensor representations"""
62
+ ret = {}
63
+ for k, v in self.batch_extra_fields.items():
64
+ # if isinstance(v, torch.Tensor):
65
+ # tensor_rep = v
66
+ # elif isinstance(v, (Boxes, Keypoints)):
67
+ # tensor_rep = v.tensor
68
+ # else:
69
+ # raise ValueError("Can't find tensor representation for: {}".format())
70
+ ret[k] = v
71
+ return ret
72
+
73
+ def has(self, name):
74
+ return name in self.batch_extra_fields
75
+
76
+ def set(self, name, value):
77
+ # len(tensor) is a bad practice that generates ONNX constants during tracing.
78
+ # Although not a problem for the `assert` statement below, torch ONNX exporter
79
+ # still raises a misleading warning as it does not this call comes from `assert`
80
+ if isinstance(value, Boxes):
81
+ data_len = value.tensor.shape[0]
82
+ elif isinstance(value, torch.Tensor):
83
+ data_len = value.shape[0]
84
+ else:
85
+ data_len = len(value)
86
+ if len(self.batch_extra_fields):
87
+ assert (
88
+ len(self) == data_len
89
+ ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
90
+ self.batch_extra_fields[name] = value
91
+
92
+ def __getattr__(self, name):
93
+ if name not in self.batch_extra_fields:
94
+ raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
95
+ return self.batch_extra_fields[name]
96
+
97
+ def __len__(self):
98
+ return len(self.indices)
99
+
100
+ def flatten(self):
101
+ ret = []
102
+ for _, v in self.batch_extra_fields.items():
103
+ if isinstance(v, (Boxes, Keypoints)):
104
+ ret.append(v.tensor)
105
+ else:
106
+ ret.append(v)
107
+ return ret
108
+
109
+ @staticmethod
110
+ def to_d2_instances_list(instances_list):
111
+ """
112
+ Convert InstancesList to List[Instances]. The input `instances_list` can
113
+ also be a List[Instances], in this case this method is a non-op.
114
+ """
115
+ if not isinstance(instances_list, InstancesList):
116
+ assert all(isinstance(x, Instances) for x in instances_list)
117
+ return instances_list
118
+
119
+ ret = []
120
+ for i, info in enumerate(instances_list.im_info):
121
+ instances = Instances(torch.Size([int(info[0].item()), int(info[1].item())]))
122
+
123
+ ids = instances_list.indices == i
124
+ for k, v in instances_list.batch_extra_fields.items():
125
+ if isinstance(v, torch.Tensor):
126
+ instances.set(k, v[ids])
127
+ continue
128
+ elif isinstance(v, Boxes):
129
+ instances.set(k, v[ids, -4:])
130
+ continue
131
+
132
+ target_type, tensor_source = v
133
+ assert isinstance(tensor_source, torch.Tensor)
134
+ assert tensor_source.shape[0] == instances_list.indices.shape[0]
135
+ tensor_source = tensor_source[ids]
136
+
137
+ if issubclass(target_type, Boxes):
138
+ instances.set(k, Boxes(tensor_source[:, -4:]))
139
+ elif issubclass(target_type, Keypoints):
140
+ instances.set(k, Keypoints(tensor_source))
141
+ elif issubclass(target_type, torch.Tensor):
142
+ instances.set(k, tensor_source)
143
+ else:
144
+ raise ValueError("Can't handle targe type: {}".format(target_type))
145
+
146
+ ret.append(instances)
147
+ return ret
148
+
149
+
150
+ class Caffe2Compatible:
151
+ """
152
+ A model can inherit this class to indicate that it can be traced and deployed with caffe2.
153
+ """
154
+
155
+ def _get_tensor_mode(self):
156
+ return self._tensor_mode
157
+
158
+ def _set_tensor_mode(self, v):
159
+ self._tensor_mode = v
160
+
161
+ tensor_mode = property(_get_tensor_mode, _set_tensor_mode)
162
+ """
163
+ If true, the model expects C2-style tensor only inputs/outputs format.
164
+ """
165
+
166
+
167
+ class Caffe2RPN(Caffe2Compatible, rpn.RPN):
168
+ @classmethod
169
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
170
+ ret = super(Caffe2Compatible, cls).from_config(cfg, input_shape)
171
+ assert tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1.0, 1.0, 1.0, 1.0) or tuple(
172
+ cfg.MODEL.RPN.BBOX_REG_WEIGHTS
173
+ ) == (1.0, 1.0, 1.0, 1.0, 1.0)
174
+ return ret
175
+
176
+ def _generate_proposals(
177
+ self, images, objectness_logits_pred, anchor_deltas_pred, gt_instances=None
178
+ ):
179
+ assert isinstance(images, ImageList)
180
+ if self.tensor_mode:
181
+ im_info = images.image_sizes
182
+ else:
183
+ im_info = torch.tensor([[im_sz[0], im_sz[1], 1.0] for im_sz in images.image_sizes]).to(
184
+ images.tensor.device
185
+ )
186
+ assert isinstance(im_info, torch.Tensor)
187
+
188
+ rpn_rois_list = []
189
+ rpn_roi_probs_list = []
190
+ for scores, bbox_deltas, cell_anchors_tensor, feat_stride in zip(
191
+ objectness_logits_pred,
192
+ anchor_deltas_pred,
193
+ [b for (n, b) in self.anchor_generator.cell_anchors.named_buffers()],
194
+ self.anchor_generator.strides,
195
+ ):
196
+ scores = scores.detach()
197
+ bbox_deltas = bbox_deltas.detach()
198
+
199
+ rpn_rois, rpn_roi_probs = torch.ops._caffe2.GenerateProposals(
200
+ scores,
201
+ bbox_deltas,
202
+ im_info,
203
+ cell_anchors_tensor,
204
+ spatial_scale=1.0 / feat_stride,
205
+ pre_nms_topN=self.pre_nms_topk[self.training],
206
+ post_nms_topN=self.post_nms_topk[self.training],
207
+ nms_thresh=self.nms_thresh,
208
+ min_size=self.min_box_size,
209
+ # correct_transform_coords=True, # deprecated argument
210
+ angle_bound_on=True, # Default
211
+ angle_bound_lo=-180,
212
+ angle_bound_hi=180,
213
+ clip_angle_thresh=1.0, # Default
214
+ legacy_plus_one=False,
215
+ )
216
+ rpn_rois_list.append(rpn_rois)
217
+ rpn_roi_probs_list.append(rpn_roi_probs)
218
+
219
+ # For FPN in D2, in RPN all proposals from different levels are concated
220
+ # together, ranked and picked by top post_nms_topk. Then in ROIPooler
221
+ # it calculates level_assignments and calls the RoIAlign from
222
+ # the corresponding level.
223
+
224
+ if len(objectness_logits_pred) == 1:
225
+ rpn_rois = rpn_rois_list[0]
226
+ rpn_roi_probs = rpn_roi_probs_list[0]
227
+ else:
228
+ assert len(rpn_rois_list) == len(rpn_roi_probs_list)
229
+ rpn_post_nms_topN = self.post_nms_topk[self.training]
230
+
231
+ device = rpn_rois_list[0].device
232
+ input_list = [to_device(x, "cpu") for x in (rpn_rois_list + rpn_roi_probs_list)]
233
+
234
+ # TODO remove this after confirming rpn_max_level/rpn_min_level
235
+ # is not needed in CollectRpnProposals.
236
+ feature_strides = list(self.anchor_generator.strides)
237
+ rpn_min_level = int(math.log2(feature_strides[0]))
238
+ rpn_max_level = int(math.log2(feature_strides[-1]))
239
+ assert (rpn_max_level - rpn_min_level + 1) == len(
240
+ rpn_rois_list
241
+ ), "CollectRpnProposals requires continuous levels"
242
+
243
+ rpn_rois = torch.ops._caffe2.CollectRpnProposals(
244
+ input_list,
245
+ # NOTE: in current implementation, rpn_max_level and rpn_min_level
246
+ # are not needed, only the subtraction of two matters and it
247
+ # can be infer from the number of inputs. Keep them now for
248
+ # consistency.
249
+ rpn_max_level=2 + len(rpn_rois_list) - 1,
250
+ rpn_min_level=2,
251
+ rpn_post_nms_topN=rpn_post_nms_topN,
252
+ )
253
+ rpn_rois = to_device(rpn_rois, device)
254
+ rpn_roi_probs = []
255
+
256
+ proposals = self.c2_postprocess(im_info, rpn_rois, rpn_roi_probs, self.tensor_mode)
257
+ return proposals, {}
258
+
259
+ def forward(self, images, features, gt_instances=None):
260
+ assert not self.training
261
+ features = [features[f] for f in self.in_features]
262
+ objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features)
263
+ return self._generate_proposals(
264
+ images,
265
+ objectness_logits_pred,
266
+ anchor_deltas_pred,
267
+ gt_instances,
268
+ )
269
+
270
+ @staticmethod
271
+ def c2_postprocess(im_info, rpn_rois, rpn_roi_probs, tensor_mode):
272
+ proposals = InstancesList(
273
+ im_info=im_info,
274
+ indices=rpn_rois[:, 0],
275
+ extra_fields={
276
+ "proposal_boxes": Caffe2Boxes(rpn_rois),
277
+ "objectness_logits": (torch.Tensor, rpn_roi_probs),
278
+ },
279
+ )
280
+ if not tensor_mode:
281
+ proposals = InstancesList.to_d2_instances_list(proposals)
282
+ else:
283
+ proposals = [proposals]
284
+ return proposals
285
+
286
+
287
+ class Caffe2ROIPooler(Caffe2Compatible, poolers.ROIPooler):
288
+ @staticmethod
289
+ def c2_preprocess(box_lists):
290
+ assert all(isinstance(x, Boxes) for x in box_lists)
291
+ if all(isinstance(x, Caffe2Boxes) for x in box_lists):
292
+ # input is pure-tensor based
293
+ assert len(box_lists) == 1
294
+ pooler_fmt_boxes = box_lists[0].tensor
295
+ else:
296
+ pooler_fmt_boxes = poolers.convert_boxes_to_pooler_format(box_lists)
297
+ return pooler_fmt_boxes
298
+
299
+ def forward(self, x, box_lists):
300
+ assert not self.training
301
+
302
+ pooler_fmt_boxes = self.c2_preprocess(box_lists)
303
+ num_level_assignments = len(self.level_poolers)
304
+
305
+ if num_level_assignments == 1:
306
+ if isinstance(self.level_poolers[0], ROIAlignRotated):
307
+ c2_roi_align = torch.ops._caffe2.RoIAlignRotated
308
+ aligned = True
309
+ else:
310
+ c2_roi_align = torch.ops._caffe2.RoIAlign
311
+ aligned = self.level_poolers[0].aligned
312
+
313
+ x0 = x[0]
314
+ if x0.is_quantized:
315
+ x0 = x0.dequantize()
316
+
317
+ out = c2_roi_align(
318
+ x0,
319
+ pooler_fmt_boxes,
320
+ order="NCHW",
321
+ spatial_scale=float(self.level_poolers[0].spatial_scale),
322
+ pooled_h=int(self.output_size[0]),
323
+ pooled_w=int(self.output_size[1]),
324
+ sampling_ratio=int(self.level_poolers[0].sampling_ratio),
325
+ aligned=aligned,
326
+ )
327
+ return out
328
+
329
+ device = pooler_fmt_boxes.device
330
+ assert (
331
+ self.max_level - self.min_level + 1 == 4
332
+ ), "Currently DistributeFpnProposals only support 4 levels"
333
+ fpn_outputs = torch.ops._caffe2.DistributeFpnProposals(
334
+ to_device(pooler_fmt_boxes, "cpu"),
335
+ roi_canonical_scale=self.canonical_box_size,
336
+ roi_canonical_level=self.canonical_level,
337
+ roi_max_level=self.max_level,
338
+ roi_min_level=self.min_level,
339
+ legacy_plus_one=False,
340
+ )
341
+ fpn_outputs = [to_device(x, device) for x in fpn_outputs]
342
+
343
+ rois_fpn_list = fpn_outputs[:-1]
344
+ rois_idx_restore_int32 = fpn_outputs[-1]
345
+
346
+ roi_feat_fpn_list = []
347
+ for roi_fpn, x_level, pooler in zip(rois_fpn_list, x, self.level_poolers):
348
+ if isinstance(pooler, ROIAlignRotated):
349
+ c2_roi_align = torch.ops._caffe2.RoIAlignRotated
350
+ aligned = True
351
+ else:
352
+ c2_roi_align = torch.ops._caffe2.RoIAlign
353
+ aligned = bool(pooler.aligned)
354
+
355
+ if x_level.is_quantized:
356
+ x_level = x_level.dequantize()
357
+
358
+ roi_feat_fpn = c2_roi_align(
359
+ x_level,
360
+ roi_fpn,
361
+ order="NCHW",
362
+ spatial_scale=float(pooler.spatial_scale),
363
+ pooled_h=int(self.output_size[0]),
364
+ pooled_w=int(self.output_size[1]),
365
+ sampling_ratio=int(pooler.sampling_ratio),
366
+ aligned=aligned,
367
+ )
368
+ roi_feat_fpn_list.append(roi_feat_fpn)
369
+
370
+ roi_feat_shuffled = cat(roi_feat_fpn_list, dim=0)
371
+ assert roi_feat_shuffled.numel() > 0 and rois_idx_restore_int32.numel() > 0, (
372
+ "Caffe2 export requires tracing with a model checkpoint + input that can produce valid"
373
+ " detections. But no detections were obtained with the given checkpoint and input!"
374
+ )
375
+ roi_feat = torch.ops._caffe2.BatchPermutation(roi_feat_shuffled, rois_idx_restore_int32)
376
+ return roi_feat
377
+
378
+
379
+ def caffe2_fast_rcnn_outputs_inference(tensor_mode, box_predictor, predictions, proposals):
380
+ """equivalent to FastRCNNOutputLayers.inference"""
381
+ num_classes = box_predictor.num_classes
382
+ score_thresh = box_predictor.test_score_thresh
383
+ nms_thresh = box_predictor.test_nms_thresh
384
+ topk_per_image = box_predictor.test_topk_per_image
385
+ is_rotated = len(box_predictor.box2box_transform.weights) == 5
386
+
387
+ if is_rotated:
388
+ box_dim = 5
389
+ assert box_predictor.box2box_transform.weights[4] == 1, (
390
+ "The weights for Rotated BBoxTransform in C2 have only 4 dimensions,"
391
+ + " thus enforcing the angle weight to be 1 for now"
392
+ )
393
+ box2box_transform_weights = box_predictor.box2box_transform.weights[:4]
394
+ else:
395
+ box_dim = 4
396
+ box2box_transform_weights = box_predictor.box2box_transform.weights
397
+
398
+ class_logits, box_regression = predictions
399
+ if num_classes + 1 == class_logits.shape[1]:
400
+ class_prob = F.softmax(class_logits, -1)
401
+ else:
402
+ assert num_classes == class_logits.shape[1]
403
+ class_prob = F.sigmoid(class_logits)
404
+ # BoxWithNMSLimit will infer num_classes from the shape of the class_prob
405
+ # So append a zero column as placeholder for the background class
406
+ class_prob = torch.cat((class_prob, torch.zeros(class_prob.shape[0], 1)), dim=1)
407
+
408
+ assert box_regression.shape[1] % box_dim == 0
409
+ cls_agnostic_bbox_reg = box_regression.shape[1] // box_dim == 1
410
+
411
+ input_tensor_mode = proposals[0].proposal_boxes.tensor.shape[1] == box_dim + 1
412
+
413
+ proposal_boxes = proposals[0].proposal_boxes
414
+ if isinstance(proposal_boxes, Caffe2Boxes):
415
+ rois = Caffe2Boxes.cat([p.proposal_boxes for p in proposals])
416
+ elif isinstance(proposal_boxes, RotatedBoxes):
417
+ rois = RotatedBoxes.cat([p.proposal_boxes for p in proposals])
418
+ elif isinstance(proposal_boxes, Boxes):
419
+ rois = Boxes.cat([p.proposal_boxes for p in proposals])
420
+ else:
421
+ raise NotImplementedError(
422
+ 'Expected proposals[0].proposal_boxes to be type "Boxes", '
423
+ f"instead got {type(proposal_boxes)}"
424
+ )
425
+
426
+ device, dtype = rois.tensor.device, rois.tensor.dtype
427
+ if input_tensor_mode:
428
+ im_info = proposals[0].image_size
429
+ rois = rois.tensor
430
+ else:
431
+ im_info = torch.tensor([[sz[0], sz[1], 1.0] for sz in [x.image_size for x in proposals]])
432
+ batch_ids = cat(
433
+ [
434
+ torch.full((b, 1), i, dtype=dtype, device=device)
435
+ for i, b in enumerate(len(p) for p in proposals)
436
+ ],
437
+ dim=0,
438
+ )
439
+ rois = torch.cat([batch_ids, rois.tensor], dim=1)
440
+
441
+ roi_pred_bbox, roi_batch_splits = torch.ops._caffe2.BBoxTransform(
442
+ to_device(rois, "cpu"),
443
+ to_device(box_regression, "cpu"),
444
+ to_device(im_info, "cpu"),
445
+ weights=box2box_transform_weights,
446
+ apply_scale=True,
447
+ rotated=is_rotated,
448
+ angle_bound_on=True,
449
+ angle_bound_lo=-180,
450
+ angle_bound_hi=180,
451
+ clip_angle_thresh=1.0,
452
+ legacy_plus_one=False,
453
+ )
454
+ roi_pred_bbox = to_device(roi_pred_bbox, device)
455
+ roi_batch_splits = to_device(roi_batch_splits, device)
456
+
457
+ nms_outputs = torch.ops._caffe2.BoxWithNMSLimit(
458
+ to_device(class_prob, "cpu"),
459
+ to_device(roi_pred_bbox, "cpu"),
460
+ to_device(roi_batch_splits, "cpu"),
461
+ score_thresh=float(score_thresh),
462
+ nms=float(nms_thresh),
463
+ detections_per_im=int(topk_per_image),
464
+ soft_nms_enabled=False,
465
+ soft_nms_method="linear",
466
+ soft_nms_sigma=0.5,
467
+ soft_nms_min_score_thres=0.001,
468
+ rotated=is_rotated,
469
+ cls_agnostic_bbox_reg=cls_agnostic_bbox_reg,
470
+ input_boxes_include_bg_cls=False,
471
+ output_classes_include_bg_cls=False,
472
+ legacy_plus_one=False,
473
+ )
474
+ roi_score_nms = to_device(nms_outputs[0], device)
475
+ roi_bbox_nms = to_device(nms_outputs[1], device)
476
+ roi_class_nms = to_device(nms_outputs[2], device)
477
+ roi_batch_splits_nms = to_device(nms_outputs[3], device)
478
+ roi_keeps_nms = to_device(nms_outputs[4], device)
479
+ roi_keeps_size_nms = to_device(nms_outputs[5], device)
480
+ if not tensor_mode:
481
+ roi_class_nms = roi_class_nms.to(torch.int64)
482
+
483
+ roi_batch_ids = cat(
484
+ [
485
+ torch.full((b, 1), i, dtype=dtype, device=device)
486
+ for i, b in enumerate(int(x.item()) for x in roi_batch_splits_nms)
487
+ ],
488
+ dim=0,
489
+ )
490
+
491
+ roi_class_nms = alias(roi_class_nms, "class_nms")
492
+ roi_score_nms = alias(roi_score_nms, "score_nms")
493
+ roi_bbox_nms = alias(roi_bbox_nms, "bbox_nms")
494
+ roi_batch_splits_nms = alias(roi_batch_splits_nms, "batch_splits_nms")
495
+ roi_keeps_nms = alias(roi_keeps_nms, "keeps_nms")
496
+ roi_keeps_size_nms = alias(roi_keeps_size_nms, "keeps_size_nms")
497
+
498
+ results = InstancesList(
499
+ im_info=im_info,
500
+ indices=roi_batch_ids[:, 0],
501
+ extra_fields={
502
+ "pred_boxes": Caffe2Boxes(roi_bbox_nms),
503
+ "scores": roi_score_nms,
504
+ "pred_classes": roi_class_nms,
505
+ },
506
+ )
507
+
508
+ if not tensor_mode:
509
+ results = InstancesList.to_d2_instances_list(results)
510
+ batch_splits = roi_batch_splits_nms.int().tolist()
511
+ kept_indices = list(roi_keeps_nms.to(torch.int64).split(batch_splits))
512
+ else:
513
+ results = [results]
514
+ kept_indices = [roi_keeps_nms]
515
+
516
+ return results, kept_indices
517
+
518
+
519
+ class Caffe2FastRCNNOutputsInference:
520
+ def __init__(self, tensor_mode):
521
+ self.tensor_mode = tensor_mode # whether the output is caffe2 tensor mode
522
+
523
+ def __call__(self, box_predictor, predictions, proposals):
524
+ return caffe2_fast_rcnn_outputs_inference(
525
+ self.tensor_mode, box_predictor, predictions, proposals
526
+ )
527
+
528
+
529
+ def caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances):
530
+ """equivalent to mask_head.mask_rcnn_inference"""
531
+ if all(isinstance(x, InstancesList) for x in pred_instances):
532
+ assert len(pred_instances) == 1
533
+ mask_probs_pred = pred_mask_logits.sigmoid()
534
+ mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs")
535
+ pred_instances[0].set("pred_masks", mask_probs_pred)
536
+ else:
537
+ mask_rcnn_inference(pred_mask_logits, pred_instances)
538
+
539
+
540
+ class Caffe2MaskRCNNInference:
541
+ def __call__(self, pred_mask_logits, pred_instances):
542
+ return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances)
543
+
544
+
545
+ def caffe2_keypoint_rcnn_inference(use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances):
546
+ # just return the keypoint heatmap for now,
547
+ # there will be option to call HeatmapMaxKeypointOp
548
+ output = alias(pred_keypoint_logits, "kps_score")
549
+ if all(isinstance(x, InstancesList) for x in pred_instances):
550
+ assert len(pred_instances) == 1
551
+ if use_heatmap_max_keypoint:
552
+ device = output.device
553
+ output = torch.ops._caffe2.HeatmapMaxKeypoint(
554
+ to_device(output, "cpu"),
555
+ pred_instances[0].pred_boxes.tensor,
556
+ should_output_softmax=True, # worth make it configerable?
557
+ )
558
+ output = to_device(output, device)
559
+ output = alias(output, "keypoints_out")
560
+ pred_instances[0].set("pred_keypoints", output)
561
+ return pred_keypoint_logits
562
+
563
+
564
+ class Caffe2KeypointRCNNInference:
565
+ def __init__(self, use_heatmap_max_keypoint):
566
+ self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
567
+
568
+ def __call__(self, pred_keypoint_logits, pred_instances):
569
+ return caffe2_keypoint_rcnn_inference(
570
+ self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances
571
+ )
Leffa/3rdparty/detectron2/export/caffe2_export.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import copy
4
+ import io
5
+ import logging
6
+ import numpy as np
7
+ from typing import List
8
+ import onnx
9
+ import onnx.optimizer
10
+ import torch
11
+ from caffe2.proto import caffe2_pb2
12
+ from caffe2.python import core
13
+ from caffe2.python.onnx.backend import Caffe2Backend
14
+ from tabulate import tabulate
15
+ from termcolor import colored
16
+ from torch.onnx import OperatorExportTypes
17
+
18
+ from .shared import (
19
+ ScopedWS,
20
+ construct_init_net_from_params,
21
+ fuse_alias_placeholder,
22
+ fuse_copy_between_cpu_and_gpu,
23
+ get_params_from_init_net,
24
+ group_norm_replace_aten_with_caffe2,
25
+ infer_device_type,
26
+ remove_dead_end_ops,
27
+ remove_reshape_for_fc,
28
+ save_graph,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def export_onnx_model(model, inputs):
35
+ """
36
+ Trace and export a model to onnx format.
37
+
38
+ Args:
39
+ model (nn.Module):
40
+ inputs (tuple[args]): the model will be called by `model(*inputs)`
41
+
42
+ Returns:
43
+ an onnx model
44
+ """
45
+ assert isinstance(model, torch.nn.Module)
46
+
47
+ # make sure all modules are in eval mode, onnx may change the training state
48
+ # of the module if the states are not consistent
49
+ def _check_eval(module):
50
+ assert not module.training
51
+
52
+ model.apply(_check_eval)
53
+
54
+ # Export the model to ONNX
55
+ with torch.no_grad():
56
+ with io.BytesIO() as f:
57
+ torch.onnx.export(
58
+ model,
59
+ inputs,
60
+ f,
61
+ operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
62
+ # verbose=True, # NOTE: uncomment this for debugging
63
+ # export_params=True,
64
+ )
65
+ onnx_model = onnx.load_from_string(f.getvalue())
66
+
67
+ return onnx_model
68
+
69
+
70
+ def _op_stats(net_def):
71
+ type_count = {}
72
+ for t in [op.type for op in net_def.op]:
73
+ type_count[t] = type_count.get(t, 0) + 1
74
+ type_count_list = sorted(type_count.items(), key=lambda kv: kv[0]) # alphabet
75
+ type_count_list = sorted(type_count_list, key=lambda kv: -kv[1]) # count
76
+ return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list)
77
+
78
+
79
+ def _assign_device_option(
80
+ predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor]
81
+ ):
82
+ """
83
+ ONNX exported network doesn't have concept of device, assign necessary
84
+ device option for each op in order to make it runable on GPU runtime.
85
+ """
86
+
87
+ def _get_device_type(torch_tensor):
88
+ assert torch_tensor.device.type in ["cpu", "cuda"]
89
+ assert torch_tensor.device.index == 0
90
+ return torch_tensor.device.type
91
+
92
+ def _assign_op_device_option(net_proto, net_ssa, blob_device_types):
93
+ for op, ssa_i in zip(net_proto.op, net_ssa):
94
+ if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]:
95
+ op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
96
+ else:
97
+ devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]]
98
+ assert all(d == devices[0] for d in devices)
99
+ if devices[0] == "cuda":
100
+ op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
101
+
102
+ # update ops in predict_net
103
+ predict_net_input_device_types = {
104
+ (name, 0): _get_device_type(tensor)
105
+ for name, tensor in zip(predict_net.external_input, tensor_inputs)
106
+ }
107
+ predict_net_device_types = infer_device_type(
108
+ predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch"
109
+ )
110
+ predict_net_ssa, _ = core.get_ssa(predict_net)
111
+ _assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types)
112
+
113
+ # update ops in init_net
114
+ init_net_ssa, versions = core.get_ssa(init_net)
115
+ init_net_output_device_types = {
116
+ (name, versions[name]): predict_net_device_types[(name, 0)]
117
+ for name in init_net.external_output
118
+ }
119
+ init_net_device_types = infer_device_type(
120
+ init_net, known_status=init_net_output_device_types, device_name_style="pytorch"
121
+ )
122
+ _assign_op_device_option(init_net, init_net_ssa, init_net_device_types)
123
+
124
+
125
+ def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]):
126
+ """
127
+ Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX.
128
+
129
+ Arg:
130
+ model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py
131
+ tensor_inputs: a list of tensors that caffe2 model takes as input.
132
+ """
133
+ model = copy.deepcopy(model)
134
+ assert isinstance(model, torch.nn.Module)
135
+ assert hasattr(model, "encode_additional_info")
136
+
137
+ # Export via ONNX
138
+ logger.info(
139
+ "Exporting a {} model via ONNX ...".format(type(model).__name__)
140
+ + " Some warnings from ONNX are expected and are usually not to worry about."
141
+ )
142
+ onnx_model = export_onnx_model(model, (tensor_inputs,))
143
+ # Convert ONNX model to Caffe2 protobuf
144
+ init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
145
+ ops_table = [[op.type, op.input, op.output] for op in predict_net.op]
146
+ table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe")
147
+ logger.info(
148
+ "ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan")
149
+ )
150
+
151
+ # Apply protobuf optimization
152
+ fuse_alias_placeholder(predict_net, init_net)
153
+ if any(t.device.type != "cpu" for t in tensor_inputs):
154
+ fuse_copy_between_cpu_and_gpu(predict_net)
155
+ remove_dead_end_ops(init_net)
156
+ _assign_device_option(predict_net, init_net, tensor_inputs)
157
+ params, device_options = get_params_from_init_net(init_net)
158
+ predict_net, params = remove_reshape_for_fc(predict_net, params)
159
+ init_net = construct_init_net_from_params(params, device_options)
160
+ group_norm_replace_aten_with_caffe2(predict_net)
161
+
162
+ # Record necessary information for running the pb model in Detectron2 system.
163
+ model.encode_additional_info(predict_net, init_net)
164
+
165
+ logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net)))
166
+ logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net)))
167
+
168
+ return predict_net, init_net
169
+
170
+
171
+ def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path):
172
+ """
173
+ Run the caffe2 model on given inputs, recording the shape and draw the graph.
174
+
175
+ predict_net/init_net: caffe2 model.
176
+ tensor_inputs: a list of tensors that caffe2 model takes as input.
177
+ graph_save_path: path for saving graph of exported model.
178
+ """
179
+
180
+ logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
181
+ save_graph(predict_net, graph_save_path, op_only=False)
182
+
183
+ # Run the exported Caffe2 net
184
+ logger.info("Running ONNX exported model ...")
185
+ with ScopedWS("__ws_tmp__", True) as ws:
186
+ ws.RunNetOnce(init_net)
187
+ initialized_blobs = set(ws.Blobs())
188
+ uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs]
189
+ for name, blob in zip(uninitialized, tensor_inputs):
190
+ ws.FeedBlob(name, blob)
191
+
192
+ try:
193
+ ws.RunNetOnce(predict_net)
194
+ except RuntimeError as e:
195
+ logger.warning("Encountered RuntimeError: \n{}".format(str(e)))
196
+
197
+ ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()}
198
+ blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)}
199
+
200
+ logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path))
201
+ save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes)
202
+
203
+ return ws_blobs
Leffa/3rdparty/detectron2/export/caffe2_inference.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import logging
4
+ import numpy as np
5
+ from itertools import count
6
+ import torch
7
+ from caffe2.proto import caffe2_pb2
8
+ from caffe2.python import core
9
+
10
+ from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
11
+ from .shared import ScopedWS, get_pb_arg_vali, get_pb_arg_vals, infer_device_type
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ # ===== ref: mobile-vision predictor's 'Caffe2Wrapper' class ======
17
+ class ProtobufModel(torch.nn.Module):
18
+ """
19
+ Wrapper of a caffe2's protobuf model.
20
+ It works just like nn.Module, but running caffe2 under the hood.
21
+ Input/Output are tuple[tensor] that match the caffe2 net's external_input/output.
22
+ """
23
+
24
+ _ids = count(0)
25
+
26
+ def __init__(self, predict_net, init_net):
27
+ logger.info(f"Initializing ProtobufModel for: {predict_net.name} ...")
28
+ super().__init__()
29
+ assert isinstance(predict_net, caffe2_pb2.NetDef)
30
+ assert isinstance(init_net, caffe2_pb2.NetDef)
31
+ # create unique temporary workspace for each instance
32
+ self.ws_name = "__tmp_ProtobufModel_{}__".format(next(self._ids))
33
+ self.net = core.Net(predict_net)
34
+
35
+ logger.info("Running init_net once to fill the parameters ...")
36
+ with ScopedWS(self.ws_name, is_reset=True, is_cleanup=False) as ws:
37
+ ws.RunNetOnce(init_net)
38
+ uninitialized_external_input = []
39
+ for blob in self.net.Proto().external_input:
40
+ if blob not in ws.Blobs():
41
+ uninitialized_external_input.append(blob)
42
+ ws.CreateBlob(blob)
43
+ ws.CreateNet(self.net)
44
+
45
+ self._error_msgs = set()
46
+ self._input_blobs = uninitialized_external_input
47
+
48
+ def _infer_output_devices(self, inputs):
49
+ """
50
+ Returns:
51
+ list[str]: list of device for each external output
52
+ """
53
+
54
+ def _get_device_type(torch_tensor):
55
+ assert torch_tensor.device.type in ["cpu", "cuda"]
56
+ assert torch_tensor.device.index == 0
57
+ return torch_tensor.device.type
58
+
59
+ predict_net = self.net.Proto()
60
+ input_device_types = {
61
+ (name, 0): _get_device_type(tensor) for name, tensor in zip(self._input_blobs, inputs)
62
+ }
63
+ device_type_map = infer_device_type(
64
+ predict_net, known_status=input_device_types, device_name_style="pytorch"
65
+ )
66
+ ssa, versions = core.get_ssa(predict_net)
67
+ versioned_outputs = [(name, versions[name]) for name in predict_net.external_output]
68
+ output_devices = [device_type_map[outp] for outp in versioned_outputs]
69
+ return output_devices
70
+
71
+ def forward(self, inputs):
72
+ """
73
+ Args:
74
+ inputs (tuple[torch.Tensor])
75
+
76
+ Returns:
77
+ tuple[torch.Tensor]
78
+ """
79
+ assert len(inputs) == len(self._input_blobs), (
80
+ f"Length of inputs ({len(inputs)}) "
81
+ f"doesn't match the required input blobs: {self._input_blobs}"
82
+ )
83
+
84
+ with ScopedWS(self.ws_name, is_reset=False, is_cleanup=False) as ws:
85
+ for b, tensor in zip(self._input_blobs, inputs):
86
+ ws.FeedBlob(b, tensor)
87
+
88
+ try:
89
+ ws.RunNet(self.net.Proto().name)
90
+ except RuntimeError as e:
91
+ if not str(e) in self._error_msgs:
92
+ self._error_msgs.add(str(e))
93
+ logger.warning("Encountered new RuntimeError: \n{}".format(str(e)))
94
+ logger.warning("Catch the error and use partial results.")
95
+
96
+ c2_outputs = [ws.FetchBlob(b) for b in self.net.Proto().external_output]
97
+ # Remove outputs of current run, this is necessary in order to
98
+ # prevent fetching the result from previous run if the model fails
99
+ # in the middle.
100
+ for b in self.net.Proto().external_output:
101
+ # Needs to create uninitialized blob to make the net runable.
102
+ # This is "equivalent" to: ws.RemoveBlob(b) then ws.CreateBlob(b),
103
+ # but there'no such API.
104
+ ws.FeedBlob(b, f"{b}, a C++ native class of type nullptr (uninitialized).")
105
+
106
+ # Cast output to torch.Tensor on the desired device
107
+ output_devices = (
108
+ self._infer_output_devices(inputs)
109
+ if any(t.device.type != "cpu" for t in inputs)
110
+ else ["cpu" for _ in self.net.Proto().external_output]
111
+ )
112
+
113
+ outputs = []
114
+ for name, c2_output, device in zip(
115
+ self.net.Proto().external_output, c2_outputs, output_devices
116
+ ):
117
+ if not isinstance(c2_output, np.ndarray):
118
+ raise RuntimeError(
119
+ "Invalid output for blob {}, received: {}".format(name, c2_output)
120
+ )
121
+ outputs.append(torch.tensor(c2_output).to(device=device))
122
+ return tuple(outputs)
123
+
124
+
125
+ class ProtobufDetectionModel(torch.nn.Module):
126
+ """
127
+ A class works just like a pytorch meta arch in terms of inference, but running
128
+ caffe2 model under the hood.
129
+ """
130
+
131
+ def __init__(self, predict_net, init_net, *, convert_outputs=None):
132
+ """
133
+ Args:
134
+ predict_net, init_net (core.Net): caffe2 nets
135
+ convert_outptus (callable): a function that converts caffe2
136
+ outputs to the same format of the original pytorch model.
137
+ By default, use the one defined in the caffe2 meta_arch.
138
+ """
139
+ super().__init__()
140
+ self.protobuf_model = ProtobufModel(predict_net, init_net)
141
+ self.size_divisibility = get_pb_arg_vali(predict_net, "size_divisibility", 0)
142
+ self.device = get_pb_arg_vals(predict_net, "device", b"cpu").decode("ascii")
143
+
144
+ if convert_outputs is None:
145
+ meta_arch = get_pb_arg_vals(predict_net, "meta_architecture", b"GeneralizedRCNN")
146
+ meta_arch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[meta_arch.decode("ascii")]
147
+ self._convert_outputs = meta_arch.get_outputs_converter(predict_net, init_net)
148
+ else:
149
+ self._convert_outputs = convert_outputs
150
+
151
+ def _convert_inputs(self, batched_inputs):
152
+ # currently all models convert inputs in the same way
153
+ return convert_batched_inputs_to_c2_format(
154
+ batched_inputs, self.size_divisibility, self.device
155
+ )
156
+
157
+ def forward(self, batched_inputs):
158
+ c2_inputs = self._convert_inputs(batched_inputs)
159
+ c2_results = self.protobuf_model(c2_inputs)
160
+ c2_results = dict(zip(self.protobuf_model.net.Proto().external_output, c2_results))
161
+ return self._convert_outputs(batched_inputs, c2_inputs, c2_results)
Leffa/3rdparty/detectron2/export/caffe2_modeling.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import functools
4
+ import io
5
+ import struct
6
+ import types
7
+ import torch
8
+
9
+ from detectron2.modeling import meta_arch
10
+ from detectron2.modeling.box_regression import Box2BoxTransform
11
+ from detectron2.modeling.roi_heads import keypoint_head
12
+ from detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes
13
+
14
+ from .c10 import Caffe2Compatible
15
+ from .caffe2_patch import ROIHeadsPatcher, patch_generalized_rcnn
16
+ from .shared import (
17
+ alias,
18
+ check_set_pb_arg,
19
+ get_pb_arg_floats,
20
+ get_pb_arg_valf,
21
+ get_pb_arg_vali,
22
+ get_pb_arg_vals,
23
+ mock_torch_nn_functional_interpolate,
24
+ )
25
+
26
+
27
+ def assemble_rcnn_outputs_by_name(image_sizes, tensor_outputs, force_mask_on=False):
28
+ """
29
+ A function to assemble caffe2 model's outputs (i.e. Dict[str, Tensor])
30
+ to detectron2's format (i.e. list of Instances instance).
31
+ This only works when the model follows the Caffe2 detectron's naming convention.
32
+
33
+ Args:
34
+ image_sizes (List[List[int, int]]): [H, W] of every image.
35
+ tensor_outputs (Dict[str, Tensor]): external_output to its tensor.
36
+
37
+ force_mask_on (Bool): if true, the it make sure there'll be pred_masks even
38
+ if the mask is not found from tensor_outputs (usually due to model crash)
39
+ """
40
+
41
+ results = [Instances(image_size) for image_size in image_sizes]
42
+
43
+ batch_splits = tensor_outputs.get("batch_splits", None)
44
+ if batch_splits:
45
+ raise NotImplementedError()
46
+ assert len(image_sizes) == 1
47
+ result = results[0]
48
+
49
+ bbox_nms = tensor_outputs["bbox_nms"]
50
+ score_nms = tensor_outputs["score_nms"]
51
+ class_nms = tensor_outputs["class_nms"]
52
+ # Detection will always success because Conv support 0-batch
53
+ assert bbox_nms is not None
54
+ assert score_nms is not None
55
+ assert class_nms is not None
56
+ if bbox_nms.shape[1] == 5:
57
+ result.pred_boxes = RotatedBoxes(bbox_nms)
58
+ else:
59
+ result.pred_boxes = Boxes(bbox_nms)
60
+ result.scores = score_nms
61
+ result.pred_classes = class_nms.to(torch.int64)
62
+
63
+ mask_fcn_probs = tensor_outputs.get("mask_fcn_probs", None)
64
+ if mask_fcn_probs is not None:
65
+ # finish the mask pred
66
+ mask_probs_pred = mask_fcn_probs
67
+ num_masks = mask_probs_pred.shape[0]
68
+ class_pred = result.pred_classes
69
+ indices = torch.arange(num_masks, device=class_pred.device)
70
+ mask_probs_pred = mask_probs_pred[indices, class_pred][:, None]
71
+ result.pred_masks = mask_probs_pred
72
+ elif force_mask_on:
73
+ # NOTE: there's no way to know the height/width of mask here, it won't be
74
+ # used anyway when batch size is 0, so just set them to 0.
75
+ result.pred_masks = torch.zeros([0, 1, 0, 0], dtype=torch.uint8)
76
+
77
+ keypoints_out = tensor_outputs.get("keypoints_out", None)
78
+ kps_score = tensor_outputs.get("kps_score", None)
79
+ if keypoints_out is not None:
80
+ # keypoints_out: [N, 4, #kypoints], where 4 is in order of (x, y, score, prob)
81
+ keypoints_tensor = keypoints_out
82
+ # NOTE: it's possible that prob is not calculated if "should_output_softmax"
83
+ # is set to False in HeatmapMaxKeypoint, so just using raw score, seems
84
+ # it doesn't affect mAP. TODO: check more carefully.
85
+ keypoint_xyp = keypoints_tensor.transpose(1, 2)[:, :, [0, 1, 2]]
86
+ result.pred_keypoints = keypoint_xyp
87
+ elif kps_score is not None:
88
+ # keypoint heatmap to sparse data structure
89
+ pred_keypoint_logits = kps_score
90
+ keypoint_head.keypoint_rcnn_inference(pred_keypoint_logits, [result])
91
+
92
+ return results
93
+
94
+
95
+ def _cast_to_f32(f64):
96
+ return struct.unpack("f", struct.pack("f", f64))[0]
97
+
98
+
99
+ def set_caffe2_compatible_tensor_mode(model, enable=True):
100
+ def _fn(m):
101
+ if isinstance(m, Caffe2Compatible):
102
+ m.tensor_mode = enable
103
+
104
+ model.apply(_fn)
105
+
106
+
107
+ def convert_batched_inputs_to_c2_format(batched_inputs, size_divisibility, device):
108
+ """
109
+ See get_caffe2_inputs() below.
110
+ """
111
+ assert all(isinstance(x, dict) for x in batched_inputs)
112
+ assert all(x["image"].dim() == 3 for x in batched_inputs)
113
+
114
+ images = [x["image"] for x in batched_inputs]
115
+ images = ImageList.from_tensors(images, size_divisibility)
116
+
117
+ im_info = []
118
+ for input_per_image, image_size in zip(batched_inputs, images.image_sizes):
119
+ target_height = input_per_image.get("height", image_size[0])
120
+ target_width = input_per_image.get("width", image_size[1]) # noqa
121
+ # NOTE: The scale inside im_info is kept as convention and for providing
122
+ # post-processing information if further processing is needed. For
123
+ # current Caffe2 model definitions that don't include post-processing inside
124
+ # the model, this number is not used.
125
+ # NOTE: There can be a slight difference between width and height
126
+ # scales, using a single number can results in numerical difference
127
+ # compared with D2's post-processing.
128
+ scale = target_height / image_size[0]
129
+ im_info.append([image_size[0], image_size[1], scale])
130
+ im_info = torch.Tensor(im_info)
131
+
132
+ return images.tensor.to(device), im_info.to(device)
133
+
134
+
135
+ class Caffe2MetaArch(Caffe2Compatible, torch.nn.Module):
136
+ """
137
+ Base class for caffe2-compatible implementation of a meta architecture.
138
+ The forward is traceable and its traced graph can be converted to caffe2
139
+ graph through ONNX.
140
+ """
141
+
142
+ def __init__(self, cfg, torch_model, enable_tensor_mode=True):
143
+ """
144
+ Args:
145
+ cfg (CfgNode):
146
+ torch_model (nn.Module): the detectron2 model (meta_arch) to be
147
+ converted.
148
+ """
149
+ super().__init__()
150
+ self._wrapped_model = torch_model
151
+ self.eval()
152
+ set_caffe2_compatible_tensor_mode(self, enable_tensor_mode)
153
+
154
+ def get_caffe2_inputs(self, batched_inputs):
155
+ """
156
+ Convert pytorch-style structured inputs to caffe2-style inputs that
157
+ are tuples of tensors.
158
+
159
+ Args:
160
+ batched_inputs (list[dict]): inputs to a detectron2 model
161
+ in its standard format. Each dict has "image" (CHW tensor), and optionally
162
+ "height" and "width".
163
+
164
+ Returns:
165
+ tuple[Tensor]:
166
+ tuple of tensors that will be the inputs to the
167
+ :meth:`forward` method. For existing models, the first
168
+ is an NCHW tensor (padded and batched); the second is
169
+ a im_info Nx3 tensor, where the rows are
170
+ (height, width, unused legacy parameter)
171
+ """
172
+ return convert_batched_inputs_to_c2_format(
173
+ batched_inputs,
174
+ self._wrapped_model.backbone.size_divisibility,
175
+ self._wrapped_model.device,
176
+ )
177
+
178
+ def encode_additional_info(self, predict_net, init_net):
179
+ """
180
+ Save extra metadata that will be used by inference in the output protobuf.
181
+ """
182
+ pass
183
+
184
+ def forward(self, inputs):
185
+ """
186
+ Run the forward in caffe2-style. It has to use caffe2-compatible ops
187
+ and the method will be used for tracing.
188
+
189
+ Args:
190
+ inputs (tuple[Tensor]): inputs defined by :meth:`get_caffe2_input`.
191
+ They will be the inputs of the converted caffe2 graph.
192
+
193
+ Returns:
194
+ tuple[Tensor]: output tensors. They will be the outputs of the
195
+ converted caffe2 graph.
196
+ """
197
+ raise NotImplementedError
198
+
199
+ def _caffe2_preprocess_image(self, inputs):
200
+ """
201
+ Caffe2 implementation of preprocess_image, which is called inside each MetaArch's forward.
202
+ It normalizes the input images, and the final caffe2 graph assumes the
203
+ inputs have been batched already.
204
+ """
205
+ data, im_info = inputs
206
+ data = alias(data, "data")
207
+ im_info = alias(im_info, "im_info")
208
+ mean, std = self._wrapped_model.pixel_mean, self._wrapped_model.pixel_std
209
+ normalized_data = (data - mean) / std
210
+ normalized_data = alias(normalized_data, "normalized_data")
211
+
212
+ # Pack (data, im_info) into ImageList which is recognized by self.inference.
213
+ images = ImageList(tensor=normalized_data, image_sizes=im_info)
214
+ return images
215
+
216
+ @staticmethod
217
+ def get_outputs_converter(predict_net, init_net):
218
+ """
219
+ Creates a function that converts outputs of the caffe2 model to
220
+ detectron2's standard format.
221
+ The function uses information in `predict_net` and `init_net` that are
222
+ available at inferene time. Therefore the function logic can be used in inference.
223
+
224
+ The returned function has the following signature:
225
+
226
+ def convert(batched_inputs, c2_inputs, c2_results) -> detectron2_outputs
227
+
228
+ Where
229
+
230
+ * batched_inputs (list[dict]): the original input format of the meta arch
231
+ * c2_inputs (tuple[Tensor]): the caffe2 inputs.
232
+ * c2_results (dict[str, Tensor]): the caffe2 output format,
233
+ corresponding to the outputs of the :meth:`forward` function.
234
+ * detectron2_outputs: the original output format of the meta arch.
235
+
236
+ This function can be used to compare the outputs of the original meta arch and
237
+ the converted caffe2 graph.
238
+
239
+ Returns:
240
+ callable: a callable of the above signature.
241
+ """
242
+ raise NotImplementedError
243
+
244
+
245
+ class Caffe2GeneralizedRCNN(Caffe2MetaArch):
246
+ def __init__(self, cfg, torch_model, enable_tensor_mode=True):
247
+ assert isinstance(torch_model, meta_arch.GeneralizedRCNN)
248
+ torch_model = patch_generalized_rcnn(torch_model)
249
+ super().__init__(cfg, torch_model, enable_tensor_mode)
250
+
251
+ try:
252
+ use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT
253
+ except AttributeError:
254
+ use_heatmap_max_keypoint = False
255
+ self.roi_heads_patcher = ROIHeadsPatcher(
256
+ self._wrapped_model.roi_heads, use_heatmap_max_keypoint
257
+ )
258
+ if self.tensor_mode:
259
+ self.roi_heads_patcher.patch_roi_heads()
260
+
261
+ def encode_additional_info(self, predict_net, init_net):
262
+ size_divisibility = self._wrapped_model.backbone.size_divisibility
263
+ check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
264
+ check_set_pb_arg(
265
+ predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
266
+ )
267
+ check_set_pb_arg(predict_net, "meta_architecture", "s", b"GeneralizedRCNN")
268
+
269
+ @mock_torch_nn_functional_interpolate()
270
+ def forward(self, inputs):
271
+ if not self.tensor_mode:
272
+ return self._wrapped_model.inference(inputs)
273
+ images = self._caffe2_preprocess_image(inputs)
274
+ features = self._wrapped_model.backbone(images.tensor)
275
+ proposals, _ = self._wrapped_model.proposal_generator(images, features)
276
+ detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals)
277
+ return tuple(detector_results[0].flatten())
278
+
279
+ @staticmethod
280
+ def get_outputs_converter(predict_net, init_net):
281
+ def f(batched_inputs, c2_inputs, c2_results):
282
+ _, im_info = c2_inputs
283
+ image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
284
+ results = assemble_rcnn_outputs_by_name(image_sizes, c2_results)
285
+ return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
286
+
287
+ return f
288
+
289
+
290
+ class Caffe2RetinaNet(Caffe2MetaArch):
291
+ def __init__(self, cfg, torch_model):
292
+ assert isinstance(torch_model, meta_arch.RetinaNet)
293
+ super().__init__(cfg, torch_model)
294
+
295
+ @mock_torch_nn_functional_interpolate()
296
+ def forward(self, inputs):
297
+ assert self.tensor_mode
298
+ images = self._caffe2_preprocess_image(inputs)
299
+
300
+ # explicitly return the images sizes to avoid removing "im_info" by ONNX
301
+ # since it's not used in the forward path
302
+ return_tensors = [images.image_sizes]
303
+
304
+ features = self._wrapped_model.backbone(images.tensor)
305
+ features = [features[f] for f in self._wrapped_model.head_in_features]
306
+ for i, feature_i in enumerate(features):
307
+ features[i] = alias(feature_i, "feature_{}".format(i), is_backward=True)
308
+ return_tensors.append(features[i])
309
+
310
+ pred_logits, pred_anchor_deltas = self._wrapped_model.head(features)
311
+ for i, (box_cls_i, box_delta_i) in enumerate(zip(pred_logits, pred_anchor_deltas)):
312
+ return_tensors.append(alias(box_cls_i, "box_cls_{}".format(i)))
313
+ return_tensors.append(alias(box_delta_i, "box_delta_{}".format(i)))
314
+
315
+ return tuple(return_tensors)
316
+
317
+ def encode_additional_info(self, predict_net, init_net):
318
+ size_divisibility = self._wrapped_model.backbone.size_divisibility
319
+ check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
320
+ check_set_pb_arg(
321
+ predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
322
+ )
323
+ check_set_pb_arg(predict_net, "meta_architecture", "s", b"RetinaNet")
324
+
325
+ # Inference parameters:
326
+ check_set_pb_arg(
327
+ predict_net, "score_threshold", "f", _cast_to_f32(self._wrapped_model.test_score_thresh)
328
+ )
329
+ check_set_pb_arg(
330
+ predict_net, "topk_candidates", "i", self._wrapped_model.test_topk_candidates
331
+ )
332
+ check_set_pb_arg(
333
+ predict_net, "nms_threshold", "f", _cast_to_f32(self._wrapped_model.test_nms_thresh)
334
+ )
335
+ check_set_pb_arg(
336
+ predict_net,
337
+ "max_detections_per_image",
338
+ "i",
339
+ self._wrapped_model.max_detections_per_image,
340
+ )
341
+
342
+ check_set_pb_arg(
343
+ predict_net,
344
+ "bbox_reg_weights",
345
+ "floats",
346
+ [_cast_to_f32(w) for w in self._wrapped_model.box2box_transform.weights],
347
+ )
348
+ self._encode_anchor_generator_cfg(predict_net)
349
+
350
+ def _encode_anchor_generator_cfg(self, predict_net):
351
+ # serialize anchor_generator for future use
352
+ serialized_anchor_generator = io.BytesIO()
353
+ torch.save(self._wrapped_model.anchor_generator, serialized_anchor_generator)
354
+ # Ideally we can put anchor generating inside the model, then we don't
355
+ # need to store this information.
356
+ bytes = serialized_anchor_generator.getvalue()
357
+ check_set_pb_arg(predict_net, "serialized_anchor_generator", "s", bytes)
358
+
359
+ @staticmethod
360
+ def get_outputs_converter(predict_net, init_net):
361
+ self = types.SimpleNamespace()
362
+ serialized_anchor_generator = io.BytesIO(
363
+ get_pb_arg_vals(predict_net, "serialized_anchor_generator", None)
364
+ )
365
+ self.anchor_generator = torch.load(serialized_anchor_generator)
366
+ bbox_reg_weights = get_pb_arg_floats(predict_net, "bbox_reg_weights", None)
367
+ self.box2box_transform = Box2BoxTransform(weights=tuple(bbox_reg_weights))
368
+ self.test_score_thresh = get_pb_arg_valf(predict_net, "score_threshold", None)
369
+ self.test_topk_candidates = get_pb_arg_vali(predict_net, "topk_candidates", None)
370
+ self.test_nms_thresh = get_pb_arg_valf(predict_net, "nms_threshold", None)
371
+ self.max_detections_per_image = get_pb_arg_vali(
372
+ predict_net, "max_detections_per_image", None
373
+ )
374
+
375
+ # hack to reuse inference code from RetinaNet
376
+ for meth in [
377
+ "forward_inference",
378
+ "inference_single_image",
379
+ "_transpose_dense_predictions",
380
+ "_decode_multi_level_predictions",
381
+ "_decode_per_level_predictions",
382
+ ]:
383
+ setattr(self, meth, functools.partial(getattr(meta_arch.RetinaNet, meth), self))
384
+
385
+ def f(batched_inputs, c2_inputs, c2_results):
386
+ _, im_info = c2_inputs
387
+ image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
388
+ dummy_images = ImageList(
389
+ torch.randn(
390
+ (
391
+ len(im_info),
392
+ 3,
393
+ )
394
+ + tuple(image_sizes[0])
395
+ ),
396
+ image_sizes,
397
+ )
398
+
399
+ num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")])
400
+ pred_logits = [c2_results["box_cls_{}".format(i)] for i in range(num_features)]
401
+ pred_anchor_deltas = [c2_results["box_delta_{}".format(i)] for i in range(num_features)]
402
+
403
+ # For each feature level, feature should have the same batch size and
404
+ # spatial dimension as the box_cls and box_delta.
405
+ dummy_features = [x.clone()[:, 0:0, :, :] for x in pred_logits]
406
+ # self.num_classess can be inferred
407
+ self.num_classes = pred_logits[0].shape[1] // (pred_anchor_deltas[0].shape[1] // 4)
408
+
409
+ results = self.forward_inference(
410
+ dummy_images, dummy_features, [pred_logits, pred_anchor_deltas]
411
+ )
412
+ return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
413
+
414
+ return f
415
+
416
+
417
+ META_ARCH_CAFFE2_EXPORT_TYPE_MAP = {
418
+ "GeneralizedRCNN": Caffe2GeneralizedRCNN,
419
+ "RetinaNet": Caffe2RetinaNet,
420
+ }
Leffa/3rdparty/detectron2/export/caffe2_patch.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import contextlib
4
+ from unittest import mock
5
+ import torch
6
+
7
+ from detectron2.modeling import poolers
8
+ from detectron2.modeling.proposal_generator import rpn
9
+ from detectron2.modeling.roi_heads import keypoint_head, mask_head
10
+ from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
11
+
12
+ from .c10 import (
13
+ Caffe2Compatible,
14
+ Caffe2FastRCNNOutputsInference,
15
+ Caffe2KeypointRCNNInference,
16
+ Caffe2MaskRCNNInference,
17
+ Caffe2ROIPooler,
18
+ Caffe2RPN,
19
+ caffe2_fast_rcnn_outputs_inference,
20
+ caffe2_keypoint_rcnn_inference,
21
+ caffe2_mask_rcnn_inference,
22
+ )
23
+
24
+
25
+ class GenericMixin:
26
+ pass
27
+
28
+
29
+ class Caffe2CompatibleConverter:
30
+ """
31
+ A GenericUpdater which implements the `create_from` interface, by modifying
32
+ module object and assign it with another class replaceCls.
33
+ """
34
+
35
+ def __init__(self, replaceCls):
36
+ self.replaceCls = replaceCls
37
+
38
+ def create_from(self, module):
39
+ # update module's class to the new class
40
+ assert isinstance(module, torch.nn.Module)
41
+ if issubclass(self.replaceCls, GenericMixin):
42
+ # replaceCls should act as mixin, create a new class on-the-fly
43
+ new_class = type(
44
+ "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__),
45
+ (self.replaceCls, module.__class__),
46
+ {}, # {"new_method": lambda self: ...},
47
+ )
48
+ module.__class__ = new_class
49
+ else:
50
+ # replaceCls is complete class, this allow arbitrary class swap
51
+ module.__class__ = self.replaceCls
52
+
53
+ # initialize Caffe2Compatible
54
+ if isinstance(module, Caffe2Compatible):
55
+ module.tensor_mode = False
56
+
57
+ return module
58
+
59
+
60
+ def patch(model, target, updater, *args, **kwargs):
61
+ """
62
+ recursively (post-order) update all modules with the target type and its
63
+ subclasses, make a initialization/composition/inheritance/... via the
64
+ updater.create_from.
65
+ """
66
+ for name, module in model.named_children():
67
+ model._modules[name] = patch(module, target, updater, *args, **kwargs)
68
+ if isinstance(model, target):
69
+ return updater.create_from(model, *args, **kwargs)
70
+ return model
71
+
72
+
73
+ def patch_generalized_rcnn(model):
74
+ ccc = Caffe2CompatibleConverter
75
+ model = patch(model, rpn.RPN, ccc(Caffe2RPN))
76
+ model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler))
77
+
78
+ return model
79
+
80
+
81
+ @contextlib.contextmanager
82
+ def mock_fastrcnn_outputs_inference(
83
+ tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers
84
+ ):
85
+ with mock.patch.object(
86
+ box_predictor_type,
87
+ "inference",
88
+ autospec=True,
89
+ side_effect=Caffe2FastRCNNOutputsInference(tensor_mode),
90
+ ) as mocked_func:
91
+ yield
92
+ if check:
93
+ assert mocked_func.call_count > 0
94
+
95
+
96
+ @contextlib.contextmanager
97
+ def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True):
98
+ with mock.patch(
99
+ "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference()
100
+ ) as mocked_func:
101
+ yield
102
+ if check:
103
+ assert mocked_func.call_count > 0
104
+
105
+
106
+ @contextlib.contextmanager
107
+ def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True):
108
+ with mock.patch(
109
+ "{}.keypoint_rcnn_inference".format(patched_module),
110
+ side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint),
111
+ ) as mocked_func:
112
+ yield
113
+ if check:
114
+ assert mocked_func.call_count > 0
115
+
116
+
117
+ class ROIHeadsPatcher:
118
+ def __init__(self, heads, use_heatmap_max_keypoint):
119
+ self.heads = heads
120
+ self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
121
+ self.previous_patched = {}
122
+
123
+ @contextlib.contextmanager
124
+ def mock_roi_heads(self, tensor_mode=True):
125
+ """
126
+ Patching several inference functions inside ROIHeads and its subclasses
127
+
128
+ Args:
129
+ tensor_mode (bool): whether the inputs/outputs are caffe2's tensor
130
+ format or not. Default to True.
131
+ """
132
+ # NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference`
133
+ # are called inside the same file as BaseXxxHead due to using mock.patch.
134
+ kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__
135
+ mask_head_mod = mask_head.BaseMaskRCNNHead.__module__
136
+
137
+ mock_ctx_managers = [
138
+ mock_fastrcnn_outputs_inference(
139
+ tensor_mode=tensor_mode,
140
+ check=True,
141
+ box_predictor_type=type(self.heads.box_predictor),
142
+ )
143
+ ]
144
+ if getattr(self.heads, "keypoint_on", False):
145
+ mock_ctx_managers += [
146
+ mock_keypoint_rcnn_inference(
147
+ tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint
148
+ )
149
+ ]
150
+ if getattr(self.heads, "mask_on", False):
151
+ mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)]
152
+
153
+ with contextlib.ExitStack() as stack: # python 3.3+
154
+ for mgr in mock_ctx_managers:
155
+ stack.enter_context(mgr)
156
+ yield
157
+
158
+ def patch_roi_heads(self, tensor_mode=True):
159
+ self.previous_patched["box_predictor"] = self.heads.box_predictor.inference
160
+ self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference
161
+ self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference
162
+
163
+ def patched_fastrcnn_outputs_inference(predictions, proposal):
164
+ return caffe2_fast_rcnn_outputs_inference(
165
+ True, self.heads.box_predictor, predictions, proposal
166
+ )
167
+
168
+ self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference
169
+
170
+ if getattr(self.heads, "keypoint_on", False):
171
+
172
+ def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances):
173
+ return caffe2_keypoint_rcnn_inference(
174
+ self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances
175
+ )
176
+
177
+ keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference
178
+
179
+ if getattr(self.heads, "mask_on", False):
180
+
181
+ def patched_mask_rcnn_inference(pred_mask_logits, pred_instances):
182
+ return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances)
183
+
184
+ mask_head.mask_rcnn_inference = patched_mask_rcnn_inference
185
+
186
+ def unpatch_roi_heads(self):
187
+ self.heads.box_predictor.inference = self.previous_patched["box_predictor"]
188
+ keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"]
189
+ mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"]
Leffa/3rdparty/detectron2/export/flatten.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import collections
3
+ from dataclasses import dataclass
4
+ from typing import Callable, List, Optional, Tuple
5
+ import torch
6
+ from torch import nn
7
+
8
+ from detectron2.structures import Boxes, Instances, ROIMasks
9
+ from detectron2.utils.registry import _convert_target_to_string, locate
10
+
11
+ from .torchscript_patch import patch_builtin_len
12
+
13
+
14
+ @dataclass
15
+ class Schema:
16
+ """
17
+ A Schema defines how to flatten a possibly hierarchical object into tuple of
18
+ primitive objects, so it can be used as inputs/outputs of PyTorch's tracing.
19
+
20
+ PyTorch does not support tracing a function that produces rich output
21
+ structures (e.g. dict, Instances, Boxes). To trace such a function, we
22
+ flatten the rich object into tuple of tensors, and return this tuple of tensors
23
+ instead. Meanwhile, we also need to know how to "rebuild" the original object
24
+ from the flattened results, so we can evaluate the flattened results.
25
+ A Schema defines how to flatten an object, and while flattening it, it records
26
+ necessary schemas so that the object can be rebuilt using the flattened outputs.
27
+
28
+ The flattened object and the schema object is returned by ``.flatten`` classmethod.
29
+ Then the original object can be rebuilt with the ``__call__`` method of schema.
30
+
31
+ A Schema is a dataclass that can be serialized easily.
32
+ """
33
+
34
+ # inspired by FetchMapper in tensorflow/python/client/session.py
35
+
36
+ @classmethod
37
+ def flatten(cls, obj):
38
+ raise NotImplementedError
39
+
40
+ def __call__(self, values):
41
+ raise NotImplementedError
42
+
43
+ @staticmethod
44
+ def _concat(values):
45
+ ret = ()
46
+ sizes = []
47
+ for v in values:
48
+ assert isinstance(v, tuple), "Flattened results must be a tuple"
49
+ ret = ret + v
50
+ sizes.append(len(v))
51
+ return ret, sizes
52
+
53
+ @staticmethod
54
+ def _split(values, sizes):
55
+ if len(sizes):
56
+ expected_len = sum(sizes)
57
+ assert (
58
+ len(values) == expected_len
59
+ ), f"Values has length {len(values)} but expect length {expected_len}."
60
+ ret = []
61
+ for k in range(len(sizes)):
62
+ begin, end = sum(sizes[:k]), sum(sizes[: k + 1])
63
+ ret.append(values[begin:end])
64
+ return ret
65
+
66
+
67
+ @dataclass
68
+ class ListSchema(Schema):
69
+ schemas: List[Schema] # the schemas that define how to flatten each element in the list
70
+ sizes: List[int] # the flattened length of each element
71
+
72
+ def __call__(self, values):
73
+ values = self._split(values, self.sizes)
74
+ if len(values) != len(self.schemas):
75
+ raise ValueError(
76
+ f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!"
77
+ )
78
+ values = [m(v) for m, v in zip(self.schemas, values)]
79
+ return list(values)
80
+
81
+ @classmethod
82
+ def flatten(cls, obj):
83
+ res = [flatten_to_tuple(k) for k in obj]
84
+ values, sizes = cls._concat([k[0] for k in res])
85
+ return values, cls([k[1] for k in res], sizes)
86
+
87
+
88
+ @dataclass
89
+ class TupleSchema(ListSchema):
90
+ def __call__(self, values):
91
+ return tuple(super().__call__(values))
92
+
93
+
94
+ @dataclass
95
+ class IdentitySchema(Schema):
96
+ def __call__(self, values):
97
+ return values[0]
98
+
99
+ @classmethod
100
+ def flatten(cls, obj):
101
+ return (obj,), cls()
102
+
103
+
104
+ @dataclass
105
+ class DictSchema(ListSchema):
106
+ keys: List[str]
107
+
108
+ def __call__(self, values):
109
+ values = super().__call__(values)
110
+ return dict(zip(self.keys, values))
111
+
112
+ @classmethod
113
+ def flatten(cls, obj):
114
+ for k in obj.keys():
115
+ if not isinstance(k, str):
116
+ raise KeyError("Only support flattening dictionaries if keys are str.")
117
+ keys = sorted(obj.keys())
118
+ values = [obj[k] for k in keys]
119
+ ret, schema = ListSchema.flatten(values)
120
+ return ret, cls(schema.schemas, schema.sizes, keys)
121
+
122
+
123
+ @dataclass
124
+ class InstancesSchema(DictSchema):
125
+ def __call__(self, values):
126
+ image_size, fields = values[-1], values[:-1]
127
+ fields = super().__call__(fields)
128
+ return Instances(image_size, **fields)
129
+
130
+ @classmethod
131
+ def flatten(cls, obj):
132
+ ret, schema = super().flatten(obj.get_fields())
133
+ size = obj.image_size
134
+ if not isinstance(size, torch.Tensor):
135
+ size = torch.tensor(size)
136
+ return ret + (size,), schema
137
+
138
+
139
+ @dataclass
140
+ class TensorWrapSchema(Schema):
141
+ """
142
+ For classes that are simple wrapper of tensors, e.g.
143
+ Boxes, RotatedBoxes, BitMasks
144
+ """
145
+
146
+ class_name: str
147
+
148
+ def __call__(self, values):
149
+ return locate(self.class_name)(values[0])
150
+
151
+ @classmethod
152
+ def flatten(cls, obj):
153
+ return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
154
+
155
+
156
+ # if more custom structures needed in the future, can allow
157
+ # passing in extra schemas for custom types
158
+ def flatten_to_tuple(obj):
159
+ """
160
+ Flatten an object so it can be used for PyTorch tracing.
161
+ Also returns how to rebuild the original object from the flattened outputs.
162
+
163
+ Returns:
164
+ res (tuple): the flattened results that can be used as tracing outputs
165
+ schema: an object with a ``__call__`` method such that ``schema(res) == obj``.
166
+ It is a pure dataclass that can be serialized.
167
+ """
168
+ schemas = [
169
+ ((str, bytes), IdentitySchema),
170
+ (list, ListSchema),
171
+ (tuple, TupleSchema),
172
+ (collections.abc.Mapping, DictSchema),
173
+ (Instances, InstancesSchema),
174
+ ((Boxes, ROIMasks), TensorWrapSchema),
175
+ ]
176
+ for klass, schema in schemas:
177
+ if isinstance(obj, klass):
178
+ F = schema
179
+ break
180
+ else:
181
+ F = IdentitySchema
182
+
183
+ return F.flatten(obj)
184
+
185
+
186
+ class TracingAdapter(nn.Module):
187
+ """
188
+ A model may take rich input/output format (e.g. dict or custom classes),
189
+ but `torch.jit.trace` requires tuple of tensors as input/output.
190
+ This adapter flattens input/output format of a model so it becomes traceable.
191
+
192
+ It also records the necessary schema to rebuild model's inputs/outputs from flattened
193
+ inputs/outputs.
194
+
195
+ Example:
196
+ ::
197
+ outputs = model(inputs) # inputs/outputs may be rich structure
198
+ adapter = TracingAdapter(model, inputs)
199
+
200
+ # can now trace the model, with adapter.flattened_inputs, or another
201
+ # tuple of tensors with the same length and meaning
202
+ traced = torch.jit.trace(adapter, adapter.flattened_inputs)
203
+
204
+ # traced model can only produce flattened outputs (tuple of tensors)
205
+ flattened_outputs = traced(*adapter.flattened_inputs)
206
+ # adapter knows the schema to convert it back (new_outputs == outputs)
207
+ new_outputs = adapter.outputs_schema(flattened_outputs)
208
+ """
209
+
210
+ flattened_inputs: Tuple[torch.Tensor] = None
211
+ """
212
+ Flattened version of inputs given to this class's constructor.
213
+ """
214
+
215
+ inputs_schema: Schema = None
216
+ """
217
+ Schema of the inputs given to this class's constructor.
218
+ """
219
+
220
+ outputs_schema: Schema = None
221
+ """
222
+ Schema of the output produced by calling the given model with inputs.
223
+ """
224
+
225
+ def __init__(
226
+ self,
227
+ model: nn.Module,
228
+ inputs,
229
+ inference_func: Optional[Callable] = None,
230
+ allow_non_tensor: bool = False,
231
+ ):
232
+ """
233
+ Args:
234
+ model: an nn.Module
235
+ inputs: An input argument or a tuple of input arguments used to call model.
236
+ After flattening, it has to only consist of tensors.
237
+ inference_func: a callable that takes (model, *inputs), calls the
238
+ model with inputs, and return outputs. By default it
239
+ is ``lambda model, *inputs: model(*inputs)``. Can be override
240
+ if you need to call the model differently.
241
+ allow_non_tensor: allow inputs/outputs to contain non-tensor objects.
242
+ This option will filter out non-tensor objects to make the
243
+ model traceable, but ``inputs_schema``/``outputs_schema`` cannot be
244
+ used anymore because inputs/outputs cannot be rebuilt from pure tensors.
245
+ This is useful when you're only interested in the single trace of
246
+ execution (e.g. for flop count), but not interested in
247
+ generalizing the traced graph to new inputs.
248
+ """
249
+ super().__init__()
250
+ if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
251
+ model = model.module
252
+ self.model = model
253
+ if not isinstance(inputs, tuple):
254
+ inputs = (inputs,)
255
+ self.inputs = inputs
256
+ self.allow_non_tensor = allow_non_tensor
257
+
258
+ if inference_func is None:
259
+ inference_func = lambda model, *inputs: model(*inputs) # noqa
260
+ self.inference_func = inference_func
261
+
262
+ self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs)
263
+
264
+ if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs):
265
+ return
266
+ if self.allow_non_tensor:
267
+ self.flattened_inputs = tuple(
268
+ [x for x in self.flattened_inputs if isinstance(x, torch.Tensor)]
269
+ )
270
+ self.inputs_schema = None
271
+ else:
272
+ for input in self.flattened_inputs:
273
+ if not isinstance(input, torch.Tensor):
274
+ raise ValueError(
275
+ "Inputs for tracing must only contain tensors. "
276
+ f"Got a {type(input)} instead."
277
+ )
278
+
279
+ def forward(self, *args: torch.Tensor):
280
+ with torch.no_grad(), patch_builtin_len():
281
+ if self.inputs_schema is not None:
282
+ inputs_orig_format = self.inputs_schema(args)
283
+ else:
284
+ if len(args) != len(self.flattened_inputs) or any(
285
+ x is not y for x, y in zip(args, self.flattened_inputs)
286
+ ):
287
+ raise ValueError(
288
+ "TracingAdapter does not contain valid inputs_schema."
289
+ " So it cannot generalize to other inputs and must be"
290
+ " traced with `.flattened_inputs`."
291
+ )
292
+ inputs_orig_format = self.inputs
293
+
294
+ outputs = self.inference_func(self.model, *inputs_orig_format)
295
+ flattened_outputs, schema = flatten_to_tuple(outputs)
296
+
297
+ flattened_output_tensors = tuple(
298
+ [x for x in flattened_outputs if isinstance(x, torch.Tensor)]
299
+ )
300
+ if len(flattened_output_tensors) < len(flattened_outputs):
301
+ if self.allow_non_tensor:
302
+ flattened_outputs = flattened_output_tensors
303
+ self.outputs_schema = None
304
+ else:
305
+ raise ValueError(
306
+ "Model cannot be traced because some model outputs "
307
+ "cannot flatten to tensors."
308
+ )
309
+ else: # schema is valid
310
+ if self.outputs_schema is None:
311
+ self.outputs_schema = schema
312
+ else:
313
+ assert self.outputs_schema == schema, (
314
+ "Model should always return outputs with the same "
315
+ "structure so it can be traced!"
316
+ )
317
+ return flattened_outputs
318
+
319
+ def _create_wrapper(self, traced_model):
320
+ """
321
+ Return a function that has an input/output interface the same as the
322
+ original model, but it calls the given traced model under the hood.
323
+ """
324
+
325
+ def forward(*args):
326
+ flattened_inputs, _ = flatten_to_tuple(args)
327
+ flattened_outputs = traced_model(*flattened_inputs)
328
+ return self.outputs_schema(flattened_outputs)
329
+
330
+ return forward
Leffa/3rdparty/detectron2/export/shared.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import collections
4
+ import copy
5
+ import functools
6
+ import logging
7
+ import numpy as np
8
+ import os
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+ from unittest import mock
11
+ import caffe2.python.utils as putils
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from caffe2.proto import caffe2_pb2
15
+ from caffe2.python import core, net_drawer, workspace
16
+ from torch.nn.functional import interpolate as interp
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # ==== torch/utils_toffee/cast.py =======================================
22
+
23
+
24
+ def to_device(t, device_str):
25
+ """
26
+ This function is a replacement of .to(another_device) such that it allows the
27
+ casting to be traced properly by explicitly calling the underlying copy ops.
28
+ It also avoids introducing unncessary op when casting to the same device.
29
+ """
30
+ src = t.device
31
+ dst = torch.device(device_str)
32
+
33
+ if src == dst:
34
+ return t
35
+ elif src.type == "cuda" and dst.type == "cpu":
36
+ return torch.ops._caffe2.CopyGPUToCPU(t)
37
+ elif src.type == "cpu" and dst.type == "cuda":
38
+ return torch.ops._caffe2.CopyCPUToGPU(t)
39
+ else:
40
+ raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst))
41
+
42
+
43
+ # ==== torch/utils_toffee/interpolate.py =======================================
44
+
45
+
46
+ # Note: borrowed from vision/detection/fair/detectron/detectron/modeling/detector.py
47
+ def BilinearInterpolation(tensor_in, up_scale):
48
+ assert up_scale % 2 == 0, "Scale should be even"
49
+
50
+ def upsample_filt(size):
51
+ factor = (size + 1) // 2
52
+ if size % 2 == 1:
53
+ center = factor - 1
54
+ else:
55
+ center = factor - 0.5
56
+
57
+ og = np.ogrid[:size, :size]
58
+ return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
59
+
60
+ kernel_size = int(up_scale) * 2
61
+ bil_filt = upsample_filt(kernel_size)
62
+
63
+ dim = int(tensor_in.shape[1])
64
+ kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32)
65
+ kernel[range(dim), range(dim), :, :] = bil_filt
66
+
67
+ tensor_out = F.conv_transpose2d(
68
+ tensor_in,
69
+ weight=to_device(torch.Tensor(kernel), tensor_in.device),
70
+ bias=None,
71
+ stride=int(up_scale),
72
+ padding=int(up_scale / 2),
73
+ )
74
+
75
+ return tensor_out
76
+
77
+
78
+ # NOTE: ONNX is incompatible with traced torch.nn.functional.interpolate if
79
+ # using dynamic `scale_factor` rather than static `size`. (T43166860)
80
+ # NOTE: Caffe2 Int8 conversion might not be able to quantize `size` properly.
81
+ def onnx_compatibale_interpolate(
82
+ input, size=None, scale_factor=None, mode="nearest", align_corners=None
83
+ ):
84
+ # NOTE: The input dimensions are interpreted in the form:
85
+ # `mini-batch x channels x [optional depth] x [optional height] x width`.
86
+ if size is None and scale_factor is not None:
87
+ if input.dim() == 4:
88
+ if isinstance(scale_factor, (int, float)):
89
+ height_scale, width_scale = (scale_factor, scale_factor)
90
+ else:
91
+ assert isinstance(scale_factor, (tuple, list))
92
+ assert len(scale_factor) == 2
93
+ height_scale, width_scale = scale_factor
94
+
95
+ assert not align_corners, "No matching C2 op for align_corners == True"
96
+ if mode == "nearest":
97
+ return torch.ops._caffe2.ResizeNearest(
98
+ input, order="NCHW", width_scale=width_scale, height_scale=height_scale
99
+ )
100
+ elif mode == "bilinear":
101
+ logger.warning(
102
+ "Use F.conv_transpose2d for bilinear interpolate"
103
+ " because there's no such C2 op, this may cause significant"
104
+ " slowdown and the boundary pixels won't be as same as"
105
+ " using F.interpolate due to padding."
106
+ )
107
+ assert height_scale == width_scale
108
+ return BilinearInterpolation(input, up_scale=height_scale)
109
+ logger.warning("Output size is not static, it might cause ONNX conversion issue")
110
+
111
+ return interp(input, size, scale_factor, mode, align_corners)
112
+
113
+
114
+ def mock_torch_nn_functional_interpolate():
115
+ def decorator(func):
116
+ @functools.wraps(func)
117
+ def _mock_torch_nn_functional_interpolate(*args, **kwargs):
118
+ if torch.onnx.is_in_onnx_export():
119
+ with mock.patch(
120
+ "torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate
121
+ ):
122
+ return func(*args, **kwargs)
123
+ else:
124
+ return func(*args, **kwargs)
125
+
126
+ return _mock_torch_nn_functional_interpolate
127
+
128
+ return decorator
129
+
130
+
131
+ # ==== torch/utils_caffe2/ws_utils.py ==========================================
132
+
133
+
134
+ class ScopedWS:
135
+ def __init__(self, ws_name, is_reset, is_cleanup=False):
136
+ self.ws_name = ws_name
137
+ self.is_reset = is_reset
138
+ self.is_cleanup = is_cleanup
139
+ self.org_ws = ""
140
+
141
+ def __enter__(self):
142
+ self.org_ws = workspace.CurrentWorkspace()
143
+ if self.ws_name is not None:
144
+ workspace.SwitchWorkspace(self.ws_name, True)
145
+ if self.is_reset:
146
+ workspace.ResetWorkspace()
147
+
148
+ return workspace
149
+
150
+ def __exit__(self, *args):
151
+ if self.is_cleanup:
152
+ workspace.ResetWorkspace()
153
+ if self.ws_name is not None:
154
+ workspace.SwitchWorkspace(self.org_ws)
155
+
156
+
157
+ def fetch_any_blob(name):
158
+ bb = None
159
+ try:
160
+ bb = workspace.FetchBlob(name)
161
+ except TypeError:
162
+ bb = workspace.FetchInt8Blob(name)
163
+ except Exception as e:
164
+ logger.error("Get blob {} error: {}".format(name, e))
165
+
166
+ return bb
167
+
168
+
169
+ # ==== torch/utils_caffe2/protobuf.py ==========================================
170
+
171
+
172
+ def get_pb_arg(pb, arg_name):
173
+ for x in pb.arg:
174
+ if x.name == arg_name:
175
+ return x
176
+ return None
177
+
178
+
179
+ def get_pb_arg_valf(pb, arg_name, default_val):
180
+ arg = get_pb_arg(pb, arg_name)
181
+ return arg.f if arg is not None else default_val
182
+
183
+
184
+ def get_pb_arg_floats(pb, arg_name, default_val):
185
+ arg = get_pb_arg(pb, arg_name)
186
+ return list(map(float, arg.floats)) if arg is not None else default_val
187
+
188
+
189
+ def get_pb_arg_ints(pb, arg_name, default_val):
190
+ arg = get_pb_arg(pb, arg_name)
191
+ return list(map(int, arg.ints)) if arg is not None else default_val
192
+
193
+
194
+ def get_pb_arg_vali(pb, arg_name, default_val):
195
+ arg = get_pb_arg(pb, arg_name)
196
+ return arg.i if arg is not None else default_val
197
+
198
+
199
+ def get_pb_arg_vals(pb, arg_name, default_val):
200
+ arg = get_pb_arg(pb, arg_name)
201
+ return arg.s if arg is not None else default_val
202
+
203
+
204
+ def get_pb_arg_valstrings(pb, arg_name, default_val):
205
+ arg = get_pb_arg(pb, arg_name)
206
+ return list(arg.strings) if arg is not None else default_val
207
+
208
+
209
+ def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False):
210
+ arg = get_pb_arg(pb, arg_name)
211
+ if arg is None:
212
+ arg = putils.MakeArgument(arg_name, arg_value)
213
+ assert hasattr(arg, arg_attr)
214
+ pb.arg.extend([arg])
215
+ if allow_override and getattr(arg, arg_attr) != arg_value:
216
+ logger.warning(
217
+ "Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value)
218
+ )
219
+ setattr(arg, arg_attr, arg_value)
220
+ else:
221
+ assert arg is not None
222
+ assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format(
223
+ getattr(arg, arg_attr), arg_value
224
+ )
225
+
226
+
227
+ def _create_const_fill_op_from_numpy(name, tensor, device_option=None):
228
+ assert type(tensor) == np.ndarray
229
+ kTypeNameMapper = {
230
+ np.dtype("float32"): "GivenTensorFill",
231
+ np.dtype("int32"): "GivenTensorIntFill",
232
+ np.dtype("int64"): "GivenTensorInt64Fill",
233
+ np.dtype("uint8"): "GivenTensorStringFill",
234
+ }
235
+
236
+ args_dict = {}
237
+ if tensor.dtype == np.dtype("uint8"):
238
+ args_dict.update({"values": [str(tensor.data)], "shape": [1]})
239
+ else:
240
+ args_dict.update({"values": tensor, "shape": tensor.shape})
241
+
242
+ if device_option is not None:
243
+ args_dict["device_option"] = device_option
244
+
245
+ return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict)
246
+
247
+
248
+ def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor):
249
+ assert type(int8_tensor) == workspace.Int8Tensor
250
+ kTypeNameMapper = {
251
+ np.dtype("int32"): "Int8GivenIntTensorFill",
252
+ np.dtype("uint8"): "Int8GivenTensorFill",
253
+ }
254
+
255
+ tensor = int8_tensor.data
256
+ assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")]
257
+ values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor
258
+
259
+ return core.CreateOperator(
260
+ kTypeNameMapper[tensor.dtype],
261
+ [],
262
+ [name],
263
+ values=values,
264
+ shape=tensor.shape,
265
+ Y_scale=int8_tensor.scale,
266
+ Y_zero_point=int8_tensor.zero_point,
267
+ )
268
+
269
+
270
+ def create_const_fill_op(
271
+ name: str,
272
+ blob: Union[np.ndarray, workspace.Int8Tensor],
273
+ device_option: Optional[caffe2_pb2.DeviceOption] = None,
274
+ ) -> caffe2_pb2.OperatorDef:
275
+ """
276
+ Given a blob object, return the Caffe2 operator that creates this blob
277
+ as constant. Currently support NumPy tensor and Caffe2 Int8Tensor.
278
+ """
279
+
280
+ tensor_type = type(blob)
281
+ assert tensor_type in [
282
+ np.ndarray,
283
+ workspace.Int8Tensor,
284
+ ], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format(
285
+ name, type(blob)
286
+ )
287
+
288
+ if tensor_type == np.ndarray:
289
+ return _create_const_fill_op_from_numpy(name, blob, device_option)
290
+ elif tensor_type == workspace.Int8Tensor:
291
+ assert device_option is None
292
+ return _create_const_fill_op_from_c2_int8_tensor(name, blob)
293
+
294
+
295
+ def construct_init_net_from_params(
296
+ params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None
297
+ ) -> caffe2_pb2.NetDef:
298
+ """
299
+ Construct the init_net from params dictionary
300
+ """
301
+ init_net = caffe2_pb2.NetDef()
302
+ device_options = device_options or {}
303
+ for name, blob in params.items():
304
+ if isinstance(blob, str):
305
+ logger.warning(
306
+ (
307
+ "Blob {} with type {} is not supported in generating init net,"
308
+ " skipped.".format(name, type(blob))
309
+ )
310
+ )
311
+ continue
312
+ init_net.op.extend(
313
+ [create_const_fill_op(name, blob, device_option=device_options.get(name, None))]
314
+ )
315
+ init_net.external_output.append(name)
316
+ return init_net
317
+
318
+
319
+ def get_producer_map(ssa):
320
+ """
321
+ Return dict from versioned blob to (i, j),
322
+ where i is index of producer op, j is the index of output of that op.
323
+ """
324
+ producer_map = {}
325
+ for i in range(len(ssa)):
326
+ outputs = ssa[i][1]
327
+ for j, outp in enumerate(outputs):
328
+ producer_map[outp] = (i, j)
329
+ return producer_map
330
+
331
+
332
+ def get_consumer_map(ssa):
333
+ """
334
+ Return dict from versioned blob to list of (i, j),
335
+ where i is index of consumer op, j is the index of input of that op.
336
+ """
337
+ consumer_map = collections.defaultdict(list)
338
+ for i in range(len(ssa)):
339
+ inputs = ssa[i][0]
340
+ for j, inp in enumerate(inputs):
341
+ consumer_map[inp].append((i, j))
342
+ return consumer_map
343
+
344
+
345
+ def get_params_from_init_net(
346
+ init_net: caffe2_pb2.NetDef,
347
+ ) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]:
348
+ """
349
+ Take the output blobs from init_net by running it.
350
+ Outputs:
351
+ params: dict from blob name to numpy array
352
+ device_options: dict from blob name to the device option of its creating op
353
+ """
354
+ # NOTE: this assumes that the params is determined by producer op with the
355
+ # only exception be CopyGPUToCPU which is CUDA op but returns CPU tensor.
356
+ def _get_device_option(producer_op):
357
+ if producer_op.type == "CopyGPUToCPU":
358
+ return caffe2_pb2.DeviceOption()
359
+ else:
360
+ return producer_op.device_option
361
+
362
+ with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws:
363
+ ws.RunNetOnce(init_net)
364
+ params = {b: fetch_any_blob(b) for b in init_net.external_output}
365
+ ssa, versions = core.get_ssa(init_net)
366
+ producer_map = get_producer_map(ssa)
367
+ device_options = {
368
+ b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]])
369
+ for b in init_net.external_output
370
+ }
371
+ return params, device_options
372
+
373
+
374
+ def _updater_raise(op, input_types, output_types):
375
+ raise RuntimeError(
376
+ "Failed to apply updater for op {} given input_types {} and"
377
+ " output_types {}".format(op, input_types, output_types)
378
+ )
379
+
380
+
381
+ def _generic_status_identifier(
382
+ predict_net: caffe2_pb2.NetDef,
383
+ status_updater: Callable,
384
+ known_status: Dict[Tuple[str, int], Any],
385
+ ) -> Dict[Tuple[str, int], Any]:
386
+ """
387
+ Statically infer the status of each blob, the status can be such as device type
388
+ (CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here
389
+ is versioned blob (Tuple[str, int]) in the format compatible with ssa.
390
+ Inputs:
391
+ predict_net: the caffe2 network
392
+ status_updater: a callable, given an op and the status of its input/output,
393
+ it returns the updated status of input/output. `None` is used for
394
+ representing unknown status.
395
+ known_status: a dict containing known status, used as initialization.
396
+ Outputs:
397
+ A dict mapping from versioned blob to its status
398
+ """
399
+ ssa, versions = core.get_ssa(predict_net)
400
+ versioned_ext_input = [(b, 0) for b in predict_net.external_input]
401
+ versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output]
402
+ all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa])
403
+
404
+ allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output)
405
+ assert all(k in allowed_vbs for k in known_status)
406
+ assert all(v is not None for v in known_status.values())
407
+ _known_status = copy.deepcopy(known_status)
408
+
409
+ def _check_and_update(key, value):
410
+ assert value is not None
411
+ if key in _known_status:
412
+ if not _known_status[key] == value:
413
+ raise RuntimeError(
414
+ "Confilict status for {}, existing status {}, new status {}".format(
415
+ key, _known_status[key], value
416
+ )
417
+ )
418
+ _known_status[key] = value
419
+
420
+ def _update_i(op, ssa_i):
421
+ versioned_inputs = ssa_i[0]
422
+ versioned_outputs = ssa_i[1]
423
+
424
+ inputs_status = [_known_status.get(b, None) for b in versioned_inputs]
425
+ outputs_status = [_known_status.get(b, None) for b in versioned_outputs]
426
+
427
+ new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status)
428
+
429
+ for versioned_blob, status in zip(
430
+ versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status
431
+ ):
432
+ if status is not None:
433
+ _check_and_update(versioned_blob, status)
434
+
435
+ for op, ssa_i in zip(predict_net.op, ssa):
436
+ _update_i(op, ssa_i)
437
+ for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)):
438
+ _update_i(op, ssa_i)
439
+
440
+ # NOTE: This strictly checks all the blob from predict_net must be assgined
441
+ # a known status. However sometimes it's impossible (eg. having deadend op),
442
+ # we may relax this constraint if
443
+ for k in all_versioned_blobs:
444
+ if k not in _known_status:
445
+ raise NotImplementedError(
446
+ "Can not infer the status for {}. Currently only support the case where"
447
+ " a single forward and backward pass can identify status for all blobs.".format(k)
448
+ )
449
+
450
+ return _known_status
451
+
452
+
453
+ def infer_device_type(
454
+ predict_net: caffe2_pb2.NetDef,
455
+ known_status: Dict[Tuple[str, int], Any],
456
+ device_name_style: str = "caffe2",
457
+ ) -> Dict[Tuple[str, int], str]:
458
+ """Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob"""
459
+
460
+ assert device_name_style in ["caffe2", "pytorch"]
461
+ _CPU_STR = "cpu"
462
+ _GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda"
463
+
464
+ def _copy_cpu_to_gpu_updater(op, input_types, output_types):
465
+ if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR:
466
+ _updater_raise(op, input_types, output_types)
467
+ return ([_CPU_STR], [_GPU_STR])
468
+
469
+ def _copy_gpu_to_cpu_updater(op, input_types, output_types):
470
+ if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR:
471
+ _updater_raise(op, input_types, output_types)
472
+ return ([_GPU_STR], [_CPU_STR])
473
+
474
+ def _other_ops_updater(op, input_types, output_types):
475
+ non_none_types = [x for x in input_types + output_types if x is not None]
476
+ if len(non_none_types) > 0:
477
+ the_type = non_none_types[0]
478
+ if not all(x == the_type for x in non_none_types):
479
+ _updater_raise(op, input_types, output_types)
480
+ else:
481
+ the_type = None
482
+ return ([the_type for _ in op.input], [the_type for _ in op.output])
483
+
484
+ def _device_updater(op, *args, **kwargs):
485
+ return {
486
+ "CopyCPUToGPU": _copy_cpu_to_gpu_updater,
487
+ "CopyGPUToCPU": _copy_gpu_to_cpu_updater,
488
+ }.get(op.type, _other_ops_updater)(op, *args, **kwargs)
489
+
490
+ return _generic_status_identifier(predict_net, _device_updater, known_status)
491
+
492
+
493
+ # ==== torch/utils_caffe2/vis.py ===============================================
494
+
495
+
496
+ def _modify_blob_names(ops, blob_rename_f):
497
+ ret = []
498
+
499
+ def _replace_list(blob_list, replaced_list):
500
+ del blob_list[:]
501
+ blob_list.extend(replaced_list)
502
+
503
+ for x in ops:
504
+ cur = copy.deepcopy(x)
505
+ _replace_list(cur.input, list(map(blob_rename_f, cur.input)))
506
+ _replace_list(cur.output, list(map(blob_rename_f, cur.output)))
507
+ ret.append(cur)
508
+
509
+ return ret
510
+
511
+
512
+ def _rename_blob(name, blob_sizes, blob_ranges):
513
+ def _list_to_str(bsize):
514
+ ret = ", ".join([str(x) for x in bsize])
515
+ ret = "[" + ret + "]"
516
+ return ret
517
+
518
+ ret = name
519
+ if blob_sizes is not None and name in blob_sizes:
520
+ ret += "\n" + _list_to_str(blob_sizes[name])
521
+ if blob_ranges is not None and name in blob_ranges:
522
+ ret += "\n" + _list_to_str(blob_ranges[name])
523
+
524
+ return ret
525
+
526
+
527
+ # graph_name could not contain word 'graph'
528
+ def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None):
529
+ blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges)
530
+ return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f)
531
+
532
+
533
+ def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None):
534
+ graph = None
535
+ ops = net.op
536
+ if blob_rename_func is not None:
537
+ ops = _modify_blob_names(ops, blob_rename_func)
538
+ if not op_only:
539
+ graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB")
540
+ else:
541
+ graph = net_drawer.GetPydotGraphMinimal(
542
+ ops, graph_name, rankdir="TB", minimal_dependency=True
543
+ )
544
+
545
+ try:
546
+ par_dir = os.path.dirname(file_name)
547
+ if not os.path.exists(par_dir):
548
+ os.makedirs(par_dir)
549
+
550
+ format = os.path.splitext(os.path.basename(file_name))[-1]
551
+ if format == ".png":
552
+ graph.write_png(file_name)
553
+ elif format == ".pdf":
554
+ graph.write_pdf(file_name)
555
+ elif format == ".svg":
556
+ graph.write_svg(file_name)
557
+ else:
558
+ print("Incorrect format {}".format(format))
559
+ except Exception as e:
560
+ print("Error when writing graph to image {}".format(e))
561
+
562
+ return graph
563
+
564
+
565
+ # ==== torch/utils_toffee/aten_to_caffe2.py ====================================
566
+
567
+
568
+ def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef):
569
+ """
570
+ For ONNX exported model, GroupNorm will be represented as ATen op,
571
+ this can be a drop in replacement from ATen to GroupNorm
572
+ """
573
+ count = 0
574
+ for op in predict_net.op:
575
+ if op.type == "ATen":
576
+ op_name = get_pb_arg_vals(op, "operator", None) # return byte in py3
577
+ if op_name and op_name.decode() == "group_norm":
578
+ op.arg.remove(get_pb_arg(op, "operator"))
579
+
580
+ if get_pb_arg_vali(op, "cudnn_enabled", None):
581
+ op.arg.remove(get_pb_arg(op, "cudnn_enabled"))
582
+
583
+ num_groups = get_pb_arg_vali(op, "num_groups", None)
584
+ if num_groups is not None:
585
+ op.arg.remove(get_pb_arg(op, "num_groups"))
586
+ check_set_pb_arg(op, "group", "i", num_groups)
587
+
588
+ op.type = "GroupNorm"
589
+ count += 1
590
+ if count > 1:
591
+ logger.info("Replaced {} ATen operator to GroupNormOp".format(count))
592
+
593
+
594
+ # ==== torch/utils_toffee/alias.py =============================================
595
+
596
+
597
+ def alias(x, name, is_backward=False):
598
+ if not torch.onnx.is_in_onnx_export():
599
+ return x
600
+ assert isinstance(x, torch.Tensor)
601
+ return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
602
+
603
+
604
+ def fuse_alias_placeholder(predict_net, init_net):
605
+ """Remove AliasWithName placeholder and rename the input/output of it"""
606
+ # First we finish all the re-naming
607
+ for i, op in enumerate(predict_net.op):
608
+ if op.type == "AliasWithName":
609
+ assert len(op.input) == 1
610
+ assert len(op.output) == 1
611
+ name = get_pb_arg_vals(op, "name", None).decode()
612
+ is_backward = bool(get_pb_arg_vali(op, "is_backward", 0))
613
+ rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward)
614
+ rename_op_output(predict_net, i, 0, name)
615
+
616
+ # Remove AliasWithName, should be very safe since it's a non-op
617
+ new_ops = []
618
+ for op in predict_net.op:
619
+ if op.type != "AliasWithName":
620
+ new_ops.append(op)
621
+ else:
622
+ # safety check
623
+ assert op.input == op.output
624
+ assert op.input[0] == op.arg[0].s.decode()
625
+ del predict_net.op[:]
626
+ predict_net.op.extend(new_ops)
627
+
628
+
629
+ # ==== torch/utils_caffe2/graph_transform.py ===================================
630
+
631
+
632
+ class IllegalGraphTransformError(ValueError):
633
+ """When a graph transform function call can't be executed."""
634
+
635
+
636
+ def _rename_versioned_blob_in_proto(
637
+ proto: caffe2_pb2.NetDef,
638
+ old_name: str,
639
+ new_name: str,
640
+ version: int,
641
+ ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]],
642
+ start_versions: Dict[str, int],
643
+ end_versions: Dict[str, int],
644
+ ):
645
+ """In given proto, rename all blobs with matched version"""
646
+ # Operater list
647
+ for op, i_th_ssa in zip(proto.op, ssa):
648
+ versioned_inputs, versioned_outputs = i_th_ssa
649
+ for i in range(len(op.input)):
650
+ if versioned_inputs[i] == (old_name, version):
651
+ op.input[i] = new_name
652
+ for i in range(len(op.output)):
653
+ if versioned_outputs[i] == (old_name, version):
654
+ op.output[i] = new_name
655
+ # external_input
656
+ if start_versions.get(old_name, 0) == version:
657
+ for i in range(len(proto.external_input)):
658
+ if proto.external_input[i] == old_name:
659
+ proto.external_input[i] = new_name
660
+ # external_output
661
+ if end_versions.get(old_name, 0) == version:
662
+ for i in range(len(proto.external_output)):
663
+ if proto.external_output[i] == old_name:
664
+ proto.external_output[i] = new_name
665
+
666
+
667
+ def rename_op_input(
668
+ predict_net: caffe2_pb2.NetDef,
669
+ init_net: caffe2_pb2.NetDef,
670
+ op_id: int,
671
+ input_id: int,
672
+ new_name: str,
673
+ from_producer: bool = False,
674
+ ):
675
+ """
676
+ Rename the op_id-th operator in predict_net, change it's input_id-th input's
677
+ name to the new_name. It also does automatic re-route and change
678
+ external_input and init_net if necessary.
679
+ - It requires the input is only consumed by this op.
680
+ - This function modifies predict_net and init_net in-place.
681
+ - When from_producer is enable, this also updates other operators that consumes
682
+ the same input. Be cautious because may trigger unintended behavior.
683
+ """
684
+ assert isinstance(predict_net, caffe2_pb2.NetDef)
685
+ assert isinstance(init_net, caffe2_pb2.NetDef)
686
+
687
+ init_net_ssa, init_net_versions = core.get_ssa(init_net)
688
+ predict_net_ssa, predict_net_versions = core.get_ssa(
689
+ predict_net, copy.deepcopy(init_net_versions)
690
+ )
691
+
692
+ versioned_inputs, versioned_outputs = predict_net_ssa[op_id]
693
+ old_name, version = versioned_inputs[input_id]
694
+
695
+ if from_producer:
696
+ producer_map = get_producer_map(predict_net_ssa)
697
+ if not (old_name, version) in producer_map:
698
+ raise NotImplementedError(
699
+ "Can't find producer, the input {} is probably from"
700
+ " init_net, this is not supported yet.".format(old_name)
701
+ )
702
+ producer = producer_map[(old_name, version)]
703
+ rename_op_output(predict_net, producer[0], producer[1], new_name)
704
+ return
705
+
706
+ def contain_targets(op_ssa):
707
+ return (old_name, version) in op_ssa[0]
708
+
709
+ is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa]
710
+ if sum(is_consumer) > 1:
711
+ raise IllegalGraphTransformError(
712
+ (
713
+ "Input '{}' of operator(#{}) are consumed by other ops, please use"
714
+ + " rename_op_output on the producer instead. Offending op: \n{}"
715
+ ).format(old_name, op_id, predict_net.op[op_id])
716
+ )
717
+
718
+ # update init_net
719
+ _rename_versioned_blob_in_proto(
720
+ init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions
721
+ )
722
+ # update predict_net
723
+ _rename_versioned_blob_in_proto(
724
+ predict_net,
725
+ old_name,
726
+ new_name,
727
+ version,
728
+ predict_net_ssa,
729
+ init_net_versions,
730
+ predict_net_versions,
731
+ )
732
+
733
+
734
+ def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str):
735
+ """
736
+ Rename the op_id-th operator in predict_net, change it's output_id-th input's
737
+ name to the new_name. It also does automatic re-route and change
738
+ external_output and if necessary.
739
+ - It allows multiple consumers of its output.
740
+ - This function modifies predict_net in-place, doesn't need init_net.
741
+ """
742
+ assert isinstance(predict_net, caffe2_pb2.NetDef)
743
+
744
+ ssa, blob_versions = core.get_ssa(predict_net)
745
+
746
+ versioned_inputs, versioned_outputs = ssa[op_id]
747
+ old_name, version = versioned_outputs[output_id]
748
+
749
+ # update predict_net
750
+ _rename_versioned_blob_in_proto(
751
+ predict_net, old_name, new_name, version, ssa, {}, blob_versions
752
+ )
753
+
754
+
755
+ def get_sub_graph_external_input_output(
756
+ predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int]
757
+ ) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
758
+ """
759
+ Return the list of external input/output of sub-graph,
760
+ each element is tuple of the name and corresponding version in predict_net.
761
+
762
+ external input/output is defined the same way as caffe2 NetDef.
763
+ """
764
+ ssa, versions = core.get_ssa(predict_net)
765
+
766
+ all_inputs = []
767
+ all_outputs = []
768
+ for op_id in sub_graph_op_indices:
769
+ all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs]
770
+ all_outputs += list(ssa[op_id][1]) # ssa output won't repeat
771
+
772
+ # for versioned blobs, external inputs are just those blob in all_inputs
773
+ # but not in all_outputs
774
+ ext_inputs = [inp for inp in all_inputs if inp not in all_outputs]
775
+
776
+ # external outputs are essentially outputs of this subgraph that are used
777
+ # outside of this sub-graph (including predict_net.external_output)
778
+ all_other_inputs = sum(
779
+ (ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices),
780
+ [(outp, versions[outp]) for outp in predict_net.external_output],
781
+ )
782
+ ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)]
783
+
784
+ return ext_inputs, ext_outputs
785
+
786
+
787
+ class DiGraph:
788
+ """A DAG representation of caffe2 graph, each vertice is a versioned blob."""
789
+
790
+ def __init__(self):
791
+ self.vertices = set()
792
+ self.graph = collections.defaultdict(list)
793
+
794
+ def add_edge(self, u, v):
795
+ self.graph[u].append(v)
796
+ self.vertices.add(u)
797
+ self.vertices.add(v)
798
+
799
+ # grab from https://www.geeksforgeeks.org/find-paths-given-source-destination/
800
+ def get_all_paths(self, s, d):
801
+ visited = {k: False for k in self.vertices}
802
+ path = []
803
+ all_paths = []
804
+
805
+ def _get_all_paths_util(graph, u, d, visited, path):
806
+ visited[u] = True
807
+ path.append(u)
808
+ if u == d:
809
+ all_paths.append(copy.deepcopy(path))
810
+ else:
811
+ for i in graph[u]:
812
+ if not visited[i]:
813
+ _get_all_paths_util(graph, i, d, visited, path)
814
+ path.pop()
815
+ visited[u] = False
816
+
817
+ _get_all_paths_util(self.graph, s, d, visited, path)
818
+ return all_paths
819
+
820
+ @staticmethod
821
+ def from_ssa(ssa):
822
+ graph = DiGraph()
823
+ for op_id in range(len(ssa)):
824
+ for inp in ssa[op_id][0]:
825
+ for outp in ssa[op_id][1]:
826
+ graph.add_edge(inp, outp)
827
+ return graph
828
+
829
+
830
+ def _get_dependency_chain(ssa, versioned_target, versioned_source):
831
+ """
832
+ Return the index list of relevant operator to produce target blob from source blob,
833
+ if there's no dependency, return empty list.
834
+ """
835
+
836
+ # finding all paths between nodes can be O(N!), thus we can only search
837
+ # in the subgraph using the op starting from the first consumer of source blob
838
+ # to the producer of the target blob.
839
+ consumer_map = get_consumer_map(ssa)
840
+ producer_map = get_producer_map(ssa)
841
+ start_op = min(x[0] for x in consumer_map[versioned_source]) - 15
842
+ end_op = (
843
+ producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op
844
+ )
845
+ sub_graph_ssa = ssa[start_op : end_op + 1]
846
+ if len(sub_graph_ssa) > 30:
847
+ logger.warning(
848
+ "Subgraph bebetween {} and {} is large (from op#{} to op#{}), it"
849
+ " might take non-trival time to find all paths between them.".format(
850
+ versioned_source, versioned_target, start_op, end_op
851
+ )
852
+ )
853
+
854
+ dag = DiGraph.from_ssa(sub_graph_ssa)
855
+ paths = dag.get_all_paths(versioned_source, versioned_target) # include two ends
856
+ ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths]
857
+ return sorted(set().union(*[set(ops) for ops in ops_in_paths]))
858
+
859
+
860
+ def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]:
861
+ """
862
+ Idenfity the reshape sub-graph in a protobuf.
863
+ The reshape sub-graph is defined as matching the following pattern:
864
+
865
+ (input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐
866
+ └-------------------------------------------> Reshape -> (output_blob)
867
+
868
+ Return:
869
+ List of sub-graphs, each sub-graph is represented as a list of indices
870
+ of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape]
871
+ """
872
+
873
+ ssa, _ = core.get_ssa(predict_net)
874
+
875
+ ret = []
876
+ for i, op in enumerate(predict_net.op):
877
+ if op.type == "Reshape":
878
+ assert len(op.input) == 2
879
+ input_ssa = ssa[i][0]
880
+ data_source = input_ssa[0]
881
+ shape_source = input_ssa[1]
882
+ op_indices = _get_dependency_chain(ssa, shape_source, data_source)
883
+ ret.append(op_indices + [i])
884
+ return ret
885
+
886
+
887
+ def remove_reshape_for_fc(predict_net, params):
888
+ """
889
+ In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape
890
+ a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping
891
+ doesn't work well with ONNX and Int8 tools, and cause using extra
892
+ ops (eg. ExpandDims) that might not be available on mobile.
893
+ Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape
894
+ after exporting ONNX model.
895
+ """
896
+ from caffe2.python import core
897
+
898
+ # find all reshape sub-graph that can be removed, which is now all Reshape
899
+ # sub-graph whose output is only consumed by FC.
900
+ # TODO: to make it safer, we may need the actually value to better determine
901
+ # if a Reshape before FC is removable.
902
+ reshape_sub_graphs = identify_reshape_sub_graph(predict_net)
903
+ sub_graphs_to_remove = []
904
+ for reshape_sub_graph in reshape_sub_graphs:
905
+ reshape_op_id = reshape_sub_graph[-1]
906
+ assert predict_net.op[reshape_op_id].type == "Reshape"
907
+ ssa, _ = core.get_ssa(predict_net)
908
+ reshape_output = ssa[reshape_op_id][1][0]
909
+ consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]]
910
+ if all(predict_net.op[consumer].type == "FC" for consumer in consumers):
911
+ # safety check if the sub-graph is isolated, for this reshape sub-graph,
912
+ # it means it has one non-param external input and one external output.
913
+ ext_inputs, ext_outputs = get_sub_graph_external_input_output(
914
+ predict_net, reshape_sub_graph
915
+ )
916
+ non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
917
+ if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1:
918
+ sub_graphs_to_remove.append(reshape_sub_graph)
919
+
920
+ # perform removing subgraph by:
921
+ # 1: rename the Reshape's output to its input, then the graph can be
922
+ # seen as in-place itentify, meaning whose external input/output are the same.
923
+ # 2: simply remove those ops.
924
+ remove_op_ids = []
925
+ params_to_remove = []
926
+ for sub_graph in sub_graphs_to_remove:
927
+ logger.info(
928
+ "Remove Reshape sub-graph:\n{}".format(
929
+ "".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph])
930
+ )
931
+ )
932
+ reshape_op_id = sub_graph[-1]
933
+ new_reshap_output = predict_net.op[reshape_op_id].input[0]
934
+ rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output)
935
+ ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph)
936
+ non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
937
+ params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0]
938
+ assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1
939
+ assert ext_outputs[0][0] == non_params_ext_inputs[0][0]
940
+ assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1
941
+ remove_op_ids.extend(sub_graph)
942
+ params_to_remove.extend(params_ext_inputs)
943
+
944
+ predict_net = copy.deepcopy(predict_net)
945
+ new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids]
946
+ del predict_net.op[:]
947
+ predict_net.op.extend(new_ops)
948
+ for versioned_params in params_to_remove:
949
+ name = versioned_params[0]
950
+ logger.info("Remove params: {} from init_net and predict_net.external_input".format(name))
951
+ del params[name]
952
+ predict_net.external_input.remove(name)
953
+
954
+ return predict_net, params
955
+
956
+
957
+ def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef):
958
+ """
959
+ In-place fuse extra copy ops between cpu/gpu for the following case:
960
+ a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1
961
+ -CopyBToA> c2 -NextOp2-> d2
962
+ The fused network will look like:
963
+ a -NextOp1-> d1
964
+ -NextOp2-> d2
965
+ """
966
+
967
+ _COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"]
968
+
969
+ def _fuse_once(predict_net):
970
+ ssa, blob_versions = core.get_ssa(predict_net)
971
+ consumer_map = get_consumer_map(ssa)
972
+ versioned_external_output = [
973
+ (name, blob_versions[name]) for name in predict_net.external_output
974
+ ]
975
+
976
+ for op_id, op in enumerate(predict_net.op):
977
+ if op.type in _COPY_OPS:
978
+ fw_copy_versioned_output = ssa[op_id][1][0]
979
+ consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]]
980
+ reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)]
981
+
982
+ is_fusable = (
983
+ len(consumer_ids) > 0
984
+ and fw_copy_versioned_output not in versioned_external_output
985
+ and all(
986
+ predict_net.op[_op_id].type == reverse_op_type
987
+ and ssa[_op_id][1][0] not in versioned_external_output
988
+ for _op_id in consumer_ids
989
+ )
990
+ )
991
+
992
+ if is_fusable:
993
+ for rv_copy_op_id in consumer_ids:
994
+ # making each NextOp uses "a" directly and removing Copy ops
995
+ rs_copy_versioned_output = ssa[rv_copy_op_id][1][0]
996
+ next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0]
997
+ predict_net.op[next_op_id].input[inp_id] = op.input[0]
998
+ # remove CopyOps
999
+ new_ops = [
1000
+ op
1001
+ for i, op in enumerate(predict_net.op)
1002
+ if i != op_id and i not in consumer_ids
1003
+ ]
1004
+ del predict_net.op[:]
1005
+ predict_net.op.extend(new_ops)
1006
+ return True
1007
+
1008
+ return False
1009
+
1010
+ # _fuse_once returns False is nothing can be fused
1011
+ while _fuse_once(predict_net):
1012
+ pass
1013
+
1014
+
1015
+ def remove_dead_end_ops(net_def: caffe2_pb2.NetDef):
1016
+ """remove ops if its output is not used or not in external_output"""
1017
+ ssa, versions = core.get_ssa(net_def)
1018
+ versioned_external_output = [(name, versions[name]) for name in net_def.external_output]
1019
+ consumer_map = get_consumer_map(ssa)
1020
+ removed_op_ids = set()
1021
+
1022
+ def _is_dead_end(versioned_blob):
1023
+ return not (
1024
+ versioned_blob in versioned_external_output
1025
+ or (
1026
+ len(consumer_map[versioned_blob]) > 0
1027
+ and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob])
1028
+ )
1029
+ )
1030
+
1031
+ for i, ssa_i in reversed(list(enumerate(ssa))):
1032
+ versioned_outputs = ssa_i[1]
1033
+ if all(_is_dead_end(outp) for outp in versioned_outputs):
1034
+ removed_op_ids.add(i)
1035
+
1036
+ # simply removing those deadend ops should have no effect to external_output
1037
+ new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids]
1038
+ del net_def.op[:]
1039
+ net_def.op.extend(new_ops)
Leffa/3rdparty/detectron2/export/torchscript.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import os
4
+ import torch
5
+
6
+ from detectron2.utils.file_io import PathManager
7
+
8
+ from .torchscript_patch import freeze_training_mode, patch_instances
9
+
10
+ __all__ = ["scripting_with_instances", "dump_torchscript_IR"]
11
+
12
+
13
+ def scripting_with_instances(model, fields):
14
+ """
15
+ Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
16
+ attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
17
+ for scripting to support it out of the box. This function is made to support scripting
18
+ a model that uses :class:`Instances`. It does the following:
19
+
20
+ 1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
21
+ but with all attributes been "static".
22
+ The attributes need to be statically declared in the ``fields`` argument.
23
+ 2. Register ``new_Instances``, and force scripting compiler to
24
+ use it when trying to compile ``Instances``.
25
+
26
+ After this function, the process will be reverted. User should be able to script another model
27
+ using different fields.
28
+
29
+ Example:
30
+ Assume that ``Instances`` in the model consist of two attributes named
31
+ ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
32
+ :class:`Tensor` respectively during inference. You can call this function like:
33
+ ::
34
+ fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
35
+ torchscipt_model = scripting_with_instances(model, fields)
36
+
37
+ Note:
38
+ It only support models in evaluation mode.
39
+
40
+ Args:
41
+ model (nn.Module): The input model to be exported by scripting.
42
+ fields (Dict[str, type]): Attribute names and corresponding type that
43
+ ``Instances`` will use in the model. Note that all attributes used in ``Instances``
44
+ need to be added, regardless of whether they are inputs/outputs of the model.
45
+ Data type not defined in detectron2 is not supported for now.
46
+
47
+ Returns:
48
+ torch.jit.ScriptModule: the model in torchscript format
49
+ """
50
+ assert (
51
+ not model.training
52
+ ), "Currently we only support exporting models in evaluation mode to torchscript"
53
+
54
+ with freeze_training_mode(model), patch_instances(fields):
55
+ scripted_model = torch.jit.script(model)
56
+ return scripted_model
57
+
58
+
59
+ # alias for old name
60
+ export_torchscript_with_instances = scripting_with_instances
61
+
62
+
63
+ def dump_torchscript_IR(model, dir):
64
+ """
65
+ Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph,
66
+ inlined graph). Useful for debugging.
67
+
68
+ Args:
69
+ model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
70
+ dir (str): output directory to dump files.
71
+ """
72
+ dir = os.path.expanduser(dir)
73
+ PathManager.mkdirs(dir)
74
+
75
+ def _get_script_mod(mod):
76
+ if isinstance(mod, torch.jit.TracedModule):
77
+ return mod._actual_script_module
78
+ return mod
79
+
80
+ # Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
81
+ with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:
82
+
83
+ def get_code(mod):
84
+ # Try a few ways to get code using private attributes.
85
+ try:
86
+ # This contains more information than just `mod.code`
87
+ return _get_script_mod(mod)._c.code
88
+ except AttributeError:
89
+ pass
90
+ try:
91
+ return mod.code
92
+ except AttributeError:
93
+ return None
94
+
95
+ def dump_code(prefix, mod):
96
+ code = get_code(mod)
97
+ name = prefix or "root model"
98
+ if code is None:
99
+ f.write(f"Could not found code for {name} (type={mod.original_name})\n")
100
+ f.write("\n")
101
+ else:
102
+ f.write(f"\nCode for {name}, type={mod.original_name}:\n")
103
+ f.write(code)
104
+ f.write("\n")
105
+ f.write("-" * 80)
106
+
107
+ for name, m in mod.named_children():
108
+ dump_code(prefix + "." + name, m)
109
+
110
+ if isinstance(model, torch.jit.ScriptFunction):
111
+ f.write(get_code(model))
112
+ else:
113
+ dump_code("", model)
114
+
115
+ def _get_graph(model):
116
+ try:
117
+ # Recursively dump IR of all modules
118
+ return _get_script_mod(model)._c.dump_to_str(True, False, False)
119
+ except AttributeError:
120
+ return model.graph.str()
121
+
122
+ with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
123
+ f.write(_get_graph(model))
124
+
125
+ # Dump IR of the entire graph (all submodules inlined)
126
+ with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
127
+ f.write(str(model.inlined_graph))
128
+
129
+ if not isinstance(model, torch.jit.ScriptFunction):
130
+ # Dump the model structure in pytorch style
131
+ with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
132
+ f.write(str(model))
Leffa/3rdparty/detectron2/export/torchscript_patch.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ from contextlib import ExitStack, contextmanager
7
+ from copy import deepcopy
8
+ from unittest import mock
9
+ import torch
10
+ from torch import nn
11
+
12
+ # need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964
13
+ import detectron2 # noqa F401
14
+ from detectron2.structures import Boxes, Instances
15
+ from detectron2.utils.env import _import_file
16
+
17
+ _counter = 0
18
+
19
+
20
+ def _clear_jit_cache():
21
+ from torch.jit._recursive import concrete_type_store
22
+ from torch.jit._state import _jit_caching_layer
23
+
24
+ concrete_type_store.type_store.clear() # for modules
25
+ _jit_caching_layer.clear() # for free functions
26
+
27
+
28
+ def _add_instances_conversion_methods(newInstances):
29
+ """
30
+ Add from_instances methods to the scripted Instances class.
31
+ """
32
+ cls_name = newInstances.__name__
33
+
34
+ @torch.jit.unused
35
+ def from_instances(instances: Instances):
36
+ """
37
+ Create scripted Instances from original Instances
38
+ """
39
+ fields = instances.get_fields()
40
+ image_size = instances.image_size
41
+ ret = newInstances(image_size)
42
+ for name, val in fields.items():
43
+ assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}"
44
+ setattr(ret, name, deepcopy(val))
45
+ return ret
46
+
47
+ newInstances.from_instances = from_instances
48
+
49
+
50
+ @contextmanager
51
+ def patch_instances(fields):
52
+ """
53
+ A contextmanager, under which the Instances class in detectron2 is replaced
54
+ by a statically-typed scriptable class, defined by `fields`.
55
+ See more in `scripting_with_instances`.
56
+ """
57
+
58
+ with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
59
+ mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
60
+ ) as f:
61
+ try:
62
+ # Objects that use Instances should not reuse previously-compiled
63
+ # results in cache, because `Instances` could be a new class each time.
64
+ _clear_jit_cache()
65
+
66
+ cls_name, s = _gen_instance_module(fields)
67
+ f.write(s)
68
+ f.flush()
69
+ f.close()
70
+
71
+ module = _import(f.name)
72
+ new_instances = getattr(module, cls_name)
73
+ _ = torch.jit.script(new_instances)
74
+ # let torchscript think Instances was scripted already
75
+ Instances.__torch_script_class__ = True
76
+ # let torchscript find new_instances when looking for the jit type of Instances
77
+ Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
78
+
79
+ _add_instances_conversion_methods(new_instances)
80
+ yield new_instances
81
+ finally:
82
+ try:
83
+ del Instances.__torch_script_class__
84
+ del Instances._jit_override_qualname
85
+ except AttributeError:
86
+ pass
87
+ sys.modules.pop(module.__name__)
88
+
89
+
90
+ def _gen_instance_class(fields):
91
+ """
92
+ Args:
93
+ fields (dict[name: type])
94
+ """
95
+
96
+ class _FieldType:
97
+ def __init__(self, name, type_):
98
+ assert isinstance(name, str), f"Field name must be str, got {name}"
99
+ self.name = name
100
+ self.type_ = type_
101
+ self.annotation = f"{type_.__module__}.{type_.__name__}"
102
+
103
+ fields = [_FieldType(k, v) for k, v in fields.items()]
104
+
105
+ def indent(level, s):
106
+ return " " * 4 * level + s
107
+
108
+ lines = []
109
+
110
+ global _counter
111
+ _counter += 1
112
+
113
+ cls_name = "ScriptedInstances{}".format(_counter)
114
+
115
+ field_names = tuple(x.name for x in fields)
116
+ extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields])
117
+ lines.append(
118
+ f"""
119
+ class {cls_name}:
120
+ def __init__(self, image_size: Tuple[int, int], {extra_args}):
121
+ self.image_size = image_size
122
+ self._field_names = {field_names}
123
+ """
124
+ )
125
+
126
+ for f in fields:
127
+ lines.append(
128
+ indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})")
129
+ )
130
+
131
+ for f in fields:
132
+ lines.append(
133
+ f"""
134
+ @property
135
+ def {f.name}(self) -> {f.annotation}:
136
+ # has to use a local for type refinement
137
+ # https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
138
+ t = self._{f.name}
139
+ assert t is not None, "{f.name} is None and cannot be accessed!"
140
+ return t
141
+
142
+ @{f.name}.setter
143
+ def {f.name}(self, value: {f.annotation}) -> None:
144
+ self._{f.name} = value
145
+ """
146
+ )
147
+
148
+ # support method `__len__`
149
+ lines.append(
150
+ """
151
+ def __len__(self) -> int:
152
+ """
153
+ )
154
+ for f in fields:
155
+ lines.append(
156
+ f"""
157
+ t = self._{f.name}
158
+ if t is not None:
159
+ return len(t)
160
+ """
161
+ )
162
+ lines.append(
163
+ """
164
+ raise NotImplementedError("Empty Instances does not support __len__!")
165
+ """
166
+ )
167
+
168
+ # support method `has`
169
+ lines.append(
170
+ """
171
+ def has(self, name: str) -> bool:
172
+ """
173
+ )
174
+ for f in fields:
175
+ lines.append(
176
+ f"""
177
+ if name == "{f.name}":
178
+ return self._{f.name} is not None
179
+ """
180
+ )
181
+ lines.append(
182
+ """
183
+ return False
184
+ """
185
+ )
186
+
187
+ # support method `to`
188
+ none_args = ", None" * len(fields)
189
+ lines.append(
190
+ f"""
191
+ def to(self, device: torch.device) -> "{cls_name}":
192
+ ret = {cls_name}(self.image_size{none_args})
193
+ """
194
+ )
195
+ for f in fields:
196
+ if hasattr(f.type_, "to"):
197
+ lines.append(
198
+ f"""
199
+ t = self._{f.name}
200
+ if t is not None:
201
+ ret._{f.name} = t.to(device)
202
+ """
203
+ )
204
+ else:
205
+ # For now, ignore fields that cannot be moved to devices.
206
+ # Maybe can support other tensor-like classes (e.g. __torch_function__)
207
+ pass
208
+ lines.append(
209
+ """
210
+ return ret
211
+ """
212
+ )
213
+
214
+ # support method `getitem`
215
+ none_args = ", None" * len(fields)
216
+ lines.append(
217
+ f"""
218
+ def __getitem__(self, item) -> "{cls_name}":
219
+ ret = {cls_name}(self.image_size{none_args})
220
+ """
221
+ )
222
+ for f in fields:
223
+ lines.append(
224
+ f"""
225
+ t = self._{f.name}
226
+ if t is not None:
227
+ ret._{f.name} = t[item]
228
+ """
229
+ )
230
+ lines.append(
231
+ """
232
+ return ret
233
+ """
234
+ )
235
+
236
+ # support method `cat`
237
+ # this version does not contain checks that all instances have same size and fields
238
+ none_args = ", None" * len(fields)
239
+ lines.append(
240
+ f"""
241
+ def cat(self, instances: List["{cls_name}"]) -> "{cls_name}":
242
+ ret = {cls_name}(self.image_size{none_args})
243
+ """
244
+ )
245
+ for f in fields:
246
+ lines.append(
247
+ f"""
248
+ t = self._{f.name}
249
+ if t is not None:
250
+ values: List[{f.annotation}] = [x.{f.name} for x in instances]
251
+ if torch.jit.isinstance(t, torch.Tensor):
252
+ ret._{f.name} = torch.cat(values, dim=0)
253
+ else:
254
+ ret._{f.name} = t.cat(values)
255
+ """
256
+ )
257
+ lines.append(
258
+ """
259
+ return ret"""
260
+ )
261
+
262
+ # support method `get_fields()`
263
+ lines.append(
264
+ """
265
+ def get_fields(self) -> Dict[str, Tensor]:
266
+ ret = {}
267
+ """
268
+ )
269
+ for f in fields:
270
+ if f.type_ == Boxes:
271
+ stmt = "t.tensor"
272
+ elif f.type_ == torch.Tensor:
273
+ stmt = "t"
274
+ else:
275
+ stmt = f'assert False, "unsupported type {str(f.type_)}"'
276
+ lines.append(
277
+ f"""
278
+ t = self._{f.name}
279
+ if t is not None:
280
+ ret["{f.name}"] = {stmt}
281
+ """
282
+ )
283
+ lines.append(
284
+ """
285
+ return ret"""
286
+ )
287
+ return cls_name, os.linesep.join(lines)
288
+
289
+
290
+ def _gen_instance_module(fields):
291
+ # TODO: find a more automatic way to enable import of other classes
292
+ s = """
293
+ from copy import deepcopy
294
+ import torch
295
+ from torch import Tensor
296
+ import typing
297
+ from typing import *
298
+
299
+ import detectron2
300
+ from detectron2.structures import Boxes, Instances
301
+
302
+ """
303
+
304
+ cls_name, cls_def = _gen_instance_class(fields)
305
+ s += cls_def
306
+ return cls_name, s
307
+
308
+
309
+ def _import(path):
310
+ return _import_file(
311
+ "{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True
312
+ )
313
+
314
+
315
+ @contextmanager
316
+ def patch_builtin_len(modules=()):
317
+ """
318
+ Patch the builtin len() function of a few detectron2 modules
319
+ to use __len__ instead, because __len__ does not convert values to
320
+ integers and therefore is friendly to tracing.
321
+
322
+ Args:
323
+ modules (list[stsr]): names of extra modules to patch len(), in
324
+ addition to those in detectron2.
325
+ """
326
+
327
+ def _new_len(obj):
328
+ return obj.__len__()
329
+
330
+ with ExitStack() as stack:
331
+ MODULES = [
332
+ "detectron2.modeling.roi_heads.fast_rcnn",
333
+ "detectron2.modeling.roi_heads.mask_head",
334
+ "detectron2.modeling.roi_heads.keypoint_head",
335
+ ] + list(modules)
336
+ ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES]
337
+ for m in ctxs:
338
+ m.side_effect = _new_len
339
+ yield
340
+
341
+
342
+ def patch_nonscriptable_classes():
343
+ """
344
+ Apply patches on a few nonscriptable detectron2 classes.
345
+ Should not have side-effects on eager usage.
346
+ """
347
+ # __prepare_scriptable__ can also be added to models for easier maintenance.
348
+ # But it complicates the clean model code.
349
+
350
+ from detectron2.modeling.backbone import ResNet, FPN
351
+
352
+ # Due to https://github.com/pytorch/pytorch/issues/36061,
353
+ # we change backbone to use ModuleList for scripting.
354
+ # (note: this changes param names in state_dict)
355
+
356
+ def prepare_resnet(self):
357
+ ret = deepcopy(self)
358
+ ret.stages = nn.ModuleList(ret.stages)
359
+ for k in self.stage_names:
360
+ delattr(ret, k)
361
+ return ret
362
+
363
+ ResNet.__prepare_scriptable__ = prepare_resnet
364
+
365
+ def prepare_fpn(self):
366
+ ret = deepcopy(self)
367
+ ret.lateral_convs = nn.ModuleList(ret.lateral_convs)
368
+ ret.output_convs = nn.ModuleList(ret.output_convs)
369
+ for name, _ in self.named_children():
370
+ if name.startswith("fpn_"):
371
+ delattr(ret, name)
372
+ return ret
373
+
374
+ FPN.__prepare_scriptable__ = prepare_fpn
375
+
376
+ # Annotate some attributes to be constants for the purpose of scripting,
377
+ # even though they are not constants in eager mode.
378
+ from detectron2.modeling.roi_heads import StandardROIHeads
379
+
380
+ if hasattr(StandardROIHeads, "__annotations__"):
381
+ # copy first to avoid editing annotations of base class
382
+ StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__)
383
+ StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool]
384
+ StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool]
385
+
386
+
387
+ # These patches are not supposed to have side-effects.
388
+ patch_nonscriptable_classes()
389
+
390
+
391
+ @contextmanager
392
+ def freeze_training_mode(model):
393
+ """
394
+ A context manager that annotates the "training" attribute of every submodule
395
+ to constant, so that the training codepath in these modules can be
396
+ meta-compiled away. Upon exiting, the annotations are reverted.
397
+ """
398
+ classes = {type(x) for x in model.modules()}
399
+ # __constants__ is the old way to annotate constants and not compatible
400
+ # with __annotations__ .
401
+ classes = {x for x in classes if not hasattr(x, "__constants__")}
402
+ for cls in classes:
403
+ cls.__annotations__["training"] = torch.jit.Final[bool]
404
+ yield
405
+ for cls in classes:
406
+ cls.__annotations__["training"] = bool
Leffa/SCHP/__init__.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from SCHP import networks
8
+ from SCHP.utils.transforms import get_affine_transform, transform_logits
9
+ from torchvision import transforms
10
+
11
+
12
+ def get_palette(num_cls):
13
+ """Returns the color map for visualizing the segmentation mask.
14
+ Args:
15
+ num_cls: Number of classes
16
+ Returns:
17
+ The color map
18
+ """
19
+ n = num_cls
20
+ palette = [0] * (n * 3)
21
+ for j in range(0, n):
22
+ lab = j
23
+ palette[j * 3 + 0] = 0
24
+ palette[j * 3 + 1] = 0
25
+ palette[j * 3 + 2] = 0
26
+ i = 0
27
+ while lab:
28
+ palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
29
+ palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
30
+ palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
31
+ i += 1
32
+ lab >>= 3
33
+ return palette
34
+
35
+
36
+ dataset_settings = {
37
+ "lip": {
38
+ "input_size": [473, 473],
39
+ "num_classes": 20,
40
+ "label": [
41
+ "Background",
42
+ "Hat",
43
+ "Hair",
44
+ "Glove",
45
+ "Sunglasses",
46
+ "Upper-clothes",
47
+ "Dress",
48
+ "Coat",
49
+ "Socks",
50
+ "Pants",
51
+ "Jumpsuits",
52
+ "Scarf",
53
+ "Skirt",
54
+ "Face",
55
+ "Left-arm",
56
+ "Right-arm",
57
+ "Left-leg",
58
+ "Right-leg",
59
+ "Left-shoe",
60
+ "Right-shoe",
61
+ ],
62
+ },
63
+ "atr": {
64
+ "input_size": [512, 512],
65
+ "num_classes": 18,
66
+ "label": [
67
+ "Background",
68
+ "Hat",
69
+ "Hair",
70
+ "Sunglasses",
71
+ "Upper-clothes",
72
+ "Skirt",
73
+ "Pants",
74
+ "Dress",
75
+ "Belt",
76
+ "Left-shoe",
77
+ "Right-shoe",
78
+ "Face",
79
+ "Left-leg",
80
+ "Right-leg",
81
+ "Left-arm",
82
+ "Right-arm",
83
+ "Bag",
84
+ "Scarf",
85
+ ],
86
+ },
87
+ "pascal": {
88
+ "input_size": [512, 512],
89
+ "num_classes": 7,
90
+ "label": [
91
+ "Background",
92
+ "Head",
93
+ "Torso",
94
+ "Upper Arms",
95
+ "Lower Arms",
96
+ "Upper Legs",
97
+ "Lower Legs",
98
+ ],
99
+ },
100
+ }
101
+
102
+
103
+ class SCHP:
104
+ def __init__(self, ckpt_path, device):
105
+ dataset_type = None
106
+ if "lip" in ckpt_path:
107
+ dataset_type = "lip"
108
+ elif "atr" in ckpt_path:
109
+ dataset_type = "atr"
110
+ elif "pascal" in ckpt_path:
111
+ dataset_type = "pascal"
112
+ assert dataset_type is not None, "Dataset type not found in checkpoint path"
113
+ self.device = device
114
+ self.num_classes = dataset_settings[dataset_type]["num_classes"]
115
+ self.input_size = dataset_settings[dataset_type]["input_size"]
116
+ self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0]
117
+ self.palette = get_palette(self.num_classes)
118
+
119
+ self.label = dataset_settings[dataset_type]["label"]
120
+ self.model = networks.init_model(
121
+ "resnet101", num_classes=self.num_classes, pretrained=None
122
+ ).to(device)
123
+ self.load_ckpt(ckpt_path)
124
+ self.model.eval()
125
+
126
+ self.transform = transforms.Compose(
127
+ [
128
+ transforms.ToTensor(),
129
+ transforms.Normalize(
130
+ mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]
131
+ ),
132
+ ]
133
+ )
134
+ self.upsample = torch.nn.Upsample(
135
+ size=self.input_size, mode="bilinear", align_corners=True
136
+ )
137
+
138
+ def load_ckpt(self, ckpt_path):
139
+ rename_map = {
140
+ "decoder.conv3.2.weight": "decoder.conv3.3.weight",
141
+ "decoder.conv3.3.weight": "decoder.conv3.4.weight",
142
+ "decoder.conv3.3.bias": "decoder.conv3.4.bias",
143
+ "decoder.conv3.3.running_mean": "decoder.conv3.4.running_mean",
144
+ "decoder.conv3.3.running_var": "decoder.conv3.4.running_var",
145
+ "fushion.3.weight": "fushion.4.weight",
146
+ "fushion.3.bias": "fushion.4.bias",
147
+ }
148
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
149
+ new_state_dict = OrderedDict()
150
+ for k, v in state_dict.items():
151
+ name = k[7:] # remove `module.`
152
+ new_state_dict[name] = v
153
+ new_state_dict_ = OrderedDict()
154
+ for k, v in list(new_state_dict.items()):
155
+ if k in rename_map:
156
+ new_state_dict_[rename_map[k]] = v
157
+ else:
158
+ new_state_dict_[k] = v
159
+ self.model.load_state_dict(new_state_dict_, strict=False)
160
+
161
+ def _box2cs(self, box):
162
+ x, y, w, h = box[:4]
163
+ return self._xywh2cs(x, y, w, h)
164
+
165
+ def _xywh2cs(self, x, y, w, h):
166
+ center = np.zeros((2), dtype=np.float32)
167
+ center[0] = x + w * 0.5
168
+ center[1] = y + h * 0.5
169
+ if w > self.aspect_ratio * h:
170
+ h = w * 1.0 / self.aspect_ratio
171
+ elif w < self.aspect_ratio * h:
172
+ w = h * self.aspect_ratio
173
+ scale = np.array([w, h], dtype=np.float32)
174
+ return center, scale
175
+
176
+ def preprocess(self, image):
177
+ if isinstance(image, str):
178
+ img = cv2.imread(image, cv2.IMREAD_COLOR)
179
+ elif isinstance(image, Image.Image):
180
+ # to cv2 format
181
+ img = np.array(image)
182
+
183
+ h, w, _ = img.shape
184
+ # Get person center and scale
185
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
186
+ r = 0
187
+ trans = get_affine_transform(person_center, s, r, self.input_size)
188
+ input = cv2.warpAffine(
189
+ img,
190
+ trans,
191
+ (int(self.input_size[1]), int(self.input_size[0])),
192
+ flags=cv2.INTER_LINEAR,
193
+ borderMode=cv2.BORDER_CONSTANT,
194
+ borderValue=(0, 0, 0),
195
+ )
196
+
197
+ input = self.transform(input).to(self.device).unsqueeze(0)
198
+ meta = {
199
+ "center": person_center,
200
+ "height": h,
201
+ "width": w,
202
+ "scale": s,
203
+ "rotation": r,
204
+ }
205
+ return input, meta
206
+
207
+ def __call__(self, image_or_path):
208
+ if isinstance(image_or_path, list):
209
+ image_list = []
210
+ meta_list = []
211
+ for image in image_or_path:
212
+ image, meta = self.preprocess(image)
213
+ image_list.append(image)
214
+ meta_list.append(meta)
215
+ image = torch.cat(image_list, dim=0)
216
+ else:
217
+ image, meta = self.preprocess(image_or_path)
218
+ meta_list = [meta]
219
+
220
+ output = self.model(image)
221
+ # upsample_outputs = self.upsample(output[0][-1])
222
+ upsample_outputs = self.upsample(output)
223
+ upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
224
+
225
+ output_img_list = []
226
+ for upsample_output, meta in zip(upsample_outputs, meta_list):
227
+ c, s, w, h = meta["center"], meta["scale"], meta["width"], meta["height"]
228
+ logits_result = transform_logits(
229
+ upsample_output.data.cpu().numpy(),
230
+ c,
231
+ s,
232
+ w,
233
+ h,
234
+ input_size=self.input_size,
235
+ )
236
+ parsing_result = np.argmax(logits_result, axis=2)
237
+ output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
238
+ output_img.putpalette(self.palette)
239
+ output_img_list.append(output_img)
240
+
241
+ return output_img_list[0] if len(output_img_list) == 1 else output_img_list
Leffa/SCHP/networks/AugmentCE2P.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : AugmentCE2P.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from torch.nn import BatchNorm2d, functional as F, LeakyReLU
18
+
19
+ affine_par = True
20
+ pretrained_settings = {
21
+ "resnet101": {
22
+ "imagenet": {
23
+ "input_space": "BGR",
24
+ "input_size": [3, 224, 224],
25
+ "input_range": [0, 1],
26
+ "mean": [0.406, 0.456, 0.485],
27
+ "std": [0.225, 0.224, 0.229],
28
+ "num_classes": 1000,
29
+ }
30
+ },
31
+ }
32
+
33
+
34
+ def conv3x3(in_planes, out_planes, stride=1):
35
+ "3x3 convolution with padding"
36
+ return nn.Conv2d(
37
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
38
+ )
39
+
40
+
41
+ class Bottleneck(nn.Module):
42
+ expansion = 4
43
+
44
+ def __init__(
45
+ self,
46
+ inplanes,
47
+ planes,
48
+ stride=1,
49
+ dilation=1,
50
+ downsample=None,
51
+ fist_dilation=1,
52
+ multi_grid=1,
53
+ ):
54
+ super(Bottleneck, self).__init__()
55
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
56
+ self.bn1 = BatchNorm2d(planes)
57
+ self.conv2 = nn.Conv2d(
58
+ planes,
59
+ planes,
60
+ kernel_size=3,
61
+ stride=stride,
62
+ padding=dilation * multi_grid,
63
+ dilation=dilation * multi_grid,
64
+ bias=False,
65
+ )
66
+ self.bn2 = BatchNorm2d(planes)
67
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68
+ self.bn3 = BatchNorm2d(planes * 4)
69
+ self.relu = nn.ReLU(inplace=False)
70
+ self.relu_inplace = nn.ReLU(inplace=True)
71
+ self.downsample = downsample
72
+ self.dilation = dilation
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out = out + residual
93
+ out = self.relu_inplace(out)
94
+
95
+ return out
96
+
97
+
98
+ class PSPModule(nn.Module):
99
+ """
100
+ Reference:
101
+ Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
102
+ """
103
+
104
+ def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
105
+ super(PSPModule, self).__init__()
106
+
107
+ self.stages = []
108
+ self.stages = nn.ModuleList(
109
+ [self._make_stage(features, out_features, size) for size in sizes]
110
+ )
111
+ self.bottleneck = nn.Sequential(
112
+ nn.Conv2d(
113
+ features + len(sizes) * out_features,
114
+ out_features,
115
+ kernel_size=3,
116
+ padding=1,
117
+ dilation=1,
118
+ bias=False,
119
+ ),
120
+ BatchNorm2d(out_features),
121
+ LeakyReLU(),
122
+ )
123
+
124
+ def _make_stage(self, features, out_features, size):
125
+ prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
126
+ conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
127
+ return nn.Sequential(
128
+ prior,
129
+ conv,
130
+ # bn
131
+ BatchNorm2d(out_features),
132
+ LeakyReLU(),
133
+ )
134
+
135
+ def forward(self, feats):
136
+ h, w = feats.size(2), feats.size(3)
137
+ priors = [
138
+ F.interpolate(
139
+ input=stage(feats), size=(h, w), mode="bilinear", align_corners=True
140
+ )
141
+ for stage in self.stages
142
+ ] + [feats]
143
+ bottle = self.bottleneck(torch.cat(priors, 1))
144
+ return bottle
145
+
146
+
147
+ class ASPPModule(nn.Module):
148
+ """
149
+ Reference:
150
+ Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
151
+ """
152
+
153
+ def __init__(
154
+ self, features, inner_features=256, out_features=512, dilations=(12, 24, 36)
155
+ ):
156
+ super(ASPPModule, self).__init__()
157
+
158
+ self.conv1 = nn.Sequential(
159
+ nn.AdaptiveAvgPool2d((1, 1)),
160
+ nn.Conv2d(
161
+ features,
162
+ inner_features,
163
+ kernel_size=1,
164
+ padding=0,
165
+ dilation=1,
166
+ bias=False,
167
+ ),
168
+ # InPlaceABNSync(inner_features)
169
+ BatchNorm2d(inner_features),
170
+ LeakyReLU(),
171
+ )
172
+ self.conv2 = nn.Sequential(
173
+ nn.Conv2d(
174
+ features,
175
+ inner_features,
176
+ kernel_size=1,
177
+ padding=0,
178
+ dilation=1,
179
+ bias=False,
180
+ ),
181
+ BatchNorm2d(inner_features),
182
+ LeakyReLU(),
183
+ )
184
+ self.conv3 = nn.Sequential(
185
+ nn.Conv2d(
186
+ features,
187
+ inner_features,
188
+ kernel_size=3,
189
+ padding=dilations[0],
190
+ dilation=dilations[0],
191
+ bias=False,
192
+ ),
193
+ BatchNorm2d(inner_features),
194
+ LeakyReLU(),
195
+ )
196
+ self.conv4 = nn.Sequential(
197
+ nn.Conv2d(
198
+ features,
199
+ inner_features,
200
+ kernel_size=3,
201
+ padding=dilations[1],
202
+ dilation=dilations[1],
203
+ bias=False,
204
+ ),
205
+ BatchNorm2d(inner_features),
206
+ LeakyReLU(),
207
+ )
208
+ self.conv5 = nn.Sequential(
209
+ nn.Conv2d(
210
+ features,
211
+ inner_features,
212
+ kernel_size=3,
213
+ padding=dilations[2],
214
+ dilation=dilations[2],
215
+ bias=False,
216
+ ),
217
+ BatchNorm2d(inner_features),
218
+ LeakyReLU(),
219
+ )
220
+
221
+ self.bottleneck = nn.Sequential(
222
+ nn.Conv2d(
223
+ inner_features * 5,
224
+ out_features,
225
+ kernel_size=1,
226
+ padding=0,
227
+ dilation=1,
228
+ bias=False,
229
+ ),
230
+ BatchNorm2d(inner_features),
231
+ LeakyReLU(),
232
+ nn.Dropout2d(0.1),
233
+ )
234
+
235
+ def forward(self, x):
236
+ _, _, h, w = x.size()
237
+
238
+ feat1 = F.interpolate(
239
+ self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
240
+ )
241
+
242
+ feat2 = self.conv2(x)
243
+ feat3 = self.conv3(x)
244
+ feat4 = self.conv4(x)
245
+ feat5 = self.conv5(x)
246
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
247
+
248
+ bottle = self.bottleneck(out)
249
+ return bottle
250
+
251
+
252
+ class Edge_Module(nn.Module):
253
+ """
254
+ Edge Learning Branch
255
+ """
256
+
257
+ def __init__(self, in_fea=[256, 512, 1024], mid_fea=256, out_fea=2):
258
+ super(Edge_Module, self).__init__()
259
+
260
+ self.conv1 = nn.Sequential(
261
+ nn.Conv2d(
262
+ in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
263
+ ),
264
+ BatchNorm2d(mid_fea),
265
+ LeakyReLU(),
266
+ )
267
+ self.conv2 = nn.Sequential(
268
+ nn.Conv2d(
269
+ in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
270
+ ),
271
+ BatchNorm2d(mid_fea),
272
+ LeakyReLU(),
273
+ )
274
+ self.conv3 = nn.Sequential(
275
+ nn.Conv2d(
276
+ in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
277
+ ),
278
+ BatchNorm2d(mid_fea),
279
+ LeakyReLU(),
280
+ )
281
+ self.conv4 = nn.Conv2d(
282
+ mid_fea, out_fea, kernel_size=3, padding=1, dilation=1, bias=True
283
+ )
284
+ # self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
285
+
286
+ def forward(self, x1, x2, x3):
287
+ _, _, h, w = x1.size()
288
+
289
+ edge1_fea = self.conv1(x1)
290
+ # edge1 = self.conv4(edge1_fea)
291
+ edge2_fea = self.conv2(x2)
292
+ edge2 = self.conv4(edge2_fea)
293
+ edge3_fea = self.conv3(x3)
294
+ edge3 = self.conv4(edge3_fea)
295
+
296
+ edge2_fea = F.interpolate(
297
+ edge2_fea, size=(h, w), mode="bilinear", align_corners=True
298
+ )
299
+ edge3_fea = F.interpolate(
300
+ edge3_fea, size=(h, w), mode="bilinear", align_corners=True
301
+ )
302
+ edge2 = F.interpolate(edge2, size=(h, w), mode="bilinear", align_corners=True)
303
+ edge3 = F.interpolate(edge3, size=(h, w), mode="bilinear", align_corners=True)
304
+
305
+ # edge = torch.cat([edge1, edge2, edge3], dim=1)
306
+ edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
307
+ # edge = self.conv5(edge)
308
+
309
+ # return edge, edge_fea
310
+ return edge_fea
311
+
312
+
313
+ class Decoder_Module(nn.Module):
314
+ """
315
+ Parsing Branch Decoder Module.
316
+ """
317
+
318
+ def __init__(self, num_classes):
319
+ super(Decoder_Module, self).__init__()
320
+ self.conv1 = nn.Sequential(
321
+ nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False),
322
+ BatchNorm2d(256),
323
+ LeakyReLU(),
324
+ )
325
+ self.conv2 = nn.Sequential(
326
+ nn.Conv2d(
327
+ 256, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False
328
+ ),
329
+ BatchNorm2d(48),
330
+ LeakyReLU(),
331
+ )
332
+ self.conv3 = nn.Sequential(
333
+ nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
334
+ BatchNorm2d(256),
335
+ LeakyReLU(),
336
+ nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
337
+ BatchNorm2d(256),
338
+ LeakyReLU(),
339
+ )
340
+
341
+ # self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
342
+
343
+ def forward(self, xt, xl):
344
+ _, _, h, w = xl.size()
345
+ xt = F.interpolate(
346
+ self.conv1(xt), size=(h, w), mode="bilinear", align_corners=True
347
+ )
348
+ xl = self.conv2(xl)
349
+ x = torch.cat([xt, xl], dim=1)
350
+ x = self.conv3(x)
351
+ # seg = self.conv4(x)
352
+ # return seg, x
353
+ return x
354
+
355
+
356
+ class ResNet(nn.Module):
357
+ def __init__(self, block, layers, num_classes):
358
+ self.inplanes = 128
359
+ super(ResNet, self).__init__()
360
+ self.conv1 = conv3x3(3, 64, stride=2)
361
+ self.bn1 = BatchNorm2d(64)
362
+ self.relu1 = nn.ReLU(inplace=False)
363
+ self.conv2 = conv3x3(64, 64)
364
+ self.bn2 = BatchNorm2d(64)
365
+ self.relu2 = nn.ReLU(inplace=False)
366
+ self.conv3 = conv3x3(64, 128)
367
+ self.bn3 = BatchNorm2d(128)
368
+ self.relu3 = nn.ReLU(inplace=False)
369
+
370
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
371
+
372
+ self.layer1 = self._make_layer(block, 64, layers[0])
373
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
374
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
375
+ self.layer4 = self._make_layer(
376
+ block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 1, 1)
377
+ )
378
+
379
+ self.context_encoding = PSPModule(2048, 512)
380
+
381
+ self.edge = Edge_Module()
382
+ self.decoder = Decoder_Module(num_classes)
383
+
384
+ self.fushion = nn.Sequential(
385
+ nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False),
386
+ BatchNorm2d(256),
387
+ LeakyReLU(),
388
+ nn.Dropout2d(0.1),
389
+ nn.Conv2d(
390
+ 256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True
391
+ ),
392
+ )
393
+
394
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
395
+ downsample = None
396
+ if stride != 1 or self.inplanes != planes * block.expansion:
397
+ downsample = nn.Sequential(
398
+ nn.Conv2d(
399
+ self.inplanes,
400
+ planes * block.expansion,
401
+ kernel_size=1,
402
+ stride=stride,
403
+ bias=False,
404
+ ),
405
+ BatchNorm2d(planes * block.expansion, affine=affine_par),
406
+ )
407
+
408
+ layers = []
409
+ generate_multi_grid = lambda index, grids: (
410
+ grids[index % len(grids)] if isinstance(grids, tuple) else 1
411
+ )
412
+ layers.append(
413
+ block(
414
+ self.inplanes,
415
+ planes,
416
+ stride,
417
+ dilation=dilation,
418
+ downsample=downsample,
419
+ multi_grid=generate_multi_grid(0, multi_grid),
420
+ )
421
+ )
422
+ self.inplanes = planes * block.expansion
423
+ for i in range(1, blocks):
424
+ layers.append(
425
+ block(
426
+ self.inplanes,
427
+ planes,
428
+ dilation=dilation,
429
+ multi_grid=generate_multi_grid(i, multi_grid),
430
+ )
431
+ )
432
+
433
+ return nn.Sequential(*layers)
434
+
435
+ def forward(self, x):
436
+ x = self.relu1(self.bn1(self.conv1(x)))
437
+ x = self.relu2(self.bn2(self.conv2(x)))
438
+ x = self.relu3(self.bn3(self.conv3(x)))
439
+ x = self.maxpool(x)
440
+ x2 = self.layer1(x)
441
+ x3 = self.layer2(x2)
442
+ x4 = self.layer3(x3)
443
+ x5 = self.layer4(x4)
444
+ x = self.context_encoding(x5)
445
+ # parsing_result, parsing_fea = self.decoder(x, x2)
446
+ parsing_fea = self.decoder(x, x2)
447
+ # Edge Branch
448
+ # edge_result, edge_fea = self.edge(x2, x3, x4)
449
+ edge_fea = self.edge(x2, x3, x4)
450
+ # Fusion Branch
451
+ x = torch.cat([parsing_fea, edge_fea], dim=1)
452
+ fusion_result = self.fushion(x)
453
+ # return [[parsing_result, fusion_result], [edge_result]]
454
+ return fusion_result
455
+
456
+
457
+ def initialize_pretrained_model(
458
+ model, settings, pretrained="./models/resnet101-imagenet.pth"
459
+ ):
460
+ model.input_space = settings["input_space"]
461
+ model.input_size = settings["input_size"]
462
+ model.input_range = settings["input_range"]
463
+ model.mean = settings["mean"]
464
+ model.std = settings["std"]
465
+
466
+ if pretrained is not None:
467
+ saved_state_dict = torch.load(pretrained)
468
+ new_params = model.state_dict().copy()
469
+ for i in saved_state_dict:
470
+ i_parts = i.split(".")
471
+ if not i_parts[0] == "fc":
472
+ new_params[".".join(i_parts[0:])] = saved_state_dict[i]
473
+ model.load_state_dict(new_params)
474
+
475
+
476
+ def resnet101(num_classes=20, pretrained="./models/resnet101-imagenet.pth"):
477
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
478
+ settings = pretrained_settings["resnet101"]["imagenet"]
479
+ initialize_pretrained_model(model, settings, pretrained)
480
+ return model
Leffa/SCHP/networks/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ from SCHP.networks.AugmentCE2P import resnet101
4
+
5
+ __factory = {
6
+ "resnet101": resnet101,
7
+ }
8
+
9
+
10
+ def init_model(name, *args, **kwargs):
11
+ if name not in __factory.keys():
12
+ raise KeyError("Unknown model arch: {}".format(name))
13
+ return __factory[name](*args, **kwargs)
Leffa/SCHP/utils/transforms.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import cv2
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ class BRG2Tensor_transform(object):
16
+ def __call__(self, pic):
17
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
18
+ if isinstance(img, torch.ByteTensor):
19
+ return img.float()
20
+ else:
21
+ return img
22
+
23
+
24
+ class BGR2RGB_transform(object):
25
+ def __call__(self, tensor):
26
+ return tensor[[2, 1, 0], :, :]
27
+
28
+
29
+ def flip_back(output_flipped, matched_parts):
30
+ """
31
+ ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
32
+ """
33
+ assert (
34
+ output_flipped.ndim == 4
35
+ ), "output_flipped should be [batch_size, num_joints, height, width]"
36
+
37
+ output_flipped = output_flipped[:, :, :, ::-1]
38
+
39
+ for pair in matched_parts:
40
+ tmp = output_flipped[:, pair[0], :, :].copy()
41
+ output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
42
+ output_flipped[:, pair[1], :, :] = tmp
43
+
44
+ return output_flipped
45
+
46
+
47
+ def fliplr_joints(joints, joints_vis, width, matched_parts):
48
+ """
49
+ flip coords
50
+ """
51
+ # Flip horizontal
52
+ joints[:, 0] = width - joints[:, 0] - 1
53
+
54
+ # Change left-right parts
55
+ for pair in matched_parts:
56
+ joints[pair[0], :], joints[pair[1], :] = (
57
+ joints[pair[1], :],
58
+ joints[pair[0], :].copy(),
59
+ )
60
+ joints_vis[pair[0], :], joints_vis[pair[1], :] = (
61
+ joints_vis[pair[1], :],
62
+ joints_vis[pair[0], :].copy(),
63
+ )
64
+
65
+ return joints * joints_vis, joints_vis
66
+
67
+
68
+ def transform_preds(coords, center, scale, input_size):
69
+ target_coords = np.zeros(coords.shape)
70
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
71
+ for p in range(coords.shape[0]):
72
+ target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
73
+ return target_coords
74
+
75
+
76
+ def transform_parsing(pred, center, scale, width, height, input_size):
77
+
78
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
79
+ target_pred = cv2.warpAffine(
80
+ pred,
81
+ trans,
82
+ (int(width), int(height)), # (int(width), int(height)),
83
+ flags=cv2.INTER_NEAREST,
84
+ borderMode=cv2.BORDER_CONSTANT,
85
+ borderValue=(0),
86
+ )
87
+
88
+ return target_pred
89
+
90
+
91
+ def transform_logits(logits, center, scale, width, height, input_size):
92
+
93
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
94
+ channel = logits.shape[2]
95
+ target_logits = []
96
+ for i in range(channel):
97
+ target_logit = cv2.warpAffine(
98
+ logits[:, :, i],
99
+ trans,
100
+ (int(width), int(height)), # (int(width), int(height)),
101
+ flags=cv2.INTER_LINEAR,
102
+ borderMode=cv2.BORDER_CONSTANT,
103
+ borderValue=(0),
104
+ )
105
+ target_logits.append(target_logit)
106
+ target_logits = np.stack(target_logits, axis=2)
107
+
108
+ return target_logits
109
+
110
+
111
+ def get_affine_transform(
112
+ center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0
113
+ ):
114
+ if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
115
+ print(scale)
116
+ scale = np.array([scale, scale])
117
+
118
+ scale_tmp = scale
119
+
120
+ src_w = scale_tmp[0]
121
+ dst_w = output_size[1]
122
+ dst_h = output_size[0]
123
+
124
+ rot_rad = np.pi * rot / 180
125
+ src_dir = get_dir([0, src_w * -0.5], rot_rad)
126
+ dst_dir = np.array([0, (dst_w - 1) * -0.5], np.float32)
127
+
128
+ src = np.zeros((3, 2), dtype=np.float32)
129
+ dst = np.zeros((3, 2), dtype=np.float32)
130
+ src[0, :] = center + scale_tmp * shift
131
+ src[1, :] = center + src_dir + scale_tmp * shift
132
+ dst[0, :] = [(dst_w - 1) * 0.5, (dst_h - 1) * 0.5]
133
+ dst[1, :] = np.array([(dst_w - 1) * 0.5, (dst_h - 1) * 0.5]) + dst_dir
134
+
135
+ src[2:, :] = get_3rd_point(src[0, :], src[1, :])
136
+ dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
137
+
138
+ if inv:
139
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
140
+ else:
141
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
142
+
143
+ return trans
144
+
145
+
146
+ def affine_transform(pt, t):
147
+ new_pt = np.array([pt[0], pt[1], 1.0]).T
148
+ new_pt = np.dot(t, new_pt)
149
+ return new_pt[:2]
150
+
151
+
152
+ def get_3rd_point(a, b):
153
+ direct = a - b
154
+ return b + np.array([-direct[1], direct[0]], dtype=np.float32)
155
+
156
+
157
+ def get_dir(src_point, rot_rad):
158
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
159
+
160
+ src_result = [0, 0]
161
+ src_result[0] = src_point[0] * cs - src_point[1] * sn
162
+ src_result[1] = src_point[0] * sn + src_point[1] * cs
163
+
164
+ return src_result
165
+
166
+
167
+ def crop(img, center, scale, output_size, rot=0):
168
+ trans = get_affine_transform(center, scale, rot, output_size)
169
+
170
+ dst_img = cv2.warpAffine(
171
+ img, trans, (int(output_size[1]), int(output_size[0])), flags=cv2.INTER_LINEAR
172
+ )
173
+
174
+ return dst_img