File size: 22,104 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.logging import MMLogger, print_log
from mmengine.model import BaseModule
from mmengine.model.weight_init import kaiming_init
from mmengine.runner.checkpoint import _load_checkpoint, load_checkpoint
from mmaction.registry import MODELS
from .resnet3d import ResNet3d
class DeConvModule(BaseModule):
"""A deconv module that bundles deconv/norm/activation layers.
Args:
in_channels (int): Number of channels in the input feature map.
out_channels (int): Number of channels produced by the convolution.
kernel_size (int | tuple[int]): Size of the convolving kernel.
stride (int | tuple[int]): Stride of the convolution.
padding (int | tuple[int]): Zero-padding added to both sides of
the input.
bias (bool): Whether to add a learnable bias to the output.
Defaults to False.
with_bn (bool): Whether to add a BN layer. Defaults to True.
with_relu (bool): Whether to add a ReLU layer. Defaults to True.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: Union[int, Tuple[int]] = (1, 1, 1),
padding: Union[int, Tuple[int]] = 0,
bias: bool = False,
with_bn: bool = True,
with_relu: bool = True) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.bias = bias
self.with_bn = with_bn
self.with_relu = with_relu
self.conv = nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=bias)
self.bn = nn.BatchNorm3d(out_channels)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call."""
# x should be a 5-d tensor
assert len(x.shape) == 5
N, C, T, H, W = x.shape
out_shape = (N, self.out_channels, self.stride[0] * T,
self.stride[1] * H, self.stride[2] * W)
x = self.conv(x, output_size=out_shape)
if self.with_bn:
x = self.bn(x)
if self.with_relu:
x = self.relu(x)
return x
class ResNet3dPathway(ResNet3d):
"""A pathway of Slowfast based on ResNet3d.
Args:
lateral (bool): Determines whether to enable the lateral connection
from another pathway. Defaults to False.
lateral_inv (bool): Whether to use deconv to upscale the time
dimension of features from another pathway. Defaults to False.
lateral_norm (bool): Determines whether to enable the lateral norm
in lateral layers. Defaults to False.
speed_ratio (int): Speed ratio indicating the ratio between time
dimension of the fast and slow pathway, corresponding to the
``alpha`` in the paper. Defaults to 8.
channel_ratio (int): Reduce the channel number of fast pathway
by ``channel_ratio``, corresponding to ``beta`` in the paper.
Defaults to 8.
fusion_kernel (int): The kernel size of lateral fusion.
Defaults to 5.
lateral_infl (int): The ratio of the inflated channels.
Defaults to 2.
lateral_activate (list[int]): Flags for activating the lateral
connection. Defaults to ``[1, 1, 1, 1]``.
"""
def __init__(self,
lateral: bool = False,
lateral_inv: bool = False,
lateral_norm: bool = False,
speed_ratio: int = 8,
channel_ratio: int = 8,
fusion_kernel: int = 5,
lateral_infl: int = 2,
lateral_activate: List[int] = [1, 1, 1, 1],
**kwargs) -> None:
self.lateral = lateral
self.lateral_inv = lateral_inv
self.lateral_norm = lateral_norm
self.speed_ratio = speed_ratio
self.channel_ratio = channel_ratio
self.fusion_kernel = fusion_kernel
self.lateral_infl = lateral_infl
self.lateral_activate = lateral_activate
self._calculate_lateral_inplanes(kwargs)
super().__init__(**kwargs)
self.inplanes = self.base_channels
if self.lateral and self.lateral_activate[0] == 1:
if self.lateral_inv:
self.conv1_lateral = DeConvModule(
self.inplanes * self.channel_ratio,
self.inplanes * self.channel_ratio // lateral_infl,
kernel_size=(fusion_kernel, 1, 1),
stride=(self.speed_ratio, 1, 1),
padding=((fusion_kernel - 1) // 2, 0, 0),
with_bn=True,
with_relu=True)
else:
self.conv1_lateral = ConvModule(
self.inplanes // self.channel_ratio,
self.inplanes * lateral_infl // self.channel_ratio,
kernel_size=(fusion_kernel, 1, 1),
stride=(self.speed_ratio, 1, 1),
padding=((fusion_kernel - 1) // 2, 0, 0),
bias=False,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg if self.lateral_norm else None,
act_cfg=self.act_cfg if self.lateral_norm else None)
self.lateral_connections = []
for i in range(len(self.stage_blocks)):
planes = self.base_channels * 2**i
self.inplanes = planes * self.block.expansion
if lateral and i != self.num_stages - 1 \
and self.lateral_activate[i + 1]:
# no lateral connection needed in final stage
lateral_name = f'layer{(i + 1)}_lateral'
if self.lateral_inv:
conv_module = DeConvModule(
self.inplanes * self.channel_ratio,
self.inplanes * self.channel_ratio // lateral_infl,
kernel_size=(fusion_kernel, 1, 1),
stride=(self.speed_ratio, 1, 1),
padding=((fusion_kernel - 1) // 2, 0, 0),
bias=False,
with_bn=True,
with_relu=True)
else:
conv_module = ConvModule(
self.inplanes // self.channel_ratio,
self.inplanes * lateral_infl // self.channel_ratio,
kernel_size=(fusion_kernel, 1, 1),
stride=(self.speed_ratio, 1, 1),
padding=((fusion_kernel - 1) // 2, 0, 0),
bias=False,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg if self.lateral_norm else None,
act_cfg=self.act_cfg if self.lateral_norm else None)
setattr(self, lateral_name, conv_module)
self.lateral_connections.append(lateral_name)
def _calculate_lateral_inplanes(self, kwargs):
"""Calculate inplanes for lateral connection."""
depth = kwargs.get('depth', 50)
expansion = 1 if depth < 50 else 4
base_channels = kwargs.get('base_channels', 64)
lateral_inplanes = []
for i in range(kwargs.get('num_stages', 4)):
if expansion % 2 == 0:
planes = base_channels * (2 ** i) * \
((expansion // 2) ** (i > 0))
else:
planes = base_channels * (2**i) // (2**(i > 0))
if self.lateral and self.lateral_activate[i]:
if self.lateral_inv:
lateral_inplane = planes * \
self.channel_ratio // self.lateral_infl
else:
lateral_inplane = planes * \
self.lateral_infl // self.channel_ratio
else:
lateral_inplane = 0
lateral_inplanes.append(lateral_inplane)
self.lateral_inplanes = lateral_inplanes
def inflate_weights(self, logger: MMLogger) -> None:
"""Inflate the resnet2d parameters to resnet3d pathway.
The differences between resnet3d and resnet2d mainly lie in an extra
axis of conv kernel. To utilize the pretrained parameters in 2d model,
the weight of conv2d models should be inflated to fit in the shapes of
the 3d counterpart. For pathway the ``lateral_connection`` part should
not be inflated from 2d weights.
Args:
logger (MMLogger): The logger used to print
debugging information.
"""
state_dict_r2d = _load_checkpoint(self.pretrained, map_location='cpu')
if 'state_dict' in state_dict_r2d:
state_dict_r2d = state_dict_r2d['state_dict']
inflated_param_names = []
for name, module in self.named_modules():
if 'lateral' in name:
continue
if isinstance(module, ConvModule):
# we use a ConvModule to wrap conv+bn+relu layers, thus the
# name mapping is needed
if 'downsample' in name:
# layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0
original_conv_name = name + '.0'
# layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1
original_bn_name = name + '.1'
else:
# layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n}
original_conv_name = name
# layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n}
original_bn_name = name.replace('conv', 'bn')
if original_conv_name + '.weight' not in state_dict_r2d:
logger.warning(f'Module not exist in the state_dict_r2d'
f': {original_conv_name}')
else:
self._inflate_conv_params(module.conv, state_dict_r2d,
original_conv_name,
inflated_param_names)
if original_bn_name + '.weight' not in state_dict_r2d:
logger.warning(f'Module not exist in the state_dict_r2d'
f': {original_bn_name}')
else:
self._inflate_bn_params(module.bn, state_dict_r2d,
original_bn_name,
inflated_param_names)
# check if any parameters in the 2d checkpoint are not loaded
remaining_names = set(
state_dict_r2d.keys()) - set(inflated_param_names)
if remaining_names:
logger.info(f'These parameters in the 2d checkpoint are not loaded'
f': {remaining_names}')
def _inflate_conv_params(self, conv3d: nn.Module,
state_dict_2d: OrderedDict, module_name_2d: str,
inflated_param_names: List[str]) -> None:
"""Inflate a conv module from 2d to 3d.
The differences of conv modules betweene 2d and 3d in Pathway
mainly lie in the inplanes due to lateral connections. To fit the
shapes of the lateral connection counterpart, it will expand
parameters by concatting conv2d parameters and extra zero paddings.
Args:
conv3d (nn.Module): The destination conv3d module.
state_dict_2d (OrderedDict): The state dict of pretrained 2d model.
module_name_2d (str): The name of corresponding conv module in the
2d model.
inflated_param_names (list[str]): List of parameters that have been
inflated.
"""
weight_2d_name = module_name_2d + '.weight'
conv2d_weight = state_dict_2d[weight_2d_name]
old_shape = conv2d_weight.shape
new_shape = conv3d.weight.data.shape
kernel_t = new_shape[2]
if new_shape[1] != old_shape[1]:
if new_shape[1] < old_shape[1]:
warnings.warn(f'The parameter of {module_name_2d} is not'
'loaded due to incompatible shapes. ')
return
# Inplanes may be different due to lateral connections
new_channels = new_shape[1] - old_shape[1]
pad_shape = old_shape
pad_shape = pad_shape[:1] + (new_channels, ) + pad_shape[2:]
# Expand parameters by concat extra channels
conv2d_weight = torch.cat(
(conv2d_weight,
torch.zeros(pad_shape).type_as(conv2d_weight).to(
conv2d_weight.device)),
dim=1)
new_weight = conv2d_weight.data.unsqueeze(2).expand_as(
conv3d.weight) / kernel_t
conv3d.weight.data.copy_(new_weight)
inflated_param_names.append(weight_2d_name)
if getattr(conv3d, 'bias') is not None:
bias_2d_name = module_name_2d + '.bias'
conv3d.bias.data.copy_(state_dict_2d[bias_2d_name])
inflated_param_names.append(bias_2d_name)
def _freeze_stages(self) -> None:
"""Prevent all the parameters from being optimized before
`self.frozen_stages`."""
if self.frozen_stages >= 0:
self.conv1.eval()
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
if i != len(self.res_layers) and self.lateral:
# No fusion needed in the final stage
lateral_name = self.lateral_connections[i - 1]
conv_lateral = getattr(self, lateral_name)
conv_lateral.eval()
for param in conv_lateral.parameters():
param.requires_grad = False
def init_weights(self, pretrained: Optional[str] = None) -> None:
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if pretrained:
self.pretrained = pretrained
# Override the init_weights of i3d
super().init_weights()
for module_name in self.lateral_connections:
layer = getattr(self, module_name)
for m in layer.modules():
if isinstance(m, (nn.Conv3d, nn.Conv2d)):
kaiming_init(m)
pathway_cfg = {
'resnet3d': ResNet3dPathway,
# TODO: BNInceptionPathway
}
def build_pathway(cfg: Dict, *args, **kwargs) -> nn.Module:
"""Build pathway.
Args:
cfg (dict): cfg should contain:
- type (str): identify backbone type.
Returns:
nn.Module: Created pathway.
"""
if not (isinstance(cfg, dict) and 'type' in cfg):
raise TypeError('cfg must be a dict containing the key "type"')
cfg_ = cfg.copy()
pathway_type = cfg_.pop('type')
if pathway_type not in pathway_cfg:
raise KeyError(f'Unrecognized pathway type {pathway_type}')
pathway_cls = pathway_cfg[pathway_type]
pathway = pathway_cls(*args, **kwargs, **cfg_)
return pathway
@MODELS.register_module()
class ResNet3dSlowFast(BaseModule):
"""Slowfast backbone.
This module is proposed in `SlowFast Networks for Video Recognition
<https://arxiv.org/abs/1812.03982>`_
Args:
pretrained (str): The file path to a pretrained model.
resample_rate (int): A large temporal stride ``resample_rate``
on input frames. The actual resample rate is calculated by
multipling the ``interval`` in ``SampleFrames`` in the
pipeline with ``resample_rate``, equivalent to the :math:`\\tau`
in the paper, i.e. it processes only one out of
``resample_rate * interval`` frames. Defaults to 8.
speed_ratio (int): Speed ratio indicating the ratio between time
dimension of the fast and slow pathway, corresponding to the
:math:`\\alpha` in the paper. Defaults to 8.
channel_ratio (int): Reduce the channel number of fast pathway
by ``channel_ratio``, corresponding to :math:`\\beta` in the paper.
Defaults to 8.
slow_pathway (dict): Configuration of slow branch. Defaults to
``dict(type='resnet3d', depth=50, pretrained=None, lateral=True,
conv1_kernel=(1, 7, 7), conv1_stride_t=1, pool1_stride_t=1,
inflate=(0, 0, 1, 1))``.
fast_pathway (dict): Configuration of fast branch. Defaults to
``dict(type='resnet3d', depth=50, pretrained=None, lateral=False,
base_channels=8, conv1_kernel=(5, 7, 7), conv1_stride_t=1,
pool1_stride_t=1)``.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
pretrained: Optional[str] = None,
resample_rate: int = 8,
speed_ratio: int = 8,
channel_ratio: int = 8,
slow_pathway: Dict = dict(
type='resnet3d',
depth=50,
pretrained=None,
lateral=True,
conv1_kernel=(1, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(0, 0, 1, 1)),
fast_pathway: Dict = dict(
type='resnet3d',
depth=50,
pretrained=None,
lateral=False,
base_channels=8,
conv1_kernel=(5, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1),
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.pretrained = pretrained
self.resample_rate = resample_rate
self.speed_ratio = speed_ratio
self.channel_ratio = channel_ratio
if slow_pathway['lateral']:
slow_pathway['speed_ratio'] = speed_ratio
slow_pathway['channel_ratio'] = channel_ratio
self.slow_path = build_pathway(slow_pathway)
self.fast_path = build_pathway(fast_pathway)
def init_weights(self, pretrained: Optional[str] = None) -> None:
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if pretrained:
self.pretrained = pretrained
if isinstance(self.pretrained, str):
logger = MMLogger.get_current_instance()
msg = f'load model from: {self.pretrained}'
print_log(msg, logger=logger)
# Directly load 3D model.
load_checkpoint(self, self.pretrained, strict=True, logger=logger)
elif self.pretrained is None:
# Init two branch separately.
self.fast_path.init_weights()
self.slow_path.init_weights()
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x: torch.Tensor) -> tuple:
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
tuple[torch.Tensor]: The feature of the input samples
extracted by the backbone.
"""
x_slow = nn.functional.interpolate(
x,
mode='nearest',
scale_factor=(1.0 / self.resample_rate, 1.0, 1.0))
x_slow = self.slow_path.conv1(x_slow)
x_slow = self.slow_path.maxpool(x_slow)
x_fast = nn.functional.interpolate(
x,
mode='nearest',
scale_factor=(1.0 / (self.resample_rate // self.speed_ratio), 1.0,
1.0))
x_fast = self.fast_path.conv1(x_fast)
x_fast = self.fast_path.maxpool(x_fast)
if self.slow_path.lateral:
x_fast_lateral = self.slow_path.conv1_lateral(x_fast)
x_slow = torch.cat((x_slow, x_fast_lateral), dim=1)
for i, layer_name in enumerate(self.slow_path.res_layers):
res_layer = getattr(self.slow_path, layer_name)
x_slow = res_layer(x_slow)
res_layer_fast = getattr(self.fast_path, layer_name)
x_fast = res_layer_fast(x_fast)
if (i != len(self.slow_path.res_layers) - 1
and self.slow_path.lateral):
# No fusion needed in the final stage
lateral_name = self.slow_path.lateral_connections[i]
conv_lateral = getattr(self.slow_path, lateral_name)
x_fast_lateral = conv_lateral(x_fast)
x_slow = torch.cat((x_slow, x_fast_lateral), dim=1)
out = (x_slow, x_fast)
return out
|