|
|
|
|
|
import pytest
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
|
|
|
|
|
from mmaction.models import ResNet
|
|
|
from mmaction.testing import check_norm_state, generate_backbone_demo_inputs
|
|
|
|
|
|
|
|
|
def test_resnet_backbone():
|
|
|
"""Test resnet backbone."""
|
|
|
with pytest.raises(KeyError):
|
|
|
|
|
|
ResNet(20)
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
|
|
ResNet(50, num_stages=0)
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
|
|
ResNet(50, num_stages=5)
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
|
|
ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)
|
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
|
resnet50 = ResNet(50, pretrained=0)
|
|
|
resnet50.init_weights()
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
|
|
ResNet(18, style='tensorflow')
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
|
|
ResNet(18, with_cp=True)
|
|
|
|
|
|
|
|
|
resnet18 = ResNet(18)
|
|
|
resnet18.init_weights()
|
|
|
|
|
|
|
|
|
resnet50 = ResNet(50, norm_eval=True)
|
|
|
resnet50.init_weights()
|
|
|
resnet50.train()
|
|
|
assert check_norm_state(resnet50.modules(), False)
|
|
|
|
|
|
|
|
|
resnet50_pretrain = ResNet(
|
|
|
pretrained='torchvision://resnet50', depth=50, norm_eval=True)
|
|
|
resnet50_pretrain.init_weights()
|
|
|
resnet50_pretrain.train()
|
|
|
assert check_norm_state(resnet50_pretrain.modules(), False)
|
|
|
|
|
|
|
|
|
frozen_stages = 1
|
|
|
resnet50_frozen = ResNet(50, frozen_stages=frozen_stages)
|
|
|
resnet50_frozen.init_weights()
|
|
|
resnet50_frozen.train()
|
|
|
assert resnet50_frozen.conv1.bn.training is False
|
|
|
for layer in resnet50_frozen.conv1.modules():
|
|
|
for param in layer.parameters():
|
|
|
assert param.requires_grad is False
|
|
|
for i in range(1, frozen_stages + 1):
|
|
|
layer = getattr(resnet50_frozen, f'layer{i}')
|
|
|
for mod in layer.modules():
|
|
|
if isinstance(mod, _BatchNorm):
|
|
|
assert mod.training is False
|
|
|
for param in layer.parameters():
|
|
|
assert param.requires_grad is False
|
|
|
|
|
|
|
|
|
resnet_pbn = ResNet(50, partial_bn=True)
|
|
|
resnet_pbn.train()
|
|
|
count_bn = 0
|
|
|
for m in resnet_pbn.modules():
|
|
|
if isinstance(m, nn.BatchNorm2d):
|
|
|
count_bn += 1
|
|
|
if count_bn >= 2:
|
|
|
assert m.weight.requires_grad is False
|
|
|
assert m.bias.requires_grad is False
|
|
|
assert m.training is False
|
|
|
else:
|
|
|
assert m.weight.requires_grad is True
|
|
|
assert m.bias.requires_grad is True
|
|
|
assert m.training is True
|
|
|
|
|
|
input_shape = (1, 3, 64, 64)
|
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
|
|
|
|
|
|
resnet18 = ResNet(18, norm_eval=False)
|
|
|
resnet18.init_weights()
|
|
|
resnet18.train()
|
|
|
feat = resnet18(imgs)
|
|
|
assert feat.shape == torch.Size([1, 512, 2, 2])
|
|
|
|
|
|
|
|
|
resnet50 = ResNet(50, norm_eval=False)
|
|
|
resnet50.init_weights()
|
|
|
resnet50.train()
|
|
|
feat = resnet50(imgs)
|
|
|
assert feat.shape == torch.Size([1, 2048, 2, 2])
|
|
|
|
|
|
|
|
|
resnet50_caffe = ResNet(50, style='caffe', norm_eval=False)
|
|
|
resnet50_caffe.init_weights()
|
|
|
resnet50_caffe.train()
|
|
|
feat = resnet50_caffe(imgs)
|
|
|
assert feat.shape == torch.Size([1, 2048, 2, 2])
|
|
|
|
|
|
resnet50_flow = ResNet(
|
|
|
depth=50, pretrained='torchvision://resnet50', in_channels=10)
|
|
|
input_shape = (1, 10, 64, 64)
|
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
feat = resnet50_flow(imgs)
|
|
|
assert feat.shape == torch.Size([1, 2048, 2, 2])
|
|
|
|
|
|
resnet50 = ResNet(
|
|
|
depth=50, pretrained='torchvision://resnet50', in_channels=3)
|
|
|
input_shape = (1, 3, 64, 64)
|
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
feat = resnet50(imgs)
|
|
|
assert feat.shape == torch.Size([1, 2048, 2, 2])
|
|
|
|