| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmocr.models.textrecog import SegHead | |
| def test_seg_head(): | |
| with pytest.raises(AssertionError): | |
| SegHead(num_classes='100') | |
| with pytest.raises(AssertionError): | |
| SegHead(num_classes=-1) | |
| seg_head = SegHead(num_classes=37) | |
| out_neck = (torch.rand(1, 128, 32, 32), ) | |
| out_head = seg_head(out_neck) | |
| assert out_head.shape == torch.Size([1, 37, 32, 32]) | |