| |
| """Train models with dynamic data.""" |
| import torch |
| from functools import partial |
| from onmt.utils.distributed import ErrorHandler, spawned_train |
| from onmt.utils.misc import set_random_seed |
| from onmt.utils.logging import init_logger, logger |
| from onmt.utils.parse import ArgumentParser |
| from onmt.opts import train_opts |
| from onmt.train_single import main as single_main |
|
|
|
|
| |
| |
|
|
|
|
| def train(opt): |
| init_logger(opt.log_file) |
|
|
| ArgumentParser.validate_train_opts(opt) |
| ArgumentParser.update_model_opts(opt) |
| ArgumentParser.validate_model_opts(opt) |
|
|
| set_random_seed(opt.seed, False) |
|
|
| train_process = partial(single_main) |
|
|
| nb_gpu = len(opt.gpu_ranks) |
|
|
| if opt.world_size > 1: |
| mp = torch.multiprocessing.get_context("spawn") |
| |
| error_queue = mp.SimpleQueue() |
| error_handler = ErrorHandler(error_queue) |
| |
| procs = [] |
| for device_id in range(nb_gpu): |
| procs.append( |
| mp.Process( |
| target=spawned_train, |
| args=(train_process, opt, device_id, error_queue), |
| daemon=False, |
| ) |
| ) |
| procs[device_id].start() |
| logger.info(" Starting process pid: %d " % procs[device_id].pid) |
| error_handler.add_child(procs[device_id].pid) |
| for p in procs: |
| p.join() |
|
|
| elif nb_gpu == 1: |
| train_process(opt, device_id=0) |
| else: |
| train_process(opt, device_id=-1) |
|
|
|
|
| def _get_parser(): |
| parser = ArgumentParser(description="train.py") |
| train_opts(parser) |
| return parser |
|
|
|
|
| def main(): |
| parser = _get_parser() |
|
|
| opt, unknown = parser.parse_known_args() |
| train(opt) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|