| |
|
| | import pytest
|
| | import torch
|
| | import torch.nn as nn
|
| | from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
| |
|
| | from mmaction.models import ResNet3dCSN
|
| | from mmaction.testing import generate_backbone_demo_inputs
|
| |
|
| |
|
| | def test_resnet_csn_backbone():
|
| | """Test resnet_csn backbone."""
|
| | with pytest.raises(ValueError):
|
| |
|
| | ResNet3dCSN(152, None, bottleneck_mode='id')
|
| |
|
| | input_shape = (2, 3, 6, 64, 64)
|
| | imgs = generate_backbone_demo_inputs(input_shape)
|
| |
|
| | resnet3d_csn_frozen = ResNet3dCSN(
|
| | 152, None, bn_frozen=True, norm_eval=True)
|
| | resnet3d_csn_frozen.train()
|
| | for m in resnet3d_csn_frozen.modules():
|
| | if isinstance(m, _BatchNorm):
|
| | for param in m.parameters():
|
| | assert param.requires_grad is False
|
| |
|
| |
|
| | resnet3d_csn_ip = ResNet3dCSN(152, None, bottleneck_mode='ip')
|
| | resnet3d_csn_ip.init_weights()
|
| | resnet3d_csn_ip.train()
|
| | for i, layer_name in enumerate(resnet3d_csn_ip.res_layers):
|
| | layers = getattr(resnet3d_csn_ip, layer_name)
|
| | num_blocks = resnet3d_csn_ip.stage_blocks[i]
|
| | assert len(layers) == num_blocks
|
| | for layer in layers:
|
| | assert isinstance(layer.conv2, nn.Sequential)
|
| | assert len(layer.conv2) == 2
|
| | assert layer.conv2[1].groups == layer.planes
|
| | if torch.__version__ == 'parrots':
|
| | if torch.cuda.is_available():
|
| | resnet3d_csn_ip = resnet3d_csn_ip.cuda()
|
| | imgs_gpu = imgs.cuda()
|
| | feat = resnet3d_csn_ip(imgs_gpu)
|
| | assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
|
| | else:
|
| | feat = resnet3d_csn_ip(imgs)
|
| | assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
|
| |
|
| |
|
| | resnet3d_csn_ir = ResNet3dCSN(152, None, bottleneck_mode='ir')
|
| | resnet3d_csn_ir.init_weights()
|
| | resnet3d_csn_ir.train()
|
| | for i, layer_name in enumerate(resnet3d_csn_ir.res_layers):
|
| | layers = getattr(resnet3d_csn_ir, layer_name)
|
| | num_blocks = resnet3d_csn_ir.stage_blocks[i]
|
| | assert len(layers) == num_blocks
|
| | for layer in layers:
|
| | assert isinstance(layer.conv2, nn.Sequential)
|
| | assert len(layer.conv2) == 1
|
| | assert layer.conv2[0].groups == layer.planes
|
| | if torch.__version__ == 'parrots':
|
| | if torch.cuda.is_available():
|
| | resnet3d_csn_ir = resnet3d_csn_ir.cuda()
|
| | imgs_gpu = imgs.cuda()
|
| | feat = resnet3d_csn_ir(imgs_gpu)
|
| | assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
|
| | else:
|
| | feat = resnet3d_csn_ir(imgs)
|
| | assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
|
| |
|
| |
|
| | resnet3d_csn_ip = ResNet3dCSN(152, None, bottleneck_mode='ip')
|
| | resnet3d_csn_ip.init_weights()
|
| | resnet3d_csn_ip.train(False)
|
| | for module in resnet3d_csn_ip.children():
|
| | assert module.training is False
|
| |
|