Add files using upload-large-folder tool
Browse files- Leffa/3rdparty/SCHP/__init__.py +241 -0
- Leffa/3rdparty/SCHP/networks/AugmentCE2P.py +480 -0
- Leffa/3rdparty/SCHP/networks/__init__.py +13 -0
- Leffa/3rdparty/SCHP/utils/transforms.py +174 -0
- Leffa/3rdparty/detectron2/data/transforms/__init__.py +14 -0
- Leffa/3rdparty/detectron2/export/README.md +15 -0
- Leffa/3rdparty/detectron2/export/__init__.py +30 -0
- Leffa/3rdparty/detectron2/export/api.py +230 -0
- Leffa/3rdparty/detectron2/export/c10.py +571 -0
- Leffa/3rdparty/detectron2/export/caffe2_export.py +203 -0
- Leffa/3rdparty/detectron2/export/caffe2_inference.py +161 -0
- Leffa/3rdparty/detectron2/export/caffe2_modeling.py +420 -0
- Leffa/3rdparty/detectron2/export/caffe2_patch.py +189 -0
- Leffa/3rdparty/detectron2/export/flatten.py +330 -0
- Leffa/3rdparty/detectron2/export/shared.py +1039 -0
- Leffa/3rdparty/detectron2/export/torchscript.py +132 -0
- Leffa/3rdparty/detectron2/export/torchscript_patch.py +406 -0
- Leffa/SCHP/__init__.py +241 -0
- Leffa/SCHP/networks/AugmentCE2P.py +480 -0
- Leffa/SCHP/networks/__init__.py +13 -0
- Leffa/SCHP/utils/transforms.py +174 -0
Leffa/3rdparty/SCHP/__init__.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from SCHP import networks
|
| 8 |
+
from SCHP.utils.transforms import get_affine_transform, transform_logits
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_palette(num_cls):
|
| 13 |
+
"""Returns the color map for visualizing the segmentation mask.
|
| 14 |
+
Args:
|
| 15 |
+
num_cls: Number of classes
|
| 16 |
+
Returns:
|
| 17 |
+
The color map
|
| 18 |
+
"""
|
| 19 |
+
n = num_cls
|
| 20 |
+
palette = [0] * (n * 3)
|
| 21 |
+
for j in range(0, n):
|
| 22 |
+
lab = j
|
| 23 |
+
palette[j * 3 + 0] = 0
|
| 24 |
+
palette[j * 3 + 1] = 0
|
| 25 |
+
palette[j * 3 + 2] = 0
|
| 26 |
+
i = 0
|
| 27 |
+
while lab:
|
| 28 |
+
palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
|
| 29 |
+
palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
|
| 30 |
+
palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
|
| 31 |
+
i += 1
|
| 32 |
+
lab >>= 3
|
| 33 |
+
return palette
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
dataset_settings = {
|
| 37 |
+
"lip": {
|
| 38 |
+
"input_size": [473, 473],
|
| 39 |
+
"num_classes": 20,
|
| 40 |
+
"label": [
|
| 41 |
+
"Background",
|
| 42 |
+
"Hat",
|
| 43 |
+
"Hair",
|
| 44 |
+
"Glove",
|
| 45 |
+
"Sunglasses",
|
| 46 |
+
"Upper-clothes",
|
| 47 |
+
"Dress",
|
| 48 |
+
"Coat",
|
| 49 |
+
"Socks",
|
| 50 |
+
"Pants",
|
| 51 |
+
"Jumpsuits",
|
| 52 |
+
"Scarf",
|
| 53 |
+
"Skirt",
|
| 54 |
+
"Face",
|
| 55 |
+
"Left-arm",
|
| 56 |
+
"Right-arm",
|
| 57 |
+
"Left-leg",
|
| 58 |
+
"Right-leg",
|
| 59 |
+
"Left-shoe",
|
| 60 |
+
"Right-shoe",
|
| 61 |
+
],
|
| 62 |
+
},
|
| 63 |
+
"atr": {
|
| 64 |
+
"input_size": [512, 512],
|
| 65 |
+
"num_classes": 18,
|
| 66 |
+
"label": [
|
| 67 |
+
"Background",
|
| 68 |
+
"Hat",
|
| 69 |
+
"Hair",
|
| 70 |
+
"Sunglasses",
|
| 71 |
+
"Upper-clothes",
|
| 72 |
+
"Skirt",
|
| 73 |
+
"Pants",
|
| 74 |
+
"Dress",
|
| 75 |
+
"Belt",
|
| 76 |
+
"Left-shoe",
|
| 77 |
+
"Right-shoe",
|
| 78 |
+
"Face",
|
| 79 |
+
"Left-leg",
|
| 80 |
+
"Right-leg",
|
| 81 |
+
"Left-arm",
|
| 82 |
+
"Right-arm",
|
| 83 |
+
"Bag",
|
| 84 |
+
"Scarf",
|
| 85 |
+
],
|
| 86 |
+
},
|
| 87 |
+
"pascal": {
|
| 88 |
+
"input_size": [512, 512],
|
| 89 |
+
"num_classes": 7,
|
| 90 |
+
"label": [
|
| 91 |
+
"Background",
|
| 92 |
+
"Head",
|
| 93 |
+
"Torso",
|
| 94 |
+
"Upper Arms",
|
| 95 |
+
"Lower Arms",
|
| 96 |
+
"Upper Legs",
|
| 97 |
+
"Lower Legs",
|
| 98 |
+
],
|
| 99 |
+
},
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SCHP:
|
| 104 |
+
def __init__(self, ckpt_path, device):
|
| 105 |
+
dataset_type = None
|
| 106 |
+
if "lip" in ckpt_path:
|
| 107 |
+
dataset_type = "lip"
|
| 108 |
+
elif "atr" in ckpt_path:
|
| 109 |
+
dataset_type = "atr"
|
| 110 |
+
elif "pascal" in ckpt_path:
|
| 111 |
+
dataset_type = "pascal"
|
| 112 |
+
assert dataset_type is not None, "Dataset type not found in checkpoint path"
|
| 113 |
+
self.device = device
|
| 114 |
+
self.num_classes = dataset_settings[dataset_type]["num_classes"]
|
| 115 |
+
self.input_size = dataset_settings[dataset_type]["input_size"]
|
| 116 |
+
self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0]
|
| 117 |
+
self.palette = get_palette(self.num_classes)
|
| 118 |
+
|
| 119 |
+
self.label = dataset_settings[dataset_type]["label"]
|
| 120 |
+
self.model = networks.init_model(
|
| 121 |
+
"resnet101", num_classes=self.num_classes, pretrained=None
|
| 122 |
+
).to(device)
|
| 123 |
+
self.load_ckpt(ckpt_path)
|
| 124 |
+
self.model.eval()
|
| 125 |
+
|
| 126 |
+
self.transform = transforms.Compose(
|
| 127 |
+
[
|
| 128 |
+
transforms.ToTensor(),
|
| 129 |
+
transforms.Normalize(
|
| 130 |
+
mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]
|
| 131 |
+
),
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
self.upsample = torch.nn.Upsample(
|
| 135 |
+
size=self.input_size, mode="bilinear", align_corners=True
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def load_ckpt(self, ckpt_path):
|
| 139 |
+
rename_map = {
|
| 140 |
+
"decoder.conv3.2.weight": "decoder.conv3.3.weight",
|
| 141 |
+
"decoder.conv3.3.weight": "decoder.conv3.4.weight",
|
| 142 |
+
"decoder.conv3.3.bias": "decoder.conv3.4.bias",
|
| 143 |
+
"decoder.conv3.3.running_mean": "decoder.conv3.4.running_mean",
|
| 144 |
+
"decoder.conv3.3.running_var": "decoder.conv3.4.running_var",
|
| 145 |
+
"fushion.3.weight": "fushion.4.weight",
|
| 146 |
+
"fushion.3.bias": "fushion.4.bias",
|
| 147 |
+
}
|
| 148 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 149 |
+
new_state_dict = OrderedDict()
|
| 150 |
+
for k, v in state_dict.items():
|
| 151 |
+
name = k[7:] # remove `module.`
|
| 152 |
+
new_state_dict[name] = v
|
| 153 |
+
new_state_dict_ = OrderedDict()
|
| 154 |
+
for k, v in list(new_state_dict.items()):
|
| 155 |
+
if k in rename_map:
|
| 156 |
+
new_state_dict_[rename_map[k]] = v
|
| 157 |
+
else:
|
| 158 |
+
new_state_dict_[k] = v
|
| 159 |
+
self.model.load_state_dict(new_state_dict_, strict=False)
|
| 160 |
+
|
| 161 |
+
def _box2cs(self, box):
|
| 162 |
+
x, y, w, h = box[:4]
|
| 163 |
+
return self._xywh2cs(x, y, w, h)
|
| 164 |
+
|
| 165 |
+
def _xywh2cs(self, x, y, w, h):
|
| 166 |
+
center = np.zeros((2), dtype=np.float32)
|
| 167 |
+
center[0] = x + w * 0.5
|
| 168 |
+
center[1] = y + h * 0.5
|
| 169 |
+
if w > self.aspect_ratio * h:
|
| 170 |
+
h = w * 1.0 / self.aspect_ratio
|
| 171 |
+
elif w < self.aspect_ratio * h:
|
| 172 |
+
w = h * self.aspect_ratio
|
| 173 |
+
scale = np.array([w, h], dtype=np.float32)
|
| 174 |
+
return center, scale
|
| 175 |
+
|
| 176 |
+
def preprocess(self, image):
|
| 177 |
+
if isinstance(image, str):
|
| 178 |
+
img = cv2.imread(image, cv2.IMREAD_COLOR)
|
| 179 |
+
elif isinstance(image, Image.Image):
|
| 180 |
+
# to cv2 format
|
| 181 |
+
img = np.array(image)
|
| 182 |
+
|
| 183 |
+
h, w, _ = img.shape
|
| 184 |
+
# Get person center and scale
|
| 185 |
+
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
| 186 |
+
r = 0
|
| 187 |
+
trans = get_affine_transform(person_center, s, r, self.input_size)
|
| 188 |
+
input = cv2.warpAffine(
|
| 189 |
+
img,
|
| 190 |
+
trans,
|
| 191 |
+
(int(self.input_size[1]), int(self.input_size[0])),
|
| 192 |
+
flags=cv2.INTER_LINEAR,
|
| 193 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 194 |
+
borderValue=(0, 0, 0),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
input = self.transform(input).to(self.device).unsqueeze(0)
|
| 198 |
+
meta = {
|
| 199 |
+
"center": person_center,
|
| 200 |
+
"height": h,
|
| 201 |
+
"width": w,
|
| 202 |
+
"scale": s,
|
| 203 |
+
"rotation": r,
|
| 204 |
+
}
|
| 205 |
+
return input, meta
|
| 206 |
+
|
| 207 |
+
def __call__(self, image_or_path):
|
| 208 |
+
if isinstance(image_or_path, list):
|
| 209 |
+
image_list = []
|
| 210 |
+
meta_list = []
|
| 211 |
+
for image in image_or_path:
|
| 212 |
+
image, meta = self.preprocess(image)
|
| 213 |
+
image_list.append(image)
|
| 214 |
+
meta_list.append(meta)
|
| 215 |
+
image = torch.cat(image_list, dim=0)
|
| 216 |
+
else:
|
| 217 |
+
image, meta = self.preprocess(image_or_path)
|
| 218 |
+
meta_list = [meta]
|
| 219 |
+
|
| 220 |
+
output = self.model(image)
|
| 221 |
+
# upsample_outputs = self.upsample(output[0][-1])
|
| 222 |
+
upsample_outputs = self.upsample(output)
|
| 223 |
+
upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
|
| 224 |
+
|
| 225 |
+
output_img_list = []
|
| 226 |
+
for upsample_output, meta in zip(upsample_outputs, meta_list):
|
| 227 |
+
c, s, w, h = meta["center"], meta["scale"], meta["width"], meta["height"]
|
| 228 |
+
logits_result = transform_logits(
|
| 229 |
+
upsample_output.data.cpu().numpy(),
|
| 230 |
+
c,
|
| 231 |
+
s,
|
| 232 |
+
w,
|
| 233 |
+
h,
|
| 234 |
+
input_size=self.input_size,
|
| 235 |
+
)
|
| 236 |
+
parsing_result = np.argmax(logits_result, axis=2)
|
| 237 |
+
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
|
| 238 |
+
output_img.putpalette(self.palette)
|
| 239 |
+
output_img_list.append(output_img)
|
| 240 |
+
|
| 241 |
+
return output_img_list[0] if len(output_img_list) == 1 else output_img_list
|
Leffa/3rdparty/SCHP/networks/AugmentCE2P.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
@Author : Peike Li
|
| 6 |
+
@Contact : peike.li@yahoo.com
|
| 7 |
+
@File : AugmentCE2P.py
|
| 8 |
+
@Time : 8/4/19 3:35 PM
|
| 9 |
+
@Desc :
|
| 10 |
+
@License : This source code is licensed under the license found in the
|
| 11 |
+
LICENSE file in the root directory of this source tree.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from torch.nn import BatchNorm2d, functional as F, LeakyReLU
|
| 18 |
+
|
| 19 |
+
affine_par = True
|
| 20 |
+
pretrained_settings = {
|
| 21 |
+
"resnet101": {
|
| 22 |
+
"imagenet": {
|
| 23 |
+
"input_space": "BGR",
|
| 24 |
+
"input_size": [3, 224, 224],
|
| 25 |
+
"input_range": [0, 1],
|
| 26 |
+
"mean": [0.406, 0.456, 0.485],
|
| 27 |
+
"std": [0.225, 0.224, 0.229],
|
| 28 |
+
"num_classes": 1000,
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 35 |
+
"3x3 convolution with padding"
|
| 36 |
+
return nn.Conv2d(
|
| 37 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Bottleneck(nn.Module):
|
| 42 |
+
expansion = 4
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
inplanes,
|
| 47 |
+
planes,
|
| 48 |
+
stride=1,
|
| 49 |
+
dilation=1,
|
| 50 |
+
downsample=None,
|
| 51 |
+
fist_dilation=1,
|
| 52 |
+
multi_grid=1,
|
| 53 |
+
):
|
| 54 |
+
super(Bottleneck, self).__init__()
|
| 55 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 56 |
+
self.bn1 = BatchNorm2d(planes)
|
| 57 |
+
self.conv2 = nn.Conv2d(
|
| 58 |
+
planes,
|
| 59 |
+
planes,
|
| 60 |
+
kernel_size=3,
|
| 61 |
+
stride=stride,
|
| 62 |
+
padding=dilation * multi_grid,
|
| 63 |
+
dilation=dilation * multi_grid,
|
| 64 |
+
bias=False,
|
| 65 |
+
)
|
| 66 |
+
self.bn2 = BatchNorm2d(planes)
|
| 67 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 68 |
+
self.bn3 = BatchNorm2d(planes * 4)
|
| 69 |
+
self.relu = nn.ReLU(inplace=False)
|
| 70 |
+
self.relu_inplace = nn.ReLU(inplace=True)
|
| 71 |
+
self.downsample = downsample
|
| 72 |
+
self.dilation = dilation
|
| 73 |
+
self.stride = stride
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
residual = x
|
| 77 |
+
|
| 78 |
+
out = self.conv1(x)
|
| 79 |
+
out = self.bn1(out)
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
|
| 82 |
+
out = self.conv2(out)
|
| 83 |
+
out = self.bn2(out)
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
|
| 86 |
+
out = self.conv3(out)
|
| 87 |
+
out = self.bn3(out)
|
| 88 |
+
|
| 89 |
+
if self.downsample is not None:
|
| 90 |
+
residual = self.downsample(x)
|
| 91 |
+
|
| 92 |
+
out = out + residual
|
| 93 |
+
out = self.relu_inplace(out)
|
| 94 |
+
|
| 95 |
+
return out
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class PSPModule(nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
Reference:
|
| 101 |
+
Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
|
| 105 |
+
super(PSPModule, self).__init__()
|
| 106 |
+
|
| 107 |
+
self.stages = []
|
| 108 |
+
self.stages = nn.ModuleList(
|
| 109 |
+
[self._make_stage(features, out_features, size) for size in sizes]
|
| 110 |
+
)
|
| 111 |
+
self.bottleneck = nn.Sequential(
|
| 112 |
+
nn.Conv2d(
|
| 113 |
+
features + len(sizes) * out_features,
|
| 114 |
+
out_features,
|
| 115 |
+
kernel_size=3,
|
| 116 |
+
padding=1,
|
| 117 |
+
dilation=1,
|
| 118 |
+
bias=False,
|
| 119 |
+
),
|
| 120 |
+
BatchNorm2d(out_features),
|
| 121 |
+
LeakyReLU(),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def _make_stage(self, features, out_features, size):
|
| 125 |
+
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
|
| 126 |
+
conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
|
| 127 |
+
return nn.Sequential(
|
| 128 |
+
prior,
|
| 129 |
+
conv,
|
| 130 |
+
# bn
|
| 131 |
+
BatchNorm2d(out_features),
|
| 132 |
+
LeakyReLU(),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def forward(self, feats):
|
| 136 |
+
h, w = feats.size(2), feats.size(3)
|
| 137 |
+
priors = [
|
| 138 |
+
F.interpolate(
|
| 139 |
+
input=stage(feats), size=(h, w), mode="bilinear", align_corners=True
|
| 140 |
+
)
|
| 141 |
+
for stage in self.stages
|
| 142 |
+
] + [feats]
|
| 143 |
+
bottle = self.bottleneck(torch.cat(priors, 1))
|
| 144 |
+
return bottle
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class ASPPModule(nn.Module):
|
| 148 |
+
"""
|
| 149 |
+
Reference:
|
| 150 |
+
Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self, features, inner_features=256, out_features=512, dilations=(12, 24, 36)
|
| 155 |
+
):
|
| 156 |
+
super(ASPPModule, self).__init__()
|
| 157 |
+
|
| 158 |
+
self.conv1 = nn.Sequential(
|
| 159 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
| 160 |
+
nn.Conv2d(
|
| 161 |
+
features,
|
| 162 |
+
inner_features,
|
| 163 |
+
kernel_size=1,
|
| 164 |
+
padding=0,
|
| 165 |
+
dilation=1,
|
| 166 |
+
bias=False,
|
| 167 |
+
),
|
| 168 |
+
# InPlaceABNSync(inner_features)
|
| 169 |
+
BatchNorm2d(inner_features),
|
| 170 |
+
LeakyReLU(),
|
| 171 |
+
)
|
| 172 |
+
self.conv2 = nn.Sequential(
|
| 173 |
+
nn.Conv2d(
|
| 174 |
+
features,
|
| 175 |
+
inner_features,
|
| 176 |
+
kernel_size=1,
|
| 177 |
+
padding=0,
|
| 178 |
+
dilation=1,
|
| 179 |
+
bias=False,
|
| 180 |
+
),
|
| 181 |
+
BatchNorm2d(inner_features),
|
| 182 |
+
LeakyReLU(),
|
| 183 |
+
)
|
| 184 |
+
self.conv3 = nn.Sequential(
|
| 185 |
+
nn.Conv2d(
|
| 186 |
+
features,
|
| 187 |
+
inner_features,
|
| 188 |
+
kernel_size=3,
|
| 189 |
+
padding=dilations[0],
|
| 190 |
+
dilation=dilations[0],
|
| 191 |
+
bias=False,
|
| 192 |
+
),
|
| 193 |
+
BatchNorm2d(inner_features),
|
| 194 |
+
LeakyReLU(),
|
| 195 |
+
)
|
| 196 |
+
self.conv4 = nn.Sequential(
|
| 197 |
+
nn.Conv2d(
|
| 198 |
+
features,
|
| 199 |
+
inner_features,
|
| 200 |
+
kernel_size=3,
|
| 201 |
+
padding=dilations[1],
|
| 202 |
+
dilation=dilations[1],
|
| 203 |
+
bias=False,
|
| 204 |
+
),
|
| 205 |
+
BatchNorm2d(inner_features),
|
| 206 |
+
LeakyReLU(),
|
| 207 |
+
)
|
| 208 |
+
self.conv5 = nn.Sequential(
|
| 209 |
+
nn.Conv2d(
|
| 210 |
+
features,
|
| 211 |
+
inner_features,
|
| 212 |
+
kernel_size=3,
|
| 213 |
+
padding=dilations[2],
|
| 214 |
+
dilation=dilations[2],
|
| 215 |
+
bias=False,
|
| 216 |
+
),
|
| 217 |
+
BatchNorm2d(inner_features),
|
| 218 |
+
LeakyReLU(),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
self.bottleneck = nn.Sequential(
|
| 222 |
+
nn.Conv2d(
|
| 223 |
+
inner_features * 5,
|
| 224 |
+
out_features,
|
| 225 |
+
kernel_size=1,
|
| 226 |
+
padding=0,
|
| 227 |
+
dilation=1,
|
| 228 |
+
bias=False,
|
| 229 |
+
),
|
| 230 |
+
BatchNorm2d(inner_features),
|
| 231 |
+
LeakyReLU(),
|
| 232 |
+
nn.Dropout2d(0.1),
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
_, _, h, w = x.size()
|
| 237 |
+
|
| 238 |
+
feat1 = F.interpolate(
|
| 239 |
+
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
feat2 = self.conv2(x)
|
| 243 |
+
feat3 = self.conv3(x)
|
| 244 |
+
feat4 = self.conv4(x)
|
| 245 |
+
feat5 = self.conv5(x)
|
| 246 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
|
| 247 |
+
|
| 248 |
+
bottle = self.bottleneck(out)
|
| 249 |
+
return bottle
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class Edge_Module(nn.Module):
|
| 253 |
+
"""
|
| 254 |
+
Edge Learning Branch
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(self, in_fea=[256, 512, 1024], mid_fea=256, out_fea=2):
|
| 258 |
+
super(Edge_Module, self).__init__()
|
| 259 |
+
|
| 260 |
+
self.conv1 = nn.Sequential(
|
| 261 |
+
nn.Conv2d(
|
| 262 |
+
in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
|
| 263 |
+
),
|
| 264 |
+
BatchNorm2d(mid_fea),
|
| 265 |
+
LeakyReLU(),
|
| 266 |
+
)
|
| 267 |
+
self.conv2 = nn.Sequential(
|
| 268 |
+
nn.Conv2d(
|
| 269 |
+
in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
|
| 270 |
+
),
|
| 271 |
+
BatchNorm2d(mid_fea),
|
| 272 |
+
LeakyReLU(),
|
| 273 |
+
)
|
| 274 |
+
self.conv3 = nn.Sequential(
|
| 275 |
+
nn.Conv2d(
|
| 276 |
+
in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
|
| 277 |
+
),
|
| 278 |
+
BatchNorm2d(mid_fea),
|
| 279 |
+
LeakyReLU(),
|
| 280 |
+
)
|
| 281 |
+
self.conv4 = nn.Conv2d(
|
| 282 |
+
mid_fea, out_fea, kernel_size=3, padding=1, dilation=1, bias=True
|
| 283 |
+
)
|
| 284 |
+
# self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
|
| 285 |
+
|
| 286 |
+
def forward(self, x1, x2, x3):
|
| 287 |
+
_, _, h, w = x1.size()
|
| 288 |
+
|
| 289 |
+
edge1_fea = self.conv1(x1)
|
| 290 |
+
# edge1 = self.conv4(edge1_fea)
|
| 291 |
+
edge2_fea = self.conv2(x2)
|
| 292 |
+
edge2 = self.conv4(edge2_fea)
|
| 293 |
+
edge3_fea = self.conv3(x3)
|
| 294 |
+
edge3 = self.conv4(edge3_fea)
|
| 295 |
+
|
| 296 |
+
edge2_fea = F.interpolate(
|
| 297 |
+
edge2_fea, size=(h, w), mode="bilinear", align_corners=True
|
| 298 |
+
)
|
| 299 |
+
edge3_fea = F.interpolate(
|
| 300 |
+
edge3_fea, size=(h, w), mode="bilinear", align_corners=True
|
| 301 |
+
)
|
| 302 |
+
edge2 = F.interpolate(edge2, size=(h, w), mode="bilinear", align_corners=True)
|
| 303 |
+
edge3 = F.interpolate(edge3, size=(h, w), mode="bilinear", align_corners=True)
|
| 304 |
+
|
| 305 |
+
# edge = torch.cat([edge1, edge2, edge3], dim=1)
|
| 306 |
+
edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
|
| 307 |
+
# edge = self.conv5(edge)
|
| 308 |
+
|
| 309 |
+
# return edge, edge_fea
|
| 310 |
+
return edge_fea
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class Decoder_Module(nn.Module):
|
| 314 |
+
"""
|
| 315 |
+
Parsing Branch Decoder Module.
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
def __init__(self, num_classes):
|
| 319 |
+
super(Decoder_Module, self).__init__()
|
| 320 |
+
self.conv1 = nn.Sequential(
|
| 321 |
+
nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
| 322 |
+
BatchNorm2d(256),
|
| 323 |
+
LeakyReLU(),
|
| 324 |
+
)
|
| 325 |
+
self.conv2 = nn.Sequential(
|
| 326 |
+
nn.Conv2d(
|
| 327 |
+
256, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False
|
| 328 |
+
),
|
| 329 |
+
BatchNorm2d(48),
|
| 330 |
+
LeakyReLU(),
|
| 331 |
+
)
|
| 332 |
+
self.conv3 = nn.Sequential(
|
| 333 |
+
nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
| 334 |
+
BatchNorm2d(256),
|
| 335 |
+
LeakyReLU(),
|
| 336 |
+
nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
| 337 |
+
BatchNorm2d(256),
|
| 338 |
+
LeakyReLU(),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
|
| 342 |
+
|
| 343 |
+
def forward(self, xt, xl):
|
| 344 |
+
_, _, h, w = xl.size()
|
| 345 |
+
xt = F.interpolate(
|
| 346 |
+
self.conv1(xt), size=(h, w), mode="bilinear", align_corners=True
|
| 347 |
+
)
|
| 348 |
+
xl = self.conv2(xl)
|
| 349 |
+
x = torch.cat([xt, xl], dim=1)
|
| 350 |
+
x = self.conv3(x)
|
| 351 |
+
# seg = self.conv4(x)
|
| 352 |
+
# return seg, x
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class ResNet(nn.Module):
|
| 357 |
+
def __init__(self, block, layers, num_classes):
|
| 358 |
+
self.inplanes = 128
|
| 359 |
+
super(ResNet, self).__init__()
|
| 360 |
+
self.conv1 = conv3x3(3, 64, stride=2)
|
| 361 |
+
self.bn1 = BatchNorm2d(64)
|
| 362 |
+
self.relu1 = nn.ReLU(inplace=False)
|
| 363 |
+
self.conv2 = conv3x3(64, 64)
|
| 364 |
+
self.bn2 = BatchNorm2d(64)
|
| 365 |
+
self.relu2 = nn.ReLU(inplace=False)
|
| 366 |
+
self.conv3 = conv3x3(64, 128)
|
| 367 |
+
self.bn3 = BatchNorm2d(128)
|
| 368 |
+
self.relu3 = nn.ReLU(inplace=False)
|
| 369 |
+
|
| 370 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 371 |
+
|
| 372 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 373 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 374 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 375 |
+
self.layer4 = self._make_layer(
|
| 376 |
+
block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 1, 1)
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.context_encoding = PSPModule(2048, 512)
|
| 380 |
+
|
| 381 |
+
self.edge = Edge_Module()
|
| 382 |
+
self.decoder = Decoder_Module(num_classes)
|
| 383 |
+
|
| 384 |
+
self.fushion = nn.Sequential(
|
| 385 |
+
nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
| 386 |
+
BatchNorm2d(256),
|
| 387 |
+
LeakyReLU(),
|
| 388 |
+
nn.Dropout2d(0.1),
|
| 389 |
+
nn.Conv2d(
|
| 390 |
+
256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True
|
| 391 |
+
),
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
|
| 395 |
+
downsample = None
|
| 396 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 397 |
+
downsample = nn.Sequential(
|
| 398 |
+
nn.Conv2d(
|
| 399 |
+
self.inplanes,
|
| 400 |
+
planes * block.expansion,
|
| 401 |
+
kernel_size=1,
|
| 402 |
+
stride=stride,
|
| 403 |
+
bias=False,
|
| 404 |
+
),
|
| 405 |
+
BatchNorm2d(planes * block.expansion, affine=affine_par),
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
layers = []
|
| 409 |
+
generate_multi_grid = lambda index, grids: (
|
| 410 |
+
grids[index % len(grids)] if isinstance(grids, tuple) else 1
|
| 411 |
+
)
|
| 412 |
+
layers.append(
|
| 413 |
+
block(
|
| 414 |
+
self.inplanes,
|
| 415 |
+
planes,
|
| 416 |
+
stride,
|
| 417 |
+
dilation=dilation,
|
| 418 |
+
downsample=downsample,
|
| 419 |
+
multi_grid=generate_multi_grid(0, multi_grid),
|
| 420 |
+
)
|
| 421 |
+
)
|
| 422 |
+
self.inplanes = planes * block.expansion
|
| 423 |
+
for i in range(1, blocks):
|
| 424 |
+
layers.append(
|
| 425 |
+
block(
|
| 426 |
+
self.inplanes,
|
| 427 |
+
planes,
|
| 428 |
+
dilation=dilation,
|
| 429 |
+
multi_grid=generate_multi_grid(i, multi_grid),
|
| 430 |
+
)
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return nn.Sequential(*layers)
|
| 434 |
+
|
| 435 |
+
def forward(self, x):
|
| 436 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 437 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 438 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 439 |
+
x = self.maxpool(x)
|
| 440 |
+
x2 = self.layer1(x)
|
| 441 |
+
x3 = self.layer2(x2)
|
| 442 |
+
x4 = self.layer3(x3)
|
| 443 |
+
x5 = self.layer4(x4)
|
| 444 |
+
x = self.context_encoding(x5)
|
| 445 |
+
# parsing_result, parsing_fea = self.decoder(x, x2)
|
| 446 |
+
parsing_fea = self.decoder(x, x2)
|
| 447 |
+
# Edge Branch
|
| 448 |
+
# edge_result, edge_fea = self.edge(x2, x3, x4)
|
| 449 |
+
edge_fea = self.edge(x2, x3, x4)
|
| 450 |
+
# Fusion Branch
|
| 451 |
+
x = torch.cat([parsing_fea, edge_fea], dim=1)
|
| 452 |
+
fusion_result = self.fushion(x)
|
| 453 |
+
# return [[parsing_result, fusion_result], [edge_result]]
|
| 454 |
+
return fusion_result
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def initialize_pretrained_model(
|
| 458 |
+
model, settings, pretrained="./models/resnet101-imagenet.pth"
|
| 459 |
+
):
|
| 460 |
+
model.input_space = settings["input_space"]
|
| 461 |
+
model.input_size = settings["input_size"]
|
| 462 |
+
model.input_range = settings["input_range"]
|
| 463 |
+
model.mean = settings["mean"]
|
| 464 |
+
model.std = settings["std"]
|
| 465 |
+
|
| 466 |
+
if pretrained is not None:
|
| 467 |
+
saved_state_dict = torch.load(pretrained)
|
| 468 |
+
new_params = model.state_dict().copy()
|
| 469 |
+
for i in saved_state_dict:
|
| 470 |
+
i_parts = i.split(".")
|
| 471 |
+
if not i_parts[0] == "fc":
|
| 472 |
+
new_params[".".join(i_parts[0:])] = saved_state_dict[i]
|
| 473 |
+
model.load_state_dict(new_params)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def resnet101(num_classes=20, pretrained="./models/resnet101-imagenet.pth"):
|
| 477 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
|
| 478 |
+
settings = pretrained_settings["resnet101"]["imagenet"]
|
| 479 |
+
initialize_pretrained_model(model, settings, pretrained)
|
| 480 |
+
return model
|
Leffa/3rdparty/SCHP/networks/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
from SCHP.networks.AugmentCE2P import resnet101
|
| 4 |
+
|
| 5 |
+
__factory = {
|
| 6 |
+
"resnet101": resnet101,
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def init_model(name, *args, **kwargs):
|
| 11 |
+
if name not in __factory.keys():
|
| 12 |
+
raise KeyError("Unknown model arch: {}".format(name))
|
| 13 |
+
return __factory[name](*args, **kwargs)
|
Leffa/3rdparty/SCHP/utils/transforms.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) Microsoft
|
| 3 |
+
# Licensed under the MIT License.
|
| 4 |
+
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
| 5 |
+
# ------------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from __future__ import absolute_import, division, print_function
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BRG2Tensor_transform(object):
|
| 16 |
+
def __call__(self, pic):
|
| 17 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
| 18 |
+
if isinstance(img, torch.ByteTensor):
|
| 19 |
+
return img.float()
|
| 20 |
+
else:
|
| 21 |
+
return img
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BGR2RGB_transform(object):
|
| 25 |
+
def __call__(self, tensor):
|
| 26 |
+
return tensor[[2, 1, 0], :, :]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def flip_back(output_flipped, matched_parts):
|
| 30 |
+
"""
|
| 31 |
+
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
|
| 32 |
+
"""
|
| 33 |
+
assert (
|
| 34 |
+
output_flipped.ndim == 4
|
| 35 |
+
), "output_flipped should be [batch_size, num_joints, height, width]"
|
| 36 |
+
|
| 37 |
+
output_flipped = output_flipped[:, :, :, ::-1]
|
| 38 |
+
|
| 39 |
+
for pair in matched_parts:
|
| 40 |
+
tmp = output_flipped[:, pair[0], :, :].copy()
|
| 41 |
+
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
|
| 42 |
+
output_flipped[:, pair[1], :, :] = tmp
|
| 43 |
+
|
| 44 |
+
return output_flipped
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def fliplr_joints(joints, joints_vis, width, matched_parts):
|
| 48 |
+
"""
|
| 49 |
+
flip coords
|
| 50 |
+
"""
|
| 51 |
+
# Flip horizontal
|
| 52 |
+
joints[:, 0] = width - joints[:, 0] - 1
|
| 53 |
+
|
| 54 |
+
# Change left-right parts
|
| 55 |
+
for pair in matched_parts:
|
| 56 |
+
joints[pair[0], :], joints[pair[1], :] = (
|
| 57 |
+
joints[pair[1], :],
|
| 58 |
+
joints[pair[0], :].copy(),
|
| 59 |
+
)
|
| 60 |
+
joints_vis[pair[0], :], joints_vis[pair[1], :] = (
|
| 61 |
+
joints_vis[pair[1], :],
|
| 62 |
+
joints_vis[pair[0], :].copy(),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return joints * joints_vis, joints_vis
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def transform_preds(coords, center, scale, input_size):
|
| 69 |
+
target_coords = np.zeros(coords.shape)
|
| 70 |
+
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
| 71 |
+
for p in range(coords.shape[0]):
|
| 72 |
+
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
|
| 73 |
+
return target_coords
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def transform_parsing(pred, center, scale, width, height, input_size):
|
| 77 |
+
|
| 78 |
+
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
| 79 |
+
target_pred = cv2.warpAffine(
|
| 80 |
+
pred,
|
| 81 |
+
trans,
|
| 82 |
+
(int(width), int(height)), # (int(width), int(height)),
|
| 83 |
+
flags=cv2.INTER_NEAREST,
|
| 84 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 85 |
+
borderValue=(0),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return target_pred
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def transform_logits(logits, center, scale, width, height, input_size):
|
| 92 |
+
|
| 93 |
+
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
| 94 |
+
channel = logits.shape[2]
|
| 95 |
+
target_logits = []
|
| 96 |
+
for i in range(channel):
|
| 97 |
+
target_logit = cv2.warpAffine(
|
| 98 |
+
logits[:, :, i],
|
| 99 |
+
trans,
|
| 100 |
+
(int(width), int(height)), # (int(width), int(height)),
|
| 101 |
+
flags=cv2.INTER_LINEAR,
|
| 102 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 103 |
+
borderValue=(0),
|
| 104 |
+
)
|
| 105 |
+
target_logits.append(target_logit)
|
| 106 |
+
target_logits = np.stack(target_logits, axis=2)
|
| 107 |
+
|
| 108 |
+
return target_logits
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_affine_transform(
|
| 112 |
+
center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0
|
| 113 |
+
):
|
| 114 |
+
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
| 115 |
+
print(scale)
|
| 116 |
+
scale = np.array([scale, scale])
|
| 117 |
+
|
| 118 |
+
scale_tmp = scale
|
| 119 |
+
|
| 120 |
+
src_w = scale_tmp[0]
|
| 121 |
+
dst_w = output_size[1]
|
| 122 |
+
dst_h = output_size[0]
|
| 123 |
+
|
| 124 |
+
rot_rad = np.pi * rot / 180
|
| 125 |
+
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
| 126 |
+
dst_dir = np.array([0, (dst_w - 1) * -0.5], np.float32)
|
| 127 |
+
|
| 128 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
| 129 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
| 130 |
+
src[0, :] = center + scale_tmp * shift
|
| 131 |
+
src[1, :] = center + src_dir + scale_tmp * shift
|
| 132 |
+
dst[0, :] = [(dst_w - 1) * 0.5, (dst_h - 1) * 0.5]
|
| 133 |
+
dst[1, :] = np.array([(dst_w - 1) * 0.5, (dst_h - 1) * 0.5]) + dst_dir
|
| 134 |
+
|
| 135 |
+
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
| 136 |
+
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
| 137 |
+
|
| 138 |
+
if inv:
|
| 139 |
+
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
| 140 |
+
else:
|
| 141 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
| 142 |
+
|
| 143 |
+
return trans
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def affine_transform(pt, t):
|
| 147 |
+
new_pt = np.array([pt[0], pt[1], 1.0]).T
|
| 148 |
+
new_pt = np.dot(t, new_pt)
|
| 149 |
+
return new_pt[:2]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_3rd_point(a, b):
|
| 153 |
+
direct = a - b
|
| 154 |
+
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_dir(src_point, rot_rad):
|
| 158 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 159 |
+
|
| 160 |
+
src_result = [0, 0]
|
| 161 |
+
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
| 162 |
+
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
| 163 |
+
|
| 164 |
+
return src_result
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def crop(img, center, scale, output_size, rot=0):
|
| 168 |
+
trans = get_affine_transform(center, scale, rot, output_size)
|
| 169 |
+
|
| 170 |
+
dst_img = cv2.warpAffine(
|
| 171 |
+
img, trans, (int(output_size[1]), int(output_size[0])), flags=cv2.INTER_LINEAR
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return dst_img
|
Leffa/3rdparty/detectron2/data/transforms/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from fvcore.transforms.transform import Transform, TransformList # order them first
|
| 3 |
+
from fvcore.transforms.transform import *
|
| 4 |
+
from .transform import *
|
| 5 |
+
from .augmentation import *
|
| 6 |
+
from .augmentation_impl import *
|
| 7 |
+
|
| 8 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from detectron2.utils.env import fixup_module_metadata
|
| 12 |
+
|
| 13 |
+
fixup_module_metadata(__name__, globals(), __all__)
|
| 14 |
+
del fixup_module_metadata
|
Leffa/3rdparty/detectron2/export/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
This directory contains code to prepare a detectron2 model for deployment.
|
| 3 |
+
Currently it supports exporting a detectron2 model to TorchScript, ONNX, or (deprecated) Caffe2 format.
|
| 4 |
+
|
| 5 |
+
Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
### Acknowledgements
|
| 9 |
+
|
| 10 |
+
Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion tools.
|
| 11 |
+
|
| 12 |
+
Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who
|
| 13 |
+
help export Detectron2 models to TorchScript.
|
| 14 |
+
|
| 15 |
+
Thanks to ONNX Converter team at Microsoft who help export Detectron2 models to ONNX.
|
Leffa/3rdparty/detectron2/export/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
from .flatten import TracingAdapter
|
| 6 |
+
from .torchscript import dump_torchscript_IR, scripting_with_instances
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from caffe2.proto import caffe2_pb2 as _tmp
|
| 10 |
+
from caffe2.python import core
|
| 11 |
+
|
| 12 |
+
# caffe2 is optional
|
| 13 |
+
except ImportError:
|
| 14 |
+
pass
|
| 15 |
+
else:
|
| 16 |
+
from .api import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# TODO: Update ONNX Opset version and run tests when a newer PyTorch is supported
|
| 20 |
+
STABLE_ONNX_OPSET_VERSION = 11
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def add_export_config(cfg):
|
| 24 |
+
warnings.warn(
|
| 25 |
+
"add_export_config has been deprecated and behaves as no-op function.", DeprecationWarning
|
| 26 |
+
)
|
| 27 |
+
return cfg
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
Leffa/3rdparty/detectron2/export/api.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from caffe2.proto import caffe2_pb2
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from detectron2.config import CfgNode
|
| 10 |
+
from detectron2.utils.file_io import PathManager
|
| 11 |
+
|
| 12 |
+
from .caffe2_inference import ProtobufDetectionModel
|
| 13 |
+
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
|
| 14 |
+
from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"Caffe2Model",
|
| 18 |
+
"Caffe2Tracer",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Caffe2Tracer:
|
| 23 |
+
"""
|
| 24 |
+
Make a detectron2 model traceable with Caffe2 operators.
|
| 25 |
+
This class creates a traceable version of a detectron2 model which:
|
| 26 |
+
|
| 27 |
+
1. Rewrite parts of the model using ops in Caffe2. Note that some ops do
|
| 28 |
+
not have GPU implementation in Caffe2.
|
| 29 |
+
2. Remove post-processing and only produce raw layer outputs
|
| 30 |
+
|
| 31 |
+
After making a traceable model, the class provide methods to export such a
|
| 32 |
+
model to different deployment formats.
|
| 33 |
+
Exported graph produced by this class take two input tensors:
|
| 34 |
+
|
| 35 |
+
1. (1, C, H, W) float "data" which is an image (usually in [0, 255]).
|
| 36 |
+
(H, W) often has to be padded to multiple of 32 (depend on the model
|
| 37 |
+
architecture).
|
| 38 |
+
2. 1x3 float "im_info", each row of which is (height, width, 1.0).
|
| 39 |
+
Height and width are true image shapes before padding.
|
| 40 |
+
|
| 41 |
+
The class currently only supports models using builtin meta architectures.
|
| 42 |
+
Batch inference is not supported, and contributions are welcome.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, cfg: CfgNode, model: nn.Module, inputs):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
cfg (CfgNode): a detectron2 config used to construct caffe2-compatible model.
|
| 49 |
+
model (nn.Module): An original pytorch model. Must be among a few official models
|
| 50 |
+
in detectron2 that can be converted to become caffe2-compatible automatically.
|
| 51 |
+
Weights have to be already loaded to this model.
|
| 52 |
+
inputs: sample inputs that the given model takes for inference.
|
| 53 |
+
Will be used to trace the model. For most models, random inputs with
|
| 54 |
+
no detected objects will not work as they lead to wrong traces.
|
| 55 |
+
"""
|
| 56 |
+
assert isinstance(cfg, CfgNode), cfg
|
| 57 |
+
assert isinstance(model, torch.nn.Module), type(model)
|
| 58 |
+
|
| 59 |
+
# TODO make it support custom models, by passing in c2 model directly
|
| 60 |
+
C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
|
| 61 |
+
self.traceable_model = C2MetaArch(cfg, copy.deepcopy(model))
|
| 62 |
+
self.inputs = inputs
|
| 63 |
+
self.traceable_inputs = self.traceable_model.get_caffe2_inputs(inputs)
|
| 64 |
+
|
| 65 |
+
def export_caffe2(self):
|
| 66 |
+
"""
|
| 67 |
+
Export the model to Caffe2's protobuf format.
|
| 68 |
+
The returned object can be saved with its :meth:`.save_protobuf()` method.
|
| 69 |
+
The result can be loaded and executed using Caffe2 runtime.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
:class:`Caffe2Model`
|
| 73 |
+
"""
|
| 74 |
+
from .caffe2_export import export_caffe2_detection_model
|
| 75 |
+
|
| 76 |
+
predict_net, init_net = export_caffe2_detection_model(
|
| 77 |
+
self.traceable_model, self.traceable_inputs
|
| 78 |
+
)
|
| 79 |
+
return Caffe2Model(predict_net, init_net)
|
| 80 |
+
|
| 81 |
+
def export_onnx(self):
|
| 82 |
+
"""
|
| 83 |
+
Export the model to ONNX format.
|
| 84 |
+
Note that the exported model contains custom ops only available in caffe2, therefore it
|
| 85 |
+
cannot be directly executed by other runtime (such as onnxruntime or TensorRT).
|
| 86 |
+
Post-processing or transformation passes may be applied on the model to accommodate
|
| 87 |
+
different runtimes, but we currently do not provide support for them.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
onnx.ModelProto: an onnx model.
|
| 91 |
+
"""
|
| 92 |
+
from .caffe2_export import export_onnx_model as export_onnx_model_impl
|
| 93 |
+
|
| 94 |
+
return export_onnx_model_impl(self.traceable_model, (self.traceable_inputs,))
|
| 95 |
+
|
| 96 |
+
def export_torchscript(self):
|
| 97 |
+
"""
|
| 98 |
+
Export the model to a ``torch.jit.TracedModule`` by tracing.
|
| 99 |
+
The returned object can be saved to a file by ``.save()``.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
torch.jit.TracedModule: a torch TracedModule
|
| 103 |
+
"""
|
| 104 |
+
logger = logging.getLogger(__name__)
|
| 105 |
+
logger.info("Tracing the model with torch.jit.trace ...")
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
return torch.jit.trace(self.traceable_model, (self.traceable_inputs,))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Caffe2Model(nn.Module):
|
| 111 |
+
"""
|
| 112 |
+
A wrapper around the traced model in Caffe2's protobuf format.
|
| 113 |
+
The exported graph has different inputs/outputs from the original Pytorch
|
| 114 |
+
model, as explained in :class:`Caffe2Tracer`. This class wraps around the
|
| 115 |
+
exported graph to simulate the same interface as the original Pytorch model.
|
| 116 |
+
It also provides functions to save/load models in Caffe2's format.'
|
| 117 |
+
|
| 118 |
+
Examples:
|
| 119 |
+
::
|
| 120 |
+
c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2()
|
| 121 |
+
inputs = [{"image": img_tensor_CHW}]
|
| 122 |
+
outputs = c2_model(inputs)
|
| 123 |
+
orig_outputs = torch_model(inputs)
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, predict_net, init_net):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.eval() # always in eval mode
|
| 129 |
+
self._predict_net = predict_net
|
| 130 |
+
self._init_net = init_net
|
| 131 |
+
self._predictor = None
|
| 132 |
+
|
| 133 |
+
__init__.__HIDE_SPHINX_DOC__ = True
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def predict_net(self):
|
| 137 |
+
"""
|
| 138 |
+
caffe2.core.Net: the underlying caffe2 predict net
|
| 139 |
+
"""
|
| 140 |
+
return self._predict_net
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def init_net(self):
|
| 144 |
+
"""
|
| 145 |
+
caffe2.core.Net: the underlying caffe2 init net
|
| 146 |
+
"""
|
| 147 |
+
return self._init_net
|
| 148 |
+
|
| 149 |
+
def save_protobuf(self, output_dir):
|
| 150 |
+
"""
|
| 151 |
+
Save the model as caffe2's protobuf format.
|
| 152 |
+
It saves the following files:
|
| 153 |
+
|
| 154 |
+
* "model.pb": definition of the graph. Can be visualized with
|
| 155 |
+
tools like `netron <https://github.com/lutzroeder/netron>`_.
|
| 156 |
+
* "model_init.pb": model parameters
|
| 157 |
+
* "model.pbtxt": human-readable definition of the graph. Not
|
| 158 |
+
needed for deployment.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
output_dir (str): the output directory to save protobuf files.
|
| 162 |
+
"""
|
| 163 |
+
logger = logging.getLogger(__name__)
|
| 164 |
+
logger.info("Saving model to {} ...".format(output_dir))
|
| 165 |
+
if not PathManager.exists(output_dir):
|
| 166 |
+
PathManager.mkdirs(output_dir)
|
| 167 |
+
|
| 168 |
+
with PathManager.open(os.path.join(output_dir, "model.pb"), "wb") as f:
|
| 169 |
+
f.write(self._predict_net.SerializeToString())
|
| 170 |
+
with PathManager.open(os.path.join(output_dir, "model.pbtxt"), "w") as f:
|
| 171 |
+
f.write(str(self._predict_net))
|
| 172 |
+
with PathManager.open(os.path.join(output_dir, "model_init.pb"), "wb") as f:
|
| 173 |
+
f.write(self._init_net.SerializeToString())
|
| 174 |
+
|
| 175 |
+
def save_graph(self, output_file, inputs=None):
|
| 176 |
+
"""
|
| 177 |
+
Save the graph as SVG format.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
output_file (str): a SVG file
|
| 181 |
+
inputs: optional inputs given to the model.
|
| 182 |
+
If given, the inputs will be used to run the graph to record
|
| 183 |
+
shape of every tensor. The shape information will be
|
| 184 |
+
saved together with the graph.
|
| 185 |
+
"""
|
| 186 |
+
from .caffe2_export import run_and_save_graph
|
| 187 |
+
|
| 188 |
+
if inputs is None:
|
| 189 |
+
save_graph(self._predict_net, output_file, op_only=False)
|
| 190 |
+
else:
|
| 191 |
+
size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0)
|
| 192 |
+
device = get_pb_arg_vals(self._predict_net, "device", b"cpu").decode("ascii")
|
| 193 |
+
inputs = convert_batched_inputs_to_c2_format(inputs, size_divisibility, device)
|
| 194 |
+
inputs = [x.cpu().numpy() for x in inputs]
|
| 195 |
+
run_and_save_graph(self._predict_net, self._init_net, inputs, output_file)
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def load_protobuf(dir):
|
| 199 |
+
"""
|
| 200 |
+
Args:
|
| 201 |
+
dir (str): a directory used to save Caffe2Model with
|
| 202 |
+
:meth:`save_protobuf`.
|
| 203 |
+
The files "model.pb" and "model_init.pb" are needed.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Caffe2Model: the caffe2 model loaded from this directory.
|
| 207 |
+
"""
|
| 208 |
+
predict_net = caffe2_pb2.NetDef()
|
| 209 |
+
with PathManager.open(os.path.join(dir, "model.pb"), "rb") as f:
|
| 210 |
+
predict_net.ParseFromString(f.read())
|
| 211 |
+
|
| 212 |
+
init_net = caffe2_pb2.NetDef()
|
| 213 |
+
with PathManager.open(os.path.join(dir, "model_init.pb"), "rb") as f:
|
| 214 |
+
init_net.ParseFromString(f.read())
|
| 215 |
+
|
| 216 |
+
return Caffe2Model(predict_net, init_net)
|
| 217 |
+
|
| 218 |
+
def __call__(self, inputs):
|
| 219 |
+
"""
|
| 220 |
+
An interface that wraps around a Caffe2 model and mimics detectron2's models'
|
| 221 |
+
input/output format. See details about the format at :doc:`/tutorials/models`.
|
| 222 |
+
This is used to compare the outputs of caffe2 model with its original torch model.
|
| 223 |
+
|
| 224 |
+
Due to the extra conversion between Pytorch/Caffe2, this method is not meant for
|
| 225 |
+
benchmark. Because of the conversion, this method also has dependency
|
| 226 |
+
on detectron2 in order to convert to detectron2's output format.
|
| 227 |
+
"""
|
| 228 |
+
if self._predictor is None:
|
| 229 |
+
self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net)
|
| 230 |
+
return self._predictor(inputs)
|
Leffa/3rdparty/detectron2/export/c10.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Dict
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from detectron2.layers import ShapeSpec, cat
|
| 9 |
+
from detectron2.layers.roi_align_rotated import ROIAlignRotated
|
| 10 |
+
from detectron2.modeling import poolers
|
| 11 |
+
from detectron2.modeling.proposal_generator import rpn
|
| 12 |
+
from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference
|
| 13 |
+
from detectron2.structures import Boxes, ImageList, Instances, Keypoints, RotatedBoxes
|
| 14 |
+
|
| 15 |
+
from .shared import alias, to_device
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This file contains caffe2-compatible implementation of several detectron2 components.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Caffe2Boxes(Boxes):
|
| 24 |
+
"""
|
| 25 |
+
Representing a list of detectron2.structures.Boxes from minibatch, each box
|
| 26 |
+
is represented by a 5d vector (batch index + 4 coordinates), or a 6d vector
|
| 27 |
+
(batch index + 5 coordinates) for RotatedBoxes.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, tensor):
|
| 31 |
+
assert isinstance(tensor, torch.Tensor)
|
| 32 |
+
assert tensor.dim() == 2 and tensor.size(-1) in [4, 5, 6], tensor.size()
|
| 33 |
+
# TODO: make tensor immutable when dim is Nx5 for Boxes,
|
| 34 |
+
# and Nx6 for RotatedBoxes?
|
| 35 |
+
self.tensor = tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# TODO clean up this class, maybe just extend Instances
|
| 39 |
+
class InstancesList:
|
| 40 |
+
"""
|
| 41 |
+
Tensor representation of a list of Instances object for a batch of images.
|
| 42 |
+
|
| 43 |
+
When dealing with a batch of images with Caffe2 ops, a list of bboxes
|
| 44 |
+
(instances) are usually represented by single Tensor with size
|
| 45 |
+
(sigma(Ni), 5) or (sigma(Ni), 4) plus a batch split Tensor. This class is
|
| 46 |
+
for providing common functions to convert between these two representations.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, im_info, indices, extra_fields=None):
|
| 50 |
+
# [N, 3] -> (H, W, Scale)
|
| 51 |
+
self.im_info = im_info
|
| 52 |
+
# [N,] -> indice of batch to which the instance belongs
|
| 53 |
+
self.indices = indices
|
| 54 |
+
# [N, ...]
|
| 55 |
+
self.batch_extra_fields = extra_fields or {}
|
| 56 |
+
|
| 57 |
+
self.image_size = self.im_info
|
| 58 |
+
|
| 59 |
+
def get_fields(self):
|
| 60 |
+
"""like `get_fields` in the Instances object,
|
| 61 |
+
but return each field in tensor representations"""
|
| 62 |
+
ret = {}
|
| 63 |
+
for k, v in self.batch_extra_fields.items():
|
| 64 |
+
# if isinstance(v, torch.Tensor):
|
| 65 |
+
# tensor_rep = v
|
| 66 |
+
# elif isinstance(v, (Boxes, Keypoints)):
|
| 67 |
+
# tensor_rep = v.tensor
|
| 68 |
+
# else:
|
| 69 |
+
# raise ValueError("Can't find tensor representation for: {}".format())
|
| 70 |
+
ret[k] = v
|
| 71 |
+
return ret
|
| 72 |
+
|
| 73 |
+
def has(self, name):
|
| 74 |
+
return name in self.batch_extra_fields
|
| 75 |
+
|
| 76 |
+
def set(self, name, value):
|
| 77 |
+
# len(tensor) is a bad practice that generates ONNX constants during tracing.
|
| 78 |
+
# Although not a problem for the `assert` statement below, torch ONNX exporter
|
| 79 |
+
# still raises a misleading warning as it does not this call comes from `assert`
|
| 80 |
+
if isinstance(value, Boxes):
|
| 81 |
+
data_len = value.tensor.shape[0]
|
| 82 |
+
elif isinstance(value, torch.Tensor):
|
| 83 |
+
data_len = value.shape[0]
|
| 84 |
+
else:
|
| 85 |
+
data_len = len(value)
|
| 86 |
+
if len(self.batch_extra_fields):
|
| 87 |
+
assert (
|
| 88 |
+
len(self) == data_len
|
| 89 |
+
), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
|
| 90 |
+
self.batch_extra_fields[name] = value
|
| 91 |
+
|
| 92 |
+
def __getattr__(self, name):
|
| 93 |
+
if name not in self.batch_extra_fields:
|
| 94 |
+
raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
|
| 95 |
+
return self.batch_extra_fields[name]
|
| 96 |
+
|
| 97 |
+
def __len__(self):
|
| 98 |
+
return len(self.indices)
|
| 99 |
+
|
| 100 |
+
def flatten(self):
|
| 101 |
+
ret = []
|
| 102 |
+
for _, v in self.batch_extra_fields.items():
|
| 103 |
+
if isinstance(v, (Boxes, Keypoints)):
|
| 104 |
+
ret.append(v.tensor)
|
| 105 |
+
else:
|
| 106 |
+
ret.append(v)
|
| 107 |
+
return ret
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def to_d2_instances_list(instances_list):
|
| 111 |
+
"""
|
| 112 |
+
Convert InstancesList to List[Instances]. The input `instances_list` can
|
| 113 |
+
also be a List[Instances], in this case this method is a non-op.
|
| 114 |
+
"""
|
| 115 |
+
if not isinstance(instances_list, InstancesList):
|
| 116 |
+
assert all(isinstance(x, Instances) for x in instances_list)
|
| 117 |
+
return instances_list
|
| 118 |
+
|
| 119 |
+
ret = []
|
| 120 |
+
for i, info in enumerate(instances_list.im_info):
|
| 121 |
+
instances = Instances(torch.Size([int(info[0].item()), int(info[1].item())]))
|
| 122 |
+
|
| 123 |
+
ids = instances_list.indices == i
|
| 124 |
+
for k, v in instances_list.batch_extra_fields.items():
|
| 125 |
+
if isinstance(v, torch.Tensor):
|
| 126 |
+
instances.set(k, v[ids])
|
| 127 |
+
continue
|
| 128 |
+
elif isinstance(v, Boxes):
|
| 129 |
+
instances.set(k, v[ids, -4:])
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
target_type, tensor_source = v
|
| 133 |
+
assert isinstance(tensor_source, torch.Tensor)
|
| 134 |
+
assert tensor_source.shape[0] == instances_list.indices.shape[0]
|
| 135 |
+
tensor_source = tensor_source[ids]
|
| 136 |
+
|
| 137 |
+
if issubclass(target_type, Boxes):
|
| 138 |
+
instances.set(k, Boxes(tensor_source[:, -4:]))
|
| 139 |
+
elif issubclass(target_type, Keypoints):
|
| 140 |
+
instances.set(k, Keypoints(tensor_source))
|
| 141 |
+
elif issubclass(target_type, torch.Tensor):
|
| 142 |
+
instances.set(k, tensor_source)
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError("Can't handle targe type: {}".format(target_type))
|
| 145 |
+
|
| 146 |
+
ret.append(instances)
|
| 147 |
+
return ret
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Caffe2Compatible:
|
| 151 |
+
"""
|
| 152 |
+
A model can inherit this class to indicate that it can be traced and deployed with caffe2.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def _get_tensor_mode(self):
|
| 156 |
+
return self._tensor_mode
|
| 157 |
+
|
| 158 |
+
def _set_tensor_mode(self, v):
|
| 159 |
+
self._tensor_mode = v
|
| 160 |
+
|
| 161 |
+
tensor_mode = property(_get_tensor_mode, _set_tensor_mode)
|
| 162 |
+
"""
|
| 163 |
+
If true, the model expects C2-style tensor only inputs/outputs format.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class Caffe2RPN(Caffe2Compatible, rpn.RPN):
|
| 168 |
+
@classmethod
|
| 169 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
| 170 |
+
ret = super(Caffe2Compatible, cls).from_config(cfg, input_shape)
|
| 171 |
+
assert tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1.0, 1.0, 1.0, 1.0) or tuple(
|
| 172 |
+
cfg.MODEL.RPN.BBOX_REG_WEIGHTS
|
| 173 |
+
) == (1.0, 1.0, 1.0, 1.0, 1.0)
|
| 174 |
+
return ret
|
| 175 |
+
|
| 176 |
+
def _generate_proposals(
|
| 177 |
+
self, images, objectness_logits_pred, anchor_deltas_pred, gt_instances=None
|
| 178 |
+
):
|
| 179 |
+
assert isinstance(images, ImageList)
|
| 180 |
+
if self.tensor_mode:
|
| 181 |
+
im_info = images.image_sizes
|
| 182 |
+
else:
|
| 183 |
+
im_info = torch.tensor([[im_sz[0], im_sz[1], 1.0] for im_sz in images.image_sizes]).to(
|
| 184 |
+
images.tensor.device
|
| 185 |
+
)
|
| 186 |
+
assert isinstance(im_info, torch.Tensor)
|
| 187 |
+
|
| 188 |
+
rpn_rois_list = []
|
| 189 |
+
rpn_roi_probs_list = []
|
| 190 |
+
for scores, bbox_deltas, cell_anchors_tensor, feat_stride in zip(
|
| 191 |
+
objectness_logits_pred,
|
| 192 |
+
anchor_deltas_pred,
|
| 193 |
+
[b for (n, b) in self.anchor_generator.cell_anchors.named_buffers()],
|
| 194 |
+
self.anchor_generator.strides,
|
| 195 |
+
):
|
| 196 |
+
scores = scores.detach()
|
| 197 |
+
bbox_deltas = bbox_deltas.detach()
|
| 198 |
+
|
| 199 |
+
rpn_rois, rpn_roi_probs = torch.ops._caffe2.GenerateProposals(
|
| 200 |
+
scores,
|
| 201 |
+
bbox_deltas,
|
| 202 |
+
im_info,
|
| 203 |
+
cell_anchors_tensor,
|
| 204 |
+
spatial_scale=1.0 / feat_stride,
|
| 205 |
+
pre_nms_topN=self.pre_nms_topk[self.training],
|
| 206 |
+
post_nms_topN=self.post_nms_topk[self.training],
|
| 207 |
+
nms_thresh=self.nms_thresh,
|
| 208 |
+
min_size=self.min_box_size,
|
| 209 |
+
# correct_transform_coords=True, # deprecated argument
|
| 210 |
+
angle_bound_on=True, # Default
|
| 211 |
+
angle_bound_lo=-180,
|
| 212 |
+
angle_bound_hi=180,
|
| 213 |
+
clip_angle_thresh=1.0, # Default
|
| 214 |
+
legacy_plus_one=False,
|
| 215 |
+
)
|
| 216 |
+
rpn_rois_list.append(rpn_rois)
|
| 217 |
+
rpn_roi_probs_list.append(rpn_roi_probs)
|
| 218 |
+
|
| 219 |
+
# For FPN in D2, in RPN all proposals from different levels are concated
|
| 220 |
+
# together, ranked and picked by top post_nms_topk. Then in ROIPooler
|
| 221 |
+
# it calculates level_assignments and calls the RoIAlign from
|
| 222 |
+
# the corresponding level.
|
| 223 |
+
|
| 224 |
+
if len(objectness_logits_pred) == 1:
|
| 225 |
+
rpn_rois = rpn_rois_list[0]
|
| 226 |
+
rpn_roi_probs = rpn_roi_probs_list[0]
|
| 227 |
+
else:
|
| 228 |
+
assert len(rpn_rois_list) == len(rpn_roi_probs_list)
|
| 229 |
+
rpn_post_nms_topN = self.post_nms_topk[self.training]
|
| 230 |
+
|
| 231 |
+
device = rpn_rois_list[0].device
|
| 232 |
+
input_list = [to_device(x, "cpu") for x in (rpn_rois_list + rpn_roi_probs_list)]
|
| 233 |
+
|
| 234 |
+
# TODO remove this after confirming rpn_max_level/rpn_min_level
|
| 235 |
+
# is not needed in CollectRpnProposals.
|
| 236 |
+
feature_strides = list(self.anchor_generator.strides)
|
| 237 |
+
rpn_min_level = int(math.log2(feature_strides[0]))
|
| 238 |
+
rpn_max_level = int(math.log2(feature_strides[-1]))
|
| 239 |
+
assert (rpn_max_level - rpn_min_level + 1) == len(
|
| 240 |
+
rpn_rois_list
|
| 241 |
+
), "CollectRpnProposals requires continuous levels"
|
| 242 |
+
|
| 243 |
+
rpn_rois = torch.ops._caffe2.CollectRpnProposals(
|
| 244 |
+
input_list,
|
| 245 |
+
# NOTE: in current implementation, rpn_max_level and rpn_min_level
|
| 246 |
+
# are not needed, only the subtraction of two matters and it
|
| 247 |
+
# can be infer from the number of inputs. Keep them now for
|
| 248 |
+
# consistency.
|
| 249 |
+
rpn_max_level=2 + len(rpn_rois_list) - 1,
|
| 250 |
+
rpn_min_level=2,
|
| 251 |
+
rpn_post_nms_topN=rpn_post_nms_topN,
|
| 252 |
+
)
|
| 253 |
+
rpn_rois = to_device(rpn_rois, device)
|
| 254 |
+
rpn_roi_probs = []
|
| 255 |
+
|
| 256 |
+
proposals = self.c2_postprocess(im_info, rpn_rois, rpn_roi_probs, self.tensor_mode)
|
| 257 |
+
return proposals, {}
|
| 258 |
+
|
| 259 |
+
def forward(self, images, features, gt_instances=None):
|
| 260 |
+
assert not self.training
|
| 261 |
+
features = [features[f] for f in self.in_features]
|
| 262 |
+
objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features)
|
| 263 |
+
return self._generate_proposals(
|
| 264 |
+
images,
|
| 265 |
+
objectness_logits_pred,
|
| 266 |
+
anchor_deltas_pred,
|
| 267 |
+
gt_instances,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def c2_postprocess(im_info, rpn_rois, rpn_roi_probs, tensor_mode):
|
| 272 |
+
proposals = InstancesList(
|
| 273 |
+
im_info=im_info,
|
| 274 |
+
indices=rpn_rois[:, 0],
|
| 275 |
+
extra_fields={
|
| 276 |
+
"proposal_boxes": Caffe2Boxes(rpn_rois),
|
| 277 |
+
"objectness_logits": (torch.Tensor, rpn_roi_probs),
|
| 278 |
+
},
|
| 279 |
+
)
|
| 280 |
+
if not tensor_mode:
|
| 281 |
+
proposals = InstancesList.to_d2_instances_list(proposals)
|
| 282 |
+
else:
|
| 283 |
+
proposals = [proposals]
|
| 284 |
+
return proposals
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class Caffe2ROIPooler(Caffe2Compatible, poolers.ROIPooler):
|
| 288 |
+
@staticmethod
|
| 289 |
+
def c2_preprocess(box_lists):
|
| 290 |
+
assert all(isinstance(x, Boxes) for x in box_lists)
|
| 291 |
+
if all(isinstance(x, Caffe2Boxes) for x in box_lists):
|
| 292 |
+
# input is pure-tensor based
|
| 293 |
+
assert len(box_lists) == 1
|
| 294 |
+
pooler_fmt_boxes = box_lists[0].tensor
|
| 295 |
+
else:
|
| 296 |
+
pooler_fmt_boxes = poolers.convert_boxes_to_pooler_format(box_lists)
|
| 297 |
+
return pooler_fmt_boxes
|
| 298 |
+
|
| 299 |
+
def forward(self, x, box_lists):
|
| 300 |
+
assert not self.training
|
| 301 |
+
|
| 302 |
+
pooler_fmt_boxes = self.c2_preprocess(box_lists)
|
| 303 |
+
num_level_assignments = len(self.level_poolers)
|
| 304 |
+
|
| 305 |
+
if num_level_assignments == 1:
|
| 306 |
+
if isinstance(self.level_poolers[0], ROIAlignRotated):
|
| 307 |
+
c2_roi_align = torch.ops._caffe2.RoIAlignRotated
|
| 308 |
+
aligned = True
|
| 309 |
+
else:
|
| 310 |
+
c2_roi_align = torch.ops._caffe2.RoIAlign
|
| 311 |
+
aligned = self.level_poolers[0].aligned
|
| 312 |
+
|
| 313 |
+
x0 = x[0]
|
| 314 |
+
if x0.is_quantized:
|
| 315 |
+
x0 = x0.dequantize()
|
| 316 |
+
|
| 317 |
+
out = c2_roi_align(
|
| 318 |
+
x0,
|
| 319 |
+
pooler_fmt_boxes,
|
| 320 |
+
order="NCHW",
|
| 321 |
+
spatial_scale=float(self.level_poolers[0].spatial_scale),
|
| 322 |
+
pooled_h=int(self.output_size[0]),
|
| 323 |
+
pooled_w=int(self.output_size[1]),
|
| 324 |
+
sampling_ratio=int(self.level_poolers[0].sampling_ratio),
|
| 325 |
+
aligned=aligned,
|
| 326 |
+
)
|
| 327 |
+
return out
|
| 328 |
+
|
| 329 |
+
device = pooler_fmt_boxes.device
|
| 330 |
+
assert (
|
| 331 |
+
self.max_level - self.min_level + 1 == 4
|
| 332 |
+
), "Currently DistributeFpnProposals only support 4 levels"
|
| 333 |
+
fpn_outputs = torch.ops._caffe2.DistributeFpnProposals(
|
| 334 |
+
to_device(pooler_fmt_boxes, "cpu"),
|
| 335 |
+
roi_canonical_scale=self.canonical_box_size,
|
| 336 |
+
roi_canonical_level=self.canonical_level,
|
| 337 |
+
roi_max_level=self.max_level,
|
| 338 |
+
roi_min_level=self.min_level,
|
| 339 |
+
legacy_plus_one=False,
|
| 340 |
+
)
|
| 341 |
+
fpn_outputs = [to_device(x, device) for x in fpn_outputs]
|
| 342 |
+
|
| 343 |
+
rois_fpn_list = fpn_outputs[:-1]
|
| 344 |
+
rois_idx_restore_int32 = fpn_outputs[-1]
|
| 345 |
+
|
| 346 |
+
roi_feat_fpn_list = []
|
| 347 |
+
for roi_fpn, x_level, pooler in zip(rois_fpn_list, x, self.level_poolers):
|
| 348 |
+
if isinstance(pooler, ROIAlignRotated):
|
| 349 |
+
c2_roi_align = torch.ops._caffe2.RoIAlignRotated
|
| 350 |
+
aligned = True
|
| 351 |
+
else:
|
| 352 |
+
c2_roi_align = torch.ops._caffe2.RoIAlign
|
| 353 |
+
aligned = bool(pooler.aligned)
|
| 354 |
+
|
| 355 |
+
if x_level.is_quantized:
|
| 356 |
+
x_level = x_level.dequantize()
|
| 357 |
+
|
| 358 |
+
roi_feat_fpn = c2_roi_align(
|
| 359 |
+
x_level,
|
| 360 |
+
roi_fpn,
|
| 361 |
+
order="NCHW",
|
| 362 |
+
spatial_scale=float(pooler.spatial_scale),
|
| 363 |
+
pooled_h=int(self.output_size[0]),
|
| 364 |
+
pooled_w=int(self.output_size[1]),
|
| 365 |
+
sampling_ratio=int(pooler.sampling_ratio),
|
| 366 |
+
aligned=aligned,
|
| 367 |
+
)
|
| 368 |
+
roi_feat_fpn_list.append(roi_feat_fpn)
|
| 369 |
+
|
| 370 |
+
roi_feat_shuffled = cat(roi_feat_fpn_list, dim=0)
|
| 371 |
+
assert roi_feat_shuffled.numel() > 0 and rois_idx_restore_int32.numel() > 0, (
|
| 372 |
+
"Caffe2 export requires tracing with a model checkpoint + input that can produce valid"
|
| 373 |
+
" detections. But no detections were obtained with the given checkpoint and input!"
|
| 374 |
+
)
|
| 375 |
+
roi_feat = torch.ops._caffe2.BatchPermutation(roi_feat_shuffled, rois_idx_restore_int32)
|
| 376 |
+
return roi_feat
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def caffe2_fast_rcnn_outputs_inference(tensor_mode, box_predictor, predictions, proposals):
|
| 380 |
+
"""equivalent to FastRCNNOutputLayers.inference"""
|
| 381 |
+
num_classes = box_predictor.num_classes
|
| 382 |
+
score_thresh = box_predictor.test_score_thresh
|
| 383 |
+
nms_thresh = box_predictor.test_nms_thresh
|
| 384 |
+
topk_per_image = box_predictor.test_topk_per_image
|
| 385 |
+
is_rotated = len(box_predictor.box2box_transform.weights) == 5
|
| 386 |
+
|
| 387 |
+
if is_rotated:
|
| 388 |
+
box_dim = 5
|
| 389 |
+
assert box_predictor.box2box_transform.weights[4] == 1, (
|
| 390 |
+
"The weights for Rotated BBoxTransform in C2 have only 4 dimensions,"
|
| 391 |
+
+ " thus enforcing the angle weight to be 1 for now"
|
| 392 |
+
)
|
| 393 |
+
box2box_transform_weights = box_predictor.box2box_transform.weights[:4]
|
| 394 |
+
else:
|
| 395 |
+
box_dim = 4
|
| 396 |
+
box2box_transform_weights = box_predictor.box2box_transform.weights
|
| 397 |
+
|
| 398 |
+
class_logits, box_regression = predictions
|
| 399 |
+
if num_classes + 1 == class_logits.shape[1]:
|
| 400 |
+
class_prob = F.softmax(class_logits, -1)
|
| 401 |
+
else:
|
| 402 |
+
assert num_classes == class_logits.shape[1]
|
| 403 |
+
class_prob = F.sigmoid(class_logits)
|
| 404 |
+
# BoxWithNMSLimit will infer num_classes from the shape of the class_prob
|
| 405 |
+
# So append a zero column as placeholder for the background class
|
| 406 |
+
class_prob = torch.cat((class_prob, torch.zeros(class_prob.shape[0], 1)), dim=1)
|
| 407 |
+
|
| 408 |
+
assert box_regression.shape[1] % box_dim == 0
|
| 409 |
+
cls_agnostic_bbox_reg = box_regression.shape[1] // box_dim == 1
|
| 410 |
+
|
| 411 |
+
input_tensor_mode = proposals[0].proposal_boxes.tensor.shape[1] == box_dim + 1
|
| 412 |
+
|
| 413 |
+
proposal_boxes = proposals[0].proposal_boxes
|
| 414 |
+
if isinstance(proposal_boxes, Caffe2Boxes):
|
| 415 |
+
rois = Caffe2Boxes.cat([p.proposal_boxes for p in proposals])
|
| 416 |
+
elif isinstance(proposal_boxes, RotatedBoxes):
|
| 417 |
+
rois = RotatedBoxes.cat([p.proposal_boxes for p in proposals])
|
| 418 |
+
elif isinstance(proposal_boxes, Boxes):
|
| 419 |
+
rois = Boxes.cat([p.proposal_boxes for p in proposals])
|
| 420 |
+
else:
|
| 421 |
+
raise NotImplementedError(
|
| 422 |
+
'Expected proposals[0].proposal_boxes to be type "Boxes", '
|
| 423 |
+
f"instead got {type(proposal_boxes)}"
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
device, dtype = rois.tensor.device, rois.tensor.dtype
|
| 427 |
+
if input_tensor_mode:
|
| 428 |
+
im_info = proposals[0].image_size
|
| 429 |
+
rois = rois.tensor
|
| 430 |
+
else:
|
| 431 |
+
im_info = torch.tensor([[sz[0], sz[1], 1.0] for sz in [x.image_size for x in proposals]])
|
| 432 |
+
batch_ids = cat(
|
| 433 |
+
[
|
| 434 |
+
torch.full((b, 1), i, dtype=dtype, device=device)
|
| 435 |
+
for i, b in enumerate(len(p) for p in proposals)
|
| 436 |
+
],
|
| 437 |
+
dim=0,
|
| 438 |
+
)
|
| 439 |
+
rois = torch.cat([batch_ids, rois.tensor], dim=1)
|
| 440 |
+
|
| 441 |
+
roi_pred_bbox, roi_batch_splits = torch.ops._caffe2.BBoxTransform(
|
| 442 |
+
to_device(rois, "cpu"),
|
| 443 |
+
to_device(box_regression, "cpu"),
|
| 444 |
+
to_device(im_info, "cpu"),
|
| 445 |
+
weights=box2box_transform_weights,
|
| 446 |
+
apply_scale=True,
|
| 447 |
+
rotated=is_rotated,
|
| 448 |
+
angle_bound_on=True,
|
| 449 |
+
angle_bound_lo=-180,
|
| 450 |
+
angle_bound_hi=180,
|
| 451 |
+
clip_angle_thresh=1.0,
|
| 452 |
+
legacy_plus_one=False,
|
| 453 |
+
)
|
| 454 |
+
roi_pred_bbox = to_device(roi_pred_bbox, device)
|
| 455 |
+
roi_batch_splits = to_device(roi_batch_splits, device)
|
| 456 |
+
|
| 457 |
+
nms_outputs = torch.ops._caffe2.BoxWithNMSLimit(
|
| 458 |
+
to_device(class_prob, "cpu"),
|
| 459 |
+
to_device(roi_pred_bbox, "cpu"),
|
| 460 |
+
to_device(roi_batch_splits, "cpu"),
|
| 461 |
+
score_thresh=float(score_thresh),
|
| 462 |
+
nms=float(nms_thresh),
|
| 463 |
+
detections_per_im=int(topk_per_image),
|
| 464 |
+
soft_nms_enabled=False,
|
| 465 |
+
soft_nms_method="linear",
|
| 466 |
+
soft_nms_sigma=0.5,
|
| 467 |
+
soft_nms_min_score_thres=0.001,
|
| 468 |
+
rotated=is_rotated,
|
| 469 |
+
cls_agnostic_bbox_reg=cls_agnostic_bbox_reg,
|
| 470 |
+
input_boxes_include_bg_cls=False,
|
| 471 |
+
output_classes_include_bg_cls=False,
|
| 472 |
+
legacy_plus_one=False,
|
| 473 |
+
)
|
| 474 |
+
roi_score_nms = to_device(nms_outputs[0], device)
|
| 475 |
+
roi_bbox_nms = to_device(nms_outputs[1], device)
|
| 476 |
+
roi_class_nms = to_device(nms_outputs[2], device)
|
| 477 |
+
roi_batch_splits_nms = to_device(nms_outputs[3], device)
|
| 478 |
+
roi_keeps_nms = to_device(nms_outputs[4], device)
|
| 479 |
+
roi_keeps_size_nms = to_device(nms_outputs[5], device)
|
| 480 |
+
if not tensor_mode:
|
| 481 |
+
roi_class_nms = roi_class_nms.to(torch.int64)
|
| 482 |
+
|
| 483 |
+
roi_batch_ids = cat(
|
| 484 |
+
[
|
| 485 |
+
torch.full((b, 1), i, dtype=dtype, device=device)
|
| 486 |
+
for i, b in enumerate(int(x.item()) for x in roi_batch_splits_nms)
|
| 487 |
+
],
|
| 488 |
+
dim=0,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
roi_class_nms = alias(roi_class_nms, "class_nms")
|
| 492 |
+
roi_score_nms = alias(roi_score_nms, "score_nms")
|
| 493 |
+
roi_bbox_nms = alias(roi_bbox_nms, "bbox_nms")
|
| 494 |
+
roi_batch_splits_nms = alias(roi_batch_splits_nms, "batch_splits_nms")
|
| 495 |
+
roi_keeps_nms = alias(roi_keeps_nms, "keeps_nms")
|
| 496 |
+
roi_keeps_size_nms = alias(roi_keeps_size_nms, "keeps_size_nms")
|
| 497 |
+
|
| 498 |
+
results = InstancesList(
|
| 499 |
+
im_info=im_info,
|
| 500 |
+
indices=roi_batch_ids[:, 0],
|
| 501 |
+
extra_fields={
|
| 502 |
+
"pred_boxes": Caffe2Boxes(roi_bbox_nms),
|
| 503 |
+
"scores": roi_score_nms,
|
| 504 |
+
"pred_classes": roi_class_nms,
|
| 505 |
+
},
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
if not tensor_mode:
|
| 509 |
+
results = InstancesList.to_d2_instances_list(results)
|
| 510 |
+
batch_splits = roi_batch_splits_nms.int().tolist()
|
| 511 |
+
kept_indices = list(roi_keeps_nms.to(torch.int64).split(batch_splits))
|
| 512 |
+
else:
|
| 513 |
+
results = [results]
|
| 514 |
+
kept_indices = [roi_keeps_nms]
|
| 515 |
+
|
| 516 |
+
return results, kept_indices
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class Caffe2FastRCNNOutputsInference:
|
| 520 |
+
def __init__(self, tensor_mode):
|
| 521 |
+
self.tensor_mode = tensor_mode # whether the output is caffe2 tensor mode
|
| 522 |
+
|
| 523 |
+
def __call__(self, box_predictor, predictions, proposals):
|
| 524 |
+
return caffe2_fast_rcnn_outputs_inference(
|
| 525 |
+
self.tensor_mode, box_predictor, predictions, proposals
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances):
|
| 530 |
+
"""equivalent to mask_head.mask_rcnn_inference"""
|
| 531 |
+
if all(isinstance(x, InstancesList) for x in pred_instances):
|
| 532 |
+
assert len(pred_instances) == 1
|
| 533 |
+
mask_probs_pred = pred_mask_logits.sigmoid()
|
| 534 |
+
mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs")
|
| 535 |
+
pred_instances[0].set("pred_masks", mask_probs_pred)
|
| 536 |
+
else:
|
| 537 |
+
mask_rcnn_inference(pred_mask_logits, pred_instances)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class Caffe2MaskRCNNInference:
|
| 541 |
+
def __call__(self, pred_mask_logits, pred_instances):
|
| 542 |
+
return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def caffe2_keypoint_rcnn_inference(use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances):
|
| 546 |
+
# just return the keypoint heatmap for now,
|
| 547 |
+
# there will be option to call HeatmapMaxKeypointOp
|
| 548 |
+
output = alias(pred_keypoint_logits, "kps_score")
|
| 549 |
+
if all(isinstance(x, InstancesList) for x in pred_instances):
|
| 550 |
+
assert len(pred_instances) == 1
|
| 551 |
+
if use_heatmap_max_keypoint:
|
| 552 |
+
device = output.device
|
| 553 |
+
output = torch.ops._caffe2.HeatmapMaxKeypoint(
|
| 554 |
+
to_device(output, "cpu"),
|
| 555 |
+
pred_instances[0].pred_boxes.tensor,
|
| 556 |
+
should_output_softmax=True, # worth make it configerable?
|
| 557 |
+
)
|
| 558 |
+
output = to_device(output, device)
|
| 559 |
+
output = alias(output, "keypoints_out")
|
| 560 |
+
pred_instances[0].set("pred_keypoints", output)
|
| 561 |
+
return pred_keypoint_logits
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class Caffe2KeypointRCNNInference:
|
| 565 |
+
def __init__(self, use_heatmap_max_keypoint):
|
| 566 |
+
self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
|
| 567 |
+
|
| 568 |
+
def __call__(self, pred_keypoint_logits, pred_instances):
|
| 569 |
+
return caffe2_keypoint_rcnn_inference(
|
| 570 |
+
self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances
|
| 571 |
+
)
|
Leffa/3rdparty/detectron2/export/caffe2_export.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import io
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List
|
| 8 |
+
import onnx
|
| 9 |
+
import onnx.optimizer
|
| 10 |
+
import torch
|
| 11 |
+
from caffe2.proto import caffe2_pb2
|
| 12 |
+
from caffe2.python import core
|
| 13 |
+
from caffe2.python.onnx.backend import Caffe2Backend
|
| 14 |
+
from tabulate import tabulate
|
| 15 |
+
from termcolor import colored
|
| 16 |
+
from torch.onnx import OperatorExportTypes
|
| 17 |
+
|
| 18 |
+
from .shared import (
|
| 19 |
+
ScopedWS,
|
| 20 |
+
construct_init_net_from_params,
|
| 21 |
+
fuse_alias_placeholder,
|
| 22 |
+
fuse_copy_between_cpu_and_gpu,
|
| 23 |
+
get_params_from_init_net,
|
| 24 |
+
group_norm_replace_aten_with_caffe2,
|
| 25 |
+
infer_device_type,
|
| 26 |
+
remove_dead_end_ops,
|
| 27 |
+
remove_reshape_for_fc,
|
| 28 |
+
save_graph,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def export_onnx_model(model, inputs):
|
| 35 |
+
"""
|
| 36 |
+
Trace and export a model to onnx format.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
model (nn.Module):
|
| 40 |
+
inputs (tuple[args]): the model will be called by `model(*inputs)`
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
an onnx model
|
| 44 |
+
"""
|
| 45 |
+
assert isinstance(model, torch.nn.Module)
|
| 46 |
+
|
| 47 |
+
# make sure all modules are in eval mode, onnx may change the training state
|
| 48 |
+
# of the module if the states are not consistent
|
| 49 |
+
def _check_eval(module):
|
| 50 |
+
assert not module.training
|
| 51 |
+
|
| 52 |
+
model.apply(_check_eval)
|
| 53 |
+
|
| 54 |
+
# Export the model to ONNX
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
with io.BytesIO() as f:
|
| 57 |
+
torch.onnx.export(
|
| 58 |
+
model,
|
| 59 |
+
inputs,
|
| 60 |
+
f,
|
| 61 |
+
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
| 62 |
+
# verbose=True, # NOTE: uncomment this for debugging
|
| 63 |
+
# export_params=True,
|
| 64 |
+
)
|
| 65 |
+
onnx_model = onnx.load_from_string(f.getvalue())
|
| 66 |
+
|
| 67 |
+
return onnx_model
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _op_stats(net_def):
|
| 71 |
+
type_count = {}
|
| 72 |
+
for t in [op.type for op in net_def.op]:
|
| 73 |
+
type_count[t] = type_count.get(t, 0) + 1
|
| 74 |
+
type_count_list = sorted(type_count.items(), key=lambda kv: kv[0]) # alphabet
|
| 75 |
+
type_count_list = sorted(type_count_list, key=lambda kv: -kv[1]) # count
|
| 76 |
+
return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _assign_device_option(
|
| 80 |
+
predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor]
|
| 81 |
+
):
|
| 82 |
+
"""
|
| 83 |
+
ONNX exported network doesn't have concept of device, assign necessary
|
| 84 |
+
device option for each op in order to make it runable on GPU runtime.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def _get_device_type(torch_tensor):
|
| 88 |
+
assert torch_tensor.device.type in ["cpu", "cuda"]
|
| 89 |
+
assert torch_tensor.device.index == 0
|
| 90 |
+
return torch_tensor.device.type
|
| 91 |
+
|
| 92 |
+
def _assign_op_device_option(net_proto, net_ssa, blob_device_types):
|
| 93 |
+
for op, ssa_i in zip(net_proto.op, net_ssa):
|
| 94 |
+
if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]:
|
| 95 |
+
op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
|
| 96 |
+
else:
|
| 97 |
+
devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]]
|
| 98 |
+
assert all(d == devices[0] for d in devices)
|
| 99 |
+
if devices[0] == "cuda":
|
| 100 |
+
op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
|
| 101 |
+
|
| 102 |
+
# update ops in predict_net
|
| 103 |
+
predict_net_input_device_types = {
|
| 104 |
+
(name, 0): _get_device_type(tensor)
|
| 105 |
+
for name, tensor in zip(predict_net.external_input, tensor_inputs)
|
| 106 |
+
}
|
| 107 |
+
predict_net_device_types = infer_device_type(
|
| 108 |
+
predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch"
|
| 109 |
+
)
|
| 110 |
+
predict_net_ssa, _ = core.get_ssa(predict_net)
|
| 111 |
+
_assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types)
|
| 112 |
+
|
| 113 |
+
# update ops in init_net
|
| 114 |
+
init_net_ssa, versions = core.get_ssa(init_net)
|
| 115 |
+
init_net_output_device_types = {
|
| 116 |
+
(name, versions[name]): predict_net_device_types[(name, 0)]
|
| 117 |
+
for name in init_net.external_output
|
| 118 |
+
}
|
| 119 |
+
init_net_device_types = infer_device_type(
|
| 120 |
+
init_net, known_status=init_net_output_device_types, device_name_style="pytorch"
|
| 121 |
+
)
|
| 122 |
+
_assign_op_device_option(init_net, init_net_ssa, init_net_device_types)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]):
|
| 126 |
+
"""
|
| 127 |
+
Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX.
|
| 128 |
+
|
| 129 |
+
Arg:
|
| 130 |
+
model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py
|
| 131 |
+
tensor_inputs: a list of tensors that caffe2 model takes as input.
|
| 132 |
+
"""
|
| 133 |
+
model = copy.deepcopy(model)
|
| 134 |
+
assert isinstance(model, torch.nn.Module)
|
| 135 |
+
assert hasattr(model, "encode_additional_info")
|
| 136 |
+
|
| 137 |
+
# Export via ONNX
|
| 138 |
+
logger.info(
|
| 139 |
+
"Exporting a {} model via ONNX ...".format(type(model).__name__)
|
| 140 |
+
+ " Some warnings from ONNX are expected and are usually not to worry about."
|
| 141 |
+
)
|
| 142 |
+
onnx_model = export_onnx_model(model, (tensor_inputs,))
|
| 143 |
+
# Convert ONNX model to Caffe2 protobuf
|
| 144 |
+
init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
|
| 145 |
+
ops_table = [[op.type, op.input, op.output] for op in predict_net.op]
|
| 146 |
+
table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe")
|
| 147 |
+
logger.info(
|
| 148 |
+
"ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan")
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Apply protobuf optimization
|
| 152 |
+
fuse_alias_placeholder(predict_net, init_net)
|
| 153 |
+
if any(t.device.type != "cpu" for t in tensor_inputs):
|
| 154 |
+
fuse_copy_between_cpu_and_gpu(predict_net)
|
| 155 |
+
remove_dead_end_ops(init_net)
|
| 156 |
+
_assign_device_option(predict_net, init_net, tensor_inputs)
|
| 157 |
+
params, device_options = get_params_from_init_net(init_net)
|
| 158 |
+
predict_net, params = remove_reshape_for_fc(predict_net, params)
|
| 159 |
+
init_net = construct_init_net_from_params(params, device_options)
|
| 160 |
+
group_norm_replace_aten_with_caffe2(predict_net)
|
| 161 |
+
|
| 162 |
+
# Record necessary information for running the pb model in Detectron2 system.
|
| 163 |
+
model.encode_additional_info(predict_net, init_net)
|
| 164 |
+
|
| 165 |
+
logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net)))
|
| 166 |
+
logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net)))
|
| 167 |
+
|
| 168 |
+
return predict_net, init_net
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path):
|
| 172 |
+
"""
|
| 173 |
+
Run the caffe2 model on given inputs, recording the shape and draw the graph.
|
| 174 |
+
|
| 175 |
+
predict_net/init_net: caffe2 model.
|
| 176 |
+
tensor_inputs: a list of tensors that caffe2 model takes as input.
|
| 177 |
+
graph_save_path: path for saving graph of exported model.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
|
| 181 |
+
save_graph(predict_net, graph_save_path, op_only=False)
|
| 182 |
+
|
| 183 |
+
# Run the exported Caffe2 net
|
| 184 |
+
logger.info("Running ONNX exported model ...")
|
| 185 |
+
with ScopedWS("__ws_tmp__", True) as ws:
|
| 186 |
+
ws.RunNetOnce(init_net)
|
| 187 |
+
initialized_blobs = set(ws.Blobs())
|
| 188 |
+
uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs]
|
| 189 |
+
for name, blob in zip(uninitialized, tensor_inputs):
|
| 190 |
+
ws.FeedBlob(name, blob)
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
ws.RunNetOnce(predict_net)
|
| 194 |
+
except RuntimeError as e:
|
| 195 |
+
logger.warning("Encountered RuntimeError: \n{}".format(str(e)))
|
| 196 |
+
|
| 197 |
+
ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()}
|
| 198 |
+
blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)}
|
| 199 |
+
|
| 200 |
+
logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path))
|
| 201 |
+
save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes)
|
| 202 |
+
|
| 203 |
+
return ws_blobs
|
Leffa/3rdparty/detectron2/export/caffe2_inference.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
from itertools import count
|
| 6 |
+
import torch
|
| 7 |
+
from caffe2.proto import caffe2_pb2
|
| 8 |
+
from caffe2.python import core
|
| 9 |
+
|
| 10 |
+
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
|
| 11 |
+
from .shared import ScopedWS, get_pb_arg_vali, get_pb_arg_vals, infer_device_type
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ===== ref: mobile-vision predictor's 'Caffe2Wrapper' class ======
|
| 17 |
+
class ProtobufModel(torch.nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Wrapper of a caffe2's protobuf model.
|
| 20 |
+
It works just like nn.Module, but running caffe2 under the hood.
|
| 21 |
+
Input/Output are tuple[tensor] that match the caffe2 net's external_input/output.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
_ids = count(0)
|
| 25 |
+
|
| 26 |
+
def __init__(self, predict_net, init_net):
|
| 27 |
+
logger.info(f"Initializing ProtobufModel for: {predict_net.name} ...")
|
| 28 |
+
super().__init__()
|
| 29 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
| 30 |
+
assert isinstance(init_net, caffe2_pb2.NetDef)
|
| 31 |
+
# create unique temporary workspace for each instance
|
| 32 |
+
self.ws_name = "__tmp_ProtobufModel_{}__".format(next(self._ids))
|
| 33 |
+
self.net = core.Net(predict_net)
|
| 34 |
+
|
| 35 |
+
logger.info("Running init_net once to fill the parameters ...")
|
| 36 |
+
with ScopedWS(self.ws_name, is_reset=True, is_cleanup=False) as ws:
|
| 37 |
+
ws.RunNetOnce(init_net)
|
| 38 |
+
uninitialized_external_input = []
|
| 39 |
+
for blob in self.net.Proto().external_input:
|
| 40 |
+
if blob not in ws.Blobs():
|
| 41 |
+
uninitialized_external_input.append(blob)
|
| 42 |
+
ws.CreateBlob(blob)
|
| 43 |
+
ws.CreateNet(self.net)
|
| 44 |
+
|
| 45 |
+
self._error_msgs = set()
|
| 46 |
+
self._input_blobs = uninitialized_external_input
|
| 47 |
+
|
| 48 |
+
def _infer_output_devices(self, inputs):
|
| 49 |
+
"""
|
| 50 |
+
Returns:
|
| 51 |
+
list[str]: list of device for each external output
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def _get_device_type(torch_tensor):
|
| 55 |
+
assert torch_tensor.device.type in ["cpu", "cuda"]
|
| 56 |
+
assert torch_tensor.device.index == 0
|
| 57 |
+
return torch_tensor.device.type
|
| 58 |
+
|
| 59 |
+
predict_net = self.net.Proto()
|
| 60 |
+
input_device_types = {
|
| 61 |
+
(name, 0): _get_device_type(tensor) for name, tensor in zip(self._input_blobs, inputs)
|
| 62 |
+
}
|
| 63 |
+
device_type_map = infer_device_type(
|
| 64 |
+
predict_net, known_status=input_device_types, device_name_style="pytorch"
|
| 65 |
+
)
|
| 66 |
+
ssa, versions = core.get_ssa(predict_net)
|
| 67 |
+
versioned_outputs = [(name, versions[name]) for name in predict_net.external_output]
|
| 68 |
+
output_devices = [device_type_map[outp] for outp in versioned_outputs]
|
| 69 |
+
return output_devices
|
| 70 |
+
|
| 71 |
+
def forward(self, inputs):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
inputs (tuple[torch.Tensor])
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
tuple[torch.Tensor]
|
| 78 |
+
"""
|
| 79 |
+
assert len(inputs) == len(self._input_blobs), (
|
| 80 |
+
f"Length of inputs ({len(inputs)}) "
|
| 81 |
+
f"doesn't match the required input blobs: {self._input_blobs}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
with ScopedWS(self.ws_name, is_reset=False, is_cleanup=False) as ws:
|
| 85 |
+
for b, tensor in zip(self._input_blobs, inputs):
|
| 86 |
+
ws.FeedBlob(b, tensor)
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
ws.RunNet(self.net.Proto().name)
|
| 90 |
+
except RuntimeError as e:
|
| 91 |
+
if not str(e) in self._error_msgs:
|
| 92 |
+
self._error_msgs.add(str(e))
|
| 93 |
+
logger.warning("Encountered new RuntimeError: \n{}".format(str(e)))
|
| 94 |
+
logger.warning("Catch the error and use partial results.")
|
| 95 |
+
|
| 96 |
+
c2_outputs = [ws.FetchBlob(b) for b in self.net.Proto().external_output]
|
| 97 |
+
# Remove outputs of current run, this is necessary in order to
|
| 98 |
+
# prevent fetching the result from previous run if the model fails
|
| 99 |
+
# in the middle.
|
| 100 |
+
for b in self.net.Proto().external_output:
|
| 101 |
+
# Needs to create uninitialized blob to make the net runable.
|
| 102 |
+
# This is "equivalent" to: ws.RemoveBlob(b) then ws.CreateBlob(b),
|
| 103 |
+
# but there'no such API.
|
| 104 |
+
ws.FeedBlob(b, f"{b}, a C++ native class of type nullptr (uninitialized).")
|
| 105 |
+
|
| 106 |
+
# Cast output to torch.Tensor on the desired device
|
| 107 |
+
output_devices = (
|
| 108 |
+
self._infer_output_devices(inputs)
|
| 109 |
+
if any(t.device.type != "cpu" for t in inputs)
|
| 110 |
+
else ["cpu" for _ in self.net.Proto().external_output]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
outputs = []
|
| 114 |
+
for name, c2_output, device in zip(
|
| 115 |
+
self.net.Proto().external_output, c2_outputs, output_devices
|
| 116 |
+
):
|
| 117 |
+
if not isinstance(c2_output, np.ndarray):
|
| 118 |
+
raise RuntimeError(
|
| 119 |
+
"Invalid output for blob {}, received: {}".format(name, c2_output)
|
| 120 |
+
)
|
| 121 |
+
outputs.append(torch.tensor(c2_output).to(device=device))
|
| 122 |
+
return tuple(outputs)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class ProtobufDetectionModel(torch.nn.Module):
|
| 126 |
+
"""
|
| 127 |
+
A class works just like a pytorch meta arch in terms of inference, but running
|
| 128 |
+
caffe2 model under the hood.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, predict_net, init_net, *, convert_outputs=None):
|
| 132 |
+
"""
|
| 133 |
+
Args:
|
| 134 |
+
predict_net, init_net (core.Net): caffe2 nets
|
| 135 |
+
convert_outptus (callable): a function that converts caffe2
|
| 136 |
+
outputs to the same format of the original pytorch model.
|
| 137 |
+
By default, use the one defined in the caffe2 meta_arch.
|
| 138 |
+
"""
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.protobuf_model = ProtobufModel(predict_net, init_net)
|
| 141 |
+
self.size_divisibility = get_pb_arg_vali(predict_net, "size_divisibility", 0)
|
| 142 |
+
self.device = get_pb_arg_vals(predict_net, "device", b"cpu").decode("ascii")
|
| 143 |
+
|
| 144 |
+
if convert_outputs is None:
|
| 145 |
+
meta_arch = get_pb_arg_vals(predict_net, "meta_architecture", b"GeneralizedRCNN")
|
| 146 |
+
meta_arch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[meta_arch.decode("ascii")]
|
| 147 |
+
self._convert_outputs = meta_arch.get_outputs_converter(predict_net, init_net)
|
| 148 |
+
else:
|
| 149 |
+
self._convert_outputs = convert_outputs
|
| 150 |
+
|
| 151 |
+
def _convert_inputs(self, batched_inputs):
|
| 152 |
+
# currently all models convert inputs in the same way
|
| 153 |
+
return convert_batched_inputs_to_c2_format(
|
| 154 |
+
batched_inputs, self.size_divisibility, self.device
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def forward(self, batched_inputs):
|
| 158 |
+
c2_inputs = self._convert_inputs(batched_inputs)
|
| 159 |
+
c2_results = self.protobuf_model(c2_inputs)
|
| 160 |
+
c2_results = dict(zip(self.protobuf_model.net.Proto().external_output, c2_results))
|
| 161 |
+
return self._convert_outputs(batched_inputs, c2_inputs, c2_results)
|
Leffa/3rdparty/detectron2/export/caffe2_modeling.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import io
|
| 5 |
+
import struct
|
| 6 |
+
import types
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from detectron2.modeling import meta_arch
|
| 10 |
+
from detectron2.modeling.box_regression import Box2BoxTransform
|
| 11 |
+
from detectron2.modeling.roi_heads import keypoint_head
|
| 12 |
+
from detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes
|
| 13 |
+
|
| 14 |
+
from .c10 import Caffe2Compatible
|
| 15 |
+
from .caffe2_patch import ROIHeadsPatcher, patch_generalized_rcnn
|
| 16 |
+
from .shared import (
|
| 17 |
+
alias,
|
| 18 |
+
check_set_pb_arg,
|
| 19 |
+
get_pb_arg_floats,
|
| 20 |
+
get_pb_arg_valf,
|
| 21 |
+
get_pb_arg_vali,
|
| 22 |
+
get_pb_arg_vals,
|
| 23 |
+
mock_torch_nn_functional_interpolate,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def assemble_rcnn_outputs_by_name(image_sizes, tensor_outputs, force_mask_on=False):
|
| 28 |
+
"""
|
| 29 |
+
A function to assemble caffe2 model's outputs (i.e. Dict[str, Tensor])
|
| 30 |
+
to detectron2's format (i.e. list of Instances instance).
|
| 31 |
+
This only works when the model follows the Caffe2 detectron's naming convention.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
image_sizes (List[List[int, int]]): [H, W] of every image.
|
| 35 |
+
tensor_outputs (Dict[str, Tensor]): external_output to its tensor.
|
| 36 |
+
|
| 37 |
+
force_mask_on (Bool): if true, the it make sure there'll be pred_masks even
|
| 38 |
+
if the mask is not found from tensor_outputs (usually due to model crash)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
results = [Instances(image_size) for image_size in image_sizes]
|
| 42 |
+
|
| 43 |
+
batch_splits = tensor_outputs.get("batch_splits", None)
|
| 44 |
+
if batch_splits:
|
| 45 |
+
raise NotImplementedError()
|
| 46 |
+
assert len(image_sizes) == 1
|
| 47 |
+
result = results[0]
|
| 48 |
+
|
| 49 |
+
bbox_nms = tensor_outputs["bbox_nms"]
|
| 50 |
+
score_nms = tensor_outputs["score_nms"]
|
| 51 |
+
class_nms = tensor_outputs["class_nms"]
|
| 52 |
+
# Detection will always success because Conv support 0-batch
|
| 53 |
+
assert bbox_nms is not None
|
| 54 |
+
assert score_nms is not None
|
| 55 |
+
assert class_nms is not None
|
| 56 |
+
if bbox_nms.shape[1] == 5:
|
| 57 |
+
result.pred_boxes = RotatedBoxes(bbox_nms)
|
| 58 |
+
else:
|
| 59 |
+
result.pred_boxes = Boxes(bbox_nms)
|
| 60 |
+
result.scores = score_nms
|
| 61 |
+
result.pred_classes = class_nms.to(torch.int64)
|
| 62 |
+
|
| 63 |
+
mask_fcn_probs = tensor_outputs.get("mask_fcn_probs", None)
|
| 64 |
+
if mask_fcn_probs is not None:
|
| 65 |
+
# finish the mask pred
|
| 66 |
+
mask_probs_pred = mask_fcn_probs
|
| 67 |
+
num_masks = mask_probs_pred.shape[0]
|
| 68 |
+
class_pred = result.pred_classes
|
| 69 |
+
indices = torch.arange(num_masks, device=class_pred.device)
|
| 70 |
+
mask_probs_pred = mask_probs_pred[indices, class_pred][:, None]
|
| 71 |
+
result.pred_masks = mask_probs_pred
|
| 72 |
+
elif force_mask_on:
|
| 73 |
+
# NOTE: there's no way to know the height/width of mask here, it won't be
|
| 74 |
+
# used anyway when batch size is 0, so just set them to 0.
|
| 75 |
+
result.pred_masks = torch.zeros([0, 1, 0, 0], dtype=torch.uint8)
|
| 76 |
+
|
| 77 |
+
keypoints_out = tensor_outputs.get("keypoints_out", None)
|
| 78 |
+
kps_score = tensor_outputs.get("kps_score", None)
|
| 79 |
+
if keypoints_out is not None:
|
| 80 |
+
# keypoints_out: [N, 4, #kypoints], where 4 is in order of (x, y, score, prob)
|
| 81 |
+
keypoints_tensor = keypoints_out
|
| 82 |
+
# NOTE: it's possible that prob is not calculated if "should_output_softmax"
|
| 83 |
+
# is set to False in HeatmapMaxKeypoint, so just using raw score, seems
|
| 84 |
+
# it doesn't affect mAP. TODO: check more carefully.
|
| 85 |
+
keypoint_xyp = keypoints_tensor.transpose(1, 2)[:, :, [0, 1, 2]]
|
| 86 |
+
result.pred_keypoints = keypoint_xyp
|
| 87 |
+
elif kps_score is not None:
|
| 88 |
+
# keypoint heatmap to sparse data structure
|
| 89 |
+
pred_keypoint_logits = kps_score
|
| 90 |
+
keypoint_head.keypoint_rcnn_inference(pred_keypoint_logits, [result])
|
| 91 |
+
|
| 92 |
+
return results
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _cast_to_f32(f64):
|
| 96 |
+
return struct.unpack("f", struct.pack("f", f64))[0]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def set_caffe2_compatible_tensor_mode(model, enable=True):
|
| 100 |
+
def _fn(m):
|
| 101 |
+
if isinstance(m, Caffe2Compatible):
|
| 102 |
+
m.tensor_mode = enable
|
| 103 |
+
|
| 104 |
+
model.apply(_fn)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def convert_batched_inputs_to_c2_format(batched_inputs, size_divisibility, device):
|
| 108 |
+
"""
|
| 109 |
+
See get_caffe2_inputs() below.
|
| 110 |
+
"""
|
| 111 |
+
assert all(isinstance(x, dict) for x in batched_inputs)
|
| 112 |
+
assert all(x["image"].dim() == 3 for x in batched_inputs)
|
| 113 |
+
|
| 114 |
+
images = [x["image"] for x in batched_inputs]
|
| 115 |
+
images = ImageList.from_tensors(images, size_divisibility)
|
| 116 |
+
|
| 117 |
+
im_info = []
|
| 118 |
+
for input_per_image, image_size in zip(batched_inputs, images.image_sizes):
|
| 119 |
+
target_height = input_per_image.get("height", image_size[0])
|
| 120 |
+
target_width = input_per_image.get("width", image_size[1]) # noqa
|
| 121 |
+
# NOTE: The scale inside im_info is kept as convention and for providing
|
| 122 |
+
# post-processing information if further processing is needed. For
|
| 123 |
+
# current Caffe2 model definitions that don't include post-processing inside
|
| 124 |
+
# the model, this number is not used.
|
| 125 |
+
# NOTE: There can be a slight difference between width and height
|
| 126 |
+
# scales, using a single number can results in numerical difference
|
| 127 |
+
# compared with D2's post-processing.
|
| 128 |
+
scale = target_height / image_size[0]
|
| 129 |
+
im_info.append([image_size[0], image_size[1], scale])
|
| 130 |
+
im_info = torch.Tensor(im_info)
|
| 131 |
+
|
| 132 |
+
return images.tensor.to(device), im_info.to(device)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Caffe2MetaArch(Caffe2Compatible, torch.nn.Module):
|
| 136 |
+
"""
|
| 137 |
+
Base class for caffe2-compatible implementation of a meta architecture.
|
| 138 |
+
The forward is traceable and its traced graph can be converted to caffe2
|
| 139 |
+
graph through ONNX.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, cfg, torch_model, enable_tensor_mode=True):
|
| 143 |
+
"""
|
| 144 |
+
Args:
|
| 145 |
+
cfg (CfgNode):
|
| 146 |
+
torch_model (nn.Module): the detectron2 model (meta_arch) to be
|
| 147 |
+
converted.
|
| 148 |
+
"""
|
| 149 |
+
super().__init__()
|
| 150 |
+
self._wrapped_model = torch_model
|
| 151 |
+
self.eval()
|
| 152 |
+
set_caffe2_compatible_tensor_mode(self, enable_tensor_mode)
|
| 153 |
+
|
| 154 |
+
def get_caffe2_inputs(self, batched_inputs):
|
| 155 |
+
"""
|
| 156 |
+
Convert pytorch-style structured inputs to caffe2-style inputs that
|
| 157 |
+
are tuples of tensors.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
batched_inputs (list[dict]): inputs to a detectron2 model
|
| 161 |
+
in its standard format. Each dict has "image" (CHW tensor), and optionally
|
| 162 |
+
"height" and "width".
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
tuple[Tensor]:
|
| 166 |
+
tuple of tensors that will be the inputs to the
|
| 167 |
+
:meth:`forward` method. For existing models, the first
|
| 168 |
+
is an NCHW tensor (padded and batched); the second is
|
| 169 |
+
a im_info Nx3 tensor, where the rows are
|
| 170 |
+
(height, width, unused legacy parameter)
|
| 171 |
+
"""
|
| 172 |
+
return convert_batched_inputs_to_c2_format(
|
| 173 |
+
batched_inputs,
|
| 174 |
+
self._wrapped_model.backbone.size_divisibility,
|
| 175 |
+
self._wrapped_model.device,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 179 |
+
"""
|
| 180 |
+
Save extra metadata that will be used by inference in the output protobuf.
|
| 181 |
+
"""
|
| 182 |
+
pass
|
| 183 |
+
|
| 184 |
+
def forward(self, inputs):
|
| 185 |
+
"""
|
| 186 |
+
Run the forward in caffe2-style. It has to use caffe2-compatible ops
|
| 187 |
+
and the method will be used for tracing.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
inputs (tuple[Tensor]): inputs defined by :meth:`get_caffe2_input`.
|
| 191 |
+
They will be the inputs of the converted caffe2 graph.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
tuple[Tensor]: output tensors. They will be the outputs of the
|
| 195 |
+
converted caffe2 graph.
|
| 196 |
+
"""
|
| 197 |
+
raise NotImplementedError
|
| 198 |
+
|
| 199 |
+
def _caffe2_preprocess_image(self, inputs):
|
| 200 |
+
"""
|
| 201 |
+
Caffe2 implementation of preprocess_image, which is called inside each MetaArch's forward.
|
| 202 |
+
It normalizes the input images, and the final caffe2 graph assumes the
|
| 203 |
+
inputs have been batched already.
|
| 204 |
+
"""
|
| 205 |
+
data, im_info = inputs
|
| 206 |
+
data = alias(data, "data")
|
| 207 |
+
im_info = alias(im_info, "im_info")
|
| 208 |
+
mean, std = self._wrapped_model.pixel_mean, self._wrapped_model.pixel_std
|
| 209 |
+
normalized_data = (data - mean) / std
|
| 210 |
+
normalized_data = alias(normalized_data, "normalized_data")
|
| 211 |
+
|
| 212 |
+
# Pack (data, im_info) into ImageList which is recognized by self.inference.
|
| 213 |
+
images = ImageList(tensor=normalized_data, image_sizes=im_info)
|
| 214 |
+
return images
|
| 215 |
+
|
| 216 |
+
@staticmethod
|
| 217 |
+
def get_outputs_converter(predict_net, init_net):
|
| 218 |
+
"""
|
| 219 |
+
Creates a function that converts outputs of the caffe2 model to
|
| 220 |
+
detectron2's standard format.
|
| 221 |
+
The function uses information in `predict_net` and `init_net` that are
|
| 222 |
+
available at inferene time. Therefore the function logic can be used in inference.
|
| 223 |
+
|
| 224 |
+
The returned function has the following signature:
|
| 225 |
+
|
| 226 |
+
def convert(batched_inputs, c2_inputs, c2_results) -> detectron2_outputs
|
| 227 |
+
|
| 228 |
+
Where
|
| 229 |
+
|
| 230 |
+
* batched_inputs (list[dict]): the original input format of the meta arch
|
| 231 |
+
* c2_inputs (tuple[Tensor]): the caffe2 inputs.
|
| 232 |
+
* c2_results (dict[str, Tensor]): the caffe2 output format,
|
| 233 |
+
corresponding to the outputs of the :meth:`forward` function.
|
| 234 |
+
* detectron2_outputs: the original output format of the meta arch.
|
| 235 |
+
|
| 236 |
+
This function can be used to compare the outputs of the original meta arch and
|
| 237 |
+
the converted caffe2 graph.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
callable: a callable of the above signature.
|
| 241 |
+
"""
|
| 242 |
+
raise NotImplementedError
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class Caffe2GeneralizedRCNN(Caffe2MetaArch):
|
| 246 |
+
def __init__(self, cfg, torch_model, enable_tensor_mode=True):
|
| 247 |
+
assert isinstance(torch_model, meta_arch.GeneralizedRCNN)
|
| 248 |
+
torch_model = patch_generalized_rcnn(torch_model)
|
| 249 |
+
super().__init__(cfg, torch_model, enable_tensor_mode)
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT
|
| 253 |
+
except AttributeError:
|
| 254 |
+
use_heatmap_max_keypoint = False
|
| 255 |
+
self.roi_heads_patcher = ROIHeadsPatcher(
|
| 256 |
+
self._wrapped_model.roi_heads, use_heatmap_max_keypoint
|
| 257 |
+
)
|
| 258 |
+
if self.tensor_mode:
|
| 259 |
+
self.roi_heads_patcher.patch_roi_heads()
|
| 260 |
+
|
| 261 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 262 |
+
size_divisibility = self._wrapped_model.backbone.size_divisibility
|
| 263 |
+
check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
|
| 264 |
+
check_set_pb_arg(
|
| 265 |
+
predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
|
| 266 |
+
)
|
| 267 |
+
check_set_pb_arg(predict_net, "meta_architecture", "s", b"GeneralizedRCNN")
|
| 268 |
+
|
| 269 |
+
@mock_torch_nn_functional_interpolate()
|
| 270 |
+
def forward(self, inputs):
|
| 271 |
+
if not self.tensor_mode:
|
| 272 |
+
return self._wrapped_model.inference(inputs)
|
| 273 |
+
images = self._caffe2_preprocess_image(inputs)
|
| 274 |
+
features = self._wrapped_model.backbone(images.tensor)
|
| 275 |
+
proposals, _ = self._wrapped_model.proposal_generator(images, features)
|
| 276 |
+
detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals)
|
| 277 |
+
return tuple(detector_results[0].flatten())
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def get_outputs_converter(predict_net, init_net):
|
| 281 |
+
def f(batched_inputs, c2_inputs, c2_results):
|
| 282 |
+
_, im_info = c2_inputs
|
| 283 |
+
image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
|
| 284 |
+
results = assemble_rcnn_outputs_by_name(image_sizes, c2_results)
|
| 285 |
+
return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
|
| 286 |
+
|
| 287 |
+
return f
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class Caffe2RetinaNet(Caffe2MetaArch):
|
| 291 |
+
def __init__(self, cfg, torch_model):
|
| 292 |
+
assert isinstance(torch_model, meta_arch.RetinaNet)
|
| 293 |
+
super().__init__(cfg, torch_model)
|
| 294 |
+
|
| 295 |
+
@mock_torch_nn_functional_interpolate()
|
| 296 |
+
def forward(self, inputs):
|
| 297 |
+
assert self.tensor_mode
|
| 298 |
+
images = self._caffe2_preprocess_image(inputs)
|
| 299 |
+
|
| 300 |
+
# explicitly return the images sizes to avoid removing "im_info" by ONNX
|
| 301 |
+
# since it's not used in the forward path
|
| 302 |
+
return_tensors = [images.image_sizes]
|
| 303 |
+
|
| 304 |
+
features = self._wrapped_model.backbone(images.tensor)
|
| 305 |
+
features = [features[f] for f in self._wrapped_model.head_in_features]
|
| 306 |
+
for i, feature_i in enumerate(features):
|
| 307 |
+
features[i] = alias(feature_i, "feature_{}".format(i), is_backward=True)
|
| 308 |
+
return_tensors.append(features[i])
|
| 309 |
+
|
| 310 |
+
pred_logits, pred_anchor_deltas = self._wrapped_model.head(features)
|
| 311 |
+
for i, (box_cls_i, box_delta_i) in enumerate(zip(pred_logits, pred_anchor_deltas)):
|
| 312 |
+
return_tensors.append(alias(box_cls_i, "box_cls_{}".format(i)))
|
| 313 |
+
return_tensors.append(alias(box_delta_i, "box_delta_{}".format(i)))
|
| 314 |
+
|
| 315 |
+
return tuple(return_tensors)
|
| 316 |
+
|
| 317 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 318 |
+
size_divisibility = self._wrapped_model.backbone.size_divisibility
|
| 319 |
+
check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
|
| 320 |
+
check_set_pb_arg(
|
| 321 |
+
predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
|
| 322 |
+
)
|
| 323 |
+
check_set_pb_arg(predict_net, "meta_architecture", "s", b"RetinaNet")
|
| 324 |
+
|
| 325 |
+
# Inference parameters:
|
| 326 |
+
check_set_pb_arg(
|
| 327 |
+
predict_net, "score_threshold", "f", _cast_to_f32(self._wrapped_model.test_score_thresh)
|
| 328 |
+
)
|
| 329 |
+
check_set_pb_arg(
|
| 330 |
+
predict_net, "topk_candidates", "i", self._wrapped_model.test_topk_candidates
|
| 331 |
+
)
|
| 332 |
+
check_set_pb_arg(
|
| 333 |
+
predict_net, "nms_threshold", "f", _cast_to_f32(self._wrapped_model.test_nms_thresh)
|
| 334 |
+
)
|
| 335 |
+
check_set_pb_arg(
|
| 336 |
+
predict_net,
|
| 337 |
+
"max_detections_per_image",
|
| 338 |
+
"i",
|
| 339 |
+
self._wrapped_model.max_detections_per_image,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
check_set_pb_arg(
|
| 343 |
+
predict_net,
|
| 344 |
+
"bbox_reg_weights",
|
| 345 |
+
"floats",
|
| 346 |
+
[_cast_to_f32(w) for w in self._wrapped_model.box2box_transform.weights],
|
| 347 |
+
)
|
| 348 |
+
self._encode_anchor_generator_cfg(predict_net)
|
| 349 |
+
|
| 350 |
+
def _encode_anchor_generator_cfg(self, predict_net):
|
| 351 |
+
# serialize anchor_generator for future use
|
| 352 |
+
serialized_anchor_generator = io.BytesIO()
|
| 353 |
+
torch.save(self._wrapped_model.anchor_generator, serialized_anchor_generator)
|
| 354 |
+
# Ideally we can put anchor generating inside the model, then we don't
|
| 355 |
+
# need to store this information.
|
| 356 |
+
bytes = serialized_anchor_generator.getvalue()
|
| 357 |
+
check_set_pb_arg(predict_net, "serialized_anchor_generator", "s", bytes)
|
| 358 |
+
|
| 359 |
+
@staticmethod
|
| 360 |
+
def get_outputs_converter(predict_net, init_net):
|
| 361 |
+
self = types.SimpleNamespace()
|
| 362 |
+
serialized_anchor_generator = io.BytesIO(
|
| 363 |
+
get_pb_arg_vals(predict_net, "serialized_anchor_generator", None)
|
| 364 |
+
)
|
| 365 |
+
self.anchor_generator = torch.load(serialized_anchor_generator)
|
| 366 |
+
bbox_reg_weights = get_pb_arg_floats(predict_net, "bbox_reg_weights", None)
|
| 367 |
+
self.box2box_transform = Box2BoxTransform(weights=tuple(bbox_reg_weights))
|
| 368 |
+
self.test_score_thresh = get_pb_arg_valf(predict_net, "score_threshold", None)
|
| 369 |
+
self.test_topk_candidates = get_pb_arg_vali(predict_net, "topk_candidates", None)
|
| 370 |
+
self.test_nms_thresh = get_pb_arg_valf(predict_net, "nms_threshold", None)
|
| 371 |
+
self.max_detections_per_image = get_pb_arg_vali(
|
| 372 |
+
predict_net, "max_detections_per_image", None
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# hack to reuse inference code from RetinaNet
|
| 376 |
+
for meth in [
|
| 377 |
+
"forward_inference",
|
| 378 |
+
"inference_single_image",
|
| 379 |
+
"_transpose_dense_predictions",
|
| 380 |
+
"_decode_multi_level_predictions",
|
| 381 |
+
"_decode_per_level_predictions",
|
| 382 |
+
]:
|
| 383 |
+
setattr(self, meth, functools.partial(getattr(meta_arch.RetinaNet, meth), self))
|
| 384 |
+
|
| 385 |
+
def f(batched_inputs, c2_inputs, c2_results):
|
| 386 |
+
_, im_info = c2_inputs
|
| 387 |
+
image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
|
| 388 |
+
dummy_images = ImageList(
|
| 389 |
+
torch.randn(
|
| 390 |
+
(
|
| 391 |
+
len(im_info),
|
| 392 |
+
3,
|
| 393 |
+
)
|
| 394 |
+
+ tuple(image_sizes[0])
|
| 395 |
+
),
|
| 396 |
+
image_sizes,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")])
|
| 400 |
+
pred_logits = [c2_results["box_cls_{}".format(i)] for i in range(num_features)]
|
| 401 |
+
pred_anchor_deltas = [c2_results["box_delta_{}".format(i)] for i in range(num_features)]
|
| 402 |
+
|
| 403 |
+
# For each feature level, feature should have the same batch size and
|
| 404 |
+
# spatial dimension as the box_cls and box_delta.
|
| 405 |
+
dummy_features = [x.clone()[:, 0:0, :, :] for x in pred_logits]
|
| 406 |
+
# self.num_classess can be inferred
|
| 407 |
+
self.num_classes = pred_logits[0].shape[1] // (pred_anchor_deltas[0].shape[1] // 4)
|
| 408 |
+
|
| 409 |
+
results = self.forward_inference(
|
| 410 |
+
dummy_images, dummy_features, [pred_logits, pred_anchor_deltas]
|
| 411 |
+
)
|
| 412 |
+
return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
|
| 413 |
+
|
| 414 |
+
return f
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
META_ARCH_CAFFE2_EXPORT_TYPE_MAP = {
|
| 418 |
+
"GeneralizedRCNN": Caffe2GeneralizedRCNN,
|
| 419 |
+
"RetinaNet": Caffe2RetinaNet,
|
| 420 |
+
}
|
Leffa/3rdparty/detectron2/export/caffe2_patch.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
from unittest import mock
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from detectron2.modeling import poolers
|
| 8 |
+
from detectron2.modeling.proposal_generator import rpn
|
| 9 |
+
from detectron2.modeling.roi_heads import keypoint_head, mask_head
|
| 10 |
+
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
|
| 11 |
+
|
| 12 |
+
from .c10 import (
|
| 13 |
+
Caffe2Compatible,
|
| 14 |
+
Caffe2FastRCNNOutputsInference,
|
| 15 |
+
Caffe2KeypointRCNNInference,
|
| 16 |
+
Caffe2MaskRCNNInference,
|
| 17 |
+
Caffe2ROIPooler,
|
| 18 |
+
Caffe2RPN,
|
| 19 |
+
caffe2_fast_rcnn_outputs_inference,
|
| 20 |
+
caffe2_keypoint_rcnn_inference,
|
| 21 |
+
caffe2_mask_rcnn_inference,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GenericMixin:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Caffe2CompatibleConverter:
|
| 30 |
+
"""
|
| 31 |
+
A GenericUpdater which implements the `create_from` interface, by modifying
|
| 32 |
+
module object and assign it with another class replaceCls.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, replaceCls):
|
| 36 |
+
self.replaceCls = replaceCls
|
| 37 |
+
|
| 38 |
+
def create_from(self, module):
|
| 39 |
+
# update module's class to the new class
|
| 40 |
+
assert isinstance(module, torch.nn.Module)
|
| 41 |
+
if issubclass(self.replaceCls, GenericMixin):
|
| 42 |
+
# replaceCls should act as mixin, create a new class on-the-fly
|
| 43 |
+
new_class = type(
|
| 44 |
+
"{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__),
|
| 45 |
+
(self.replaceCls, module.__class__),
|
| 46 |
+
{}, # {"new_method": lambda self: ...},
|
| 47 |
+
)
|
| 48 |
+
module.__class__ = new_class
|
| 49 |
+
else:
|
| 50 |
+
# replaceCls is complete class, this allow arbitrary class swap
|
| 51 |
+
module.__class__ = self.replaceCls
|
| 52 |
+
|
| 53 |
+
# initialize Caffe2Compatible
|
| 54 |
+
if isinstance(module, Caffe2Compatible):
|
| 55 |
+
module.tensor_mode = False
|
| 56 |
+
|
| 57 |
+
return module
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def patch(model, target, updater, *args, **kwargs):
|
| 61 |
+
"""
|
| 62 |
+
recursively (post-order) update all modules with the target type and its
|
| 63 |
+
subclasses, make a initialization/composition/inheritance/... via the
|
| 64 |
+
updater.create_from.
|
| 65 |
+
"""
|
| 66 |
+
for name, module in model.named_children():
|
| 67 |
+
model._modules[name] = patch(module, target, updater, *args, **kwargs)
|
| 68 |
+
if isinstance(model, target):
|
| 69 |
+
return updater.create_from(model, *args, **kwargs)
|
| 70 |
+
return model
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def patch_generalized_rcnn(model):
|
| 74 |
+
ccc = Caffe2CompatibleConverter
|
| 75 |
+
model = patch(model, rpn.RPN, ccc(Caffe2RPN))
|
| 76 |
+
model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler))
|
| 77 |
+
|
| 78 |
+
return model
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@contextlib.contextmanager
|
| 82 |
+
def mock_fastrcnn_outputs_inference(
|
| 83 |
+
tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers
|
| 84 |
+
):
|
| 85 |
+
with mock.patch.object(
|
| 86 |
+
box_predictor_type,
|
| 87 |
+
"inference",
|
| 88 |
+
autospec=True,
|
| 89 |
+
side_effect=Caffe2FastRCNNOutputsInference(tensor_mode),
|
| 90 |
+
) as mocked_func:
|
| 91 |
+
yield
|
| 92 |
+
if check:
|
| 93 |
+
assert mocked_func.call_count > 0
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@contextlib.contextmanager
|
| 97 |
+
def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True):
|
| 98 |
+
with mock.patch(
|
| 99 |
+
"{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference()
|
| 100 |
+
) as mocked_func:
|
| 101 |
+
yield
|
| 102 |
+
if check:
|
| 103 |
+
assert mocked_func.call_count > 0
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@contextlib.contextmanager
|
| 107 |
+
def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True):
|
| 108 |
+
with mock.patch(
|
| 109 |
+
"{}.keypoint_rcnn_inference".format(patched_module),
|
| 110 |
+
side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint),
|
| 111 |
+
) as mocked_func:
|
| 112 |
+
yield
|
| 113 |
+
if check:
|
| 114 |
+
assert mocked_func.call_count > 0
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ROIHeadsPatcher:
|
| 118 |
+
def __init__(self, heads, use_heatmap_max_keypoint):
|
| 119 |
+
self.heads = heads
|
| 120 |
+
self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
|
| 121 |
+
self.previous_patched = {}
|
| 122 |
+
|
| 123 |
+
@contextlib.contextmanager
|
| 124 |
+
def mock_roi_heads(self, tensor_mode=True):
|
| 125 |
+
"""
|
| 126 |
+
Patching several inference functions inside ROIHeads and its subclasses
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
tensor_mode (bool): whether the inputs/outputs are caffe2's tensor
|
| 130 |
+
format or not. Default to True.
|
| 131 |
+
"""
|
| 132 |
+
# NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference`
|
| 133 |
+
# are called inside the same file as BaseXxxHead due to using mock.patch.
|
| 134 |
+
kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__
|
| 135 |
+
mask_head_mod = mask_head.BaseMaskRCNNHead.__module__
|
| 136 |
+
|
| 137 |
+
mock_ctx_managers = [
|
| 138 |
+
mock_fastrcnn_outputs_inference(
|
| 139 |
+
tensor_mode=tensor_mode,
|
| 140 |
+
check=True,
|
| 141 |
+
box_predictor_type=type(self.heads.box_predictor),
|
| 142 |
+
)
|
| 143 |
+
]
|
| 144 |
+
if getattr(self.heads, "keypoint_on", False):
|
| 145 |
+
mock_ctx_managers += [
|
| 146 |
+
mock_keypoint_rcnn_inference(
|
| 147 |
+
tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint
|
| 148 |
+
)
|
| 149 |
+
]
|
| 150 |
+
if getattr(self.heads, "mask_on", False):
|
| 151 |
+
mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)]
|
| 152 |
+
|
| 153 |
+
with contextlib.ExitStack() as stack: # python 3.3+
|
| 154 |
+
for mgr in mock_ctx_managers:
|
| 155 |
+
stack.enter_context(mgr)
|
| 156 |
+
yield
|
| 157 |
+
|
| 158 |
+
def patch_roi_heads(self, tensor_mode=True):
|
| 159 |
+
self.previous_patched["box_predictor"] = self.heads.box_predictor.inference
|
| 160 |
+
self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference
|
| 161 |
+
self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference
|
| 162 |
+
|
| 163 |
+
def patched_fastrcnn_outputs_inference(predictions, proposal):
|
| 164 |
+
return caffe2_fast_rcnn_outputs_inference(
|
| 165 |
+
True, self.heads.box_predictor, predictions, proposal
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference
|
| 169 |
+
|
| 170 |
+
if getattr(self.heads, "keypoint_on", False):
|
| 171 |
+
|
| 172 |
+
def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances):
|
| 173 |
+
return caffe2_keypoint_rcnn_inference(
|
| 174 |
+
self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference
|
| 178 |
+
|
| 179 |
+
if getattr(self.heads, "mask_on", False):
|
| 180 |
+
|
| 181 |
+
def patched_mask_rcnn_inference(pred_mask_logits, pred_instances):
|
| 182 |
+
return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances)
|
| 183 |
+
|
| 184 |
+
mask_head.mask_rcnn_inference = patched_mask_rcnn_inference
|
| 185 |
+
|
| 186 |
+
def unpatch_roi_heads(self):
|
| 187 |
+
self.heads.box_predictor.inference = self.previous_patched["box_predictor"]
|
| 188 |
+
keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"]
|
| 189 |
+
mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"]
|
Leffa/3rdparty/detectron2/export/flatten.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import collections
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Callable, List, Optional, Tuple
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from detectron2.structures import Boxes, Instances, ROIMasks
|
| 9 |
+
from detectron2.utils.registry import _convert_target_to_string, locate
|
| 10 |
+
|
| 11 |
+
from .torchscript_patch import patch_builtin_len
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Schema:
|
| 16 |
+
"""
|
| 17 |
+
A Schema defines how to flatten a possibly hierarchical object into tuple of
|
| 18 |
+
primitive objects, so it can be used as inputs/outputs of PyTorch's tracing.
|
| 19 |
+
|
| 20 |
+
PyTorch does not support tracing a function that produces rich output
|
| 21 |
+
structures (e.g. dict, Instances, Boxes). To trace such a function, we
|
| 22 |
+
flatten the rich object into tuple of tensors, and return this tuple of tensors
|
| 23 |
+
instead. Meanwhile, we also need to know how to "rebuild" the original object
|
| 24 |
+
from the flattened results, so we can evaluate the flattened results.
|
| 25 |
+
A Schema defines how to flatten an object, and while flattening it, it records
|
| 26 |
+
necessary schemas so that the object can be rebuilt using the flattened outputs.
|
| 27 |
+
|
| 28 |
+
The flattened object and the schema object is returned by ``.flatten`` classmethod.
|
| 29 |
+
Then the original object can be rebuilt with the ``__call__`` method of schema.
|
| 30 |
+
|
| 31 |
+
A Schema is a dataclass that can be serialized easily.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# inspired by FetchMapper in tensorflow/python/client/session.py
|
| 35 |
+
|
| 36 |
+
@classmethod
|
| 37 |
+
def flatten(cls, obj):
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
|
| 40 |
+
def __call__(self, values):
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def _concat(values):
|
| 45 |
+
ret = ()
|
| 46 |
+
sizes = []
|
| 47 |
+
for v in values:
|
| 48 |
+
assert isinstance(v, tuple), "Flattened results must be a tuple"
|
| 49 |
+
ret = ret + v
|
| 50 |
+
sizes.append(len(v))
|
| 51 |
+
return ret, sizes
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def _split(values, sizes):
|
| 55 |
+
if len(sizes):
|
| 56 |
+
expected_len = sum(sizes)
|
| 57 |
+
assert (
|
| 58 |
+
len(values) == expected_len
|
| 59 |
+
), f"Values has length {len(values)} but expect length {expected_len}."
|
| 60 |
+
ret = []
|
| 61 |
+
for k in range(len(sizes)):
|
| 62 |
+
begin, end = sum(sizes[:k]), sum(sizes[: k + 1])
|
| 63 |
+
ret.append(values[begin:end])
|
| 64 |
+
return ret
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class ListSchema(Schema):
|
| 69 |
+
schemas: List[Schema] # the schemas that define how to flatten each element in the list
|
| 70 |
+
sizes: List[int] # the flattened length of each element
|
| 71 |
+
|
| 72 |
+
def __call__(self, values):
|
| 73 |
+
values = self._split(values, self.sizes)
|
| 74 |
+
if len(values) != len(self.schemas):
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!"
|
| 77 |
+
)
|
| 78 |
+
values = [m(v) for m, v in zip(self.schemas, values)]
|
| 79 |
+
return list(values)
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def flatten(cls, obj):
|
| 83 |
+
res = [flatten_to_tuple(k) for k in obj]
|
| 84 |
+
values, sizes = cls._concat([k[0] for k in res])
|
| 85 |
+
return values, cls([k[1] for k in res], sizes)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class TupleSchema(ListSchema):
|
| 90 |
+
def __call__(self, values):
|
| 91 |
+
return tuple(super().__call__(values))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class IdentitySchema(Schema):
|
| 96 |
+
def __call__(self, values):
|
| 97 |
+
return values[0]
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def flatten(cls, obj):
|
| 101 |
+
return (obj,), cls()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class DictSchema(ListSchema):
|
| 106 |
+
keys: List[str]
|
| 107 |
+
|
| 108 |
+
def __call__(self, values):
|
| 109 |
+
values = super().__call__(values)
|
| 110 |
+
return dict(zip(self.keys, values))
|
| 111 |
+
|
| 112 |
+
@classmethod
|
| 113 |
+
def flatten(cls, obj):
|
| 114 |
+
for k in obj.keys():
|
| 115 |
+
if not isinstance(k, str):
|
| 116 |
+
raise KeyError("Only support flattening dictionaries if keys are str.")
|
| 117 |
+
keys = sorted(obj.keys())
|
| 118 |
+
values = [obj[k] for k in keys]
|
| 119 |
+
ret, schema = ListSchema.flatten(values)
|
| 120 |
+
return ret, cls(schema.schemas, schema.sizes, keys)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@dataclass
|
| 124 |
+
class InstancesSchema(DictSchema):
|
| 125 |
+
def __call__(self, values):
|
| 126 |
+
image_size, fields = values[-1], values[:-1]
|
| 127 |
+
fields = super().__call__(fields)
|
| 128 |
+
return Instances(image_size, **fields)
|
| 129 |
+
|
| 130 |
+
@classmethod
|
| 131 |
+
def flatten(cls, obj):
|
| 132 |
+
ret, schema = super().flatten(obj.get_fields())
|
| 133 |
+
size = obj.image_size
|
| 134 |
+
if not isinstance(size, torch.Tensor):
|
| 135 |
+
size = torch.tensor(size)
|
| 136 |
+
return ret + (size,), schema
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass
|
| 140 |
+
class TensorWrapSchema(Schema):
|
| 141 |
+
"""
|
| 142 |
+
For classes that are simple wrapper of tensors, e.g.
|
| 143 |
+
Boxes, RotatedBoxes, BitMasks
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
class_name: str
|
| 147 |
+
|
| 148 |
+
def __call__(self, values):
|
| 149 |
+
return locate(self.class_name)(values[0])
|
| 150 |
+
|
| 151 |
+
@classmethod
|
| 152 |
+
def flatten(cls, obj):
|
| 153 |
+
return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# if more custom structures needed in the future, can allow
|
| 157 |
+
# passing in extra schemas for custom types
|
| 158 |
+
def flatten_to_tuple(obj):
|
| 159 |
+
"""
|
| 160 |
+
Flatten an object so it can be used for PyTorch tracing.
|
| 161 |
+
Also returns how to rebuild the original object from the flattened outputs.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
res (tuple): the flattened results that can be used as tracing outputs
|
| 165 |
+
schema: an object with a ``__call__`` method such that ``schema(res) == obj``.
|
| 166 |
+
It is a pure dataclass that can be serialized.
|
| 167 |
+
"""
|
| 168 |
+
schemas = [
|
| 169 |
+
((str, bytes), IdentitySchema),
|
| 170 |
+
(list, ListSchema),
|
| 171 |
+
(tuple, TupleSchema),
|
| 172 |
+
(collections.abc.Mapping, DictSchema),
|
| 173 |
+
(Instances, InstancesSchema),
|
| 174 |
+
((Boxes, ROIMasks), TensorWrapSchema),
|
| 175 |
+
]
|
| 176 |
+
for klass, schema in schemas:
|
| 177 |
+
if isinstance(obj, klass):
|
| 178 |
+
F = schema
|
| 179 |
+
break
|
| 180 |
+
else:
|
| 181 |
+
F = IdentitySchema
|
| 182 |
+
|
| 183 |
+
return F.flatten(obj)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class TracingAdapter(nn.Module):
|
| 187 |
+
"""
|
| 188 |
+
A model may take rich input/output format (e.g. dict or custom classes),
|
| 189 |
+
but `torch.jit.trace` requires tuple of tensors as input/output.
|
| 190 |
+
This adapter flattens input/output format of a model so it becomes traceable.
|
| 191 |
+
|
| 192 |
+
It also records the necessary schema to rebuild model's inputs/outputs from flattened
|
| 193 |
+
inputs/outputs.
|
| 194 |
+
|
| 195 |
+
Example:
|
| 196 |
+
::
|
| 197 |
+
outputs = model(inputs) # inputs/outputs may be rich structure
|
| 198 |
+
adapter = TracingAdapter(model, inputs)
|
| 199 |
+
|
| 200 |
+
# can now trace the model, with adapter.flattened_inputs, or another
|
| 201 |
+
# tuple of tensors with the same length and meaning
|
| 202 |
+
traced = torch.jit.trace(adapter, adapter.flattened_inputs)
|
| 203 |
+
|
| 204 |
+
# traced model can only produce flattened outputs (tuple of tensors)
|
| 205 |
+
flattened_outputs = traced(*adapter.flattened_inputs)
|
| 206 |
+
# adapter knows the schema to convert it back (new_outputs == outputs)
|
| 207 |
+
new_outputs = adapter.outputs_schema(flattened_outputs)
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
flattened_inputs: Tuple[torch.Tensor] = None
|
| 211 |
+
"""
|
| 212 |
+
Flattened version of inputs given to this class's constructor.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
inputs_schema: Schema = None
|
| 216 |
+
"""
|
| 217 |
+
Schema of the inputs given to this class's constructor.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
outputs_schema: Schema = None
|
| 221 |
+
"""
|
| 222 |
+
Schema of the output produced by calling the given model with inputs.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
model: nn.Module,
|
| 228 |
+
inputs,
|
| 229 |
+
inference_func: Optional[Callable] = None,
|
| 230 |
+
allow_non_tensor: bool = False,
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Args:
|
| 234 |
+
model: an nn.Module
|
| 235 |
+
inputs: An input argument or a tuple of input arguments used to call model.
|
| 236 |
+
After flattening, it has to only consist of tensors.
|
| 237 |
+
inference_func: a callable that takes (model, *inputs), calls the
|
| 238 |
+
model with inputs, and return outputs. By default it
|
| 239 |
+
is ``lambda model, *inputs: model(*inputs)``. Can be override
|
| 240 |
+
if you need to call the model differently.
|
| 241 |
+
allow_non_tensor: allow inputs/outputs to contain non-tensor objects.
|
| 242 |
+
This option will filter out non-tensor objects to make the
|
| 243 |
+
model traceable, but ``inputs_schema``/``outputs_schema`` cannot be
|
| 244 |
+
used anymore because inputs/outputs cannot be rebuilt from pure tensors.
|
| 245 |
+
This is useful when you're only interested in the single trace of
|
| 246 |
+
execution (e.g. for flop count), but not interested in
|
| 247 |
+
generalizing the traced graph to new inputs.
|
| 248 |
+
"""
|
| 249 |
+
super().__init__()
|
| 250 |
+
if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
|
| 251 |
+
model = model.module
|
| 252 |
+
self.model = model
|
| 253 |
+
if not isinstance(inputs, tuple):
|
| 254 |
+
inputs = (inputs,)
|
| 255 |
+
self.inputs = inputs
|
| 256 |
+
self.allow_non_tensor = allow_non_tensor
|
| 257 |
+
|
| 258 |
+
if inference_func is None:
|
| 259 |
+
inference_func = lambda model, *inputs: model(*inputs) # noqa
|
| 260 |
+
self.inference_func = inference_func
|
| 261 |
+
|
| 262 |
+
self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs)
|
| 263 |
+
|
| 264 |
+
if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs):
|
| 265 |
+
return
|
| 266 |
+
if self.allow_non_tensor:
|
| 267 |
+
self.flattened_inputs = tuple(
|
| 268 |
+
[x for x in self.flattened_inputs if isinstance(x, torch.Tensor)]
|
| 269 |
+
)
|
| 270 |
+
self.inputs_schema = None
|
| 271 |
+
else:
|
| 272 |
+
for input in self.flattened_inputs:
|
| 273 |
+
if not isinstance(input, torch.Tensor):
|
| 274 |
+
raise ValueError(
|
| 275 |
+
"Inputs for tracing must only contain tensors. "
|
| 276 |
+
f"Got a {type(input)} instead."
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def forward(self, *args: torch.Tensor):
|
| 280 |
+
with torch.no_grad(), patch_builtin_len():
|
| 281 |
+
if self.inputs_schema is not None:
|
| 282 |
+
inputs_orig_format = self.inputs_schema(args)
|
| 283 |
+
else:
|
| 284 |
+
if len(args) != len(self.flattened_inputs) or any(
|
| 285 |
+
x is not y for x, y in zip(args, self.flattened_inputs)
|
| 286 |
+
):
|
| 287 |
+
raise ValueError(
|
| 288 |
+
"TracingAdapter does not contain valid inputs_schema."
|
| 289 |
+
" So it cannot generalize to other inputs and must be"
|
| 290 |
+
" traced with `.flattened_inputs`."
|
| 291 |
+
)
|
| 292 |
+
inputs_orig_format = self.inputs
|
| 293 |
+
|
| 294 |
+
outputs = self.inference_func(self.model, *inputs_orig_format)
|
| 295 |
+
flattened_outputs, schema = flatten_to_tuple(outputs)
|
| 296 |
+
|
| 297 |
+
flattened_output_tensors = tuple(
|
| 298 |
+
[x for x in flattened_outputs if isinstance(x, torch.Tensor)]
|
| 299 |
+
)
|
| 300 |
+
if len(flattened_output_tensors) < len(flattened_outputs):
|
| 301 |
+
if self.allow_non_tensor:
|
| 302 |
+
flattened_outputs = flattened_output_tensors
|
| 303 |
+
self.outputs_schema = None
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
"Model cannot be traced because some model outputs "
|
| 307 |
+
"cannot flatten to tensors."
|
| 308 |
+
)
|
| 309 |
+
else: # schema is valid
|
| 310 |
+
if self.outputs_schema is None:
|
| 311 |
+
self.outputs_schema = schema
|
| 312 |
+
else:
|
| 313 |
+
assert self.outputs_schema == schema, (
|
| 314 |
+
"Model should always return outputs with the same "
|
| 315 |
+
"structure so it can be traced!"
|
| 316 |
+
)
|
| 317 |
+
return flattened_outputs
|
| 318 |
+
|
| 319 |
+
def _create_wrapper(self, traced_model):
|
| 320 |
+
"""
|
| 321 |
+
Return a function that has an input/output interface the same as the
|
| 322 |
+
original model, but it calls the given traced model under the hood.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
def forward(*args):
|
| 326 |
+
flattened_inputs, _ = flatten_to_tuple(args)
|
| 327 |
+
flattened_outputs = traced_model(*flattened_inputs)
|
| 328 |
+
return self.outputs_schema(flattened_outputs)
|
| 329 |
+
|
| 330 |
+
return forward
|
Leffa/3rdparty/detectron2/export/shared.py
ADDED
|
@@ -0,0 +1,1039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import copy
|
| 5 |
+
import functools
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 10 |
+
from unittest import mock
|
| 11 |
+
import caffe2.python.utils as putils
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from caffe2.proto import caffe2_pb2
|
| 15 |
+
from caffe2.python import core, net_drawer, workspace
|
| 16 |
+
from torch.nn.functional import interpolate as interp
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ==== torch/utils_toffee/cast.py =======================================
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def to_device(t, device_str):
|
| 25 |
+
"""
|
| 26 |
+
This function is a replacement of .to(another_device) such that it allows the
|
| 27 |
+
casting to be traced properly by explicitly calling the underlying copy ops.
|
| 28 |
+
It also avoids introducing unncessary op when casting to the same device.
|
| 29 |
+
"""
|
| 30 |
+
src = t.device
|
| 31 |
+
dst = torch.device(device_str)
|
| 32 |
+
|
| 33 |
+
if src == dst:
|
| 34 |
+
return t
|
| 35 |
+
elif src.type == "cuda" and dst.type == "cpu":
|
| 36 |
+
return torch.ops._caffe2.CopyGPUToCPU(t)
|
| 37 |
+
elif src.type == "cpu" and dst.type == "cuda":
|
| 38 |
+
return torch.ops._caffe2.CopyCPUToGPU(t)
|
| 39 |
+
else:
|
| 40 |
+
raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ==== torch/utils_toffee/interpolate.py =======================================
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Note: borrowed from vision/detection/fair/detectron/detectron/modeling/detector.py
|
| 47 |
+
def BilinearInterpolation(tensor_in, up_scale):
|
| 48 |
+
assert up_scale % 2 == 0, "Scale should be even"
|
| 49 |
+
|
| 50 |
+
def upsample_filt(size):
|
| 51 |
+
factor = (size + 1) // 2
|
| 52 |
+
if size % 2 == 1:
|
| 53 |
+
center = factor - 1
|
| 54 |
+
else:
|
| 55 |
+
center = factor - 0.5
|
| 56 |
+
|
| 57 |
+
og = np.ogrid[:size, :size]
|
| 58 |
+
return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
|
| 59 |
+
|
| 60 |
+
kernel_size = int(up_scale) * 2
|
| 61 |
+
bil_filt = upsample_filt(kernel_size)
|
| 62 |
+
|
| 63 |
+
dim = int(tensor_in.shape[1])
|
| 64 |
+
kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32)
|
| 65 |
+
kernel[range(dim), range(dim), :, :] = bil_filt
|
| 66 |
+
|
| 67 |
+
tensor_out = F.conv_transpose2d(
|
| 68 |
+
tensor_in,
|
| 69 |
+
weight=to_device(torch.Tensor(kernel), tensor_in.device),
|
| 70 |
+
bias=None,
|
| 71 |
+
stride=int(up_scale),
|
| 72 |
+
padding=int(up_scale / 2),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return tensor_out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# NOTE: ONNX is incompatible with traced torch.nn.functional.interpolate if
|
| 79 |
+
# using dynamic `scale_factor` rather than static `size`. (T43166860)
|
| 80 |
+
# NOTE: Caffe2 Int8 conversion might not be able to quantize `size` properly.
|
| 81 |
+
def onnx_compatibale_interpolate(
|
| 82 |
+
input, size=None, scale_factor=None, mode="nearest", align_corners=None
|
| 83 |
+
):
|
| 84 |
+
# NOTE: The input dimensions are interpreted in the form:
|
| 85 |
+
# `mini-batch x channels x [optional depth] x [optional height] x width`.
|
| 86 |
+
if size is None and scale_factor is not None:
|
| 87 |
+
if input.dim() == 4:
|
| 88 |
+
if isinstance(scale_factor, (int, float)):
|
| 89 |
+
height_scale, width_scale = (scale_factor, scale_factor)
|
| 90 |
+
else:
|
| 91 |
+
assert isinstance(scale_factor, (tuple, list))
|
| 92 |
+
assert len(scale_factor) == 2
|
| 93 |
+
height_scale, width_scale = scale_factor
|
| 94 |
+
|
| 95 |
+
assert not align_corners, "No matching C2 op for align_corners == True"
|
| 96 |
+
if mode == "nearest":
|
| 97 |
+
return torch.ops._caffe2.ResizeNearest(
|
| 98 |
+
input, order="NCHW", width_scale=width_scale, height_scale=height_scale
|
| 99 |
+
)
|
| 100 |
+
elif mode == "bilinear":
|
| 101 |
+
logger.warning(
|
| 102 |
+
"Use F.conv_transpose2d for bilinear interpolate"
|
| 103 |
+
" because there's no such C2 op, this may cause significant"
|
| 104 |
+
" slowdown and the boundary pixels won't be as same as"
|
| 105 |
+
" using F.interpolate due to padding."
|
| 106 |
+
)
|
| 107 |
+
assert height_scale == width_scale
|
| 108 |
+
return BilinearInterpolation(input, up_scale=height_scale)
|
| 109 |
+
logger.warning("Output size is not static, it might cause ONNX conversion issue")
|
| 110 |
+
|
| 111 |
+
return interp(input, size, scale_factor, mode, align_corners)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def mock_torch_nn_functional_interpolate():
|
| 115 |
+
def decorator(func):
|
| 116 |
+
@functools.wraps(func)
|
| 117 |
+
def _mock_torch_nn_functional_interpolate(*args, **kwargs):
|
| 118 |
+
if torch.onnx.is_in_onnx_export():
|
| 119 |
+
with mock.patch(
|
| 120 |
+
"torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate
|
| 121 |
+
):
|
| 122 |
+
return func(*args, **kwargs)
|
| 123 |
+
else:
|
| 124 |
+
return func(*args, **kwargs)
|
| 125 |
+
|
| 126 |
+
return _mock_torch_nn_functional_interpolate
|
| 127 |
+
|
| 128 |
+
return decorator
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ==== torch/utils_caffe2/ws_utils.py ==========================================
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class ScopedWS:
|
| 135 |
+
def __init__(self, ws_name, is_reset, is_cleanup=False):
|
| 136 |
+
self.ws_name = ws_name
|
| 137 |
+
self.is_reset = is_reset
|
| 138 |
+
self.is_cleanup = is_cleanup
|
| 139 |
+
self.org_ws = ""
|
| 140 |
+
|
| 141 |
+
def __enter__(self):
|
| 142 |
+
self.org_ws = workspace.CurrentWorkspace()
|
| 143 |
+
if self.ws_name is not None:
|
| 144 |
+
workspace.SwitchWorkspace(self.ws_name, True)
|
| 145 |
+
if self.is_reset:
|
| 146 |
+
workspace.ResetWorkspace()
|
| 147 |
+
|
| 148 |
+
return workspace
|
| 149 |
+
|
| 150 |
+
def __exit__(self, *args):
|
| 151 |
+
if self.is_cleanup:
|
| 152 |
+
workspace.ResetWorkspace()
|
| 153 |
+
if self.ws_name is not None:
|
| 154 |
+
workspace.SwitchWorkspace(self.org_ws)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def fetch_any_blob(name):
|
| 158 |
+
bb = None
|
| 159 |
+
try:
|
| 160 |
+
bb = workspace.FetchBlob(name)
|
| 161 |
+
except TypeError:
|
| 162 |
+
bb = workspace.FetchInt8Blob(name)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error("Get blob {} error: {}".format(name, e))
|
| 165 |
+
|
| 166 |
+
return bb
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# ==== torch/utils_caffe2/protobuf.py ==========================================
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_pb_arg(pb, arg_name):
|
| 173 |
+
for x in pb.arg:
|
| 174 |
+
if x.name == arg_name:
|
| 175 |
+
return x
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_pb_arg_valf(pb, arg_name, default_val):
|
| 180 |
+
arg = get_pb_arg(pb, arg_name)
|
| 181 |
+
return arg.f if arg is not None else default_val
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_pb_arg_floats(pb, arg_name, default_val):
|
| 185 |
+
arg = get_pb_arg(pb, arg_name)
|
| 186 |
+
return list(map(float, arg.floats)) if arg is not None else default_val
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def get_pb_arg_ints(pb, arg_name, default_val):
|
| 190 |
+
arg = get_pb_arg(pb, arg_name)
|
| 191 |
+
return list(map(int, arg.ints)) if arg is not None else default_val
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_pb_arg_vali(pb, arg_name, default_val):
|
| 195 |
+
arg = get_pb_arg(pb, arg_name)
|
| 196 |
+
return arg.i if arg is not None else default_val
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def get_pb_arg_vals(pb, arg_name, default_val):
|
| 200 |
+
arg = get_pb_arg(pb, arg_name)
|
| 201 |
+
return arg.s if arg is not None else default_val
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_pb_arg_valstrings(pb, arg_name, default_val):
|
| 205 |
+
arg = get_pb_arg(pb, arg_name)
|
| 206 |
+
return list(arg.strings) if arg is not None else default_val
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False):
|
| 210 |
+
arg = get_pb_arg(pb, arg_name)
|
| 211 |
+
if arg is None:
|
| 212 |
+
arg = putils.MakeArgument(arg_name, arg_value)
|
| 213 |
+
assert hasattr(arg, arg_attr)
|
| 214 |
+
pb.arg.extend([arg])
|
| 215 |
+
if allow_override and getattr(arg, arg_attr) != arg_value:
|
| 216 |
+
logger.warning(
|
| 217 |
+
"Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value)
|
| 218 |
+
)
|
| 219 |
+
setattr(arg, arg_attr, arg_value)
|
| 220 |
+
else:
|
| 221 |
+
assert arg is not None
|
| 222 |
+
assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format(
|
| 223 |
+
getattr(arg, arg_attr), arg_value
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _create_const_fill_op_from_numpy(name, tensor, device_option=None):
|
| 228 |
+
assert type(tensor) == np.ndarray
|
| 229 |
+
kTypeNameMapper = {
|
| 230 |
+
np.dtype("float32"): "GivenTensorFill",
|
| 231 |
+
np.dtype("int32"): "GivenTensorIntFill",
|
| 232 |
+
np.dtype("int64"): "GivenTensorInt64Fill",
|
| 233 |
+
np.dtype("uint8"): "GivenTensorStringFill",
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
args_dict = {}
|
| 237 |
+
if tensor.dtype == np.dtype("uint8"):
|
| 238 |
+
args_dict.update({"values": [str(tensor.data)], "shape": [1]})
|
| 239 |
+
else:
|
| 240 |
+
args_dict.update({"values": tensor, "shape": tensor.shape})
|
| 241 |
+
|
| 242 |
+
if device_option is not None:
|
| 243 |
+
args_dict["device_option"] = device_option
|
| 244 |
+
|
| 245 |
+
return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor):
|
| 249 |
+
assert type(int8_tensor) == workspace.Int8Tensor
|
| 250 |
+
kTypeNameMapper = {
|
| 251 |
+
np.dtype("int32"): "Int8GivenIntTensorFill",
|
| 252 |
+
np.dtype("uint8"): "Int8GivenTensorFill",
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
tensor = int8_tensor.data
|
| 256 |
+
assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")]
|
| 257 |
+
values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor
|
| 258 |
+
|
| 259 |
+
return core.CreateOperator(
|
| 260 |
+
kTypeNameMapper[tensor.dtype],
|
| 261 |
+
[],
|
| 262 |
+
[name],
|
| 263 |
+
values=values,
|
| 264 |
+
shape=tensor.shape,
|
| 265 |
+
Y_scale=int8_tensor.scale,
|
| 266 |
+
Y_zero_point=int8_tensor.zero_point,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def create_const_fill_op(
|
| 271 |
+
name: str,
|
| 272 |
+
blob: Union[np.ndarray, workspace.Int8Tensor],
|
| 273 |
+
device_option: Optional[caffe2_pb2.DeviceOption] = None,
|
| 274 |
+
) -> caffe2_pb2.OperatorDef:
|
| 275 |
+
"""
|
| 276 |
+
Given a blob object, return the Caffe2 operator that creates this blob
|
| 277 |
+
as constant. Currently support NumPy tensor and Caffe2 Int8Tensor.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
tensor_type = type(blob)
|
| 281 |
+
assert tensor_type in [
|
| 282 |
+
np.ndarray,
|
| 283 |
+
workspace.Int8Tensor,
|
| 284 |
+
], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format(
|
| 285 |
+
name, type(blob)
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if tensor_type == np.ndarray:
|
| 289 |
+
return _create_const_fill_op_from_numpy(name, blob, device_option)
|
| 290 |
+
elif tensor_type == workspace.Int8Tensor:
|
| 291 |
+
assert device_option is None
|
| 292 |
+
return _create_const_fill_op_from_c2_int8_tensor(name, blob)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def construct_init_net_from_params(
|
| 296 |
+
params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None
|
| 297 |
+
) -> caffe2_pb2.NetDef:
|
| 298 |
+
"""
|
| 299 |
+
Construct the init_net from params dictionary
|
| 300 |
+
"""
|
| 301 |
+
init_net = caffe2_pb2.NetDef()
|
| 302 |
+
device_options = device_options or {}
|
| 303 |
+
for name, blob in params.items():
|
| 304 |
+
if isinstance(blob, str):
|
| 305 |
+
logger.warning(
|
| 306 |
+
(
|
| 307 |
+
"Blob {} with type {} is not supported in generating init net,"
|
| 308 |
+
" skipped.".format(name, type(blob))
|
| 309 |
+
)
|
| 310 |
+
)
|
| 311 |
+
continue
|
| 312 |
+
init_net.op.extend(
|
| 313 |
+
[create_const_fill_op(name, blob, device_option=device_options.get(name, None))]
|
| 314 |
+
)
|
| 315 |
+
init_net.external_output.append(name)
|
| 316 |
+
return init_net
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def get_producer_map(ssa):
|
| 320 |
+
"""
|
| 321 |
+
Return dict from versioned blob to (i, j),
|
| 322 |
+
where i is index of producer op, j is the index of output of that op.
|
| 323 |
+
"""
|
| 324 |
+
producer_map = {}
|
| 325 |
+
for i in range(len(ssa)):
|
| 326 |
+
outputs = ssa[i][1]
|
| 327 |
+
for j, outp in enumerate(outputs):
|
| 328 |
+
producer_map[outp] = (i, j)
|
| 329 |
+
return producer_map
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def get_consumer_map(ssa):
|
| 333 |
+
"""
|
| 334 |
+
Return dict from versioned blob to list of (i, j),
|
| 335 |
+
where i is index of consumer op, j is the index of input of that op.
|
| 336 |
+
"""
|
| 337 |
+
consumer_map = collections.defaultdict(list)
|
| 338 |
+
for i in range(len(ssa)):
|
| 339 |
+
inputs = ssa[i][0]
|
| 340 |
+
for j, inp in enumerate(inputs):
|
| 341 |
+
consumer_map[inp].append((i, j))
|
| 342 |
+
return consumer_map
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def get_params_from_init_net(
|
| 346 |
+
init_net: caffe2_pb2.NetDef,
|
| 347 |
+
) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]:
|
| 348 |
+
"""
|
| 349 |
+
Take the output blobs from init_net by running it.
|
| 350 |
+
Outputs:
|
| 351 |
+
params: dict from blob name to numpy array
|
| 352 |
+
device_options: dict from blob name to the device option of its creating op
|
| 353 |
+
"""
|
| 354 |
+
# NOTE: this assumes that the params is determined by producer op with the
|
| 355 |
+
# only exception be CopyGPUToCPU which is CUDA op but returns CPU tensor.
|
| 356 |
+
def _get_device_option(producer_op):
|
| 357 |
+
if producer_op.type == "CopyGPUToCPU":
|
| 358 |
+
return caffe2_pb2.DeviceOption()
|
| 359 |
+
else:
|
| 360 |
+
return producer_op.device_option
|
| 361 |
+
|
| 362 |
+
with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws:
|
| 363 |
+
ws.RunNetOnce(init_net)
|
| 364 |
+
params = {b: fetch_any_blob(b) for b in init_net.external_output}
|
| 365 |
+
ssa, versions = core.get_ssa(init_net)
|
| 366 |
+
producer_map = get_producer_map(ssa)
|
| 367 |
+
device_options = {
|
| 368 |
+
b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]])
|
| 369 |
+
for b in init_net.external_output
|
| 370 |
+
}
|
| 371 |
+
return params, device_options
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _updater_raise(op, input_types, output_types):
|
| 375 |
+
raise RuntimeError(
|
| 376 |
+
"Failed to apply updater for op {} given input_types {} and"
|
| 377 |
+
" output_types {}".format(op, input_types, output_types)
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def _generic_status_identifier(
|
| 382 |
+
predict_net: caffe2_pb2.NetDef,
|
| 383 |
+
status_updater: Callable,
|
| 384 |
+
known_status: Dict[Tuple[str, int], Any],
|
| 385 |
+
) -> Dict[Tuple[str, int], Any]:
|
| 386 |
+
"""
|
| 387 |
+
Statically infer the status of each blob, the status can be such as device type
|
| 388 |
+
(CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here
|
| 389 |
+
is versioned blob (Tuple[str, int]) in the format compatible with ssa.
|
| 390 |
+
Inputs:
|
| 391 |
+
predict_net: the caffe2 network
|
| 392 |
+
status_updater: a callable, given an op and the status of its input/output,
|
| 393 |
+
it returns the updated status of input/output. `None` is used for
|
| 394 |
+
representing unknown status.
|
| 395 |
+
known_status: a dict containing known status, used as initialization.
|
| 396 |
+
Outputs:
|
| 397 |
+
A dict mapping from versioned blob to its status
|
| 398 |
+
"""
|
| 399 |
+
ssa, versions = core.get_ssa(predict_net)
|
| 400 |
+
versioned_ext_input = [(b, 0) for b in predict_net.external_input]
|
| 401 |
+
versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output]
|
| 402 |
+
all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa])
|
| 403 |
+
|
| 404 |
+
allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output)
|
| 405 |
+
assert all(k in allowed_vbs for k in known_status)
|
| 406 |
+
assert all(v is not None for v in known_status.values())
|
| 407 |
+
_known_status = copy.deepcopy(known_status)
|
| 408 |
+
|
| 409 |
+
def _check_and_update(key, value):
|
| 410 |
+
assert value is not None
|
| 411 |
+
if key in _known_status:
|
| 412 |
+
if not _known_status[key] == value:
|
| 413 |
+
raise RuntimeError(
|
| 414 |
+
"Confilict status for {}, existing status {}, new status {}".format(
|
| 415 |
+
key, _known_status[key], value
|
| 416 |
+
)
|
| 417 |
+
)
|
| 418 |
+
_known_status[key] = value
|
| 419 |
+
|
| 420 |
+
def _update_i(op, ssa_i):
|
| 421 |
+
versioned_inputs = ssa_i[0]
|
| 422 |
+
versioned_outputs = ssa_i[1]
|
| 423 |
+
|
| 424 |
+
inputs_status = [_known_status.get(b, None) for b in versioned_inputs]
|
| 425 |
+
outputs_status = [_known_status.get(b, None) for b in versioned_outputs]
|
| 426 |
+
|
| 427 |
+
new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status)
|
| 428 |
+
|
| 429 |
+
for versioned_blob, status in zip(
|
| 430 |
+
versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status
|
| 431 |
+
):
|
| 432 |
+
if status is not None:
|
| 433 |
+
_check_and_update(versioned_blob, status)
|
| 434 |
+
|
| 435 |
+
for op, ssa_i in zip(predict_net.op, ssa):
|
| 436 |
+
_update_i(op, ssa_i)
|
| 437 |
+
for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)):
|
| 438 |
+
_update_i(op, ssa_i)
|
| 439 |
+
|
| 440 |
+
# NOTE: This strictly checks all the blob from predict_net must be assgined
|
| 441 |
+
# a known status. However sometimes it's impossible (eg. having deadend op),
|
| 442 |
+
# we may relax this constraint if
|
| 443 |
+
for k in all_versioned_blobs:
|
| 444 |
+
if k not in _known_status:
|
| 445 |
+
raise NotImplementedError(
|
| 446 |
+
"Can not infer the status for {}. Currently only support the case where"
|
| 447 |
+
" a single forward and backward pass can identify status for all blobs.".format(k)
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
return _known_status
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def infer_device_type(
|
| 454 |
+
predict_net: caffe2_pb2.NetDef,
|
| 455 |
+
known_status: Dict[Tuple[str, int], Any],
|
| 456 |
+
device_name_style: str = "caffe2",
|
| 457 |
+
) -> Dict[Tuple[str, int], str]:
|
| 458 |
+
"""Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob"""
|
| 459 |
+
|
| 460 |
+
assert device_name_style in ["caffe2", "pytorch"]
|
| 461 |
+
_CPU_STR = "cpu"
|
| 462 |
+
_GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda"
|
| 463 |
+
|
| 464 |
+
def _copy_cpu_to_gpu_updater(op, input_types, output_types):
|
| 465 |
+
if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR:
|
| 466 |
+
_updater_raise(op, input_types, output_types)
|
| 467 |
+
return ([_CPU_STR], [_GPU_STR])
|
| 468 |
+
|
| 469 |
+
def _copy_gpu_to_cpu_updater(op, input_types, output_types):
|
| 470 |
+
if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR:
|
| 471 |
+
_updater_raise(op, input_types, output_types)
|
| 472 |
+
return ([_GPU_STR], [_CPU_STR])
|
| 473 |
+
|
| 474 |
+
def _other_ops_updater(op, input_types, output_types):
|
| 475 |
+
non_none_types = [x for x in input_types + output_types if x is not None]
|
| 476 |
+
if len(non_none_types) > 0:
|
| 477 |
+
the_type = non_none_types[0]
|
| 478 |
+
if not all(x == the_type for x in non_none_types):
|
| 479 |
+
_updater_raise(op, input_types, output_types)
|
| 480 |
+
else:
|
| 481 |
+
the_type = None
|
| 482 |
+
return ([the_type for _ in op.input], [the_type for _ in op.output])
|
| 483 |
+
|
| 484 |
+
def _device_updater(op, *args, **kwargs):
|
| 485 |
+
return {
|
| 486 |
+
"CopyCPUToGPU": _copy_cpu_to_gpu_updater,
|
| 487 |
+
"CopyGPUToCPU": _copy_gpu_to_cpu_updater,
|
| 488 |
+
}.get(op.type, _other_ops_updater)(op, *args, **kwargs)
|
| 489 |
+
|
| 490 |
+
return _generic_status_identifier(predict_net, _device_updater, known_status)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
# ==== torch/utils_caffe2/vis.py ===============================================
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def _modify_blob_names(ops, blob_rename_f):
|
| 497 |
+
ret = []
|
| 498 |
+
|
| 499 |
+
def _replace_list(blob_list, replaced_list):
|
| 500 |
+
del blob_list[:]
|
| 501 |
+
blob_list.extend(replaced_list)
|
| 502 |
+
|
| 503 |
+
for x in ops:
|
| 504 |
+
cur = copy.deepcopy(x)
|
| 505 |
+
_replace_list(cur.input, list(map(blob_rename_f, cur.input)))
|
| 506 |
+
_replace_list(cur.output, list(map(blob_rename_f, cur.output)))
|
| 507 |
+
ret.append(cur)
|
| 508 |
+
|
| 509 |
+
return ret
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def _rename_blob(name, blob_sizes, blob_ranges):
|
| 513 |
+
def _list_to_str(bsize):
|
| 514 |
+
ret = ", ".join([str(x) for x in bsize])
|
| 515 |
+
ret = "[" + ret + "]"
|
| 516 |
+
return ret
|
| 517 |
+
|
| 518 |
+
ret = name
|
| 519 |
+
if blob_sizes is not None and name in blob_sizes:
|
| 520 |
+
ret += "\n" + _list_to_str(blob_sizes[name])
|
| 521 |
+
if blob_ranges is not None and name in blob_ranges:
|
| 522 |
+
ret += "\n" + _list_to_str(blob_ranges[name])
|
| 523 |
+
|
| 524 |
+
return ret
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# graph_name could not contain word 'graph'
|
| 528 |
+
def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None):
|
| 529 |
+
blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges)
|
| 530 |
+
return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None):
|
| 534 |
+
graph = None
|
| 535 |
+
ops = net.op
|
| 536 |
+
if blob_rename_func is not None:
|
| 537 |
+
ops = _modify_blob_names(ops, blob_rename_func)
|
| 538 |
+
if not op_only:
|
| 539 |
+
graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB")
|
| 540 |
+
else:
|
| 541 |
+
graph = net_drawer.GetPydotGraphMinimal(
|
| 542 |
+
ops, graph_name, rankdir="TB", minimal_dependency=True
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
try:
|
| 546 |
+
par_dir = os.path.dirname(file_name)
|
| 547 |
+
if not os.path.exists(par_dir):
|
| 548 |
+
os.makedirs(par_dir)
|
| 549 |
+
|
| 550 |
+
format = os.path.splitext(os.path.basename(file_name))[-1]
|
| 551 |
+
if format == ".png":
|
| 552 |
+
graph.write_png(file_name)
|
| 553 |
+
elif format == ".pdf":
|
| 554 |
+
graph.write_pdf(file_name)
|
| 555 |
+
elif format == ".svg":
|
| 556 |
+
graph.write_svg(file_name)
|
| 557 |
+
else:
|
| 558 |
+
print("Incorrect format {}".format(format))
|
| 559 |
+
except Exception as e:
|
| 560 |
+
print("Error when writing graph to image {}".format(e))
|
| 561 |
+
|
| 562 |
+
return graph
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
# ==== torch/utils_toffee/aten_to_caffe2.py ====================================
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef):
|
| 569 |
+
"""
|
| 570 |
+
For ONNX exported model, GroupNorm will be represented as ATen op,
|
| 571 |
+
this can be a drop in replacement from ATen to GroupNorm
|
| 572 |
+
"""
|
| 573 |
+
count = 0
|
| 574 |
+
for op in predict_net.op:
|
| 575 |
+
if op.type == "ATen":
|
| 576 |
+
op_name = get_pb_arg_vals(op, "operator", None) # return byte in py3
|
| 577 |
+
if op_name and op_name.decode() == "group_norm":
|
| 578 |
+
op.arg.remove(get_pb_arg(op, "operator"))
|
| 579 |
+
|
| 580 |
+
if get_pb_arg_vali(op, "cudnn_enabled", None):
|
| 581 |
+
op.arg.remove(get_pb_arg(op, "cudnn_enabled"))
|
| 582 |
+
|
| 583 |
+
num_groups = get_pb_arg_vali(op, "num_groups", None)
|
| 584 |
+
if num_groups is not None:
|
| 585 |
+
op.arg.remove(get_pb_arg(op, "num_groups"))
|
| 586 |
+
check_set_pb_arg(op, "group", "i", num_groups)
|
| 587 |
+
|
| 588 |
+
op.type = "GroupNorm"
|
| 589 |
+
count += 1
|
| 590 |
+
if count > 1:
|
| 591 |
+
logger.info("Replaced {} ATen operator to GroupNormOp".format(count))
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
# ==== torch/utils_toffee/alias.py =============================================
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def alias(x, name, is_backward=False):
|
| 598 |
+
if not torch.onnx.is_in_onnx_export():
|
| 599 |
+
return x
|
| 600 |
+
assert isinstance(x, torch.Tensor)
|
| 601 |
+
return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def fuse_alias_placeholder(predict_net, init_net):
|
| 605 |
+
"""Remove AliasWithName placeholder and rename the input/output of it"""
|
| 606 |
+
# First we finish all the re-naming
|
| 607 |
+
for i, op in enumerate(predict_net.op):
|
| 608 |
+
if op.type == "AliasWithName":
|
| 609 |
+
assert len(op.input) == 1
|
| 610 |
+
assert len(op.output) == 1
|
| 611 |
+
name = get_pb_arg_vals(op, "name", None).decode()
|
| 612 |
+
is_backward = bool(get_pb_arg_vali(op, "is_backward", 0))
|
| 613 |
+
rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward)
|
| 614 |
+
rename_op_output(predict_net, i, 0, name)
|
| 615 |
+
|
| 616 |
+
# Remove AliasWithName, should be very safe since it's a non-op
|
| 617 |
+
new_ops = []
|
| 618 |
+
for op in predict_net.op:
|
| 619 |
+
if op.type != "AliasWithName":
|
| 620 |
+
new_ops.append(op)
|
| 621 |
+
else:
|
| 622 |
+
# safety check
|
| 623 |
+
assert op.input == op.output
|
| 624 |
+
assert op.input[0] == op.arg[0].s.decode()
|
| 625 |
+
del predict_net.op[:]
|
| 626 |
+
predict_net.op.extend(new_ops)
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
# ==== torch/utils_caffe2/graph_transform.py ===================================
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class IllegalGraphTransformError(ValueError):
|
| 633 |
+
"""When a graph transform function call can't be executed."""
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def _rename_versioned_blob_in_proto(
|
| 637 |
+
proto: caffe2_pb2.NetDef,
|
| 638 |
+
old_name: str,
|
| 639 |
+
new_name: str,
|
| 640 |
+
version: int,
|
| 641 |
+
ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]],
|
| 642 |
+
start_versions: Dict[str, int],
|
| 643 |
+
end_versions: Dict[str, int],
|
| 644 |
+
):
|
| 645 |
+
"""In given proto, rename all blobs with matched version"""
|
| 646 |
+
# Operater list
|
| 647 |
+
for op, i_th_ssa in zip(proto.op, ssa):
|
| 648 |
+
versioned_inputs, versioned_outputs = i_th_ssa
|
| 649 |
+
for i in range(len(op.input)):
|
| 650 |
+
if versioned_inputs[i] == (old_name, version):
|
| 651 |
+
op.input[i] = new_name
|
| 652 |
+
for i in range(len(op.output)):
|
| 653 |
+
if versioned_outputs[i] == (old_name, version):
|
| 654 |
+
op.output[i] = new_name
|
| 655 |
+
# external_input
|
| 656 |
+
if start_versions.get(old_name, 0) == version:
|
| 657 |
+
for i in range(len(proto.external_input)):
|
| 658 |
+
if proto.external_input[i] == old_name:
|
| 659 |
+
proto.external_input[i] = new_name
|
| 660 |
+
# external_output
|
| 661 |
+
if end_versions.get(old_name, 0) == version:
|
| 662 |
+
for i in range(len(proto.external_output)):
|
| 663 |
+
if proto.external_output[i] == old_name:
|
| 664 |
+
proto.external_output[i] = new_name
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def rename_op_input(
|
| 668 |
+
predict_net: caffe2_pb2.NetDef,
|
| 669 |
+
init_net: caffe2_pb2.NetDef,
|
| 670 |
+
op_id: int,
|
| 671 |
+
input_id: int,
|
| 672 |
+
new_name: str,
|
| 673 |
+
from_producer: bool = False,
|
| 674 |
+
):
|
| 675 |
+
"""
|
| 676 |
+
Rename the op_id-th operator in predict_net, change it's input_id-th input's
|
| 677 |
+
name to the new_name. It also does automatic re-route and change
|
| 678 |
+
external_input and init_net if necessary.
|
| 679 |
+
- It requires the input is only consumed by this op.
|
| 680 |
+
- This function modifies predict_net and init_net in-place.
|
| 681 |
+
- When from_producer is enable, this also updates other operators that consumes
|
| 682 |
+
the same input. Be cautious because may trigger unintended behavior.
|
| 683 |
+
"""
|
| 684 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
| 685 |
+
assert isinstance(init_net, caffe2_pb2.NetDef)
|
| 686 |
+
|
| 687 |
+
init_net_ssa, init_net_versions = core.get_ssa(init_net)
|
| 688 |
+
predict_net_ssa, predict_net_versions = core.get_ssa(
|
| 689 |
+
predict_net, copy.deepcopy(init_net_versions)
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
versioned_inputs, versioned_outputs = predict_net_ssa[op_id]
|
| 693 |
+
old_name, version = versioned_inputs[input_id]
|
| 694 |
+
|
| 695 |
+
if from_producer:
|
| 696 |
+
producer_map = get_producer_map(predict_net_ssa)
|
| 697 |
+
if not (old_name, version) in producer_map:
|
| 698 |
+
raise NotImplementedError(
|
| 699 |
+
"Can't find producer, the input {} is probably from"
|
| 700 |
+
" init_net, this is not supported yet.".format(old_name)
|
| 701 |
+
)
|
| 702 |
+
producer = producer_map[(old_name, version)]
|
| 703 |
+
rename_op_output(predict_net, producer[0], producer[1], new_name)
|
| 704 |
+
return
|
| 705 |
+
|
| 706 |
+
def contain_targets(op_ssa):
|
| 707 |
+
return (old_name, version) in op_ssa[0]
|
| 708 |
+
|
| 709 |
+
is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa]
|
| 710 |
+
if sum(is_consumer) > 1:
|
| 711 |
+
raise IllegalGraphTransformError(
|
| 712 |
+
(
|
| 713 |
+
"Input '{}' of operator(#{}) are consumed by other ops, please use"
|
| 714 |
+
+ " rename_op_output on the producer instead. Offending op: \n{}"
|
| 715 |
+
).format(old_name, op_id, predict_net.op[op_id])
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# update init_net
|
| 719 |
+
_rename_versioned_blob_in_proto(
|
| 720 |
+
init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions
|
| 721 |
+
)
|
| 722 |
+
# update predict_net
|
| 723 |
+
_rename_versioned_blob_in_proto(
|
| 724 |
+
predict_net,
|
| 725 |
+
old_name,
|
| 726 |
+
new_name,
|
| 727 |
+
version,
|
| 728 |
+
predict_net_ssa,
|
| 729 |
+
init_net_versions,
|
| 730 |
+
predict_net_versions,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str):
|
| 735 |
+
"""
|
| 736 |
+
Rename the op_id-th operator in predict_net, change it's output_id-th input's
|
| 737 |
+
name to the new_name. It also does automatic re-route and change
|
| 738 |
+
external_output and if necessary.
|
| 739 |
+
- It allows multiple consumers of its output.
|
| 740 |
+
- This function modifies predict_net in-place, doesn't need init_net.
|
| 741 |
+
"""
|
| 742 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
| 743 |
+
|
| 744 |
+
ssa, blob_versions = core.get_ssa(predict_net)
|
| 745 |
+
|
| 746 |
+
versioned_inputs, versioned_outputs = ssa[op_id]
|
| 747 |
+
old_name, version = versioned_outputs[output_id]
|
| 748 |
+
|
| 749 |
+
# update predict_net
|
| 750 |
+
_rename_versioned_blob_in_proto(
|
| 751 |
+
predict_net, old_name, new_name, version, ssa, {}, blob_versions
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def get_sub_graph_external_input_output(
|
| 756 |
+
predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int]
|
| 757 |
+
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
| 758 |
+
"""
|
| 759 |
+
Return the list of external input/output of sub-graph,
|
| 760 |
+
each element is tuple of the name and corresponding version in predict_net.
|
| 761 |
+
|
| 762 |
+
external input/output is defined the same way as caffe2 NetDef.
|
| 763 |
+
"""
|
| 764 |
+
ssa, versions = core.get_ssa(predict_net)
|
| 765 |
+
|
| 766 |
+
all_inputs = []
|
| 767 |
+
all_outputs = []
|
| 768 |
+
for op_id in sub_graph_op_indices:
|
| 769 |
+
all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs]
|
| 770 |
+
all_outputs += list(ssa[op_id][1]) # ssa output won't repeat
|
| 771 |
+
|
| 772 |
+
# for versioned blobs, external inputs are just those blob in all_inputs
|
| 773 |
+
# but not in all_outputs
|
| 774 |
+
ext_inputs = [inp for inp in all_inputs if inp not in all_outputs]
|
| 775 |
+
|
| 776 |
+
# external outputs are essentially outputs of this subgraph that are used
|
| 777 |
+
# outside of this sub-graph (including predict_net.external_output)
|
| 778 |
+
all_other_inputs = sum(
|
| 779 |
+
(ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices),
|
| 780 |
+
[(outp, versions[outp]) for outp in predict_net.external_output],
|
| 781 |
+
)
|
| 782 |
+
ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)]
|
| 783 |
+
|
| 784 |
+
return ext_inputs, ext_outputs
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
class DiGraph:
|
| 788 |
+
"""A DAG representation of caffe2 graph, each vertice is a versioned blob."""
|
| 789 |
+
|
| 790 |
+
def __init__(self):
|
| 791 |
+
self.vertices = set()
|
| 792 |
+
self.graph = collections.defaultdict(list)
|
| 793 |
+
|
| 794 |
+
def add_edge(self, u, v):
|
| 795 |
+
self.graph[u].append(v)
|
| 796 |
+
self.vertices.add(u)
|
| 797 |
+
self.vertices.add(v)
|
| 798 |
+
|
| 799 |
+
# grab from https://www.geeksforgeeks.org/find-paths-given-source-destination/
|
| 800 |
+
def get_all_paths(self, s, d):
|
| 801 |
+
visited = {k: False for k in self.vertices}
|
| 802 |
+
path = []
|
| 803 |
+
all_paths = []
|
| 804 |
+
|
| 805 |
+
def _get_all_paths_util(graph, u, d, visited, path):
|
| 806 |
+
visited[u] = True
|
| 807 |
+
path.append(u)
|
| 808 |
+
if u == d:
|
| 809 |
+
all_paths.append(copy.deepcopy(path))
|
| 810 |
+
else:
|
| 811 |
+
for i in graph[u]:
|
| 812 |
+
if not visited[i]:
|
| 813 |
+
_get_all_paths_util(graph, i, d, visited, path)
|
| 814 |
+
path.pop()
|
| 815 |
+
visited[u] = False
|
| 816 |
+
|
| 817 |
+
_get_all_paths_util(self.graph, s, d, visited, path)
|
| 818 |
+
return all_paths
|
| 819 |
+
|
| 820 |
+
@staticmethod
|
| 821 |
+
def from_ssa(ssa):
|
| 822 |
+
graph = DiGraph()
|
| 823 |
+
for op_id in range(len(ssa)):
|
| 824 |
+
for inp in ssa[op_id][0]:
|
| 825 |
+
for outp in ssa[op_id][1]:
|
| 826 |
+
graph.add_edge(inp, outp)
|
| 827 |
+
return graph
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
def _get_dependency_chain(ssa, versioned_target, versioned_source):
|
| 831 |
+
"""
|
| 832 |
+
Return the index list of relevant operator to produce target blob from source blob,
|
| 833 |
+
if there's no dependency, return empty list.
|
| 834 |
+
"""
|
| 835 |
+
|
| 836 |
+
# finding all paths between nodes can be O(N!), thus we can only search
|
| 837 |
+
# in the subgraph using the op starting from the first consumer of source blob
|
| 838 |
+
# to the producer of the target blob.
|
| 839 |
+
consumer_map = get_consumer_map(ssa)
|
| 840 |
+
producer_map = get_producer_map(ssa)
|
| 841 |
+
start_op = min(x[0] for x in consumer_map[versioned_source]) - 15
|
| 842 |
+
end_op = (
|
| 843 |
+
producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op
|
| 844 |
+
)
|
| 845 |
+
sub_graph_ssa = ssa[start_op : end_op + 1]
|
| 846 |
+
if len(sub_graph_ssa) > 30:
|
| 847 |
+
logger.warning(
|
| 848 |
+
"Subgraph bebetween {} and {} is large (from op#{} to op#{}), it"
|
| 849 |
+
" might take non-trival time to find all paths between them.".format(
|
| 850 |
+
versioned_source, versioned_target, start_op, end_op
|
| 851 |
+
)
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
dag = DiGraph.from_ssa(sub_graph_ssa)
|
| 855 |
+
paths = dag.get_all_paths(versioned_source, versioned_target) # include two ends
|
| 856 |
+
ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths]
|
| 857 |
+
return sorted(set().union(*[set(ops) for ops in ops_in_paths]))
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]:
|
| 861 |
+
"""
|
| 862 |
+
Idenfity the reshape sub-graph in a protobuf.
|
| 863 |
+
The reshape sub-graph is defined as matching the following pattern:
|
| 864 |
+
|
| 865 |
+
(input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐
|
| 866 |
+
└-------------------------------------------> Reshape -> (output_blob)
|
| 867 |
+
|
| 868 |
+
Return:
|
| 869 |
+
List of sub-graphs, each sub-graph is represented as a list of indices
|
| 870 |
+
of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape]
|
| 871 |
+
"""
|
| 872 |
+
|
| 873 |
+
ssa, _ = core.get_ssa(predict_net)
|
| 874 |
+
|
| 875 |
+
ret = []
|
| 876 |
+
for i, op in enumerate(predict_net.op):
|
| 877 |
+
if op.type == "Reshape":
|
| 878 |
+
assert len(op.input) == 2
|
| 879 |
+
input_ssa = ssa[i][0]
|
| 880 |
+
data_source = input_ssa[0]
|
| 881 |
+
shape_source = input_ssa[1]
|
| 882 |
+
op_indices = _get_dependency_chain(ssa, shape_source, data_source)
|
| 883 |
+
ret.append(op_indices + [i])
|
| 884 |
+
return ret
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def remove_reshape_for_fc(predict_net, params):
|
| 888 |
+
"""
|
| 889 |
+
In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape
|
| 890 |
+
a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping
|
| 891 |
+
doesn't work well with ONNX and Int8 tools, and cause using extra
|
| 892 |
+
ops (eg. ExpandDims) that might not be available on mobile.
|
| 893 |
+
Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape
|
| 894 |
+
after exporting ONNX model.
|
| 895 |
+
"""
|
| 896 |
+
from caffe2.python import core
|
| 897 |
+
|
| 898 |
+
# find all reshape sub-graph that can be removed, which is now all Reshape
|
| 899 |
+
# sub-graph whose output is only consumed by FC.
|
| 900 |
+
# TODO: to make it safer, we may need the actually value to better determine
|
| 901 |
+
# if a Reshape before FC is removable.
|
| 902 |
+
reshape_sub_graphs = identify_reshape_sub_graph(predict_net)
|
| 903 |
+
sub_graphs_to_remove = []
|
| 904 |
+
for reshape_sub_graph in reshape_sub_graphs:
|
| 905 |
+
reshape_op_id = reshape_sub_graph[-1]
|
| 906 |
+
assert predict_net.op[reshape_op_id].type == "Reshape"
|
| 907 |
+
ssa, _ = core.get_ssa(predict_net)
|
| 908 |
+
reshape_output = ssa[reshape_op_id][1][0]
|
| 909 |
+
consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]]
|
| 910 |
+
if all(predict_net.op[consumer].type == "FC" for consumer in consumers):
|
| 911 |
+
# safety check if the sub-graph is isolated, for this reshape sub-graph,
|
| 912 |
+
# it means it has one non-param external input and one external output.
|
| 913 |
+
ext_inputs, ext_outputs = get_sub_graph_external_input_output(
|
| 914 |
+
predict_net, reshape_sub_graph
|
| 915 |
+
)
|
| 916 |
+
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
| 917 |
+
if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1:
|
| 918 |
+
sub_graphs_to_remove.append(reshape_sub_graph)
|
| 919 |
+
|
| 920 |
+
# perform removing subgraph by:
|
| 921 |
+
# 1: rename the Reshape's output to its input, then the graph can be
|
| 922 |
+
# seen as in-place itentify, meaning whose external input/output are the same.
|
| 923 |
+
# 2: simply remove those ops.
|
| 924 |
+
remove_op_ids = []
|
| 925 |
+
params_to_remove = []
|
| 926 |
+
for sub_graph in sub_graphs_to_remove:
|
| 927 |
+
logger.info(
|
| 928 |
+
"Remove Reshape sub-graph:\n{}".format(
|
| 929 |
+
"".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph])
|
| 930 |
+
)
|
| 931 |
+
)
|
| 932 |
+
reshape_op_id = sub_graph[-1]
|
| 933 |
+
new_reshap_output = predict_net.op[reshape_op_id].input[0]
|
| 934 |
+
rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output)
|
| 935 |
+
ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph)
|
| 936 |
+
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
| 937 |
+
params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0]
|
| 938 |
+
assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1
|
| 939 |
+
assert ext_outputs[0][0] == non_params_ext_inputs[0][0]
|
| 940 |
+
assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1
|
| 941 |
+
remove_op_ids.extend(sub_graph)
|
| 942 |
+
params_to_remove.extend(params_ext_inputs)
|
| 943 |
+
|
| 944 |
+
predict_net = copy.deepcopy(predict_net)
|
| 945 |
+
new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids]
|
| 946 |
+
del predict_net.op[:]
|
| 947 |
+
predict_net.op.extend(new_ops)
|
| 948 |
+
for versioned_params in params_to_remove:
|
| 949 |
+
name = versioned_params[0]
|
| 950 |
+
logger.info("Remove params: {} from init_net and predict_net.external_input".format(name))
|
| 951 |
+
del params[name]
|
| 952 |
+
predict_net.external_input.remove(name)
|
| 953 |
+
|
| 954 |
+
return predict_net, params
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef):
|
| 958 |
+
"""
|
| 959 |
+
In-place fuse extra copy ops between cpu/gpu for the following case:
|
| 960 |
+
a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1
|
| 961 |
+
-CopyBToA> c2 -NextOp2-> d2
|
| 962 |
+
The fused network will look like:
|
| 963 |
+
a -NextOp1-> d1
|
| 964 |
+
-NextOp2-> d2
|
| 965 |
+
"""
|
| 966 |
+
|
| 967 |
+
_COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"]
|
| 968 |
+
|
| 969 |
+
def _fuse_once(predict_net):
|
| 970 |
+
ssa, blob_versions = core.get_ssa(predict_net)
|
| 971 |
+
consumer_map = get_consumer_map(ssa)
|
| 972 |
+
versioned_external_output = [
|
| 973 |
+
(name, blob_versions[name]) for name in predict_net.external_output
|
| 974 |
+
]
|
| 975 |
+
|
| 976 |
+
for op_id, op in enumerate(predict_net.op):
|
| 977 |
+
if op.type in _COPY_OPS:
|
| 978 |
+
fw_copy_versioned_output = ssa[op_id][1][0]
|
| 979 |
+
consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]]
|
| 980 |
+
reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)]
|
| 981 |
+
|
| 982 |
+
is_fusable = (
|
| 983 |
+
len(consumer_ids) > 0
|
| 984 |
+
and fw_copy_versioned_output not in versioned_external_output
|
| 985 |
+
and all(
|
| 986 |
+
predict_net.op[_op_id].type == reverse_op_type
|
| 987 |
+
and ssa[_op_id][1][0] not in versioned_external_output
|
| 988 |
+
for _op_id in consumer_ids
|
| 989 |
+
)
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
if is_fusable:
|
| 993 |
+
for rv_copy_op_id in consumer_ids:
|
| 994 |
+
# making each NextOp uses "a" directly and removing Copy ops
|
| 995 |
+
rs_copy_versioned_output = ssa[rv_copy_op_id][1][0]
|
| 996 |
+
next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0]
|
| 997 |
+
predict_net.op[next_op_id].input[inp_id] = op.input[0]
|
| 998 |
+
# remove CopyOps
|
| 999 |
+
new_ops = [
|
| 1000 |
+
op
|
| 1001 |
+
for i, op in enumerate(predict_net.op)
|
| 1002 |
+
if i != op_id and i not in consumer_ids
|
| 1003 |
+
]
|
| 1004 |
+
del predict_net.op[:]
|
| 1005 |
+
predict_net.op.extend(new_ops)
|
| 1006 |
+
return True
|
| 1007 |
+
|
| 1008 |
+
return False
|
| 1009 |
+
|
| 1010 |
+
# _fuse_once returns False is nothing can be fused
|
| 1011 |
+
while _fuse_once(predict_net):
|
| 1012 |
+
pass
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
def remove_dead_end_ops(net_def: caffe2_pb2.NetDef):
|
| 1016 |
+
"""remove ops if its output is not used or not in external_output"""
|
| 1017 |
+
ssa, versions = core.get_ssa(net_def)
|
| 1018 |
+
versioned_external_output = [(name, versions[name]) for name in net_def.external_output]
|
| 1019 |
+
consumer_map = get_consumer_map(ssa)
|
| 1020 |
+
removed_op_ids = set()
|
| 1021 |
+
|
| 1022 |
+
def _is_dead_end(versioned_blob):
|
| 1023 |
+
return not (
|
| 1024 |
+
versioned_blob in versioned_external_output
|
| 1025 |
+
or (
|
| 1026 |
+
len(consumer_map[versioned_blob]) > 0
|
| 1027 |
+
and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob])
|
| 1028 |
+
)
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
for i, ssa_i in reversed(list(enumerate(ssa))):
|
| 1032 |
+
versioned_outputs = ssa_i[1]
|
| 1033 |
+
if all(_is_dead_end(outp) for outp in versioned_outputs):
|
| 1034 |
+
removed_op_ids.add(i)
|
| 1035 |
+
|
| 1036 |
+
# simply removing those deadend ops should have no effect to external_output
|
| 1037 |
+
new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids]
|
| 1038 |
+
del net_def.op[:]
|
| 1039 |
+
net_def.op.extend(new_ops)
|
Leffa/3rdparty/detectron2/export/torchscript.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from detectron2.utils.file_io import PathManager
|
| 7 |
+
|
| 8 |
+
from .torchscript_patch import freeze_training_mode, patch_instances
|
| 9 |
+
|
| 10 |
+
__all__ = ["scripting_with_instances", "dump_torchscript_IR"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def scripting_with_instances(model, fields):
|
| 14 |
+
"""
|
| 15 |
+
Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
|
| 16 |
+
attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
|
| 17 |
+
for scripting to support it out of the box. This function is made to support scripting
|
| 18 |
+
a model that uses :class:`Instances`. It does the following:
|
| 19 |
+
|
| 20 |
+
1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
|
| 21 |
+
but with all attributes been "static".
|
| 22 |
+
The attributes need to be statically declared in the ``fields`` argument.
|
| 23 |
+
2. Register ``new_Instances``, and force scripting compiler to
|
| 24 |
+
use it when trying to compile ``Instances``.
|
| 25 |
+
|
| 26 |
+
After this function, the process will be reverted. User should be able to script another model
|
| 27 |
+
using different fields.
|
| 28 |
+
|
| 29 |
+
Example:
|
| 30 |
+
Assume that ``Instances`` in the model consist of two attributes named
|
| 31 |
+
``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
|
| 32 |
+
:class:`Tensor` respectively during inference. You can call this function like:
|
| 33 |
+
::
|
| 34 |
+
fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
|
| 35 |
+
torchscipt_model = scripting_with_instances(model, fields)
|
| 36 |
+
|
| 37 |
+
Note:
|
| 38 |
+
It only support models in evaluation mode.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model (nn.Module): The input model to be exported by scripting.
|
| 42 |
+
fields (Dict[str, type]): Attribute names and corresponding type that
|
| 43 |
+
``Instances`` will use in the model. Note that all attributes used in ``Instances``
|
| 44 |
+
need to be added, regardless of whether they are inputs/outputs of the model.
|
| 45 |
+
Data type not defined in detectron2 is not supported for now.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
torch.jit.ScriptModule: the model in torchscript format
|
| 49 |
+
"""
|
| 50 |
+
assert (
|
| 51 |
+
not model.training
|
| 52 |
+
), "Currently we only support exporting models in evaluation mode to torchscript"
|
| 53 |
+
|
| 54 |
+
with freeze_training_mode(model), patch_instances(fields):
|
| 55 |
+
scripted_model = torch.jit.script(model)
|
| 56 |
+
return scripted_model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# alias for old name
|
| 60 |
+
export_torchscript_with_instances = scripting_with_instances
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def dump_torchscript_IR(model, dir):
|
| 64 |
+
"""
|
| 65 |
+
Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph,
|
| 66 |
+
inlined graph). Useful for debugging.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
|
| 70 |
+
dir (str): output directory to dump files.
|
| 71 |
+
"""
|
| 72 |
+
dir = os.path.expanduser(dir)
|
| 73 |
+
PathManager.mkdirs(dir)
|
| 74 |
+
|
| 75 |
+
def _get_script_mod(mod):
|
| 76 |
+
if isinstance(mod, torch.jit.TracedModule):
|
| 77 |
+
return mod._actual_script_module
|
| 78 |
+
return mod
|
| 79 |
+
|
| 80 |
+
# Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
|
| 81 |
+
with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:
|
| 82 |
+
|
| 83 |
+
def get_code(mod):
|
| 84 |
+
# Try a few ways to get code using private attributes.
|
| 85 |
+
try:
|
| 86 |
+
# This contains more information than just `mod.code`
|
| 87 |
+
return _get_script_mod(mod)._c.code
|
| 88 |
+
except AttributeError:
|
| 89 |
+
pass
|
| 90 |
+
try:
|
| 91 |
+
return mod.code
|
| 92 |
+
except AttributeError:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
def dump_code(prefix, mod):
|
| 96 |
+
code = get_code(mod)
|
| 97 |
+
name = prefix or "root model"
|
| 98 |
+
if code is None:
|
| 99 |
+
f.write(f"Could not found code for {name} (type={mod.original_name})\n")
|
| 100 |
+
f.write("\n")
|
| 101 |
+
else:
|
| 102 |
+
f.write(f"\nCode for {name}, type={mod.original_name}:\n")
|
| 103 |
+
f.write(code)
|
| 104 |
+
f.write("\n")
|
| 105 |
+
f.write("-" * 80)
|
| 106 |
+
|
| 107 |
+
for name, m in mod.named_children():
|
| 108 |
+
dump_code(prefix + "." + name, m)
|
| 109 |
+
|
| 110 |
+
if isinstance(model, torch.jit.ScriptFunction):
|
| 111 |
+
f.write(get_code(model))
|
| 112 |
+
else:
|
| 113 |
+
dump_code("", model)
|
| 114 |
+
|
| 115 |
+
def _get_graph(model):
|
| 116 |
+
try:
|
| 117 |
+
# Recursively dump IR of all modules
|
| 118 |
+
return _get_script_mod(model)._c.dump_to_str(True, False, False)
|
| 119 |
+
except AttributeError:
|
| 120 |
+
return model.graph.str()
|
| 121 |
+
|
| 122 |
+
with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
|
| 123 |
+
f.write(_get_graph(model))
|
| 124 |
+
|
| 125 |
+
# Dump IR of the entire graph (all submodules inlined)
|
| 126 |
+
with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
|
| 127 |
+
f.write(str(model.inlined_graph))
|
| 128 |
+
|
| 129 |
+
if not isinstance(model, torch.jit.ScriptFunction):
|
| 130 |
+
# Dump the model structure in pytorch style
|
| 131 |
+
with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
|
| 132 |
+
f.write(str(model))
|
Leffa/3rdparty/detectron2/export/torchscript_patch.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import tempfile
|
| 6 |
+
from contextlib import ExitStack, contextmanager
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from unittest import mock
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
# need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964
|
| 13 |
+
import detectron2 # noqa F401
|
| 14 |
+
from detectron2.structures import Boxes, Instances
|
| 15 |
+
from detectron2.utils.env import _import_file
|
| 16 |
+
|
| 17 |
+
_counter = 0
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _clear_jit_cache():
|
| 21 |
+
from torch.jit._recursive import concrete_type_store
|
| 22 |
+
from torch.jit._state import _jit_caching_layer
|
| 23 |
+
|
| 24 |
+
concrete_type_store.type_store.clear() # for modules
|
| 25 |
+
_jit_caching_layer.clear() # for free functions
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _add_instances_conversion_methods(newInstances):
|
| 29 |
+
"""
|
| 30 |
+
Add from_instances methods to the scripted Instances class.
|
| 31 |
+
"""
|
| 32 |
+
cls_name = newInstances.__name__
|
| 33 |
+
|
| 34 |
+
@torch.jit.unused
|
| 35 |
+
def from_instances(instances: Instances):
|
| 36 |
+
"""
|
| 37 |
+
Create scripted Instances from original Instances
|
| 38 |
+
"""
|
| 39 |
+
fields = instances.get_fields()
|
| 40 |
+
image_size = instances.image_size
|
| 41 |
+
ret = newInstances(image_size)
|
| 42 |
+
for name, val in fields.items():
|
| 43 |
+
assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}"
|
| 44 |
+
setattr(ret, name, deepcopy(val))
|
| 45 |
+
return ret
|
| 46 |
+
|
| 47 |
+
newInstances.from_instances = from_instances
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@contextmanager
|
| 51 |
+
def patch_instances(fields):
|
| 52 |
+
"""
|
| 53 |
+
A contextmanager, under which the Instances class in detectron2 is replaced
|
| 54 |
+
by a statically-typed scriptable class, defined by `fields`.
|
| 55 |
+
See more in `scripting_with_instances`.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
|
| 59 |
+
mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
|
| 60 |
+
) as f:
|
| 61 |
+
try:
|
| 62 |
+
# Objects that use Instances should not reuse previously-compiled
|
| 63 |
+
# results in cache, because `Instances` could be a new class each time.
|
| 64 |
+
_clear_jit_cache()
|
| 65 |
+
|
| 66 |
+
cls_name, s = _gen_instance_module(fields)
|
| 67 |
+
f.write(s)
|
| 68 |
+
f.flush()
|
| 69 |
+
f.close()
|
| 70 |
+
|
| 71 |
+
module = _import(f.name)
|
| 72 |
+
new_instances = getattr(module, cls_name)
|
| 73 |
+
_ = torch.jit.script(new_instances)
|
| 74 |
+
# let torchscript think Instances was scripted already
|
| 75 |
+
Instances.__torch_script_class__ = True
|
| 76 |
+
# let torchscript find new_instances when looking for the jit type of Instances
|
| 77 |
+
Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
|
| 78 |
+
|
| 79 |
+
_add_instances_conversion_methods(new_instances)
|
| 80 |
+
yield new_instances
|
| 81 |
+
finally:
|
| 82 |
+
try:
|
| 83 |
+
del Instances.__torch_script_class__
|
| 84 |
+
del Instances._jit_override_qualname
|
| 85 |
+
except AttributeError:
|
| 86 |
+
pass
|
| 87 |
+
sys.modules.pop(module.__name__)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _gen_instance_class(fields):
|
| 91 |
+
"""
|
| 92 |
+
Args:
|
| 93 |
+
fields (dict[name: type])
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
class _FieldType:
|
| 97 |
+
def __init__(self, name, type_):
|
| 98 |
+
assert isinstance(name, str), f"Field name must be str, got {name}"
|
| 99 |
+
self.name = name
|
| 100 |
+
self.type_ = type_
|
| 101 |
+
self.annotation = f"{type_.__module__}.{type_.__name__}"
|
| 102 |
+
|
| 103 |
+
fields = [_FieldType(k, v) for k, v in fields.items()]
|
| 104 |
+
|
| 105 |
+
def indent(level, s):
|
| 106 |
+
return " " * 4 * level + s
|
| 107 |
+
|
| 108 |
+
lines = []
|
| 109 |
+
|
| 110 |
+
global _counter
|
| 111 |
+
_counter += 1
|
| 112 |
+
|
| 113 |
+
cls_name = "ScriptedInstances{}".format(_counter)
|
| 114 |
+
|
| 115 |
+
field_names = tuple(x.name for x in fields)
|
| 116 |
+
extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields])
|
| 117 |
+
lines.append(
|
| 118 |
+
f"""
|
| 119 |
+
class {cls_name}:
|
| 120 |
+
def __init__(self, image_size: Tuple[int, int], {extra_args}):
|
| 121 |
+
self.image_size = image_size
|
| 122 |
+
self._field_names = {field_names}
|
| 123 |
+
"""
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
for f in fields:
|
| 127 |
+
lines.append(
|
| 128 |
+
indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})")
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
for f in fields:
|
| 132 |
+
lines.append(
|
| 133 |
+
f"""
|
| 134 |
+
@property
|
| 135 |
+
def {f.name}(self) -> {f.annotation}:
|
| 136 |
+
# has to use a local for type refinement
|
| 137 |
+
# https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
|
| 138 |
+
t = self._{f.name}
|
| 139 |
+
assert t is not None, "{f.name} is None and cannot be accessed!"
|
| 140 |
+
return t
|
| 141 |
+
|
| 142 |
+
@{f.name}.setter
|
| 143 |
+
def {f.name}(self, value: {f.annotation}) -> None:
|
| 144 |
+
self._{f.name} = value
|
| 145 |
+
"""
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# support method `__len__`
|
| 149 |
+
lines.append(
|
| 150 |
+
"""
|
| 151 |
+
def __len__(self) -> int:
|
| 152 |
+
"""
|
| 153 |
+
)
|
| 154 |
+
for f in fields:
|
| 155 |
+
lines.append(
|
| 156 |
+
f"""
|
| 157 |
+
t = self._{f.name}
|
| 158 |
+
if t is not None:
|
| 159 |
+
return len(t)
|
| 160 |
+
"""
|
| 161 |
+
)
|
| 162 |
+
lines.append(
|
| 163 |
+
"""
|
| 164 |
+
raise NotImplementedError("Empty Instances does not support __len__!")
|
| 165 |
+
"""
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# support method `has`
|
| 169 |
+
lines.append(
|
| 170 |
+
"""
|
| 171 |
+
def has(self, name: str) -> bool:
|
| 172 |
+
"""
|
| 173 |
+
)
|
| 174 |
+
for f in fields:
|
| 175 |
+
lines.append(
|
| 176 |
+
f"""
|
| 177 |
+
if name == "{f.name}":
|
| 178 |
+
return self._{f.name} is not None
|
| 179 |
+
"""
|
| 180 |
+
)
|
| 181 |
+
lines.append(
|
| 182 |
+
"""
|
| 183 |
+
return False
|
| 184 |
+
"""
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# support method `to`
|
| 188 |
+
none_args = ", None" * len(fields)
|
| 189 |
+
lines.append(
|
| 190 |
+
f"""
|
| 191 |
+
def to(self, device: torch.device) -> "{cls_name}":
|
| 192 |
+
ret = {cls_name}(self.image_size{none_args})
|
| 193 |
+
"""
|
| 194 |
+
)
|
| 195 |
+
for f in fields:
|
| 196 |
+
if hasattr(f.type_, "to"):
|
| 197 |
+
lines.append(
|
| 198 |
+
f"""
|
| 199 |
+
t = self._{f.name}
|
| 200 |
+
if t is not None:
|
| 201 |
+
ret._{f.name} = t.to(device)
|
| 202 |
+
"""
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
# For now, ignore fields that cannot be moved to devices.
|
| 206 |
+
# Maybe can support other tensor-like classes (e.g. __torch_function__)
|
| 207 |
+
pass
|
| 208 |
+
lines.append(
|
| 209 |
+
"""
|
| 210 |
+
return ret
|
| 211 |
+
"""
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# support method `getitem`
|
| 215 |
+
none_args = ", None" * len(fields)
|
| 216 |
+
lines.append(
|
| 217 |
+
f"""
|
| 218 |
+
def __getitem__(self, item) -> "{cls_name}":
|
| 219 |
+
ret = {cls_name}(self.image_size{none_args})
|
| 220 |
+
"""
|
| 221 |
+
)
|
| 222 |
+
for f in fields:
|
| 223 |
+
lines.append(
|
| 224 |
+
f"""
|
| 225 |
+
t = self._{f.name}
|
| 226 |
+
if t is not None:
|
| 227 |
+
ret._{f.name} = t[item]
|
| 228 |
+
"""
|
| 229 |
+
)
|
| 230 |
+
lines.append(
|
| 231 |
+
"""
|
| 232 |
+
return ret
|
| 233 |
+
"""
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# support method `cat`
|
| 237 |
+
# this version does not contain checks that all instances have same size and fields
|
| 238 |
+
none_args = ", None" * len(fields)
|
| 239 |
+
lines.append(
|
| 240 |
+
f"""
|
| 241 |
+
def cat(self, instances: List["{cls_name}"]) -> "{cls_name}":
|
| 242 |
+
ret = {cls_name}(self.image_size{none_args})
|
| 243 |
+
"""
|
| 244 |
+
)
|
| 245 |
+
for f in fields:
|
| 246 |
+
lines.append(
|
| 247 |
+
f"""
|
| 248 |
+
t = self._{f.name}
|
| 249 |
+
if t is not None:
|
| 250 |
+
values: List[{f.annotation}] = [x.{f.name} for x in instances]
|
| 251 |
+
if torch.jit.isinstance(t, torch.Tensor):
|
| 252 |
+
ret._{f.name} = torch.cat(values, dim=0)
|
| 253 |
+
else:
|
| 254 |
+
ret._{f.name} = t.cat(values)
|
| 255 |
+
"""
|
| 256 |
+
)
|
| 257 |
+
lines.append(
|
| 258 |
+
"""
|
| 259 |
+
return ret"""
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# support method `get_fields()`
|
| 263 |
+
lines.append(
|
| 264 |
+
"""
|
| 265 |
+
def get_fields(self) -> Dict[str, Tensor]:
|
| 266 |
+
ret = {}
|
| 267 |
+
"""
|
| 268 |
+
)
|
| 269 |
+
for f in fields:
|
| 270 |
+
if f.type_ == Boxes:
|
| 271 |
+
stmt = "t.tensor"
|
| 272 |
+
elif f.type_ == torch.Tensor:
|
| 273 |
+
stmt = "t"
|
| 274 |
+
else:
|
| 275 |
+
stmt = f'assert False, "unsupported type {str(f.type_)}"'
|
| 276 |
+
lines.append(
|
| 277 |
+
f"""
|
| 278 |
+
t = self._{f.name}
|
| 279 |
+
if t is not None:
|
| 280 |
+
ret["{f.name}"] = {stmt}
|
| 281 |
+
"""
|
| 282 |
+
)
|
| 283 |
+
lines.append(
|
| 284 |
+
"""
|
| 285 |
+
return ret"""
|
| 286 |
+
)
|
| 287 |
+
return cls_name, os.linesep.join(lines)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def _gen_instance_module(fields):
|
| 291 |
+
# TODO: find a more automatic way to enable import of other classes
|
| 292 |
+
s = """
|
| 293 |
+
from copy import deepcopy
|
| 294 |
+
import torch
|
| 295 |
+
from torch import Tensor
|
| 296 |
+
import typing
|
| 297 |
+
from typing import *
|
| 298 |
+
|
| 299 |
+
import detectron2
|
| 300 |
+
from detectron2.structures import Boxes, Instances
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
cls_name, cls_def = _gen_instance_class(fields)
|
| 305 |
+
s += cls_def
|
| 306 |
+
return cls_name, s
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _import(path):
|
| 310 |
+
return _import_file(
|
| 311 |
+
"{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@contextmanager
|
| 316 |
+
def patch_builtin_len(modules=()):
|
| 317 |
+
"""
|
| 318 |
+
Patch the builtin len() function of a few detectron2 modules
|
| 319 |
+
to use __len__ instead, because __len__ does not convert values to
|
| 320 |
+
integers and therefore is friendly to tracing.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
modules (list[stsr]): names of extra modules to patch len(), in
|
| 324 |
+
addition to those in detectron2.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def _new_len(obj):
|
| 328 |
+
return obj.__len__()
|
| 329 |
+
|
| 330 |
+
with ExitStack() as stack:
|
| 331 |
+
MODULES = [
|
| 332 |
+
"detectron2.modeling.roi_heads.fast_rcnn",
|
| 333 |
+
"detectron2.modeling.roi_heads.mask_head",
|
| 334 |
+
"detectron2.modeling.roi_heads.keypoint_head",
|
| 335 |
+
] + list(modules)
|
| 336 |
+
ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES]
|
| 337 |
+
for m in ctxs:
|
| 338 |
+
m.side_effect = _new_len
|
| 339 |
+
yield
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def patch_nonscriptable_classes():
|
| 343 |
+
"""
|
| 344 |
+
Apply patches on a few nonscriptable detectron2 classes.
|
| 345 |
+
Should not have side-effects on eager usage.
|
| 346 |
+
"""
|
| 347 |
+
# __prepare_scriptable__ can also be added to models for easier maintenance.
|
| 348 |
+
# But it complicates the clean model code.
|
| 349 |
+
|
| 350 |
+
from detectron2.modeling.backbone import ResNet, FPN
|
| 351 |
+
|
| 352 |
+
# Due to https://github.com/pytorch/pytorch/issues/36061,
|
| 353 |
+
# we change backbone to use ModuleList for scripting.
|
| 354 |
+
# (note: this changes param names in state_dict)
|
| 355 |
+
|
| 356 |
+
def prepare_resnet(self):
|
| 357 |
+
ret = deepcopy(self)
|
| 358 |
+
ret.stages = nn.ModuleList(ret.stages)
|
| 359 |
+
for k in self.stage_names:
|
| 360 |
+
delattr(ret, k)
|
| 361 |
+
return ret
|
| 362 |
+
|
| 363 |
+
ResNet.__prepare_scriptable__ = prepare_resnet
|
| 364 |
+
|
| 365 |
+
def prepare_fpn(self):
|
| 366 |
+
ret = deepcopy(self)
|
| 367 |
+
ret.lateral_convs = nn.ModuleList(ret.lateral_convs)
|
| 368 |
+
ret.output_convs = nn.ModuleList(ret.output_convs)
|
| 369 |
+
for name, _ in self.named_children():
|
| 370 |
+
if name.startswith("fpn_"):
|
| 371 |
+
delattr(ret, name)
|
| 372 |
+
return ret
|
| 373 |
+
|
| 374 |
+
FPN.__prepare_scriptable__ = prepare_fpn
|
| 375 |
+
|
| 376 |
+
# Annotate some attributes to be constants for the purpose of scripting,
|
| 377 |
+
# even though they are not constants in eager mode.
|
| 378 |
+
from detectron2.modeling.roi_heads import StandardROIHeads
|
| 379 |
+
|
| 380 |
+
if hasattr(StandardROIHeads, "__annotations__"):
|
| 381 |
+
# copy first to avoid editing annotations of base class
|
| 382 |
+
StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__)
|
| 383 |
+
StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool]
|
| 384 |
+
StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool]
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# These patches are not supposed to have side-effects.
|
| 388 |
+
patch_nonscriptable_classes()
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@contextmanager
|
| 392 |
+
def freeze_training_mode(model):
|
| 393 |
+
"""
|
| 394 |
+
A context manager that annotates the "training" attribute of every submodule
|
| 395 |
+
to constant, so that the training codepath in these modules can be
|
| 396 |
+
meta-compiled away. Upon exiting, the annotations are reverted.
|
| 397 |
+
"""
|
| 398 |
+
classes = {type(x) for x in model.modules()}
|
| 399 |
+
# __constants__ is the old way to annotate constants and not compatible
|
| 400 |
+
# with __annotations__ .
|
| 401 |
+
classes = {x for x in classes if not hasattr(x, "__constants__")}
|
| 402 |
+
for cls in classes:
|
| 403 |
+
cls.__annotations__["training"] = torch.jit.Final[bool]
|
| 404 |
+
yield
|
| 405 |
+
for cls in classes:
|
| 406 |
+
cls.__annotations__["training"] = bool
|
Leffa/SCHP/__init__.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from SCHP import networks
|
| 8 |
+
from SCHP.utils.transforms import get_affine_transform, transform_logits
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_palette(num_cls):
|
| 13 |
+
"""Returns the color map for visualizing the segmentation mask.
|
| 14 |
+
Args:
|
| 15 |
+
num_cls: Number of classes
|
| 16 |
+
Returns:
|
| 17 |
+
The color map
|
| 18 |
+
"""
|
| 19 |
+
n = num_cls
|
| 20 |
+
palette = [0] * (n * 3)
|
| 21 |
+
for j in range(0, n):
|
| 22 |
+
lab = j
|
| 23 |
+
palette[j * 3 + 0] = 0
|
| 24 |
+
palette[j * 3 + 1] = 0
|
| 25 |
+
palette[j * 3 + 2] = 0
|
| 26 |
+
i = 0
|
| 27 |
+
while lab:
|
| 28 |
+
palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
|
| 29 |
+
palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
|
| 30 |
+
palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
|
| 31 |
+
i += 1
|
| 32 |
+
lab >>= 3
|
| 33 |
+
return palette
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
dataset_settings = {
|
| 37 |
+
"lip": {
|
| 38 |
+
"input_size": [473, 473],
|
| 39 |
+
"num_classes": 20,
|
| 40 |
+
"label": [
|
| 41 |
+
"Background",
|
| 42 |
+
"Hat",
|
| 43 |
+
"Hair",
|
| 44 |
+
"Glove",
|
| 45 |
+
"Sunglasses",
|
| 46 |
+
"Upper-clothes",
|
| 47 |
+
"Dress",
|
| 48 |
+
"Coat",
|
| 49 |
+
"Socks",
|
| 50 |
+
"Pants",
|
| 51 |
+
"Jumpsuits",
|
| 52 |
+
"Scarf",
|
| 53 |
+
"Skirt",
|
| 54 |
+
"Face",
|
| 55 |
+
"Left-arm",
|
| 56 |
+
"Right-arm",
|
| 57 |
+
"Left-leg",
|
| 58 |
+
"Right-leg",
|
| 59 |
+
"Left-shoe",
|
| 60 |
+
"Right-shoe",
|
| 61 |
+
],
|
| 62 |
+
},
|
| 63 |
+
"atr": {
|
| 64 |
+
"input_size": [512, 512],
|
| 65 |
+
"num_classes": 18,
|
| 66 |
+
"label": [
|
| 67 |
+
"Background",
|
| 68 |
+
"Hat",
|
| 69 |
+
"Hair",
|
| 70 |
+
"Sunglasses",
|
| 71 |
+
"Upper-clothes",
|
| 72 |
+
"Skirt",
|
| 73 |
+
"Pants",
|
| 74 |
+
"Dress",
|
| 75 |
+
"Belt",
|
| 76 |
+
"Left-shoe",
|
| 77 |
+
"Right-shoe",
|
| 78 |
+
"Face",
|
| 79 |
+
"Left-leg",
|
| 80 |
+
"Right-leg",
|
| 81 |
+
"Left-arm",
|
| 82 |
+
"Right-arm",
|
| 83 |
+
"Bag",
|
| 84 |
+
"Scarf",
|
| 85 |
+
],
|
| 86 |
+
},
|
| 87 |
+
"pascal": {
|
| 88 |
+
"input_size": [512, 512],
|
| 89 |
+
"num_classes": 7,
|
| 90 |
+
"label": [
|
| 91 |
+
"Background",
|
| 92 |
+
"Head",
|
| 93 |
+
"Torso",
|
| 94 |
+
"Upper Arms",
|
| 95 |
+
"Lower Arms",
|
| 96 |
+
"Upper Legs",
|
| 97 |
+
"Lower Legs",
|
| 98 |
+
],
|
| 99 |
+
},
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SCHP:
|
| 104 |
+
def __init__(self, ckpt_path, device):
|
| 105 |
+
dataset_type = None
|
| 106 |
+
if "lip" in ckpt_path:
|
| 107 |
+
dataset_type = "lip"
|
| 108 |
+
elif "atr" in ckpt_path:
|
| 109 |
+
dataset_type = "atr"
|
| 110 |
+
elif "pascal" in ckpt_path:
|
| 111 |
+
dataset_type = "pascal"
|
| 112 |
+
assert dataset_type is not None, "Dataset type not found in checkpoint path"
|
| 113 |
+
self.device = device
|
| 114 |
+
self.num_classes = dataset_settings[dataset_type]["num_classes"]
|
| 115 |
+
self.input_size = dataset_settings[dataset_type]["input_size"]
|
| 116 |
+
self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0]
|
| 117 |
+
self.palette = get_palette(self.num_classes)
|
| 118 |
+
|
| 119 |
+
self.label = dataset_settings[dataset_type]["label"]
|
| 120 |
+
self.model = networks.init_model(
|
| 121 |
+
"resnet101", num_classes=self.num_classes, pretrained=None
|
| 122 |
+
).to(device)
|
| 123 |
+
self.load_ckpt(ckpt_path)
|
| 124 |
+
self.model.eval()
|
| 125 |
+
|
| 126 |
+
self.transform = transforms.Compose(
|
| 127 |
+
[
|
| 128 |
+
transforms.ToTensor(),
|
| 129 |
+
transforms.Normalize(
|
| 130 |
+
mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]
|
| 131 |
+
),
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
self.upsample = torch.nn.Upsample(
|
| 135 |
+
size=self.input_size, mode="bilinear", align_corners=True
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def load_ckpt(self, ckpt_path):
|
| 139 |
+
rename_map = {
|
| 140 |
+
"decoder.conv3.2.weight": "decoder.conv3.3.weight",
|
| 141 |
+
"decoder.conv3.3.weight": "decoder.conv3.4.weight",
|
| 142 |
+
"decoder.conv3.3.bias": "decoder.conv3.4.bias",
|
| 143 |
+
"decoder.conv3.3.running_mean": "decoder.conv3.4.running_mean",
|
| 144 |
+
"decoder.conv3.3.running_var": "decoder.conv3.4.running_var",
|
| 145 |
+
"fushion.3.weight": "fushion.4.weight",
|
| 146 |
+
"fushion.3.bias": "fushion.4.bias",
|
| 147 |
+
}
|
| 148 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 149 |
+
new_state_dict = OrderedDict()
|
| 150 |
+
for k, v in state_dict.items():
|
| 151 |
+
name = k[7:] # remove `module.`
|
| 152 |
+
new_state_dict[name] = v
|
| 153 |
+
new_state_dict_ = OrderedDict()
|
| 154 |
+
for k, v in list(new_state_dict.items()):
|
| 155 |
+
if k in rename_map:
|
| 156 |
+
new_state_dict_[rename_map[k]] = v
|
| 157 |
+
else:
|
| 158 |
+
new_state_dict_[k] = v
|
| 159 |
+
self.model.load_state_dict(new_state_dict_, strict=False)
|
| 160 |
+
|
| 161 |
+
def _box2cs(self, box):
|
| 162 |
+
x, y, w, h = box[:4]
|
| 163 |
+
return self._xywh2cs(x, y, w, h)
|
| 164 |
+
|
| 165 |
+
def _xywh2cs(self, x, y, w, h):
|
| 166 |
+
center = np.zeros((2), dtype=np.float32)
|
| 167 |
+
center[0] = x + w * 0.5
|
| 168 |
+
center[1] = y + h * 0.5
|
| 169 |
+
if w > self.aspect_ratio * h:
|
| 170 |
+
h = w * 1.0 / self.aspect_ratio
|
| 171 |
+
elif w < self.aspect_ratio * h:
|
| 172 |
+
w = h * self.aspect_ratio
|
| 173 |
+
scale = np.array([w, h], dtype=np.float32)
|
| 174 |
+
return center, scale
|
| 175 |
+
|
| 176 |
+
def preprocess(self, image):
|
| 177 |
+
if isinstance(image, str):
|
| 178 |
+
img = cv2.imread(image, cv2.IMREAD_COLOR)
|
| 179 |
+
elif isinstance(image, Image.Image):
|
| 180 |
+
# to cv2 format
|
| 181 |
+
img = np.array(image)
|
| 182 |
+
|
| 183 |
+
h, w, _ = img.shape
|
| 184 |
+
# Get person center and scale
|
| 185 |
+
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
| 186 |
+
r = 0
|
| 187 |
+
trans = get_affine_transform(person_center, s, r, self.input_size)
|
| 188 |
+
input = cv2.warpAffine(
|
| 189 |
+
img,
|
| 190 |
+
trans,
|
| 191 |
+
(int(self.input_size[1]), int(self.input_size[0])),
|
| 192 |
+
flags=cv2.INTER_LINEAR,
|
| 193 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 194 |
+
borderValue=(0, 0, 0),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
input = self.transform(input).to(self.device).unsqueeze(0)
|
| 198 |
+
meta = {
|
| 199 |
+
"center": person_center,
|
| 200 |
+
"height": h,
|
| 201 |
+
"width": w,
|
| 202 |
+
"scale": s,
|
| 203 |
+
"rotation": r,
|
| 204 |
+
}
|
| 205 |
+
return input, meta
|
| 206 |
+
|
| 207 |
+
def __call__(self, image_or_path):
|
| 208 |
+
if isinstance(image_or_path, list):
|
| 209 |
+
image_list = []
|
| 210 |
+
meta_list = []
|
| 211 |
+
for image in image_or_path:
|
| 212 |
+
image, meta = self.preprocess(image)
|
| 213 |
+
image_list.append(image)
|
| 214 |
+
meta_list.append(meta)
|
| 215 |
+
image = torch.cat(image_list, dim=0)
|
| 216 |
+
else:
|
| 217 |
+
image, meta = self.preprocess(image_or_path)
|
| 218 |
+
meta_list = [meta]
|
| 219 |
+
|
| 220 |
+
output = self.model(image)
|
| 221 |
+
# upsample_outputs = self.upsample(output[0][-1])
|
| 222 |
+
upsample_outputs = self.upsample(output)
|
| 223 |
+
upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
|
| 224 |
+
|
| 225 |
+
output_img_list = []
|
| 226 |
+
for upsample_output, meta in zip(upsample_outputs, meta_list):
|
| 227 |
+
c, s, w, h = meta["center"], meta["scale"], meta["width"], meta["height"]
|
| 228 |
+
logits_result = transform_logits(
|
| 229 |
+
upsample_output.data.cpu().numpy(),
|
| 230 |
+
c,
|
| 231 |
+
s,
|
| 232 |
+
w,
|
| 233 |
+
h,
|
| 234 |
+
input_size=self.input_size,
|
| 235 |
+
)
|
| 236 |
+
parsing_result = np.argmax(logits_result, axis=2)
|
| 237 |
+
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
|
| 238 |
+
output_img.putpalette(self.palette)
|
| 239 |
+
output_img_list.append(output_img)
|
| 240 |
+
|
| 241 |
+
return output_img_list[0] if len(output_img_list) == 1 else output_img_list
|
Leffa/SCHP/networks/AugmentCE2P.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
@Author : Peike Li
|
| 6 |
+
@Contact : peike.li@yahoo.com
|
| 7 |
+
@File : AugmentCE2P.py
|
| 8 |
+
@Time : 8/4/19 3:35 PM
|
| 9 |
+
@Desc :
|
| 10 |
+
@License : This source code is licensed under the license found in the
|
| 11 |
+
LICENSE file in the root directory of this source tree.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from torch.nn import BatchNorm2d, functional as F, LeakyReLU
|
| 18 |
+
|
| 19 |
+
affine_par = True
|
| 20 |
+
pretrained_settings = {
|
| 21 |
+
"resnet101": {
|
| 22 |
+
"imagenet": {
|
| 23 |
+
"input_space": "BGR",
|
| 24 |
+
"input_size": [3, 224, 224],
|
| 25 |
+
"input_range": [0, 1],
|
| 26 |
+
"mean": [0.406, 0.456, 0.485],
|
| 27 |
+
"std": [0.225, 0.224, 0.229],
|
| 28 |
+
"num_classes": 1000,
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 35 |
+
"3x3 convolution with padding"
|
| 36 |
+
return nn.Conv2d(
|
| 37 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Bottleneck(nn.Module):
|
| 42 |
+
expansion = 4
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
inplanes,
|
| 47 |
+
planes,
|
| 48 |
+
stride=1,
|
| 49 |
+
dilation=1,
|
| 50 |
+
downsample=None,
|
| 51 |
+
fist_dilation=1,
|
| 52 |
+
multi_grid=1,
|
| 53 |
+
):
|
| 54 |
+
super(Bottleneck, self).__init__()
|
| 55 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 56 |
+
self.bn1 = BatchNorm2d(planes)
|
| 57 |
+
self.conv2 = nn.Conv2d(
|
| 58 |
+
planes,
|
| 59 |
+
planes,
|
| 60 |
+
kernel_size=3,
|
| 61 |
+
stride=stride,
|
| 62 |
+
padding=dilation * multi_grid,
|
| 63 |
+
dilation=dilation * multi_grid,
|
| 64 |
+
bias=False,
|
| 65 |
+
)
|
| 66 |
+
self.bn2 = BatchNorm2d(planes)
|
| 67 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 68 |
+
self.bn3 = BatchNorm2d(planes * 4)
|
| 69 |
+
self.relu = nn.ReLU(inplace=False)
|
| 70 |
+
self.relu_inplace = nn.ReLU(inplace=True)
|
| 71 |
+
self.downsample = downsample
|
| 72 |
+
self.dilation = dilation
|
| 73 |
+
self.stride = stride
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
residual = x
|
| 77 |
+
|
| 78 |
+
out = self.conv1(x)
|
| 79 |
+
out = self.bn1(out)
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
|
| 82 |
+
out = self.conv2(out)
|
| 83 |
+
out = self.bn2(out)
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
|
| 86 |
+
out = self.conv3(out)
|
| 87 |
+
out = self.bn3(out)
|
| 88 |
+
|
| 89 |
+
if self.downsample is not None:
|
| 90 |
+
residual = self.downsample(x)
|
| 91 |
+
|
| 92 |
+
out = out + residual
|
| 93 |
+
out = self.relu_inplace(out)
|
| 94 |
+
|
| 95 |
+
return out
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class PSPModule(nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
Reference:
|
| 101 |
+
Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
|
| 105 |
+
super(PSPModule, self).__init__()
|
| 106 |
+
|
| 107 |
+
self.stages = []
|
| 108 |
+
self.stages = nn.ModuleList(
|
| 109 |
+
[self._make_stage(features, out_features, size) for size in sizes]
|
| 110 |
+
)
|
| 111 |
+
self.bottleneck = nn.Sequential(
|
| 112 |
+
nn.Conv2d(
|
| 113 |
+
features + len(sizes) * out_features,
|
| 114 |
+
out_features,
|
| 115 |
+
kernel_size=3,
|
| 116 |
+
padding=1,
|
| 117 |
+
dilation=1,
|
| 118 |
+
bias=False,
|
| 119 |
+
),
|
| 120 |
+
BatchNorm2d(out_features),
|
| 121 |
+
LeakyReLU(),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def _make_stage(self, features, out_features, size):
|
| 125 |
+
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
|
| 126 |
+
conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
|
| 127 |
+
return nn.Sequential(
|
| 128 |
+
prior,
|
| 129 |
+
conv,
|
| 130 |
+
# bn
|
| 131 |
+
BatchNorm2d(out_features),
|
| 132 |
+
LeakyReLU(),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def forward(self, feats):
|
| 136 |
+
h, w = feats.size(2), feats.size(3)
|
| 137 |
+
priors = [
|
| 138 |
+
F.interpolate(
|
| 139 |
+
input=stage(feats), size=(h, w), mode="bilinear", align_corners=True
|
| 140 |
+
)
|
| 141 |
+
for stage in self.stages
|
| 142 |
+
] + [feats]
|
| 143 |
+
bottle = self.bottleneck(torch.cat(priors, 1))
|
| 144 |
+
return bottle
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class ASPPModule(nn.Module):
|
| 148 |
+
"""
|
| 149 |
+
Reference:
|
| 150 |
+
Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self, features, inner_features=256, out_features=512, dilations=(12, 24, 36)
|
| 155 |
+
):
|
| 156 |
+
super(ASPPModule, self).__init__()
|
| 157 |
+
|
| 158 |
+
self.conv1 = nn.Sequential(
|
| 159 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
| 160 |
+
nn.Conv2d(
|
| 161 |
+
features,
|
| 162 |
+
inner_features,
|
| 163 |
+
kernel_size=1,
|
| 164 |
+
padding=0,
|
| 165 |
+
dilation=1,
|
| 166 |
+
bias=False,
|
| 167 |
+
),
|
| 168 |
+
# InPlaceABNSync(inner_features)
|
| 169 |
+
BatchNorm2d(inner_features),
|
| 170 |
+
LeakyReLU(),
|
| 171 |
+
)
|
| 172 |
+
self.conv2 = nn.Sequential(
|
| 173 |
+
nn.Conv2d(
|
| 174 |
+
features,
|
| 175 |
+
inner_features,
|
| 176 |
+
kernel_size=1,
|
| 177 |
+
padding=0,
|
| 178 |
+
dilation=1,
|
| 179 |
+
bias=False,
|
| 180 |
+
),
|
| 181 |
+
BatchNorm2d(inner_features),
|
| 182 |
+
LeakyReLU(),
|
| 183 |
+
)
|
| 184 |
+
self.conv3 = nn.Sequential(
|
| 185 |
+
nn.Conv2d(
|
| 186 |
+
features,
|
| 187 |
+
inner_features,
|
| 188 |
+
kernel_size=3,
|
| 189 |
+
padding=dilations[0],
|
| 190 |
+
dilation=dilations[0],
|
| 191 |
+
bias=False,
|
| 192 |
+
),
|
| 193 |
+
BatchNorm2d(inner_features),
|
| 194 |
+
LeakyReLU(),
|
| 195 |
+
)
|
| 196 |
+
self.conv4 = nn.Sequential(
|
| 197 |
+
nn.Conv2d(
|
| 198 |
+
features,
|
| 199 |
+
inner_features,
|
| 200 |
+
kernel_size=3,
|
| 201 |
+
padding=dilations[1],
|
| 202 |
+
dilation=dilations[1],
|
| 203 |
+
bias=False,
|
| 204 |
+
),
|
| 205 |
+
BatchNorm2d(inner_features),
|
| 206 |
+
LeakyReLU(),
|
| 207 |
+
)
|
| 208 |
+
self.conv5 = nn.Sequential(
|
| 209 |
+
nn.Conv2d(
|
| 210 |
+
features,
|
| 211 |
+
inner_features,
|
| 212 |
+
kernel_size=3,
|
| 213 |
+
padding=dilations[2],
|
| 214 |
+
dilation=dilations[2],
|
| 215 |
+
bias=False,
|
| 216 |
+
),
|
| 217 |
+
BatchNorm2d(inner_features),
|
| 218 |
+
LeakyReLU(),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
self.bottleneck = nn.Sequential(
|
| 222 |
+
nn.Conv2d(
|
| 223 |
+
inner_features * 5,
|
| 224 |
+
out_features,
|
| 225 |
+
kernel_size=1,
|
| 226 |
+
padding=0,
|
| 227 |
+
dilation=1,
|
| 228 |
+
bias=False,
|
| 229 |
+
),
|
| 230 |
+
BatchNorm2d(inner_features),
|
| 231 |
+
LeakyReLU(),
|
| 232 |
+
nn.Dropout2d(0.1),
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
_, _, h, w = x.size()
|
| 237 |
+
|
| 238 |
+
feat1 = F.interpolate(
|
| 239 |
+
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
feat2 = self.conv2(x)
|
| 243 |
+
feat3 = self.conv3(x)
|
| 244 |
+
feat4 = self.conv4(x)
|
| 245 |
+
feat5 = self.conv5(x)
|
| 246 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
|
| 247 |
+
|
| 248 |
+
bottle = self.bottleneck(out)
|
| 249 |
+
return bottle
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class Edge_Module(nn.Module):
|
| 253 |
+
"""
|
| 254 |
+
Edge Learning Branch
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(self, in_fea=[256, 512, 1024], mid_fea=256, out_fea=2):
|
| 258 |
+
super(Edge_Module, self).__init__()
|
| 259 |
+
|
| 260 |
+
self.conv1 = nn.Sequential(
|
| 261 |
+
nn.Conv2d(
|
| 262 |
+
in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
|
| 263 |
+
),
|
| 264 |
+
BatchNorm2d(mid_fea),
|
| 265 |
+
LeakyReLU(),
|
| 266 |
+
)
|
| 267 |
+
self.conv2 = nn.Sequential(
|
| 268 |
+
nn.Conv2d(
|
| 269 |
+
in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
|
| 270 |
+
),
|
| 271 |
+
BatchNorm2d(mid_fea),
|
| 272 |
+
LeakyReLU(),
|
| 273 |
+
)
|
| 274 |
+
self.conv3 = nn.Sequential(
|
| 275 |
+
nn.Conv2d(
|
| 276 |
+
in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
|
| 277 |
+
),
|
| 278 |
+
BatchNorm2d(mid_fea),
|
| 279 |
+
LeakyReLU(),
|
| 280 |
+
)
|
| 281 |
+
self.conv4 = nn.Conv2d(
|
| 282 |
+
mid_fea, out_fea, kernel_size=3, padding=1, dilation=1, bias=True
|
| 283 |
+
)
|
| 284 |
+
# self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
|
| 285 |
+
|
| 286 |
+
def forward(self, x1, x2, x3):
|
| 287 |
+
_, _, h, w = x1.size()
|
| 288 |
+
|
| 289 |
+
edge1_fea = self.conv1(x1)
|
| 290 |
+
# edge1 = self.conv4(edge1_fea)
|
| 291 |
+
edge2_fea = self.conv2(x2)
|
| 292 |
+
edge2 = self.conv4(edge2_fea)
|
| 293 |
+
edge3_fea = self.conv3(x3)
|
| 294 |
+
edge3 = self.conv4(edge3_fea)
|
| 295 |
+
|
| 296 |
+
edge2_fea = F.interpolate(
|
| 297 |
+
edge2_fea, size=(h, w), mode="bilinear", align_corners=True
|
| 298 |
+
)
|
| 299 |
+
edge3_fea = F.interpolate(
|
| 300 |
+
edge3_fea, size=(h, w), mode="bilinear", align_corners=True
|
| 301 |
+
)
|
| 302 |
+
edge2 = F.interpolate(edge2, size=(h, w), mode="bilinear", align_corners=True)
|
| 303 |
+
edge3 = F.interpolate(edge3, size=(h, w), mode="bilinear", align_corners=True)
|
| 304 |
+
|
| 305 |
+
# edge = torch.cat([edge1, edge2, edge3], dim=1)
|
| 306 |
+
edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
|
| 307 |
+
# edge = self.conv5(edge)
|
| 308 |
+
|
| 309 |
+
# return edge, edge_fea
|
| 310 |
+
return edge_fea
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class Decoder_Module(nn.Module):
|
| 314 |
+
"""
|
| 315 |
+
Parsing Branch Decoder Module.
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
def __init__(self, num_classes):
|
| 319 |
+
super(Decoder_Module, self).__init__()
|
| 320 |
+
self.conv1 = nn.Sequential(
|
| 321 |
+
nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
| 322 |
+
BatchNorm2d(256),
|
| 323 |
+
LeakyReLU(),
|
| 324 |
+
)
|
| 325 |
+
self.conv2 = nn.Sequential(
|
| 326 |
+
nn.Conv2d(
|
| 327 |
+
256, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False
|
| 328 |
+
),
|
| 329 |
+
BatchNorm2d(48),
|
| 330 |
+
LeakyReLU(),
|
| 331 |
+
)
|
| 332 |
+
self.conv3 = nn.Sequential(
|
| 333 |
+
nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
| 334 |
+
BatchNorm2d(256),
|
| 335 |
+
LeakyReLU(),
|
| 336 |
+
nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
| 337 |
+
BatchNorm2d(256),
|
| 338 |
+
LeakyReLU(),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
|
| 342 |
+
|
| 343 |
+
def forward(self, xt, xl):
|
| 344 |
+
_, _, h, w = xl.size()
|
| 345 |
+
xt = F.interpolate(
|
| 346 |
+
self.conv1(xt), size=(h, w), mode="bilinear", align_corners=True
|
| 347 |
+
)
|
| 348 |
+
xl = self.conv2(xl)
|
| 349 |
+
x = torch.cat([xt, xl], dim=1)
|
| 350 |
+
x = self.conv3(x)
|
| 351 |
+
# seg = self.conv4(x)
|
| 352 |
+
# return seg, x
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class ResNet(nn.Module):
|
| 357 |
+
def __init__(self, block, layers, num_classes):
|
| 358 |
+
self.inplanes = 128
|
| 359 |
+
super(ResNet, self).__init__()
|
| 360 |
+
self.conv1 = conv3x3(3, 64, stride=2)
|
| 361 |
+
self.bn1 = BatchNorm2d(64)
|
| 362 |
+
self.relu1 = nn.ReLU(inplace=False)
|
| 363 |
+
self.conv2 = conv3x3(64, 64)
|
| 364 |
+
self.bn2 = BatchNorm2d(64)
|
| 365 |
+
self.relu2 = nn.ReLU(inplace=False)
|
| 366 |
+
self.conv3 = conv3x3(64, 128)
|
| 367 |
+
self.bn3 = BatchNorm2d(128)
|
| 368 |
+
self.relu3 = nn.ReLU(inplace=False)
|
| 369 |
+
|
| 370 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 371 |
+
|
| 372 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 373 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 374 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 375 |
+
self.layer4 = self._make_layer(
|
| 376 |
+
block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 1, 1)
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.context_encoding = PSPModule(2048, 512)
|
| 380 |
+
|
| 381 |
+
self.edge = Edge_Module()
|
| 382 |
+
self.decoder = Decoder_Module(num_classes)
|
| 383 |
+
|
| 384 |
+
self.fushion = nn.Sequential(
|
| 385 |
+
nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
| 386 |
+
BatchNorm2d(256),
|
| 387 |
+
LeakyReLU(),
|
| 388 |
+
nn.Dropout2d(0.1),
|
| 389 |
+
nn.Conv2d(
|
| 390 |
+
256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True
|
| 391 |
+
),
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
|
| 395 |
+
downsample = None
|
| 396 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 397 |
+
downsample = nn.Sequential(
|
| 398 |
+
nn.Conv2d(
|
| 399 |
+
self.inplanes,
|
| 400 |
+
planes * block.expansion,
|
| 401 |
+
kernel_size=1,
|
| 402 |
+
stride=stride,
|
| 403 |
+
bias=False,
|
| 404 |
+
),
|
| 405 |
+
BatchNorm2d(planes * block.expansion, affine=affine_par),
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
layers = []
|
| 409 |
+
generate_multi_grid = lambda index, grids: (
|
| 410 |
+
grids[index % len(grids)] if isinstance(grids, tuple) else 1
|
| 411 |
+
)
|
| 412 |
+
layers.append(
|
| 413 |
+
block(
|
| 414 |
+
self.inplanes,
|
| 415 |
+
planes,
|
| 416 |
+
stride,
|
| 417 |
+
dilation=dilation,
|
| 418 |
+
downsample=downsample,
|
| 419 |
+
multi_grid=generate_multi_grid(0, multi_grid),
|
| 420 |
+
)
|
| 421 |
+
)
|
| 422 |
+
self.inplanes = planes * block.expansion
|
| 423 |
+
for i in range(1, blocks):
|
| 424 |
+
layers.append(
|
| 425 |
+
block(
|
| 426 |
+
self.inplanes,
|
| 427 |
+
planes,
|
| 428 |
+
dilation=dilation,
|
| 429 |
+
multi_grid=generate_multi_grid(i, multi_grid),
|
| 430 |
+
)
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return nn.Sequential(*layers)
|
| 434 |
+
|
| 435 |
+
def forward(self, x):
|
| 436 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 437 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 438 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 439 |
+
x = self.maxpool(x)
|
| 440 |
+
x2 = self.layer1(x)
|
| 441 |
+
x3 = self.layer2(x2)
|
| 442 |
+
x4 = self.layer3(x3)
|
| 443 |
+
x5 = self.layer4(x4)
|
| 444 |
+
x = self.context_encoding(x5)
|
| 445 |
+
# parsing_result, parsing_fea = self.decoder(x, x2)
|
| 446 |
+
parsing_fea = self.decoder(x, x2)
|
| 447 |
+
# Edge Branch
|
| 448 |
+
# edge_result, edge_fea = self.edge(x2, x3, x4)
|
| 449 |
+
edge_fea = self.edge(x2, x3, x4)
|
| 450 |
+
# Fusion Branch
|
| 451 |
+
x = torch.cat([parsing_fea, edge_fea], dim=1)
|
| 452 |
+
fusion_result = self.fushion(x)
|
| 453 |
+
# return [[parsing_result, fusion_result], [edge_result]]
|
| 454 |
+
return fusion_result
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def initialize_pretrained_model(
|
| 458 |
+
model, settings, pretrained="./models/resnet101-imagenet.pth"
|
| 459 |
+
):
|
| 460 |
+
model.input_space = settings["input_space"]
|
| 461 |
+
model.input_size = settings["input_size"]
|
| 462 |
+
model.input_range = settings["input_range"]
|
| 463 |
+
model.mean = settings["mean"]
|
| 464 |
+
model.std = settings["std"]
|
| 465 |
+
|
| 466 |
+
if pretrained is not None:
|
| 467 |
+
saved_state_dict = torch.load(pretrained)
|
| 468 |
+
new_params = model.state_dict().copy()
|
| 469 |
+
for i in saved_state_dict:
|
| 470 |
+
i_parts = i.split(".")
|
| 471 |
+
if not i_parts[0] == "fc":
|
| 472 |
+
new_params[".".join(i_parts[0:])] = saved_state_dict[i]
|
| 473 |
+
model.load_state_dict(new_params)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def resnet101(num_classes=20, pretrained="./models/resnet101-imagenet.pth"):
|
| 477 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
|
| 478 |
+
settings = pretrained_settings["resnet101"]["imagenet"]
|
| 479 |
+
initialize_pretrained_model(model, settings, pretrained)
|
| 480 |
+
return model
|
Leffa/SCHP/networks/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
from SCHP.networks.AugmentCE2P import resnet101
|
| 4 |
+
|
| 5 |
+
__factory = {
|
| 6 |
+
"resnet101": resnet101,
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def init_model(name, *args, **kwargs):
|
| 11 |
+
if name not in __factory.keys():
|
| 12 |
+
raise KeyError("Unknown model arch: {}".format(name))
|
| 13 |
+
return __factory[name](*args, **kwargs)
|
Leffa/SCHP/utils/transforms.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) Microsoft
|
| 3 |
+
# Licensed under the MIT License.
|
| 4 |
+
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
| 5 |
+
# ------------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from __future__ import absolute_import, division, print_function
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BRG2Tensor_transform(object):
|
| 16 |
+
def __call__(self, pic):
|
| 17 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
| 18 |
+
if isinstance(img, torch.ByteTensor):
|
| 19 |
+
return img.float()
|
| 20 |
+
else:
|
| 21 |
+
return img
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BGR2RGB_transform(object):
|
| 25 |
+
def __call__(self, tensor):
|
| 26 |
+
return tensor[[2, 1, 0], :, :]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def flip_back(output_flipped, matched_parts):
|
| 30 |
+
"""
|
| 31 |
+
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
|
| 32 |
+
"""
|
| 33 |
+
assert (
|
| 34 |
+
output_flipped.ndim == 4
|
| 35 |
+
), "output_flipped should be [batch_size, num_joints, height, width]"
|
| 36 |
+
|
| 37 |
+
output_flipped = output_flipped[:, :, :, ::-1]
|
| 38 |
+
|
| 39 |
+
for pair in matched_parts:
|
| 40 |
+
tmp = output_flipped[:, pair[0], :, :].copy()
|
| 41 |
+
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
|
| 42 |
+
output_flipped[:, pair[1], :, :] = tmp
|
| 43 |
+
|
| 44 |
+
return output_flipped
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def fliplr_joints(joints, joints_vis, width, matched_parts):
|
| 48 |
+
"""
|
| 49 |
+
flip coords
|
| 50 |
+
"""
|
| 51 |
+
# Flip horizontal
|
| 52 |
+
joints[:, 0] = width - joints[:, 0] - 1
|
| 53 |
+
|
| 54 |
+
# Change left-right parts
|
| 55 |
+
for pair in matched_parts:
|
| 56 |
+
joints[pair[0], :], joints[pair[1], :] = (
|
| 57 |
+
joints[pair[1], :],
|
| 58 |
+
joints[pair[0], :].copy(),
|
| 59 |
+
)
|
| 60 |
+
joints_vis[pair[0], :], joints_vis[pair[1], :] = (
|
| 61 |
+
joints_vis[pair[1], :],
|
| 62 |
+
joints_vis[pair[0], :].copy(),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return joints * joints_vis, joints_vis
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def transform_preds(coords, center, scale, input_size):
|
| 69 |
+
target_coords = np.zeros(coords.shape)
|
| 70 |
+
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
| 71 |
+
for p in range(coords.shape[0]):
|
| 72 |
+
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
|
| 73 |
+
return target_coords
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def transform_parsing(pred, center, scale, width, height, input_size):
|
| 77 |
+
|
| 78 |
+
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
| 79 |
+
target_pred = cv2.warpAffine(
|
| 80 |
+
pred,
|
| 81 |
+
trans,
|
| 82 |
+
(int(width), int(height)), # (int(width), int(height)),
|
| 83 |
+
flags=cv2.INTER_NEAREST,
|
| 84 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 85 |
+
borderValue=(0),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return target_pred
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def transform_logits(logits, center, scale, width, height, input_size):
|
| 92 |
+
|
| 93 |
+
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
| 94 |
+
channel = logits.shape[2]
|
| 95 |
+
target_logits = []
|
| 96 |
+
for i in range(channel):
|
| 97 |
+
target_logit = cv2.warpAffine(
|
| 98 |
+
logits[:, :, i],
|
| 99 |
+
trans,
|
| 100 |
+
(int(width), int(height)), # (int(width), int(height)),
|
| 101 |
+
flags=cv2.INTER_LINEAR,
|
| 102 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 103 |
+
borderValue=(0),
|
| 104 |
+
)
|
| 105 |
+
target_logits.append(target_logit)
|
| 106 |
+
target_logits = np.stack(target_logits, axis=2)
|
| 107 |
+
|
| 108 |
+
return target_logits
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_affine_transform(
|
| 112 |
+
center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0
|
| 113 |
+
):
|
| 114 |
+
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
| 115 |
+
print(scale)
|
| 116 |
+
scale = np.array([scale, scale])
|
| 117 |
+
|
| 118 |
+
scale_tmp = scale
|
| 119 |
+
|
| 120 |
+
src_w = scale_tmp[0]
|
| 121 |
+
dst_w = output_size[1]
|
| 122 |
+
dst_h = output_size[0]
|
| 123 |
+
|
| 124 |
+
rot_rad = np.pi * rot / 180
|
| 125 |
+
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
| 126 |
+
dst_dir = np.array([0, (dst_w - 1) * -0.5], np.float32)
|
| 127 |
+
|
| 128 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
| 129 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
| 130 |
+
src[0, :] = center + scale_tmp * shift
|
| 131 |
+
src[1, :] = center + src_dir + scale_tmp * shift
|
| 132 |
+
dst[0, :] = [(dst_w - 1) * 0.5, (dst_h - 1) * 0.5]
|
| 133 |
+
dst[1, :] = np.array([(dst_w - 1) * 0.5, (dst_h - 1) * 0.5]) + dst_dir
|
| 134 |
+
|
| 135 |
+
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
| 136 |
+
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
| 137 |
+
|
| 138 |
+
if inv:
|
| 139 |
+
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
| 140 |
+
else:
|
| 141 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
| 142 |
+
|
| 143 |
+
return trans
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def affine_transform(pt, t):
|
| 147 |
+
new_pt = np.array([pt[0], pt[1], 1.0]).T
|
| 148 |
+
new_pt = np.dot(t, new_pt)
|
| 149 |
+
return new_pt[:2]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_3rd_point(a, b):
|
| 153 |
+
direct = a - b
|
| 154 |
+
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_dir(src_point, rot_rad):
|
| 158 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 159 |
+
|
| 160 |
+
src_result = [0, 0]
|
| 161 |
+
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
| 162 |
+
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
| 163 |
+
|
| 164 |
+
return src_result
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def crop(img, center, scale, output_size, rot=0):
|
| 168 |
+
trans = get_affine_transform(center, scale, rot, output_size)
|
| 169 |
+
|
| 170 |
+
dst_img = cv2.warpAffine(
|
| 171 |
+
img, trans, (int(output_size[1]), int(output_size[0])), flags=cv2.INTER_LINEAR
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return dst_img
|