Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmocr.models.textrecog.encoders import (ABIVisionModel, BaseEncoder, | |
| NRTREncoder, SAREncoder, | |
| SatrnEncoder, TransformerEncoder) | |
| def test_sar_encoder(): | |
| with pytest.raises(AssertionError): | |
| SAREncoder(enc_bi_rnn='bi') | |
| with pytest.raises(AssertionError): | |
| SAREncoder(enc_do_rnn=2) | |
| with pytest.raises(AssertionError): | |
| SAREncoder(enc_gru='gru') | |
| with pytest.raises(AssertionError): | |
| SAREncoder(d_model=512.5) | |
| with pytest.raises(AssertionError): | |
| SAREncoder(d_enc=200.5) | |
| with pytest.raises(AssertionError): | |
| SAREncoder(mask='mask') | |
| encoder = SAREncoder() | |
| encoder.init_weights() | |
| encoder.train() | |
| feat = torch.randn(1, 512, 4, 40) | |
| img_metas = [{'valid_ratio': 1.0}] | |
| with pytest.raises(AssertionError): | |
| encoder(feat, img_metas * 2) | |
| out_enc = encoder(feat, img_metas) | |
| assert out_enc.shape == torch.Size([1, 512]) | |
| def test_nrtr_encoder(): | |
| tf_encoder = NRTREncoder() | |
| tf_encoder.init_weights() | |
| tf_encoder.train() | |
| feat = torch.randn(1, 512, 1, 25) | |
| out_enc = tf_encoder(feat) | |
| print('hello', out_enc.size()) | |
| assert out_enc.shape == torch.Size([1, 25, 512]) | |
| def test_satrn_encoder(): | |
| satrn_encoder = SatrnEncoder() | |
| satrn_encoder.init_weights() | |
| satrn_encoder.train() | |
| feat = torch.randn(1, 512, 8, 25) | |
| out_enc = satrn_encoder(feat) | |
| assert out_enc.shape == torch.Size([1, 200, 512]) | |
| def test_base_encoder(): | |
| encoder = BaseEncoder() | |
| encoder.init_weights() | |
| encoder.train() | |
| feat = torch.randn(1, 256, 4, 40) | |
| out_enc = encoder(feat) | |
| assert out_enc.shape == torch.Size([1, 256, 4, 40]) | |
| def test_transformer_encoder(): | |
| model = TransformerEncoder() | |
| x = torch.randn(10, 512, 8, 32) | |
| assert model(x).shape == torch.Size([10, 512, 8, 32]) | |
| def test_abi_vision_model(): | |
| model = ABIVisionModel( | |
| decoder=dict(type='ABIVisionDecoder', max_seq_len=10, use_result=None)) | |
| x = torch.randn(1, 512, 8, 32) | |
| result = model(x) | |
| assert result['feature'].shape == torch.Size([1, 10, 512]) | |
| assert result['logits'].shape == torch.Size([1, 10, 90]) | |
| assert result['attn_scores'].shape == torch.Size([1, 10, 8, 32]) | |