File size: 7,496 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmaction.registry import MODELS
from mmaction.utils import SampleList
from .base import BaseRecognizer
@MODELS.register_module()
class Recognizer2D(BaseRecognizer):
"""2D recognizer model framework."""
def extract_feat(self,
inputs: torch.Tensor,
stage: str = 'neck',
data_samples: SampleList = None,
test_mode: bool = False) -> tuple:
"""Extract features of different stages.
Args:
inputs (Tensor): The input data.
stage (str): Which stage to output the feature.
Defaults to ``neck``.
data_samples (List[:obj:`ActionDataSample`]): Action data
samples, which are only needed in training. Defaults to None.
test_mode: (bool): Whether in test mode. Defaults to False.
Returns:
Tensor: The extracted features.
dict: A dict recording the kwargs for downstream
pipeline. These keys are usually included:
``num_segs``, ``fcn_test``, ``loss_aux``.
"""
# Record the kwargs required by `loss` and `predict`.
loss_predict_kwargs = dict()
num_segs = inputs.shape[1]
loss_predict_kwargs['num_segs'] = num_segs
# [N, num_crops * num_segs, C, H, W] ->
# [N * num_crops * num_segs, C, H, W]
# `num_crops` is calculated by:
# 1) `twice_sample` in `SampleFrames`
# 2) `num_sample_positions` in `DenseSampleFrames`
# 3) `ThreeCrop/TenCrop` in `test_pipeline`
# 4) `num_clips` in `SampleFrames` or its subclass if `clip_len != 1`
inputs = inputs.view((-1, ) + inputs.shape[2:])
def forward_once(batch_imgs):
# Extract features through backbone.
if (hasattr(self.backbone, 'features')
and self.backbone_from == 'torchvision'):
x = self.backbone.features(batch_imgs)
elif self.backbone_from == 'timm':
x = self.backbone.forward_features(batch_imgs)
elif self.backbone_from in ['mmcls', 'mmpretrain']:
x = self.backbone(batch_imgs)
if isinstance(x, tuple):
assert len(x) == 1
x = x[0]
else:
x = self.backbone(batch_imgs)
if self.backbone_from in ['torchvision', 'timm']:
if not self.feature_shape:
# Transformer-based feature shape: B x L x C.
if len(x.shape) == 3:
self.feature_shape = 'NLC'
# Resnet-based feature shape: B x C x Hs x Ws.
elif len(x.shape) == 4:
self.feature_shape = 'NCHW'
if self.feature_shape == 'NHWC':
x = nn.AdaptiveAvgPool2d(1)(x.permute(0, 3, 1,
2)) # B x C x 1 x 1
elif self.feature_shape == 'NCHW':
x = nn.AdaptiveAvgPool2d(1)(x) # B x C x 1 x 1
elif self.feature_shape == 'NLC':
x = nn.AdaptiveAvgPool1d(1)(x.transpose(1, 2)) # B x C x 1
x = x.reshape((x.shape[0], -1)) # B x C
x = x.reshape(x.shape + (1, 1)) # B x C x 1 x 1
return x
# Check settings of `fcn_test`.
fcn_test = False
if test_mode:
if self.test_cfg is not None and self.test_cfg.get(
'fcn_test', False):
fcn_test = True
num_segs = self.test_cfg.get('num_segs',
self.backbone.num_segments)
loss_predict_kwargs['fcn_test'] = fcn_test
# inference with batch size of `max_testing_views` if set
if self.test_cfg is not None and self.test_cfg.get(
'max_testing_views', False):
max_testing_views = self.test_cfg.get('max_testing_views')
assert isinstance(max_testing_views, int)
# backbone specify num_segments
num_segments = self.backbone.get('num_segments')
if num_segments is not None:
assert max_testing_views % num_segments == 0, \
'make sure that max_testing_views is a multiple of ' \
'num_segments, but got {max_testing_views} and '\
'{num_segments}'
total_views = inputs.shape[0]
view_ptr = 0
feats = []
while view_ptr < total_views:
batch_imgs = inputs[view_ptr:view_ptr + max_testing_views]
feat = forward_once(batch_imgs)
if self.with_neck:
feat, _ = self.neck(feat)
feats.append(feat)
view_ptr += max_testing_views
def recursively_cat(feats):
# recursively traverse feats until it's a tensor,
# then concat
out_feats = []
for e_idx, elem in enumerate(feats[0]):
batch_elem = [feat[e_idx] for feat in feats]
if not isinstance(elem, torch.Tensor):
batch_elem = recursively_cat(batch_elem)
else:
batch_elem = torch.cat(batch_elem)
out_feats.append(batch_elem)
return tuple(out_feats)
if isinstance(feats[0], tuple):
x = recursively_cat(feats)
else:
x = torch.cat(feats)
else:
x = forward_once(inputs)
else:
x = forward_once(inputs)
# Return features extracted through backbone.
if stage == 'backbone':
return x, loss_predict_kwargs
loss_aux = dict()
if self.with_neck:
# x is a tuple with multiple feature maps.
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, loss_aux = self.neck(x, data_samples=data_samples)
if not fcn_test:
x = x.squeeze(2)
loss_predict_kwargs['num_segs'] = 1
elif fcn_test:
# full convolution (fcn) testing when no neck
# [N * num_crops * num_segs, C', H', W'] ->
# [N * num_crops, C', num_segs, H', W']
x = x.reshape((-1, num_segs) +
x.shape[1:]).transpose(1, 2).contiguous()
loss_predict_kwargs['loss_aux'] = loss_aux
# Return features extracted through neck.
if stage == 'neck':
return x, loss_predict_kwargs
# Return raw logits through head.
if self.with_cls_head and stage == 'head':
# [N * num_crops, num_classes]
x = self.cls_head(x, **loss_predict_kwargs)
return x, loss_predict_kwargs
|