AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
raw
history blame
4.53 kB
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmaction.models import ResNet
from mmaction.testing import check_norm_state, generate_backbone_demo_inputs
def test_resnet_backbone():
"""Test resnet backbone."""
with pytest.raises(KeyError):
# ResNet depth should be in [18, 34, 50, 101, 152]
ResNet(20)
with pytest.raises(AssertionError):
# In ResNet: 1 <= num_stages <= 4
ResNet(50, num_stages=0)
with pytest.raises(AssertionError):
# In ResNet: 1 <= num_stages <= 4
ResNet(50, num_stages=5)
with pytest.raises(AssertionError):
# len(strides) == len(dilations) == num_stages
ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)
with pytest.raises(TypeError):
# pretrain must be a str
resnet50 = ResNet(50, pretrained=0)
resnet50.init_weights()
with pytest.raises(AssertionError):
# style must be in ['pytorch', 'caffe']
ResNet(18, style='tensorflow')
with pytest.raises(AssertionError):
# assert not with_cp
ResNet(18, with_cp=True)
# resnet with depth 18, norm_eval False, initial weights
resnet18 = ResNet(18)
resnet18.init_weights()
# resnet with depth 50, norm_eval True
resnet50 = ResNet(50, norm_eval=True)
resnet50.init_weights()
resnet50.train()
assert check_norm_state(resnet50.modules(), False)
# resnet with depth 50, norm_eval True, pretrained
resnet50_pretrain = ResNet(
pretrained='torchvision://resnet50', depth=50, norm_eval=True)
resnet50_pretrain.init_weights()
resnet50_pretrain.train()
assert check_norm_state(resnet50_pretrain.modules(), False)
# resnet with depth 50, norm_eval True, frozen_stages 1
frozen_stages = 1
resnet50_frozen = ResNet(50, frozen_stages=frozen_stages)
resnet50_frozen.init_weights()
resnet50_frozen.train()
assert resnet50_frozen.conv1.bn.training is False
for layer in resnet50_frozen.conv1.modules():
for param in layer.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(resnet50_frozen, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# resnet with depth 50, partial batchnorm
resnet_pbn = ResNet(50, partial_bn=True)
resnet_pbn.train()
count_bn = 0
for m in resnet_pbn.modules():
if isinstance(m, nn.BatchNorm2d):
count_bn += 1
if count_bn >= 2:
assert m.weight.requires_grad is False
assert m.bias.requires_grad is False
assert m.training is False
else:
assert m.weight.requires_grad is True
assert m.bias.requires_grad is True
assert m.training is True
input_shape = (1, 3, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
# resnet with depth 18 inference
resnet18 = ResNet(18, norm_eval=False)
resnet18.init_weights()
resnet18.train()
feat = resnet18(imgs)
assert feat.shape == torch.Size([1, 512, 2, 2])
# resnet with depth 50 inference
resnet50 = ResNet(50, norm_eval=False)
resnet50.init_weights()
resnet50.train()
feat = resnet50(imgs)
assert feat.shape == torch.Size([1, 2048, 2, 2])
# resnet with depth 50 in caffe style inference
resnet50_caffe = ResNet(50, style='caffe', norm_eval=False)
resnet50_caffe.init_weights()
resnet50_caffe.train()
feat = resnet50_caffe(imgs)
assert feat.shape == torch.Size([1, 2048, 2, 2])
resnet50_flow = ResNet(
depth=50, pretrained='torchvision://resnet50', in_channels=10)
input_shape = (1, 10, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
feat = resnet50_flow(imgs)
assert feat.shape == torch.Size([1, 2048, 2, 2])
resnet50 = ResNet(
depth=50, pretrained='torchvision://resnet50', in_channels=3)
input_shape = (1, 3, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
feat = resnet50(imgs)
assert feat.shape == torch.Size([1, 2048, 2, 2])