Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os | |
| import sys | |
| import platform | |
| import yaml | |
| import time | |
| import datetime | |
| import paddle | |
| import paddle.distributed as dist | |
| from tqdm import tqdm | |
| import cv2 | |
| import numpy as np | |
| from argparse import ArgumentParser, RawDescriptionHelpFormatter | |
| from ppocr.utils.stats import TrainingStats | |
| from ppocr.utils.save_load import save_model | |
| from ppocr.utils.utility import print_dict, AverageMeter | |
| from ppocr.utils.logging import get_logger | |
| from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers | |
| from ppocr.utils import profiler | |
| from ppocr.data import build_dataloader | |
| class ArgsParser(ArgumentParser): | |
| def __init__(self): | |
| super(ArgsParser, self).__init__( | |
| formatter_class=RawDescriptionHelpFormatter) | |
| self.add_argument("-c", "--config", help="configuration file to use") | |
| self.add_argument( | |
| "-o", "--opt", nargs='+', help="set configuration options") | |
| self.add_argument( | |
| '-p', | |
| '--profiler_options', | |
| type=str, | |
| default=None, | |
| help='The option of profiler, which should be in format ' \ | |
| '\"key1=value1;key2=value2;key3=value3\".' | |
| ) | |
| def parse_args(self, argv=None): | |
| args = super(ArgsParser, self).parse_args(argv) | |
| assert args.config is not None, \ | |
| "Please specify --config=configure_file_path." | |
| args.opt = self._parse_opt(args.opt) | |
| return args | |
| def _parse_opt(self, opts): | |
| config = {} | |
| if not opts: | |
| return config | |
| for s in opts: | |
| s = s.strip() | |
| k, v = s.split('=') | |
| config[k] = yaml.load(v, Loader=yaml.Loader) | |
| return config | |
| def load_config(file_path): | |
| """ | |
| Load config from yml/yaml file. | |
| Args: | |
| file_path (str): Path of the config file to be loaded. | |
| Returns: global config | |
| """ | |
| _, ext = os.path.splitext(file_path) | |
| assert ext in ['.yml', '.yaml'], "only support yaml files for now" | |
| config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader) | |
| return config | |
| def merge_config(config, opts): | |
| """ | |
| Merge config into global config. | |
| Args: | |
| config (dict): Config to be merged. | |
| Returns: global config | |
| """ | |
| for key, value in opts.items(): | |
| if "." not in key: | |
| if isinstance(value, dict) and key in config: | |
| config[key].update(value) | |
| else: | |
| config[key] = value | |
| else: | |
| sub_keys = key.split('.') | |
| assert ( | |
| sub_keys[0] in config | |
| ), "the sub_keys can only be one of global_config: {}, but get: " \ | |
| "{}, please check your running command".format( | |
| config.keys(), sub_keys[0]) | |
| cur = config[sub_keys[0]] | |
| for idx, sub_key in enumerate(sub_keys[1:]): | |
| if idx == len(sub_keys) - 2: | |
| cur[sub_key] = value | |
| else: | |
| cur = cur[sub_key] | |
| return config | |
| def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False): | |
| """ | |
| Log error and exit when set use_gpu=true in paddlepaddle | |
| cpu version. | |
| """ | |
| err = "Config {} cannot be set as true while your paddle " \ | |
| "is not compiled with {} ! \nPlease try: \n" \ | |
| "\t1. Install paddlepaddle to run model on {} \n" \ | |
| "\t2. Set {} as false in config file to run " \ | |
| "model on CPU" | |
| try: | |
| if use_gpu and use_xpu: | |
| print("use_xpu and use_gpu can not both be ture.") | |
| if use_gpu and not paddle.is_compiled_with_cuda(): | |
| print(err.format("use_gpu", "cuda", "gpu", "use_gpu")) | |
| sys.exit(1) | |
| if use_xpu and not paddle.device.is_compiled_with_xpu(): | |
| print(err.format("use_xpu", "xpu", "xpu", "use_xpu")) | |
| sys.exit(1) | |
| if use_npu: | |
| if int(paddle.version.major) != 0 and int( | |
| paddle.version.major) <= 2 and int( | |
| paddle.version.minor) <= 4: | |
| if not paddle.device.is_compiled_with_npu(): | |
| print(err.format("use_npu", "npu", "npu", "use_npu")) | |
| sys.exit(1) | |
| # is_compiled_with_npu() has been updated after paddle-2.4 | |
| else: | |
| if not paddle.device.is_compiled_with_custom_device("npu"): | |
| print(err.format("use_npu", "npu", "npu", "use_npu")) | |
| sys.exit(1) | |
| if use_mlu and not paddle.device.is_compiled_with_mlu(): | |
| print(err.format("use_mlu", "mlu", "mlu", "use_mlu")) | |
| sys.exit(1) | |
| except Exception as e: | |
| pass | |
| def to_float32(preds): | |
| if isinstance(preds, dict): | |
| for k in preds: | |
| if isinstance(preds[k], dict) or isinstance(preds[k], list): | |
| preds[k] = to_float32(preds[k]) | |
| elif isinstance(preds[k], paddle.Tensor): | |
| preds[k] = preds[k].astype(paddle.float32) | |
| elif isinstance(preds, list): | |
| for k in range(len(preds)): | |
| if isinstance(preds[k], dict): | |
| preds[k] = to_float32(preds[k]) | |
| elif isinstance(preds[k], list): | |
| preds[k] = to_float32(preds[k]) | |
| elif isinstance(preds[k], paddle.Tensor): | |
| preds[k] = preds[k].astype(paddle.float32) | |
| elif isinstance(preds, paddle.Tensor): | |
| preds = preds.astype(paddle.float32) | |
| return preds | |
| def train(config, | |
| train_dataloader, | |
| valid_dataloader, | |
| device, | |
| model, | |
| loss_class, | |
| optimizer, | |
| lr_scheduler, | |
| post_process_class, | |
| eval_class, | |
| pre_best_model_dict, | |
| logger, | |
| log_writer=None, | |
| scaler=None, | |
| amp_level='O2', | |
| amp_custom_black_list=[], | |
| amp_custom_white_list=[], | |
| amp_dtype='float16'): | |
| cal_metric_during_train = config['Global'].get('cal_metric_during_train', | |
| False) | |
| calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) | |
| log_smooth_window = config['Global']['log_smooth_window'] | |
| epoch_num = config['Global']['epoch_num'] | |
| print_batch_step = config['Global']['print_batch_step'] | |
| eval_batch_step = config['Global']['eval_batch_step'] | |
| profiler_options = config['profiler_options'] | |
| global_step = 0 | |
| if 'global_step' in pre_best_model_dict: | |
| global_step = pre_best_model_dict['global_step'] | |
| start_eval_step = 0 | |
| if type(eval_batch_step) == list and len(eval_batch_step) >= 2: | |
| start_eval_step = eval_batch_step[0] | |
| eval_batch_step = eval_batch_step[1] | |
| if len(valid_dataloader) == 0: | |
| logger.info( | |
| 'No Images in eval dataset, evaluation during training ' \ | |
| 'will be disabled' | |
| ) | |
| start_eval_step = 1e111 | |
| logger.info( | |
| "During the training process, after the {}th iteration, " \ | |
| "an evaluation is run every {} iterations". | |
| format(start_eval_step, eval_batch_step)) | |
| save_epoch_step = config['Global']['save_epoch_step'] | |
| save_model_dir = config['Global']['save_model_dir'] | |
| if not os.path.exists(save_model_dir): | |
| os.makedirs(save_model_dir) | |
| main_indicator = eval_class.main_indicator | |
| best_model_dict = {main_indicator: 0} | |
| best_model_dict.update(pre_best_model_dict) | |
| train_stats = TrainingStats(log_smooth_window, ['lr']) | |
| model_average = False | |
| model.train() | |
| use_srn = config['Architecture']['algorithm'] == "SRN" | |
| extra_input_models = [ | |
| "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN", | |
| "RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet' | |
| ] | |
| extra_input = False | |
| if config['Architecture']['algorithm'] == 'Distillation': | |
| for key in config['Architecture']["Models"]: | |
| extra_input = extra_input or config['Architecture']['Models'][key][ | |
| 'algorithm'] in extra_input_models | |
| else: | |
| extra_input = config['Architecture']['algorithm'] in extra_input_models | |
| try: | |
| model_type = config['Architecture']['model_type'] | |
| except: | |
| model_type = None | |
| algorithm = config['Architecture']['algorithm'] | |
| start_epoch = best_model_dict[ | |
| 'start_epoch'] if 'start_epoch' in best_model_dict else 1 | |
| total_samples = 0 | |
| train_reader_cost = 0.0 | |
| train_batch_cost = 0.0 | |
| reader_start = time.time() | |
| eta_meter = AverageMeter() | |
| max_iter = len(train_dataloader) - 1 if platform.system( | |
| ) == "Windows" else len(train_dataloader) | |
| for epoch in range(start_epoch, epoch_num + 1): | |
| if train_dataloader.dataset.need_reset: | |
| train_dataloader = build_dataloader( | |
| config, 'Train', device, logger, seed=epoch) | |
| max_iter = len(train_dataloader) - 1 if platform.system( | |
| ) == "Windows" else len(train_dataloader) | |
| for idx, batch in enumerate(train_dataloader): | |
| profiler.add_profiler_step(profiler_options) | |
| train_reader_cost += time.time() - reader_start | |
| if idx >= max_iter: | |
| break | |
| lr = optimizer.get_lr() | |
| images = batch[0] | |
| if use_srn: | |
| model_average = True | |
| # use amp | |
| if scaler: | |
| with paddle.amp.auto_cast( | |
| level=amp_level, | |
| custom_black_list=amp_custom_black_list, | |
| custom_white_list=amp_custom_white_list, | |
| dtype=amp_dtype): | |
| if model_type == 'table' or extra_input: | |
| preds = model(images, data=batch[1:]) | |
| elif model_type in ["kie"]: | |
| preds = model(batch) | |
| elif algorithm in ['CAN']: | |
| preds = model(batch[:3]) | |
| else: | |
| preds = model(images) | |
| preds = to_float32(preds) | |
| loss = loss_class(preds, batch) | |
| avg_loss = loss['loss'] | |
| scaled_avg_loss = scaler.scale(avg_loss) | |
| scaled_avg_loss.backward() | |
| scaler.minimize(optimizer, scaled_avg_loss) | |
| else: | |
| if model_type == 'table' or extra_input: | |
| preds = model(images, data=batch[1:]) | |
| elif model_type in ["kie", 'sr']: | |
| preds = model(batch) | |
| elif algorithm in ['CAN']: | |
| preds = model(batch[:3]) | |
| else: | |
| preds = model(images) | |
| loss = loss_class(preds, batch) | |
| avg_loss = loss['loss'] | |
| avg_loss.backward() | |
| optimizer.step() | |
| optimizer.clear_grad() | |
| if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need | |
| batch = [item.numpy() for item in batch] | |
| if model_type in ['kie', 'sr']: | |
| eval_class(preds, batch) | |
| elif model_type in ['table']: | |
| post_result = post_process_class(preds, batch) | |
| eval_class(post_result, batch) | |
| elif algorithm in ['CAN']: | |
| model_type = 'can' | |
| eval_class(preds[0], batch[2:], epoch_reset=(idx == 0)) | |
| else: | |
| if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2' | |
| ]: # for multi head loss | |
| post_result = post_process_class( | |
| preds['ctc'], batch[1]) # for CTC head out | |
| elif config['Loss']['name'] in ['VLLoss']: | |
| post_result = post_process_class(preds, batch[1], | |
| batch[-1]) | |
| else: | |
| post_result = post_process_class(preds, batch[1]) | |
| eval_class(post_result, batch) | |
| metric = eval_class.get_metric() | |
| train_stats.update(metric) | |
| train_batch_time = time.time() - reader_start | |
| train_batch_cost += train_batch_time | |
| eta_meter.update(train_batch_time) | |
| global_step += 1 | |
| total_samples += len(images) | |
| if not isinstance(lr_scheduler, float): | |
| lr_scheduler.step() | |
| # logger and visualdl | |
| stats = { | |
| k: float(v) if v.shape == [] else v.numpy().mean() | |
| for k, v in loss.items() | |
| } | |
| stats['lr'] = lr | |
| train_stats.update(stats) | |
| if log_writer is not None and dist.get_rank() == 0: | |
| log_writer.log_metrics( | |
| metrics=train_stats.get(), prefix="TRAIN", step=global_step) | |
| if dist.get_rank() == 0 and ( | |
| (global_step > 0 and global_step % print_batch_step == 0) or | |
| (idx >= len(train_dataloader) - 1)): | |
| logs = train_stats.log() | |
| eta_sec = ((epoch_num + 1 - epoch) * \ | |
| len(train_dataloader) - idx - 1) * eta_meter.avg | |
| eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec))) | |
| strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \ | |
| '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \ | |
| 'ips: {:.5f} samples/s, eta: {}'.format( | |
| epoch, epoch_num, global_step, logs, | |
| train_reader_cost / print_batch_step, | |
| train_batch_cost / print_batch_step, | |
| total_samples / print_batch_step, | |
| total_samples / train_batch_cost, eta_sec_format) | |
| logger.info(strs) | |
| total_samples = 0 | |
| train_reader_cost = 0.0 | |
| train_batch_cost = 0.0 | |
| # eval | |
| if global_step > start_eval_step and \ | |
| (global_step - start_eval_step) % eval_batch_step == 0 \ | |
| and dist.get_rank() == 0: | |
| if model_average: | |
| Model_Average = paddle.incubate.optimizer.ModelAverage( | |
| 0.15, | |
| parameters=model.parameters(), | |
| min_average_window=10000, | |
| max_average_window=15625) | |
| Model_Average.apply() | |
| cur_metric = eval( | |
| model, | |
| valid_dataloader, | |
| post_process_class, | |
| eval_class, | |
| model_type, | |
| extra_input=extra_input, | |
| scaler=scaler, | |
| amp_level=amp_level, | |
| amp_custom_black_list=amp_custom_black_list, | |
| amp_custom_white_list=amp_custom_white_list, | |
| amp_dtype=amp_dtype) | |
| cur_metric_str = 'cur metric, {}'.format(', '.join( | |
| ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) | |
| logger.info(cur_metric_str) | |
| # logger metric | |
| if log_writer is not None: | |
| log_writer.log_metrics( | |
| metrics=cur_metric, prefix="EVAL", step=global_step) | |
| if cur_metric[main_indicator] >= best_model_dict[ | |
| main_indicator]: | |
| best_model_dict.update(cur_metric) | |
| best_model_dict['best_epoch'] = epoch | |
| save_model( | |
| model, | |
| optimizer, | |
| save_model_dir, | |
| logger, | |
| config, | |
| is_best=True, | |
| prefix='best_accuracy', | |
| best_model_dict=best_model_dict, | |
| epoch=epoch, | |
| global_step=global_step) | |
| best_str = 'best metric, {}'.format(', '.join([ | |
| '{}: {}'.format(k, v) for k, v in best_model_dict.items() | |
| ])) | |
| logger.info(best_str) | |
| # logger best metric | |
| if log_writer is not None: | |
| log_writer.log_metrics( | |
| metrics={ | |
| "best_{}".format(main_indicator): | |
| best_model_dict[main_indicator] | |
| }, | |
| prefix="EVAL", | |
| step=global_step) | |
| log_writer.log_model( | |
| is_best=True, | |
| prefix="best_accuracy", | |
| metadata=best_model_dict) | |
| reader_start = time.time() | |
| if dist.get_rank() == 0: | |
| save_model( | |
| model, | |
| optimizer, | |
| save_model_dir, | |
| logger, | |
| config, | |
| is_best=False, | |
| prefix='latest', | |
| best_model_dict=best_model_dict, | |
| epoch=epoch, | |
| global_step=global_step) | |
| if log_writer is not None: | |
| log_writer.log_model(is_best=False, prefix="latest") | |
| if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: | |
| save_model( | |
| model, | |
| optimizer, | |
| save_model_dir, | |
| logger, | |
| config, | |
| is_best=False, | |
| prefix='a', | |
| # prefix='iter_epoch_{}'.format(epoch), | |
| best_model_dict=best_model_dict, | |
| epoch=epoch, | |
| global_step=global_step) | |
| if log_writer is not None: | |
| log_writer.log_model( | |
| is_best=False, prefix='a') | |
| # is_best=False, prefix='iter_epoch_{}'.format(epoch)) | |
| best_str = 'best metric, {}'.format(', '.join( | |
| ['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) | |
| logger.info(best_str) | |
| if dist.get_rank() == 0 and log_writer is not None: | |
| log_writer.close() | |
| return | |
| def eval(model, | |
| valid_dataloader, | |
| post_process_class, | |
| eval_class, | |
| model_type=None, | |
| extra_input=False, | |
| scaler=None, | |
| amp_level='O2', | |
| amp_custom_black_list=[], | |
| amp_custom_white_list=[], | |
| amp_dtype='float16'): | |
| model.eval() | |
| with paddle.no_grad(): | |
| total_frame = 0.0 | |
| total_time = 0.0 | |
| pbar = tqdm( | |
| total=len(valid_dataloader), | |
| desc='eval model:', | |
| position=0, | |
| leave=True) | |
| max_iter = len(valid_dataloader) - 1 if platform.system( | |
| ) == "Windows" else len(valid_dataloader) | |
| sum_images = 0 | |
| for idx, batch in enumerate(valid_dataloader): | |
| if idx >= max_iter: | |
| break | |
| images = batch[0] | |
| start = time.time() | |
| # use amp | |
| if scaler: | |
| with paddle.amp.auto_cast( | |
| level=amp_level, | |
| custom_black_list=amp_custom_black_list, | |
| dtype=amp_dtype): | |
| if model_type == 'table' or extra_input: | |
| preds = model(images, data=batch[1:]) | |
| elif model_type in ["kie"]: | |
| preds = model(batch) | |
| elif model_type in ['can']: | |
| preds = model(batch[:3]) | |
| elif model_type in ['sr']: | |
| preds = model(batch) | |
| sr_img = preds["sr_img"] | |
| lr_img = preds["lr_img"] | |
| else: | |
| preds = model(images) | |
| preds = to_float32(preds) | |
| else: | |
| if model_type == 'table' or extra_input: | |
| preds = model(images, data=batch[1:]) | |
| elif model_type in ["kie"]: | |
| preds = model(batch) | |
| elif model_type in ['can']: | |
| preds = model(batch[:3]) | |
| elif model_type in ['sr']: | |
| preds = model(batch) | |
| sr_img = preds["sr_img"] | |
| lr_img = preds["lr_img"] | |
| else: | |
| preds = model(images) | |
| batch_numpy = [] | |
| for item in batch: | |
| if isinstance(item, paddle.Tensor): | |
| batch_numpy.append(item.numpy()) | |
| else: | |
| batch_numpy.append(item) | |
| # Obtain usable results from post-processing methods | |
| total_time += time.time() - start | |
| # Evaluate the results of the current batch | |
| if model_type in ['table', 'kie']: | |
| if post_process_class is None: | |
| eval_class(preds, batch_numpy) | |
| else: | |
| post_result = post_process_class(preds, batch_numpy) | |
| eval_class(post_result, batch_numpy) | |
| elif model_type in ['sr']: | |
| eval_class(preds, batch_numpy) | |
| elif model_type in ['can']: | |
| eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0)) | |
| else: | |
| post_result = post_process_class(preds, batch_numpy[1]) | |
| eval_class(post_result, batch_numpy) | |
| pbar.update(1) | |
| total_frame += len(images) | |
| sum_images += 1 | |
| # Get final metric,eg. acc or hmean | |
| metric = eval_class.get_metric() | |
| pbar.close() | |
| model.train() | |
| metric['fps'] = total_frame / total_time | |
| return metric | |
| def update_center(char_center, post_result, preds): | |
| result, label = post_result | |
| feats, logits = preds | |
| logits = paddle.argmax(logits, axis=-1) | |
| feats = feats.numpy() | |
| logits = logits.numpy() | |
| for idx_sample in range(len(label)): | |
| if result[idx_sample][0] == label[idx_sample][0]: | |
| feat = feats[idx_sample] | |
| logit = logits[idx_sample] | |
| for idx_time in range(len(logit)): | |
| index = logit[idx_time] | |
| if index in char_center.keys(): | |
| char_center[index][0] = ( | |
| char_center[index][0] * char_center[index][1] + | |
| feat[idx_time]) / (char_center[index][1] + 1) | |
| char_center[index][1] += 1 | |
| else: | |
| char_center[index] = [feat[idx_time], 1] | |
| return char_center | |
| def get_center(model, eval_dataloader, post_process_class): | |
| pbar = tqdm(total=len(eval_dataloader), desc='get center:') | |
| max_iter = len(eval_dataloader) - 1 if platform.system( | |
| ) == "Windows" else len(eval_dataloader) | |
| char_center = dict() | |
| for idx, batch in enumerate(eval_dataloader): | |
| if idx >= max_iter: | |
| break | |
| images = batch[0] | |
| start = time.time() | |
| preds = model(images) | |
| batch = [item.numpy() for item in batch] | |
| # Obtain usable results from post-processing methods | |
| post_result = post_process_class(preds, batch[1]) | |
| #update char_center | |
| char_center = update_center(char_center, post_result, preds) | |
| pbar.update(1) | |
| pbar.close() | |
| for key in char_center.keys(): | |
| char_center[key] = char_center[key][0] | |
| return char_center | |
| def preprocess(is_train=False): | |
| FLAGS = ArgsParser().parse_args() | |
| profiler_options = FLAGS.profiler_options | |
| config = load_config(FLAGS.config) | |
| config = merge_config(config, FLAGS.opt) | |
| profile_dic = {"profiler_options": FLAGS.profiler_options} | |
| config = merge_config(config, profile_dic) | |
| if is_train: | |
| # save_config | |
| save_model_dir = config['Global']['save_model_dir'] | |
| os.makedirs(save_model_dir, exist_ok=True) | |
| with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: | |
| yaml.dump( | |
| dict(config), f, default_flow_style=False, sort_keys=False) | |
| log_file = '{}/train.log'.format(save_model_dir) | |
| else: | |
| log_file = None | |
| logger = get_logger(log_file=log_file) | |
| # check if set use_gpu=True in paddlepaddle cpu version | |
| use_gpu = config['Global'].get('use_gpu', False) | |
| use_xpu = config['Global'].get('use_xpu', False) | |
| use_npu = config['Global'].get('use_npu', False) | |
| use_mlu = config['Global'].get('use_mlu', False) | |
| alg = config['Architecture']['algorithm'] | |
| assert alg in [ | |
| 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', | |
| 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', | |
| 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', | |
| 'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', | |
| 'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', | |
| 'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet' | |
| ] | |
| if use_xpu: | |
| device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0)) | |
| elif use_npu: | |
| device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0)) | |
| elif use_mlu: | |
| device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0)) | |
| else: | |
| device = 'gpu:{}'.format(dist.ParallelEnv() | |
| .dev_id) if use_gpu else 'cpu' | |
| check_device(use_gpu, use_xpu, use_npu, use_mlu) | |
| device = paddle.set_device(device) | |
| config['Global']['distributed'] = dist.get_world_size() != 1 | |
| loggers = [] | |
| if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']: | |
| save_model_dir = config['Global']['save_model_dir'] | |
| vdl_writer_path = save_model_dir | |
| log_writer = VDLLogger(vdl_writer_path) | |
| loggers.append(log_writer) | |
| if ('use_wandb' in config['Global'] and | |
| config['Global']['use_wandb']) or 'wandb' in config: | |
| save_dir = config['Global']['save_model_dir'] | |
| wandb_writer_path = "{}/wandb".format(save_dir) | |
| if "wandb" in config: | |
| wandb_params = config['wandb'] | |
| else: | |
| wandb_params = dict() | |
| wandb_params.update({'save_dir': save_dir}) | |
| log_writer = WandbLogger(**wandb_params, config=config) | |
| loggers.append(log_writer) | |
| else: | |
| log_writer = None | |
| print_dict(config, logger) | |
| if loggers: | |
| log_writer = Loggers(loggers) | |
| else: | |
| log_writer = None | |
| logger.info('train with paddle {} and device {}'.format(paddle.__version__, | |
| device)) | |
| return config, device, logger, log_writer | |