|
|
| import torch
|
| import torchvision
|
|
|
| from mmaction.models import OmniResNet
|
| from mmaction.testing import generate_backbone_demo_inputs
|
|
|
|
|
| def test_x3d_backbone():
|
| """Test x3d backbone."""
|
| _ = OmniResNet()
|
|
|
| resnet50 = torchvision.models.resnet50()
|
| params = resnet50.state_dict()
|
| torch.save(params, './r50.pth')
|
| model = OmniResNet(pretrain_2d='./r50.pth')
|
|
|
| input_shape = (2, 3, 8, 64, 64)
|
| videos = generate_backbone_demo_inputs(input_shape)
|
| feat = model(videos)
|
| assert feat.shape == torch.Size([2, 2048, 8, 2, 2])
|
|
|
| input_shape = (2, 3, 64, 64)
|
| images = generate_backbone_demo_inputs(input_shape)
|
| feat = model(images)
|
| assert feat.shape == torch.Size([2, 2048, 2, 2])
|
|
|