File size: 14,738 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, NonLocal3d
from mmengine.logging import MMLogger
from mmengine.runner.checkpoint import _load_checkpoint
from torch.nn.modules.utils import _ntuple
from mmaction.registry import MODELS
from .resnet import ResNet
class NL3DWrapper(nn.Module):
"""3D Non-local wrapper for ResNet50.
Wrap ResNet layers with 3D NonLocal modules.
Args:
block (nn.Module): Residual blocks to be built.
num_segments (int): Number of frame segments.
non_local_cfg (dict): Config for non-local layers. Default: ``dict()``.
"""
def __init__(self, block, num_segments, non_local_cfg=dict()):
super(NL3DWrapper, self).__init__()
self.block = block
self.non_local_cfg = non_local_cfg
self.non_local_block = NonLocal3d(self.block.conv3.norm.num_features,
**self.non_local_cfg)
self.num_segments = num_segments
def forward(self, x):
"""Defines the computation performed at every call."""
x = self.block(x)
n, c, h, w = x.size()
x = x.view(n // self.num_segments, self.num_segments, c, h,
w).transpose(1, 2).contiguous()
x = self.non_local_block(x)
x = x.transpose(1, 2).contiguous().view(n, c, h, w)
return x
class TemporalShift(nn.Module):
"""Temporal shift module.
This module is proposed in
`TSM: Temporal Shift Module for Efficient Video Understanding
<https://arxiv.org/abs/1811.08383>`_
Args:
net (nn.module): Module to make temporal shift.
num_segments (int): Number of frame segments. Default: 3.
shift_div (int): Number of divisions for shift. Default: 8.
"""
def __init__(self, net, num_segments=3, shift_div=8):
super().__init__()
self.net = net
self.num_segments = num_segments
self.shift_div = shift_div
def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
x = self.shift(x, self.num_segments, shift_div=self.shift_div)
return self.net(x)
@staticmethod
def shift(x, num_segments, shift_div=3):
"""Perform temporal shift operation on the feature.
Args:
x (torch.Tensor): The input feature to be shifted.
num_segments (int): Number of frame segments.
shift_div (int): Number of divisions for shift. Default: 3.
Returns:
torch.Tensor: The shifted feature.
"""
# [N, C, H, W]
n, c, h, w = x.size()
# [N // num_segments, num_segments, C, H*W]
# can't use 5 dimensional array on PPL2D backend for caffe
x = x.view(-1, num_segments, c, h * w)
# get shift fold
fold = c // shift_div
# split c channel into three parts:
# left_split, mid_split, right_split
left_split = x[:, :, :fold, :]
mid_split = x[:, :, fold:2 * fold, :]
right_split = x[:, :, 2 * fold:, :]
# can't use torch.zeros(*A.shape) or torch.zeros_like(A)
# because array on caffe inference must be got by computing
# shift left on num_segments channel in `left_split`
zeros = left_split - left_split
blank = zeros[:, :1, :, :]
left_split = left_split[:, 1:, :, :]
left_split = torch.cat((left_split, blank), 1)
# shift right on num_segments channel in `mid_split`
zeros = mid_split - mid_split
blank = zeros[:, :1, :, :]
mid_split = mid_split[:, :-1, :, :]
mid_split = torch.cat((blank, mid_split), 1)
# right_split: no shift
# concatenate
out = torch.cat((left_split, mid_split, right_split), 2)
# [N, C, H, W]
# restore the original dimension
return out.view(n, c, h, w)
@MODELS.register_module()
class ResNetTSM(ResNet):
"""ResNet backbone for TSM.
Args:
num_segments (int): Number of frame segments. Defaults to 8.
is_shift (bool): Whether to make temporal shift in reset layers.
Defaults to True.
non_local (Sequence[int]): Determine whether to apply non-local module
in the corresponding block of each stages.
Defaults to (0, 0, 0, 0).
non_local_cfg (dict): Config for non-local module.
Defaults to ``dict()``.
shift_div (int): Number of div for shift. Defaults to 8.
shift_place (str): Places in resnet layers for shift, which is chosen
from ['block', 'blockres'].
If set to 'block', it will apply temporal shift to all child blocks
in each resnet layer.
If set to 'blockres', it will apply temporal shift to each `conv1`
layer of all child blocks in each resnet layer.
Defaults to 'blockres'.
temporal_pool (bool): Whether to add temporal pooling.
Defaults to False.
pretrained2d (bool): Whether to load pretrained 2D model.
Defaults to True.
**kwargs (keyword arguments, optional): Arguments for ResNet.
"""
def __init__(self,
depth,
num_segments=8,
is_shift=True,
non_local=(0, 0, 0, 0),
non_local_cfg=dict(),
shift_div=8,
shift_place='blockres',
temporal_pool=False,
pretrained2d=True,
**kwargs):
super().__init__(depth, **kwargs)
self.num_segments = num_segments
self.is_shift = is_shift
self.shift_div = shift_div
self.shift_place = shift_place
self.temporal_pool = temporal_pool
self.non_local = non_local
self.non_local_stages = _ntuple(self.num_stages)(non_local)
self.non_local_cfg = non_local_cfg
self.pretrained2d = pretrained2d
self.init_structure()
def init_structure(self):
"""Initialize structure for tsm."""
if self.is_shift:
self.make_temporal_shift()
if len(self.non_local_cfg) != 0:
self.make_non_local()
if self.temporal_pool:
self.make_temporal_pool()
def make_temporal_shift(self):
"""Make temporal shift for some layers."""
if self.temporal_pool:
num_segment_list = [
self.num_segments, self.num_segments // 2,
self.num_segments // 2, self.num_segments // 2
]
else:
num_segment_list = [self.num_segments] * 4
if num_segment_list[-1] <= 0:
raise ValueError('num_segment_list[-1] must be positive')
if self.shift_place == 'block':
def make_block_temporal(stage, num_segments):
"""Make temporal shift on some blocks.
Args:
stage (nn.Module): Model layers to be shifted.
num_segments (int): Number of frame segments.
Returns:
nn.Module: The shifted blocks.
"""
blocks = list(stage.children())
for i, b in enumerate(blocks):
blocks[i] = TemporalShift(
b, num_segments=num_segments, shift_div=self.shift_div)
return nn.Sequential(*blocks)
self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
elif 'blockres' in self.shift_place:
n_round = 1
if len(list(self.layer3.children())) >= 23:
n_round = 2
def make_block_temporal(stage, num_segments):
"""Make temporal shift on some blocks.
Args:
stage (nn.Module): Model layers to be shifted.
num_segments (int): Number of frame segments.
Returns:
nn.Module: The shifted blocks.
"""
blocks = list(stage.children())
for i, b in enumerate(blocks):
if i % n_round == 0:
blocks[i].conv1.conv = TemporalShift(
b.conv1.conv,
num_segments=num_segments,
shift_div=self.shift_div)
return nn.Sequential(*blocks)
self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
else:
raise NotImplementedError
def make_temporal_pool(self):
"""Make temporal pooling between layer1 and layer2, using a 3D max
pooling layer."""
class TemporalPool(nn.Module):
"""Temporal pool module.
Wrap layer2 in ResNet50 with a 3D max pooling layer.
Args:
net (nn.Module): Module to make temporal pool.
num_segments (int): Number of frame segments.
"""
def __init__(self, net, num_segments):
super().__init__()
self.net = net
self.num_segments = num_segments
self.max_pool3d = nn.MaxPool3d(
kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
def forward(self, x):
"""Defines the computation performed at every call."""
# [N, C, H, W]
n, c, h, w = x.size()
# [N // num_segments, C, num_segments, H, W]
x = x.view(n // self.num_segments, self.num_segments, c, h,
w).transpose(1, 2)
# [N // num_segmnets, C, num_segments // 2, H, W]
x = self.max_pool3d(x)
# [N // 2, C, H, W]
x = x.transpose(1, 2).contiguous().view(n // 2, c, h, w)
return self.net(x)
self.layer2 = TemporalPool(self.layer2, self.num_segments)
def make_non_local(self):
"""Wrap resnet layer into non local wrapper."""
# This part is for ResNet50
for i in range(self.num_stages):
non_local_stage = self.non_local_stages[i]
if sum(non_local_stage) == 0:
continue
layer_name = f'layer{i + 1}'
res_layer = getattr(self, layer_name)
for idx, non_local in enumerate(non_local_stage):
if non_local:
res_layer[idx] = NL3DWrapper(res_layer[idx],
self.num_segments,
self.non_local_cfg)
def _get_wrap_prefix(self):
return ['.net', '.block']
def load_original_weights(self, logger):
"""Load weights from original checkpoint, which required converting
keys."""
state_dict_torchvision = _load_checkpoint(
self.pretrained, map_location='cpu')
if 'state_dict' in state_dict_torchvision:
state_dict_torchvision = state_dict_torchvision['state_dict']
wrapped_layers_map = dict()
for name, module in self.named_modules():
# convert torchvision keys
ori_name = name
for wrap_prefix in self._get_wrap_prefix():
if wrap_prefix in ori_name:
ori_name = ori_name.replace(wrap_prefix, '')
wrapped_layers_map[ori_name] = name
if isinstance(module, ConvModule):
if 'downsample' in ori_name:
# layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0
tv_conv_name = ori_name + '.0'
# layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1
tv_bn_name = ori_name + '.1'
else:
# layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n}
tv_conv_name = ori_name
# layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n}
tv_bn_name = ori_name.replace('conv', 'bn')
for conv_param in ['.weight', '.bias']:
if tv_conv_name + conv_param in state_dict_torchvision:
state_dict_torchvision[ori_name+'.conv'+conv_param] = \
state_dict_torchvision.pop(tv_conv_name+conv_param)
for bn_param in [
'.weight', '.bias', '.running_mean', '.running_var'
]:
if tv_bn_name + bn_param in state_dict_torchvision:
state_dict_torchvision[ori_name+'.bn'+bn_param] = \
state_dict_torchvision.pop(tv_bn_name+bn_param)
# convert wrapped keys
for param_name in list(state_dict_torchvision.keys()):
layer_name = '.'.join(param_name.split('.')[:-1])
if layer_name in wrapped_layers_map:
wrapped_name = param_name.replace(
layer_name, wrapped_layers_map[layer_name])
print(f'wrapped_name {wrapped_name}')
state_dict_torchvision[
wrapped_name] = state_dict_torchvision.pop(param_name)
msg = self.load_state_dict(state_dict_torchvision, strict=False)
logger.info(msg)
def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if self.pretrained2d:
logger = MMLogger.get_current_instance()
self.load_original_weights(logger)
else:
if self.pretrained:
self.init_cfg = dict(
type='Pretrained', checkpoint=self.pretrained)
super().init_weights()
|