| |
|
| | import os
|
| | import tempfile
|
| |
|
| | import torch
|
| | from mmengine.runner import load_checkpoint, save_checkpoint
|
| | from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
|
| |
|
| | from mmaction.models.backbones.mobileone_tsm import MobileOneTSM
|
| | from mmaction.testing import generate_backbone_demo_inputs
|
| |
|
| |
|
| | def test_mobileone_tsm_backbone():
|
| | """Test MobileOne TSM backbone."""
|
| |
|
| | from mmpretrain.models.backbones.mobileone import MobileOneBlock
|
| |
|
| | from mmaction.models.backbones.resnet_tsm import TemporalShift
|
| |
|
| | model = MobileOneTSM('s0', pretrained2d=False)
|
| | model.init_weights()
|
| | for cur_module in model.modules():
|
| | if isinstance(cur_module, TemporalShift):
|
| |
|
| | assert isinstance(cur_module.net, MobileOneBlock)
|
| | assert cur_module.num_segments == model.num_segments
|
| | assert cur_module.shift_div == model.shift_div
|
| |
|
| | inputs = generate_backbone_demo_inputs((8, 3, 64, 64))
|
| |
|
| | feat = model(inputs)
|
| | assert feat.shape == torch.Size([8, 1024, 2, 2])
|
| |
|
| | model = MobileOneTSM('s1', pretrained2d=False)
|
| | feat = model(inputs)
|
| | assert feat.shape == torch.Size([8, 1280, 2, 2])
|
| |
|
| | model = MobileOneTSM('s2', pretrained2d=False)
|
| | feat = model(inputs)
|
| | assert feat.shape == torch.Size([8, 2048, 2, 2])
|
| |
|
| | model = MobileOneTSM('s3', pretrained2d=False)
|
| | feat = model(inputs)
|
| | assert feat.shape == torch.Size([8, 2048, 2, 2])
|
| |
|
| | model = MobileOneTSM('s4', pretrained2d=False)
|
| | feat = model(inputs)
|
| | assert feat.shape == torch.Size([8, 2048, 2, 2])
|
| |
|
| |
|
| | def test_mobileone_init_weight():
|
| | checkpoint = ('https://download.openmmlab.com/mmclassification/v0'
|
| | '/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth')
|
| |
|
| | model = MobileOneTSM(
|
| | arch='s0',
|
| | init_cfg=dict(
|
| | type='Pretrained', checkpoint=checkpoint, prefix='backbone'))
|
| | model.init_weights()
|
| | ori_ckpt = _load_checkpoint_with_prefix(
|
| | 'backbone', model.init_cfg['checkpoint'], map_location='cpu')
|
| | for name, param in model.named_parameters():
|
| | ori_name = name.replace('.net', '')
|
| | assert torch.allclose(param, ori_ckpt[ori_name]), \
|
| | f'layer {name} fail to load from pretrained checkpoint'
|
| |
|
| |
|
| | def test_load_deploy_mobileone():
|
| |
|
| | model = MobileOneTSM('s0', pretrained2d=False)
|
| | inputs = generate_backbone_demo_inputs((8, 3, 64, 64))
|
| | tmpdir = tempfile.gettempdir()
|
| | ckpt_path = os.path.join(tmpdir, 'ckpt.pth')
|
| | model.switch_to_deploy()
|
| | model.eval()
|
| | outputs = model(inputs)
|
| |
|
| | model_deploy = MobileOneTSM('s0', pretrained2d=False, deploy=True)
|
| | save_checkpoint(model.state_dict(), ckpt_path)
|
| | load_checkpoint(model_deploy, ckpt_path)
|
| |
|
| | outputs_load = model_deploy(inputs)
|
| | for feat, feat_load in zip(outputs, outputs_load):
|
| | assert torch.allclose(feat, feat_load)
|
| | os.remove(ckpt_path)
|
| |
|