| | |
| | import json |
| | import os.path as osp |
| | import tempfile |
| |
|
| | import torch |
| |
|
| | from mmocr.datasets.ner_dataset import NerDataset |
| | from mmocr.models.ner.convertors.ner_convertor import NerConvertor |
| | from mmocr.utils import list_to_file |
| |
|
| |
|
| | def _create_dummy_ann_file(ann_file): |
| | data = { |
| | 'text': '彭小军认为,国内银行现在走的是台湾的发卡模式', |
| | 'label': { |
| | 'address': { |
| | '台湾': [[15, 16]] |
| | }, |
| | 'name': { |
| | '彭小军': [[0, 2]] |
| | } |
| | } |
| | } |
| |
|
| | list_to_file(ann_file, [json.dumps(data, ensure_ascii=False)]) |
| |
|
| |
|
| | def _create_dummy_vocab_file(vocab_file): |
| | for char in list(map(chr, range(ord('a'), ord('z') + 1))): |
| | list_to_file(vocab_file, [json.dumps(char + '\n', ensure_ascii=False)]) |
| |
|
| |
|
| | def _create_dummy_loader(): |
| | loader = dict( |
| | type='HardDiskLoader', |
| | repeat=1, |
| | parser=dict(type='LineJsonParser', keys=['text', 'label'])) |
| | return loader |
| |
|
| |
|
| | def test_ner_dataset(): |
| | |
| | loader = _create_dummy_loader() |
| | categories = [ |
| | 'address', 'book', 'company', 'game', 'government', 'movie', 'name', |
| | 'organization', 'position', 'scene' |
| | ] |
| |
|
| | |
| | tmp_dir = tempfile.TemporaryDirectory() |
| | ann_file = osp.join(tmp_dir.name, 'fake_data.txt') |
| | vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt') |
| | _create_dummy_ann_file(ann_file) |
| | _create_dummy_vocab_file(vocab_file) |
| |
|
| | max_len = 128 |
| | ner_convertor = dict( |
| | type='NerConvertor', |
| | annotation_type='bio', |
| | vocab_file=vocab_file, |
| | categories=categories, |
| | max_len=max_len) |
| |
|
| | test_pipeline = [ |
| | dict( |
| | type='NerTransform', |
| | label_convertor=ner_convertor, |
| | max_len=max_len), |
| | dict(type='ToTensorNER') |
| | ] |
| | dataset = NerDataset(ann_file, loader, pipeline=test_pipeline) |
| |
|
| | |
| | img_info = dataset.data_infos[0] |
| | results = dict(img_info=img_info) |
| | dataset.pre_pipeline(results) |
| |
|
| | |
| | dataset.prepare_train_img(0) |
| |
|
| | |
| | result = [[['address', 15, 16], ['name', 0, 2]]] |
| |
|
| | dataset.evaluate(result) |
| |
|
| | |
| | pred = [ |
| | 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, 11, |
| | 21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, |
| | 11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
| | 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
| | 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
| | 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
| | 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
| | 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, 21, |
| | 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, |
| | 21, 21 |
| | ] |
| | preds = [pred[:128]] |
| | mask = [0] * 128 |
| | for i in range(10): |
| | mask[i] = 1 |
| | assert len(preds[0]) == len(mask) |
| | masks = torch.tensor([mask]) |
| | convertor = NerConvertor( |
| | annotation_type='bio', |
| | vocab_file=vocab_file, |
| | categories=categories, |
| | max_len=128) |
| | all_entities = convertor.convert_pred2entities(preds=preds, masks=masks) |
| | assert len(all_entities[0][0]) == 3 |
| |
|
| | tmp_dir.cleanup() |
| |
|