|
|
|
|
|
import pytest
|
|
|
import torch
|
|
|
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
|
|
|
|
|
from mmaction.models import ResNet3dSlowFast
|
|
|
from mmaction.testing import generate_backbone_demo_inputs
|
|
|
|
|
|
|
|
|
def test_slowfast_backbone():
|
|
|
"""Test SlowFast backbone."""
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
|
ResNet3dSlowFast(slow_pathway=list(['foo', 'bar']))
|
|
|
with pytest.raises(KeyError):
|
|
|
|
|
|
ResNet3dSlowFast(slow_pathway=dict(type='resnext'))
|
|
|
|
|
|
|
|
|
sf_50_inflate = ResNet3dSlowFast(
|
|
|
slow_pathway=dict(
|
|
|
type='resnet3d',
|
|
|
depth=50,
|
|
|
pretrained='torchvision://resnet50',
|
|
|
pretrained2d=True,
|
|
|
lateral=True,
|
|
|
conv1_kernel=(1, 7, 7),
|
|
|
dilations=(1, 1, 1, 1),
|
|
|
conv1_stride_t=1,
|
|
|
pool1_stride_t=1,
|
|
|
inflate=(0, 0, 1, 1)))
|
|
|
sf_50_inflate.init_weights()
|
|
|
sf_50_inflate.train()
|
|
|
|
|
|
|
|
|
sf_50_wo_lateral = ResNet3dSlowFast(
|
|
|
None,
|
|
|
slow_pathway=dict(
|
|
|
type='resnet3d',
|
|
|
depth=50,
|
|
|
pretrained=None,
|
|
|
lateral=False,
|
|
|
conv1_kernel=(1, 7, 7),
|
|
|
dilations=(1, 1, 1, 1),
|
|
|
conv1_stride_t=1,
|
|
|
pool1_stride_t=1,
|
|
|
inflate=(0, 0, 1, 1)))
|
|
|
sf_50_wo_lateral.init_weights()
|
|
|
sf_50_wo_lateral.train()
|
|
|
|
|
|
|
|
|
input_shape = (1, 3, 8, 64, 64)
|
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
feat = sf_50_wo_lateral(imgs)
|
|
|
|
|
|
assert isinstance(feat, tuple)
|
|
|
assert feat[0].shape == torch.Size([1, 2048, 1, 2, 2])
|
|
|
assert feat[1].shape == torch.Size([1, 256, 8, 2, 2])
|
|
|
|
|
|
|
|
|
frozen_slow = 3
|
|
|
sf_50 = ResNet3dSlowFast(
|
|
|
None,
|
|
|
slow_pathway=dict(
|
|
|
type='resnet3d',
|
|
|
depth=50,
|
|
|
pretrained=None,
|
|
|
pretrained2d=True,
|
|
|
lateral=True,
|
|
|
conv1_kernel=(1, 7, 7),
|
|
|
dilations=(1, 1, 1, 1),
|
|
|
conv1_stride_t=1,
|
|
|
pool1_stride_t=1,
|
|
|
inflate=(0, 0, 1, 1),
|
|
|
frozen_stages=frozen_slow))
|
|
|
sf_50.init_weights()
|
|
|
sf_50.train()
|
|
|
|
|
|
for stage in range(1, sf_50.slow_path.num_stages):
|
|
|
lateral_name = sf_50.slow_path.lateral_connections[stage - 1]
|
|
|
conv_lateral = getattr(sf_50.slow_path, lateral_name)
|
|
|
for mod in conv_lateral.modules():
|
|
|
if isinstance(mod, _BatchNorm):
|
|
|
if stage <= frozen_slow:
|
|
|
assert mod.training is False
|
|
|
else:
|
|
|
assert mod.training is True
|
|
|
for param in conv_lateral.parameters():
|
|
|
if stage <= frozen_slow:
|
|
|
assert param.requires_grad is False
|
|
|
else:
|
|
|
assert param.requires_grad is True
|
|
|
|
|
|
|
|
|
sf_50 = ResNet3dSlowFast()
|
|
|
sf_50.init_weights()
|
|
|
sf_50.train()
|
|
|
|
|
|
|
|
|
input_shape = (1, 3, 8, 64, 64)
|
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
feat = sf_50(imgs)
|
|
|
|
|
|
assert isinstance(feat, tuple)
|
|
|
assert feat[0].shape == torch.Size([1, 2048, 1, 2, 2])
|
|
|
assert feat[1].shape == torch.Size([1, 256, 8, 2, 2])
|
|
|
|