File size: 15,757 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Tuple
import numpy as np
import torch
from mmcv.transforms import BaseTransform, to_tensor
from mmengine.structures import InstanceData
from mmaction.registry import TRANSFORMS
from mmaction.structures import ActionDataSample
@TRANSFORMS.register_module()
class PackActionInputs(BaseTransform):
"""Pack the inputs data.
Args:
collect_keys (tuple[str], optional): The keys to be collected
to ``packed_results['inputs']``. Defaults to ``
meta_keys (Sequence[str]): The meta keys to saved in the
`metainfo` of the `data_sample`.
Defaults to ``('img_shape', 'img_key', 'video_id', 'timestamp')``.
algorithm_keys (Sequence[str]): The keys of custom elements to be used
in the algorithm. Defaults to an empty tuple.
"""
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_labels': 'labels',
}
def __init__(
self,
collect_keys: Optional[Tuple[str]] = None,
meta_keys: Sequence[str] = ('img_shape', 'img_key', 'video_id',
'timestamp'),
algorithm_keys: Sequence[str] = (),
) -> None:
self.collect_keys = collect_keys
self.meta_keys = meta_keys
self.algorithm_keys = algorithm_keys
def transform(self, results: Dict) -> Dict:
"""The transform function of :class:`PackActionInputs`.
Args:
results (dict): The result dict.
Returns:
dict: The result dict.
"""
packed_results = dict()
if self.collect_keys is not None:
packed_results['inputs'] = dict()
for key in self.collect_keys:
packed_results['inputs'][key] = to_tensor(results[key])
else:
if 'imgs' in results:
imgs = results['imgs']
packed_results['inputs'] = to_tensor(imgs)
elif 'heatmap_imgs' in results:
heatmap_imgs = results['heatmap_imgs']
packed_results['inputs'] = to_tensor(heatmap_imgs)
elif 'keypoint' in results:
keypoint = results['keypoint']
packed_results['inputs'] = to_tensor(keypoint)
elif 'audios' in results:
audios = results['audios']
packed_results['inputs'] = to_tensor(audios)
elif 'text' in results:
text = results['text']
packed_results['inputs'] = to_tensor(text)
else:
raise ValueError(
'Cannot get `imgs`, `keypoint`, `heatmap_imgs`, '
'`audios` or `text` in the input dict of '
'`PackActionInputs`.')
data_sample = ActionDataSample()
if 'gt_bboxes' in results:
instance_data = InstanceData()
for key in self.mapping_table.keys():
instance_data[self.mapping_table[key]] = to_tensor(
results[key])
data_sample.gt_instances = instance_data
if 'proposals' in results:
data_sample.proposals = InstanceData(
bboxes=to_tensor(results['proposals']))
if 'label' in results:
data_sample.set_gt_label(results['label'])
# Set custom algorithm keys
for key in self.algorithm_keys:
if key in results:
data_sample.set_field(results[key], key)
# Set meta keys
img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(collect_keys={self.collect_keys}, '
repr_str += f'meta_keys={self.meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class PackLocalizationInputs(BaseTransform):
def __init__(self, keys=(), meta_keys=('video_name', )):
self.keys = keys
self.meta_keys = meta_keys
def transform(self, results):
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
- 'data_samples' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
if 'raw_feature' in results:
raw_feature = results['raw_feature']
packed_results['inputs'] = to_tensor(raw_feature)
elif 'bsp_feature' in results:
packed_results['inputs'] = torch.tensor(0.)
else:
raise ValueError(
'Cannot get "raw_feature" or "bsp_feature" in the input '
'dict of `PackActionInputs`.')
data_sample = ActionDataSample()
for key in self.keys:
if key not in results:
continue
elif key == 'proposals':
instance_data = InstanceData()
instance_data[key] = to_tensor(results[key])
data_sample.proposals = instance_data
else:
if hasattr(data_sample, 'gt_instances'):
data_sample.gt_instances[key] = to_tensor(results[key])
else:
instance_data = InstanceData()
instance_data[key] = to_tensor(results[key])
data_sample.gt_instances = instance_data
img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class Transpose(BaseTransform):
"""Transpose image channels to a given order.
Args:
keys (Sequence[str]): Required keys to be converted.
order (Sequence[int]): Image channel order.
"""
def __init__(self, keys, order):
self.keys = keys
self.order = order
def transform(self, results):
"""Performs the Transpose formatting.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return (f'{self.__class__.__name__}('
f'keys={self.keys}, order={self.order})')
@TRANSFORMS.register_module()
class FormatShape(BaseTransform):
"""Format final imgs shape to the given input_format.
Required keys:
- imgs (optional)
- heatmap_imgs (optional)
- modality (optional)
- num_clips
- clip_len
Modified Keys:
- imgs
Added Keys:
- input_shape
- heatmap_input_shape (optional)
Args:
input_format (str): Define the final data format.
collapse (bool): To collapse input_format N... to ... (NCTHW to CTHW,
etc.) if N is 1. Should be set as True when training and testing
detectors. Defaults to False.
"""
def __init__(self, input_format: str, collapse: bool = False) -> None:
self.input_format = input_format
self.collapse = collapse
if self.input_format not in [
'NCTHW', 'NCHW', 'NCTHW_Heatmap', 'NPTCHW'
]:
raise ValueError(
f'The input format {self.input_format} is invalid.')
def transform(self, results: Dict) -> Dict:
"""Performs the FormatShape formatting.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
if not isinstance(results['imgs'], np.ndarray):
results['imgs'] = np.array(results['imgs'])
# [M x H x W x C]
# M = 1 * N_crops * N_clips * T
if self.collapse:
assert results['num_clips'] == 1
if self.input_format == 'NCTHW':
if 'imgs' in results:
imgs = results['imgs']
num_clips = results['num_clips']
clip_len = results['clip_len']
if isinstance(clip_len, dict):
clip_len = clip_len['RGB']
imgs = imgs.reshape((-1, num_clips, clip_len) + imgs.shape[1:])
# N_crops x N_clips x T x H x W x C
imgs = np.transpose(imgs, (0, 1, 5, 2, 3, 4))
# N_crops x N_clips x C x T x H x W
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
# M' x C x T x H x W
# M' = N_crops x N_clips
results['imgs'] = imgs
results['input_shape'] = imgs.shape
if 'heatmap_imgs' in results:
imgs = results['heatmap_imgs']
num_clips = results['num_clips']
clip_len = results['clip_len']
# clip_len must be a dict
clip_len = clip_len['Pose']
imgs = imgs.reshape((-1, num_clips, clip_len) + imgs.shape[1:])
# N_crops x N_clips x T x C x H x W
imgs = np.transpose(imgs, (0, 1, 3, 2, 4, 5))
# N_crops x N_clips x C x T x H x W
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
# M' x C x T x H x W
# M' = N_crops x N_clips
results['heatmap_imgs'] = imgs
results['heatmap_input_shape'] = imgs.shape
elif self.input_format == 'NCTHW_Heatmap':
num_clips = results['num_clips']
clip_len = results['clip_len']
imgs = results['imgs']
imgs = imgs.reshape((-1, num_clips, clip_len) + imgs.shape[1:])
# N_crops x N_clips x T x C x H x W
imgs = np.transpose(imgs, (0, 1, 3, 2, 4, 5))
# N_crops x N_clips x C x T x H x W
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
# M' x C x T x H x W
# M' = N_crops x N_clips
results['imgs'] = imgs
results['input_shape'] = imgs.shape
elif self.input_format == 'NCHW':
imgs = results['imgs']
imgs = np.transpose(imgs, (0, 3, 1, 2))
if 'modality' in results and results['modality'] == 'Flow':
clip_len = results['clip_len']
imgs = imgs.reshape((-1, clip_len * imgs.shape[1]) +
imgs.shape[2:])
# M x C x H x W
results['imgs'] = imgs
results['input_shape'] = imgs.shape
elif self.input_format == 'NPTCHW':
num_proposals = results['num_proposals']
num_clips = results['num_clips']
clip_len = results['clip_len']
imgs = results['imgs']
imgs = imgs.reshape((num_proposals, num_clips * clip_len) +
imgs.shape[1:])
# P x M x H x W x C
# M = N_clips x T
imgs = np.transpose(imgs, (0, 1, 4, 2, 3))
# P x M x C x H x W
results['imgs'] = imgs
results['input_shape'] = imgs.shape
if self.collapse:
assert results['imgs'].shape[0] == 1
results['imgs'] = results['imgs'].squeeze(0)
results['input_shape'] = results['imgs'].shape
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f"(input_format='{self.input_format}')"
return repr_str
@TRANSFORMS.register_module()
class FormatAudioShape(BaseTransform):
"""Format final audio shape to the given input_format.
Required Keys:
- audios
Modified Keys:
- audios
Added Keys:
- input_shape
Args:
input_format (str): Define the final imgs format.
"""
def __init__(self, input_format: str) -> None:
self.input_format = input_format
if self.input_format not in ['NCTF']:
raise ValueError(
f'The input format {self.input_format} is invalid.')
def transform(self, results: Dict) -> Dict:
"""Performs the FormatShape formatting.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
audios = results['audios']
# clip x sample x freq -> clip x channel x sample x freq
clip, sample, freq = audios.shape
audios = audios.reshape(clip, 1, sample, freq)
results['audios'] = audios
results['input_shape'] = audios.shape
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f"(input_format='{self.input_format}')"
return repr_str
@TRANSFORMS.register_module()
class FormatGCNInput(BaseTransform):
"""Format final skeleton shape.
Required Keys:
- keypoint
- keypoint_score (optional)
- num_clips (optional)
Modified Key:
- keypoint
Args:
num_person (int): The maximum number of people. Defaults to 2.
mode (str): The padding mode. Defaults to ``'zero'``.
"""
def __init__(self, num_person: int = 2, mode: str = 'zero') -> None:
self.num_person = num_person
assert mode in ['zero', 'loop']
self.mode = mode
def transform(self, results: Dict) -> Dict:
"""The transform function of :class:`FormatGCNInput`.
Args:
results (dict): The result dict.
Returns:
dict: The result dict.
"""
keypoint = results['keypoint']
if 'keypoint_score' in results:
keypoint = np.concatenate(
(keypoint, results['keypoint_score'][..., None]), axis=-1)
cur_num_person = keypoint.shape[0]
if cur_num_person < self.num_person:
pad_dim = self.num_person - cur_num_person
pad = np.zeros(
(pad_dim, ) + keypoint.shape[1:], dtype=keypoint.dtype)
keypoint = np.concatenate((keypoint, pad), axis=0)
if self.mode == 'loop' and cur_num_person == 1:
for i in range(1, self.num_person):
keypoint[i] = keypoint[0]
elif cur_num_person > self.num_person:
keypoint = keypoint[:self.num_person]
M, T, V, C = keypoint.shape
nc = results.get('num_clips', 1)
assert T % nc == 0
keypoint = keypoint.reshape(
(M, nc, T // nc, V, C)).transpose(1, 0, 2, 3, 4)
results['keypoint'] = np.ascontiguousarray(keypoint)
return results
def __repr__(self) -> str:
repr_str = (f'{self.__class__.__name__}('
f'num_person={self.num_person}, '
f'mode={self.mode})')
return repr_str
|