File size: 789 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torchvision
from mmaction.models import OmniResNet
from mmaction.testing import generate_backbone_demo_inputs
def test_x3d_backbone():
"""Test x3d backbone."""
_ = OmniResNet()
resnet50 = torchvision.models.resnet50()
params = resnet50.state_dict()
torch.save(params, './r50.pth')
model = OmniResNet(pretrain_2d='./r50.pth')
input_shape = (2, 3, 8, 64, 64)
videos = generate_backbone_demo_inputs(input_shape)
feat = model(videos)
assert feat.shape == torch.Size([2, 2048, 8, 2, 2])
input_shape = (2, 3, 64, 64)
images = generate_backbone_demo_inputs(input_shape)
feat = model(images)
assert feat.shape == torch.Size([2, 2048, 2, 2])
|