File size: 3,544 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmaction.models import ResNet3dSlowFast
from mmaction.testing import generate_backbone_demo_inputs
def test_slowfast_backbone():
"""Test SlowFast backbone."""
with pytest.raises(TypeError):
# cfg should be a dict
ResNet3dSlowFast(slow_pathway=list(['foo', 'bar']))
with pytest.raises(KeyError):
# pathway type should be implemented
ResNet3dSlowFast(slow_pathway=dict(type='resnext'))
# test slowfast with slow inflated
sf_50_inflate = ResNet3dSlowFast(
slow_pathway=dict(
type='resnet3d',
depth=50,
pretrained='torchvision://resnet50',
pretrained2d=True,
lateral=True,
conv1_kernel=(1, 7, 7),
dilations=(1, 1, 1, 1),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(0, 0, 1, 1)))
sf_50_inflate.init_weights()
sf_50_inflate.train()
# test slowfast with no lateral connection
sf_50_wo_lateral = ResNet3dSlowFast(
None,
slow_pathway=dict(
type='resnet3d',
depth=50,
pretrained=None,
lateral=False,
conv1_kernel=(1, 7, 7),
dilations=(1, 1, 1, 1),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(0, 0, 1, 1)))
sf_50_wo_lateral.init_weights()
sf_50_wo_lateral.train()
# slowfast w/o lateral connection inference test
input_shape = (1, 3, 8, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
feat = sf_50_wo_lateral(imgs)
assert isinstance(feat, tuple)
assert feat[0].shape == torch.Size([1, 2048, 1, 2, 2])
assert feat[1].shape == torch.Size([1, 256, 8, 2, 2])
# test slowfast with frozen stages config
frozen_slow = 3
sf_50 = ResNet3dSlowFast(
None,
slow_pathway=dict(
type='resnet3d',
depth=50,
pretrained=None,
pretrained2d=True,
lateral=True,
conv1_kernel=(1, 7, 7),
dilations=(1, 1, 1, 1),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(0, 0, 1, 1),
frozen_stages=frozen_slow))
sf_50.init_weights()
sf_50.train()
for stage in range(1, sf_50.slow_path.num_stages):
lateral_name = sf_50.slow_path.lateral_connections[stage - 1]
conv_lateral = getattr(sf_50.slow_path, lateral_name)
for mod in conv_lateral.modules():
if isinstance(mod, _BatchNorm):
if stage <= frozen_slow:
assert mod.training is False
else:
assert mod.training is True
for param in conv_lateral.parameters():
if stage <= frozen_slow:
assert param.requires_grad is False
else:
assert param.requires_grad is True
# test slowfast with normal config
sf_50 = ResNet3dSlowFast()
sf_50.init_weights()
sf_50.train()
# slowfast inference test
input_shape = (1, 3, 8, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
feat = sf_50(imgs)
assert isinstance(feat, tuple)
assert feat[0].shape == torch.Size([1, 2048, 1, 2, 2])
assert feat[1].shape == torch.Size([1, 256, 8, 2, 2])
|