Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| from __future__ import unicode_literals | |
| import os | |
| import sys | |
| import numpy as np | |
| import skimage | |
| import paddle | |
| import signal | |
| import random | |
| __dir__ = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) | |
| import copy | |
| from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler | |
| import paddle.distributed as dist | |
| from ppocr.data.imaug import transform, create_operators | |
| from ppocr.data.simple_dataset import SimpleDataSet, MultiScaleDataSet | |
| from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR, LMDBDataSetTableMaster | |
| from ppocr.data.pgnet_dataset import PGDataSet | |
| from ppocr.data.pubtab_dataset import PubTabDataSet | |
| from ppocr.data.multi_scale_sampler import MultiScaleSampler | |
| # for PaddleX dataset_type | |
| TextDetDataset = SimpleDataSet | |
| TextRecDataset = SimpleDataSet | |
| MSTextRecDataset = MultiScaleDataSet | |
| PubTabTableRecDataset = PubTabDataSet | |
| KieDataset = SimpleDataSet | |
| __all__ = [ | |
| 'build_dataloader', 'transform', 'create_operators', 'set_signal_handlers' | |
| ] | |
| def term_mp(sig_num, frame): | |
| """ kill all child processes | |
| """ | |
| pid = os.getpid() | |
| pgid = os.getpgid(os.getpid()) | |
| print("main proc {} exit, kill process group " "{}".format(pid, pgid)) | |
| os.killpg(pgid, signal.SIGKILL) | |
| def set_signal_handlers(): | |
| pid = os.getpid() | |
| try: | |
| pgid = os.getpgid(pid) | |
| except AttributeError: | |
| # In case `os.getpgid` is not available, no signal handler will be set, | |
| # because we cannot do safe cleanup. | |
| pass | |
| else: | |
| # XXX: `term_mp` kills all processes in the process group, which in | |
| # some cases includes the parent process of current process and may | |
| # cause unexpected results. To solve this problem, we set signal | |
| # handlers only when current process is the group leader. In the | |
| # future, it would be better to consider killing only descendants of | |
| # the current process. | |
| if pid == pgid: | |
| # support exit using ctrl+c | |
| signal.signal(signal.SIGINT, term_mp) | |
| signal.signal(signal.SIGTERM, term_mp) | |
| def build_dataloader(config, mode, device, logger, seed=None): | |
| config = copy.deepcopy(config) | |
| support_dict = [ | |
| 'SimpleDataSet', | |
| 'LMDBDataSet', | |
| 'PGDataSet', | |
| 'PubTabDataSet', | |
| 'LMDBDataSetSR', | |
| 'LMDBDataSetTableMaster', | |
| 'MultiScaleDataSet', | |
| 'TextDetDataset', | |
| 'TextRecDataset', | |
| 'MSTextRecDataset', | |
| 'PubTabTableRecDataset', | |
| 'KieDataset', | |
| ] | |
| module_name = config[mode]['dataset']['name'] | |
| assert module_name in support_dict, Exception( | |
| 'DataSet only support {}'.format(support_dict)) | |
| assert mode in ['Train', 'Eval', 'Test' | |
| ], "Mode should be Train, Eval or Test." | |
| dataset = eval(module_name)(config, mode, logger, seed) | |
| loader_config = config[mode]['loader'] | |
| batch_size = loader_config['batch_size_per_card'] | |
| drop_last = loader_config['drop_last'] | |
| shuffle = loader_config['shuffle'] | |
| num_workers = loader_config['num_workers'] | |
| if 'use_shared_memory' in loader_config.keys(): | |
| use_shared_memory = loader_config['use_shared_memory'] | |
| else: | |
| use_shared_memory = True | |
| if mode == "Train": | |
| # Distribute data to multiple cards | |
| if 'sampler' in config[mode]: | |
| config_sampler = config[mode]['sampler'] | |
| sampler_name = config_sampler.pop("name") | |
| batch_sampler = eval(sampler_name)(dataset, **config_sampler) | |
| else: | |
| batch_sampler = DistributedBatchSampler( | |
| dataset=dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| drop_last=drop_last) | |
| else: | |
| # Distribute data to single card | |
| batch_sampler = BatchSampler( | |
| dataset=dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| drop_last=drop_last) | |
| if 'collate_fn' in loader_config: | |
| from . import collate_fn | |
| collate_fn = getattr(collate_fn, loader_config['collate_fn'])() | |
| else: | |
| collate_fn = None | |
| data_loader = DataLoader( | |
| dataset=dataset, | |
| batch_sampler=batch_sampler, | |
| places=device, | |
| num_workers=num_workers, | |
| return_list=True, | |
| use_shared_memory=use_shared_memory, | |
| collate_fn=collate_fn) | |
| return data_loader | |