File size: 6,133 Bytes
d3dbf03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Tuple, Union
import torch
from mmengine.model import BaseDataPreprocessor, stack_batch
from mmaction.registry import MODELS
from mmaction.utils import SampleList
@MODELS.register_module()
class ActionDataPreprocessor(BaseDataPreprocessor):
"""Data pre-processor for action recognition tasks.
Args:
mean (Sequence[float or int], optional): The pixel mean of channels
of images or stacked optical flow. Defaults to None.
std (Sequence[float or int], optional): The pixel standard deviation
of channels of images or stacked optical flow. Defaults to None.
to_rgb (bool): Whether to convert image from BGR to RGB.
Defaults to False.
to_float32 (bool): Whether to convert data to float32.
Defaults to True.
blending (dict, optional): Config for batch blending.
Defaults to None.
format_shape (str): Format shape of input data.
Defaults to ``'NCHW'``.
"""
def __init__(self,
mean: Optional[Sequence[Union[float, int]]] = None,
std: Optional[Sequence[Union[float, int]]] = None,
to_rgb: bool = False,
to_float32: bool = True,
blending: Optional[dict] = None,
format_shape: str = 'NCHW') -> None:
super().__init__()
self.to_rgb = to_rgb
self.to_float32 = to_float32
self.format_shape = format_shape
if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both ' \
'`mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
if self.format_shape == 'NCHW':
normalizer_shape = (-1, 1, 1)
elif self.format_shape in ['NCTHW', 'MIX2d3d']:
normalizer_shape = (-1, 1, 1, 1)
else:
raise ValueError(f'Invalid format shape: {format_shape}')
self.register_buffer(
'mean',
torch.tensor(mean, dtype=torch.float32).view(normalizer_shape),
False)
self.register_buffer(
'std',
torch.tensor(std, dtype=torch.float32).view(normalizer_shape),
False)
else:
self._enable_normalize = False
if blending is not None:
self.blending = MODELS.build(blending)
else:
self.blending = None
def forward(self,
data: Union[dict, Tuple[dict]],
training: bool = False) -> Union[dict, Tuple[dict]]:
"""Perform normalization, padding, bgr2rgb conversion and batch
augmentation based on ``BaseDataPreprocessor``.
Args:
data (dict or Tuple[dict]): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
dict or Tuple[dict]: Data in the same format as the model input.
"""
data = self.cast_data(data)
if isinstance(data, dict):
return self.forward_onesample(data, training=training)
elif isinstance(data, (tuple, list)):
outputs = []
for data_sample in data:
output = self.forward_onesample(data_sample, training=training)
outputs.append(output)
return tuple(outputs)
else:
raise TypeError(f'Unsupported data type: {type(data)}!')
def forward_onesample(self, data, training: bool = False) -> dict:
"""Perform normalization, padding, bgr2rgb conversion and batch
augmentation on one data sample.
Args:
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
dict: Data in the same format as the model input.
"""
inputs, data_samples = data['inputs'], data['data_samples']
inputs, data_samples = self.preprocess(inputs, data_samples, training)
data['inputs'] = inputs
data['data_samples'] = data_samples
return data
def preprocess(self,
inputs: List[torch.Tensor],
data_samples: SampleList,
training: bool = False) -> Tuple:
# --- Pad and stack --
batch_inputs = stack_batch(inputs)
if self.format_shape == 'MIX2d3d':
if batch_inputs.ndim == 4:
format_shape, view_shape = 'NCHW', (-1, 1, 1)
else:
format_shape, view_shape = 'NCTHW', None
else:
format_shape, view_shape = self.format_shape, None
# ------ To RGB ------
if self.to_rgb:
if format_shape == 'NCHW':
batch_inputs = batch_inputs[..., [2, 1, 0], :, :]
elif format_shape == 'NCTHW':
batch_inputs = batch_inputs[..., [2, 1, 0], :, :, :]
else:
raise ValueError(f'Invalid format shape: {format_shape}')
# -- Normalization ---
if self._enable_normalize:
if view_shape is None:
batch_inputs = (batch_inputs - self.mean) / self.std
else:
mean = self.mean.view(view_shape)
std = self.std.view(view_shape)
batch_inputs = (batch_inputs - mean) / std
elif self.to_float32:
batch_inputs = batch_inputs.to(torch.float32)
# ----- Blending -----
if training and self.blending is not None:
batch_inputs, data_samples = self.blending(batch_inputs,
data_samples)
return batch_inputs, data_samples
|