Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmocr.models.textrecog.backbones import (ResNet, ResNet31OCR, ResNetABI, | |
| ShallowCNN, VeryDeepVgg) | |
| def test_resnet31_ocr_backbone(): | |
| """Test resnet backbone.""" | |
| with pytest.raises(AssertionError): | |
| ResNet31OCR(2.5) | |
| with pytest.raises(AssertionError): | |
| ResNet31OCR(3, layers=5) | |
| with pytest.raises(AssertionError): | |
| ResNet31OCR(3, channels=5) | |
| # Test ResNet18 forward | |
| model = ResNet31OCR() | |
| model.init_weights() | |
| model.train() | |
| imgs = torch.randn(1, 3, 32, 160) | |
| feat = model(imgs) | |
| assert feat.shape == torch.Size([1, 512, 4, 40]) | |
| def test_vgg_deep_vgg_ocr_backbone(): | |
| model = VeryDeepVgg() | |
| model.init_weights() | |
| model.train() | |
| imgs = torch.randn(1, 3, 32, 160) | |
| feats = model(imgs) | |
| assert feats.shape == torch.Size([1, 512, 1, 41]) | |
| def test_shallow_cnn_ocr_backbone(): | |
| model = ShallowCNN() | |
| model.init_weights() | |
| model.train() | |
| imgs = torch.randn(1, 1, 32, 100) | |
| feat = model(imgs) | |
| assert feat.shape == torch.Size([1, 512, 8, 25]) | |
| def test_resnet_abi(): | |
| """Test resnet backbone.""" | |
| with pytest.raises(AssertionError): | |
| ResNetABI(2.5) | |
| with pytest.raises(AssertionError): | |
| ResNetABI(3, arch_settings=5) | |
| with pytest.raises(AssertionError): | |
| ResNetABI(3, stem_channels=None) | |
| with pytest.raises(AssertionError): | |
| ResNetABI(arch_settings=[3, 4, 6, 6], strides=[1, 2, 1, 2, 1]) | |
| # Test forwarding | |
| model = ResNetABI() | |
| model.train() | |
| imgs = torch.randn(1, 3, 32, 160) | |
| feat = model(imgs) | |
| assert feat.shape == torch.Size([1, 512, 8, 40]) | |
| def test_resnet(): | |
| """Test all ResNet backbones.""" | |
| resnet45_aster = ResNet( | |
| in_channels=3, | |
| stem_channels=[64, 128], | |
| block_cfgs=dict(type='BasicBlock', use_conv1x1='True'), | |
| arch_layers=[3, 4, 6, 6, 3], | |
| arch_channels=[32, 64, 128, 256, 512], | |
| strides=[(2, 2), (2, 2), (2, 1), (2, 1), (2, 1)]) | |
| resnet45_abi = ResNet( | |
| in_channels=3, | |
| stem_channels=32, | |
| block_cfgs=dict(type='BasicBlock', use_conv1x1='True'), | |
| arch_layers=[3, 4, 6, 6, 3], | |
| arch_channels=[32, 64, 128, 256, 512], | |
| strides=[2, 1, 2, 1, 1]) | |
| resnet_31 = ResNet( | |
| in_channels=3, | |
| stem_channels=[64, 128], | |
| block_cfgs=dict(type='BasicBlock'), | |
| arch_layers=[1, 2, 5, 3], | |
| arch_channels=[256, 256, 512, 512], | |
| strides=[1, 1, 1, 1], | |
| plugins=[ | |
| dict( | |
| cfg=dict(type='Maxpool2d', kernel_size=2, stride=(2, 2)), | |
| stages=(True, True, False, False), | |
| position='before_stage'), | |
| dict( | |
| cfg=dict(type='Maxpool2d', kernel_size=(2, 1), stride=(2, 1)), | |
| stages=(False, False, True, False), | |
| position='before_stage'), | |
| dict( | |
| cfg=dict( | |
| type='ConvModule', | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU')), | |
| stages=(True, True, True, True), | |
| position='after_stage') | |
| ]) | |
| img = torch.rand(1, 3, 32, 100) | |
| assert resnet45_aster(img).shape == torch.Size([1, 512, 1, 25]) | |
| assert resnet45_abi(img).shape == torch.Size([1, 512, 8, 25]) | |
| assert resnet_31(img).shape == torch.Size([1, 512, 4, 25]) | |