File size: 5,437 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.runner.checkpoint import (_load_checkpoint,
_load_checkpoint_with_prefix)
from mmpretrain.models import MobileOne
from mmaction.registry import MODELS
from .resnet_tsm import TemporalShift
@MODELS.register_module()
class MobileOneTSM(MobileOne):
"""MobileOne backbone for TSM.
Args:
arch (str | dict): MobileOne architecture. If use string, choose
from 's0', 's1', 's2', 's3' and 's4'. If use dict, it should
have below keys:
- num_blocks (Sequence[int]): Number of blocks in each stage.
- width_factor (Sequence[float]): Width factor in each stage.
- num_conv_branches (Sequence[int]): Number of conv branches
in each stage.
- num_se_blocks (Sequence[int]): Number of SE layers in each
stage, all the SE layers are placed in the subsequent order
in each stage.
Defaults to 's0'.
num_segments (int): Number of frame segments. Defaults to 8.
is_shift (bool): Whether to make temporal shift in reset layers.
Defaults to True.
shift_div (int): Number of div for shift. Defaults to 8.
pretraind2d (bool): Whether to load pretrained 2D model.
Defaults to True.
**kwargs (keyword arguments, optional): Arguments for MobileOne.
"""
def __init__(self,
arch: str,
num_segments: int = 8,
is_shift: bool = True,
shift_div: int = 8,
pretrained2d: bool = True,
**kwargs):
super().__init__(arch, **kwargs)
self.num_segments = num_segments
self.is_shift = is_shift
self.shift_div = shift_div
self.pretrained2d = pretrained2d
self.init_structure()
def make_temporal_shift(self):
"""Make temporal shift for some layers.
To make reparameterization work, we can only build the shift layer
before the 'block', instead of the 'blockres'
"""
def make_block_temporal(stage, num_segments):
"""Make temporal shift on some blocks.
Args:
stage (nn.Module): Model layers to be shifted.
num_segments (int): Number of frame segments.
Returns:
nn.Module: The shifted blocks.
"""
blocks = list(stage.children())
for i, b in enumerate(blocks):
blocks[i] = TemporalShift(
b, num_segments=num_segments, shift_div=self.shift_div)
return nn.Sequential(*blocks)
self.stage0 = make_block_temporal(
nn.Sequential(self.stage0), self.num_segments)[0]
for i in range(1, 5):
temporal_stage = make_block_temporal(
getattr(self, f'stage{i}'), self.num_segments)
setattr(self, f'stage{i}', temporal_stage)
def init_structure(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if self.is_shift:
self.make_temporal_shift()
def load_original_weights(self, logger):
assert self.init_cfg.get('type') == 'Pretrained', (
'Please specify '
'init_cfg to use pretrained 2d checkpoint')
self.pretrained = self.init_cfg.get('checkpoint')
prefix = self.init_cfg.get('prefix')
if prefix is not None:
original_state_dict = _load_checkpoint_with_prefix(
prefix, self.pretrained, map_location='cpu')
else:
original_state_dict = _load_checkpoint(
self.pretrained, map_location='cpu')
if 'state_dict' in original_state_dict:
original_state_dict = original_state_dict['state_dict']
wrapped_layers_map = dict()
for name, module in self.named_modules():
ori_name = name
for wrap_prefix in ['.net']:
if wrap_prefix in ori_name:
ori_name = ori_name.replace(wrap_prefix, '')
wrapped_layers_map[ori_name] = name
# convert wrapped keys
for param_name in list(original_state_dict.keys()):
layer_name = '.'.join(param_name.split('.')[:-1])
if layer_name in wrapped_layers_map:
wrapped_name = param_name.replace(
layer_name, wrapped_layers_map[layer_name])
original_state_dict[wrapped_name] = original_state_dict.pop(
param_name)
msg = self.load_state_dict(original_state_dict, strict=True)
logger.info(msg)
def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if self.pretrained2d:
logger = MMLogger.get_current_instance()
self.load_original_weights(logger)
else:
super().init_weights()
def forward(self, x):
"""unpack tuple result."""
x = super().forward(x)
if isinstance(x, tuple):
assert len(x) == 1
x = x[0]
return x
|