| |
|
| | import platform
|
| |
|
| | import numpy as np
|
| | import pytest
|
| | import torch
|
| | from mmcv.transforms import to_tensor
|
| | from mmengine.structures import InstanceData
|
| |
|
| | from mmaction.registry import MODELS
|
| | from mmaction.structures import ActionDataSample
|
| | from mmaction.testing import get_localizer_cfg
|
| | from mmaction.utils import register_all_modules
|
| |
|
| | register_all_modules()
|
| |
|
| |
|
| | def get_localization_data_sample():
|
| | gt_bbox = np.array([[0.1, 0.3], [0.375, 0.625]])
|
| | data_sample = ActionDataSample()
|
| | instance_data = InstanceData()
|
| | instance_data['gt_bbox'] = to_tensor(gt_bbox)
|
| | data_sample.gt_instances = instance_data
|
| | data_sample.set_metainfo(
|
| | dict(
|
| | video_name='v_test',
|
| | duration_second=100,
|
| | duration_frame=960,
|
| | feature_frame=960))
|
| | return data_sample
|
| |
|
| |
|
| | @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
| | def test_tem():
|
| | model_cfg = get_localizer_cfg(
|
| | 'bsn/bsn_tem_1xb16-400x100-20e_activitynet-feature.py')
|
| |
|
| | localizer_tem = MODELS.build(model_cfg.model)
|
| | raw_feature = torch.rand(8, 400, 100)
|
| |
|
| | data_samples = [get_localization_data_sample()] * 8
|
| | losses = localizer_tem(raw_feature, data_samples, mode='loss')
|
| | assert isinstance(losses, dict)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | for one_raw_feature in raw_feature:
|
| | one_raw_feature = one_raw_feature.reshape(1, 400, 100)
|
| | data_samples = [get_localization_data_sample()]
|
| | localizer_tem(one_raw_feature, data_samples, mode='predict')
|
| |
|