| |
|
| | from unittest.mock import MagicMock
|
| |
|
| | import torch
|
| |
|
| | from mmaction.registry import MODELS
|
| | from mmaction.structures import ActionDataSample
|
| | from mmaction.testing import get_skeletongcn_cfg
|
| | from mmaction.utils import register_all_modules
|
| |
|
| |
|
| | def train_test_step(cfg, input_shape):
|
| | recognizer = MODELS.build(cfg.model)
|
| | num_classes = cfg.model.cls_head.num_classes
|
| | data_batch = {
|
| | 'inputs': [torch.randn(input_shape)],
|
| | 'data_samples': [ActionDataSample().set_gt_label(2)]
|
| | }
|
| |
|
| |
|
| | optim_wrapper = MagicMock()
|
| | loss_vars = recognizer.train_step(data_batch, optim_wrapper)
|
| | assert 'loss' in loss_vars
|
| | assert 'loss_cls' in loss_vars
|
| | optim_wrapper.update_params.assert_called_once()
|
| |
|
| |
|
| | with torch.no_grad():
|
| | predictions = recognizer.test_step(data_batch)
|
| | score = predictions[0].pred_score
|
| | assert len(predictions) == 1
|
| | assert score.shape == torch.Size([num_classes])
|
| | assert torch.min(score) >= 0
|
| | assert torch.max(score) <= 1
|
| |
|
| |
|
| | recognizer.cls_head.average_clips = None
|
| | num_clips = 3
|
| | input_shape = (num_clips, *input_shape[1:])
|
| | data_batch['inputs'] = [torch.randn(input_shape)]
|
| | with torch.no_grad():
|
| | predictions = recognizer.test_step(data_batch)
|
| | score = predictions[0].pred_score
|
| | assert len(predictions) == 1
|
| | assert score.shape == torch.Size([num_clips, num_classes])
|
| |
|
| | return loss_vars, predictions
|
| |
|
| |
|
| | def test_stgcn():
|
| | register_all_modules()
|
| | config = get_skeletongcn_cfg(
|
| | 'stgcn/stgcn_8xb16-joint-u100-80e_ntu60-xsub-keypoint-2d.py')
|
| | input_shape = (1, 2, 30, 17, 3)
|
| | train_test_step(config, input_shape=input_shape)
|
| |
|
| |
|
| | def test_agcn():
|
| | register_all_modules()
|
| | config = get_skeletongcn_cfg(
|
| | '2s-agcn/2s-agcn_8xb16-joint-u100-80e_ntu60-xsub-keypoint-2d.py')
|
| | input_shape = (1, 2, 30, 17, 3)
|
| | train_test_step(config, input_shape=input_shape)
|
| |
|
| |
|
| | def test_stgcn_plusplus():
|
| | register_all_modules()
|
| | config = get_skeletongcn_cfg(
|
| | 'stgcnpp/stgcnpp_8xb16-joint-u100-80e_ntu60-xsub-keypoint-2d.py')
|
| | input_shape = (1, 2, 30, 17, 3)
|
| | train_test_step(config, input_shape=input_shape)
|
| |
|