| |
| """Training on a single process.""" |
| import torch |
| import sys |
|
|
| from onmt.utils.logging import init_logger, logger |
| from onmt.utils.parse import ArgumentParser |
| from onmt.constants import CorpusTask |
| from onmt.transforms import ( |
| make_transforms, |
| save_transforms, |
| get_specials, |
| get_transforms_cls, |
| ) |
| from onmt.inputters import build_vocab, IterOnDevice |
| from onmt.inputters.inputter import dict_to_vocabs, vocabs_to_dict |
| from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter |
| from onmt.inputters.text_corpus import save_transformed_sample |
| from onmt.model_builder import build_model |
| from onmt.models.model_saver import load_checkpoint |
| from onmt.utils.optimizers import Optimizer |
| from onmt.utils.misc import set_random_seed |
| from onmt.trainer import build_trainer |
| from onmt.models import build_model_saver |
| from onmt.modules.embeddings import prepare_pretrained_embeddings |
|
|
|
|
| def prepare_transforms_vocabs(opt, transforms_cls): |
| """Prepare or dump transforms before training.""" |
| |
| |
| validset_transforms = opt.data.get("valid", {}).get("transforms", None) |
| if validset_transforms: |
| opt.transforms = validset_transforms |
| if opt.data.get("valid", {}).get("tgt_prefix", None): |
| opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None) |
| if opt.data.get("valid", {}).get("src_prefix", None): |
| opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None) |
| if opt.data.get("valid", {}).get("tgt_suffix", None): |
| opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None) |
| if opt.data.get("valid", {}).get("src_suffix", None): |
| opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None) |
| specials = get_specials(opt, transforms_cls) |
|
|
| vocabs = build_vocab(opt, specials) |
|
|
| |
| prepare_pretrained_embeddings(opt, vocabs) |
|
|
| if opt.dump_transforms or opt.n_sample != 0: |
| transforms = make_transforms(opt, transforms_cls, vocabs) |
| if opt.dump_transforms: |
| save_transforms(transforms, opt.save_data, overwrite=opt.overwrite) |
| if opt.n_sample != 0: |
| logger.warning( |
| "`-n_sample` != 0: Training will not be started. " |
| f"Stop after saving {opt.n_sample} samples/corpus." |
| ) |
| save_transformed_sample(opt, transforms, n_sample=opt.n_sample) |
| logger.info("Sample saved, please check it before restart training.") |
| sys.exit() |
| logger.info( |
| "The first 10 tokens of the vocabs are:" |
| f"{vocabs_to_dict(vocabs)['src'][0:10]}" |
| ) |
| logger.info(f"The decoder start token is: {opt.decoder_start_token}") |
| return vocabs |
|
|
|
|
| def _init_train(opt): |
| """Common initilization stuff for all training process. |
| We need to build or rebuild the vocab in 3 cases: |
| - training from scratch (train_from is false) |
| - resume training but transforms have changed |
| - resume training but vocab file has been modified |
| """ |
| ArgumentParser.validate_prepare_opts(opt) |
| transforms_cls = get_transforms_cls(opt._all_transform) |
| if opt.train_from: |
| |
| checkpoint = load_checkpoint(ckpt_path=opt.train_from) |
| vocabs = dict_to_vocabs(checkpoint["vocab"]) |
| if ( |
| hasattr(checkpoint["opt"], "_all_transform") |
| and len( |
| opt._all_transform.symmetric_difference( |
| checkpoint["opt"]._all_transform |
| ) |
| ) |
| != 0 |
| ): |
| _msg = "configured transforms is different from checkpoint:" |
| new_transf = opt._all_transform.difference(checkpoint["opt"]._all_transform) |
| old_transf = checkpoint["opt"]._all_transform.difference(opt._all_transform) |
| if len(new_transf) != 0: |
| _msg += f" +{new_transf}" |
| if len(old_transf) != 0: |
| _msg += f" -{old_transf}." |
| logger.warning(_msg) |
| vocabs = prepare_transforms_vocabs(opt, transforms_cls) |
| if opt.update_vocab: |
| logger.info("Updating checkpoint vocabulary with new vocabulary") |
| vocabs = prepare_transforms_vocabs(opt, transforms_cls) |
| else: |
| checkpoint = None |
| vocabs = prepare_transforms_vocabs(opt, transforms_cls) |
|
|
| return checkpoint, vocabs, transforms_cls |
|
|
|
|
| def configure_process(opt, device_id): |
| if device_id >= 0: |
| torch.cuda.set_device(device_id) |
| set_random_seed(opt.seed, device_id >= 0) |
|
|
|
|
| def _get_model_opts(opt, checkpoint=None): |
| """Get `model_opt` to build model, may load from `checkpoint` if any.""" |
| if checkpoint is not None: |
| model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) |
| if opt.override_opts: |
| logger.info("Over-ride model option set to true - use with care") |
| args = list(opt.__dict__.keys()) |
| model_args = list(model_opt.__dict__.keys()) |
| for arg in args: |
| if arg in model_args and getattr(opt, arg) != getattr(model_opt, arg): |
| logger.info( |
| "Option: %s , value: %s overriding model: %s" |
| % (arg, getattr(opt, arg), getattr(model_opt, arg)) |
| ) |
| model_opt = opt |
| else: |
| model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) |
| ArgumentParser.update_model_opts(model_opt) |
| ArgumentParser.validate_model_opts(model_opt) |
| if opt.tensorboard_log_dir == model_opt.tensorboard_log_dir and hasattr( |
| model_opt, "tensorboard_log_dir_dated" |
| ): |
| |
| |
| opt.tensorboard_log_dir_dated = ( |
| model_opt.tensorboard_log_dir_dated |
| ) |
| |
| model_opt.update_vocab = opt.update_vocab |
| |
| model_opt.freeze_encoder = opt.freeze_encoder |
| model_opt.freeze_decoder = opt.freeze_decoder |
| else: |
| model_opt = opt |
| return model_opt |
|
|
|
|
| def main(opt, device_id): |
| """Start training on `device_id`.""" |
| |
| |
|
|
| configure_process(opt, device_id) |
| init_logger(opt.log_file) |
| checkpoint, vocabs, transforms_cls = _init_train(opt) |
| model_opt = _get_model_opts(opt, checkpoint=checkpoint) |
|
|
| |
| model = build_model(model_opt, opt, vocabs, checkpoint, device_id) |
|
|
| model.count_parameters(log=logger.info) |
| trainable = { |
| "torch.float32": 0, |
| "torch.float16": 0, |
| "torch.uint8": 0, |
| "torch.int8": 0, |
| } |
| non_trainable = { |
| "torch.float32": 0, |
| "torch.float16": 0, |
| "torch.uint8": 0, |
| "torch.int8": 0, |
| } |
| for n, p in model.named_parameters(): |
| if p.requires_grad: |
| trainable[str(p.dtype)] += p.numel() |
| else: |
| non_trainable[str(p.dtype)] += p.numel() |
| logger.info("Trainable parameters = %s" % str(trainable)) |
| logger.info("Non trainable parameters = %s" % str(non_trainable)) |
| logger.info(" * src vocab size = %d" % len(vocabs["src"])) |
| logger.info(" * tgt vocab size = %d" % len(vocabs["tgt"])) |
| if "src_feats" in vocabs: |
| for i, feat_vocab in enumerate(vocabs["src_feats"]): |
| logger.info(f"* src_feat {i} vocab size = {len(feat_vocab)}") |
|
|
| |
| optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) |
|
|
| del checkpoint |
|
|
| |
| model_saver = build_model_saver(model_opt, opt, model, vocabs, optim, device_id) |
|
|
| trainer = build_trainer( |
| opt, device_id, model, vocabs, optim, model_saver=model_saver |
| ) |
|
|
| offset = max(0, device_id) if opt.parallel_mode == "data_parallel" else 0 |
| stride = max(1, len(opt.gpu_ranks)) if opt.parallel_mode == "data_parallel" else 1 |
|
|
| _train_iter = build_dynamic_dataset_iter( |
| opt, |
| transforms_cls, |
| vocabs, |
| task=CorpusTask.TRAIN, |
| copy=opt.copy_attn, |
| stride=stride, |
| offset=offset, |
| ) |
| train_iter = IterOnDevice(_train_iter, device_id) |
|
|
| valid_iter = build_dynamic_dataset_iter( |
| opt, transforms_cls, vocabs, task=CorpusTask.VALID, copy=opt.copy_attn |
| ) |
|
|
| if valid_iter is not None: |
| valid_iter = IterOnDevice(valid_iter, device_id) |
|
|
| if len(opt.gpu_ranks): |
| logger.info("Starting training on GPU: %s" % opt.gpu_ranks) |
| else: |
| logger.info("Starting training on CPU, could be very slow") |
| train_steps = opt.train_steps |
| if opt.single_pass and train_steps > 0: |
| logger.warning("Option single_pass is enabled, ignoring train_steps.") |
| train_steps = 0 |
|
|
| trainer.train( |
| train_iter, |
| train_steps, |
| save_checkpoint_steps=opt.save_checkpoint_steps, |
| valid_iter=valid_iter, |
| valid_steps=opt.valid_steps, |
| ) |
|
|
| if trainer.report_manager.tensorboard_writer is not None: |
| trainer.report_manager.tensorboard_writer.close() |
|
|