|
|
| from unittest.mock import MagicMock
|
|
|
| import torch
|
|
|
| from mmaction.registry import MODELS
|
| from mmaction.structures import ActionDataSample
|
| from mmaction.testing import get_recognizer_cfg
|
| from mmaction.utils import register_all_modules
|
|
|
|
|
| def test_omni_resnet():
|
| register_all_modules()
|
| config = get_recognizer_cfg(
|
| 'omnisource/slowonly_r50_8xb16-8x8x1-256e_imagenet-kinetics400-rgb.py')
|
| recognizer = MODELS.build(config.model)
|
|
|
|
|
|
|
| video_sample = {
|
| 'inputs': [
|
| torch.randint(0, 255, (1, 3, 8, 224, 224)),
|
| torch.randint(0, 255, (1, 3, 8, 224, 224))
|
| ],
|
| 'data_samples': [
|
| ActionDataSample().set_gt_label(2),
|
| ActionDataSample().set_gt_label(2)
|
| ]
|
| }
|
|
|
| image_sample = {
|
| 'inputs': [
|
| torch.randint(0, 255, (1, 3, 224, 224)),
|
| torch.randint(0, 255, (1, 3, 224, 224))
|
| ],
|
| 'data_samples': [
|
| ActionDataSample().set_gt_label(2),
|
| ActionDataSample().set_gt_label(2)
|
| ]
|
| }
|
|
|
| optim_wrapper = MagicMock()
|
| loss_vars = recognizer.train_step([video_sample, image_sample],
|
| optim_wrapper)
|
| assert 'loss_cls_0' in loss_vars
|
| assert 'loss_cls_1' in loss_vars
|
|
|
| loss_vars = recognizer.train_step([image_sample, video_sample],
|
| optim_wrapper)
|
| assert 'loss_cls_0' in loss_vars
|
| assert 'loss_cls_1' in loss_vars
|
|
|
|
|
| with torch.no_grad():
|
| predictions = recognizer.test_step(video_sample)
|
| score = predictions[0].pred_score
|
| assert len(predictions) == 2
|
| assert torch.min(score) >= 0
|
| assert torch.max(score) <= 1
|
|
|