| | |
| | |
| | |
| | |
| | |
| | """ |
| | Train a new model on one or across multiple GPUs. |
| | """ |
| |
|
| | import argparse |
| | import logging |
| | import math |
| | import os |
| | import random |
| | import sys |
| |
|
| | import numpy as np |
| | import torch |
| | from fairseq import ( |
| | checkpoint_utils, |
| | distributed_utils, |
| | options, |
| | quantization_utils, |
| | tasks, |
| | utils, |
| | ) |
| | from fairseq.data import iterators |
| | from fairseq.logging import meters, metrics, progress_bar |
| | from fairseq.model_parallel.megatron_trainer import MegatronTrainer |
| | from fairseq.trainer import Trainer |
| |
|
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | level=os.environ.get("LOGLEVEL", "INFO").upper(), |
| | stream=sys.stdout, |
| | ) |
| | logger = logging.getLogger("fairseq_cli.train") |
| |
|
| |
|
| | def main(args): |
| | utils.import_user_module(args) |
| |
|
| | assert ( |
| | args.max_tokens is not None or args.batch_size is not None |
| | ), "Must specify batch size either with --max-tokens or --batch-size" |
| |
|
| | metrics.reset() |
| |
|
| | np.random.seed(args.seed) |
| | utils.set_torch_seed(args.seed) |
| |
|
| | if distributed_utils.is_master(args): |
| | checkpoint_utils.verify_checkpoint_directory(args.save_dir) |
| |
|
| | |
| | logger.info(args) |
| |
|
| | |
| | task = tasks.setup_task(args) |
| |
|
| | |
| | for valid_sub_split in args.valid_subset.split(","): |
| | task.load_dataset(valid_sub_split, combine=False, epoch=1) |
| |
|
| | |
| | model = task.build_model(args) |
| | criterion = task.build_criterion(args) |
| | logger.info(model) |
| | logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) |
| | logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) |
| | logger.info( |
| | "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) |
| | ) |
| | logger.info( |
| | "num. model params: {} (num. trained: {})".format( |
| | sum(p.numel() for p in model.parameters()), |
| | sum(p.numel() for p in model.parameters() if p.requires_grad), |
| | ) |
| | ) |
| |
|
| | |
| | if args.quantization_config_path is not None: |
| | quantizer = quantization_utils.Quantizer( |
| | config_path=args.quantization_config_path, |
| | max_epoch=args.max_epoch, |
| | max_update=args.max_update, |
| | ) |
| | else: |
| | quantizer = None |
| |
|
| | |
| | if args.model_parallel_size == 1: |
| | trainer = Trainer(args, task, model, criterion, quantizer) |
| | else: |
| | trainer = MegatronTrainer(args, task, model, criterion) |
| |
|
| | logger.info( |
| | "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) |
| | ) |
| | logger.info( |
| | "max tokens per GPU = {} and max sentences per GPU = {}".format( |
| | args.max_tokens, args.batch_size |
| | ) |
| | ) |
| |
|
| | |
| | |
| | extra_state, epoch_itr = checkpoint_utils.load_checkpoint( |
| | args, |
| | trainer, |
| | |
| | disable_iterator_cache=task.has_sharded_data("train"), |
| | ) |
| |
|
| | |
| | max_epoch = args.max_epoch or math.inf |
| | lr = trainer.get_lr() |
| | train_meter = meters.StopwatchMeter() |
| | train_meter.start() |
| |
|
| | while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: |
| | |
| | valid_losses, should_stop = train(args, trainer, task, epoch_itr) |
| | if should_stop: |
| | break |
| |
|
| | |
| | lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) |
| |
|
| | epoch_itr = trainer.get_train_iterator( |
| | epoch_itr.next_epoch_idx, |
| | |
| | load_dataset=task.has_sharded_data("train"), |
| | |
| | disable_iterator_cache=task.has_sharded_data("train"), |
| | ) |
| | train_meter.stop() |
| | logger.info("done training in {:.1f} seconds".format(train_meter.sum)) |
| |
|
| |
|
| | def should_stop_early(args, valid_loss): |
| | |
| | if valid_loss is None: |
| | return False |
| | if args.patience <= 0: |
| | return False |
| |
|
| | def is_better(a, b): |
| | return a > b if args.maximize_best_checkpoint_metric else a < b |
| |
|
| | prev_best = getattr(should_stop_early, "best", None) |
| | if prev_best is None or is_better(valid_loss, prev_best): |
| | should_stop_early.best = valid_loss |
| | should_stop_early.num_runs = 0 |
| | return False |
| | else: |
| | should_stop_early.num_runs += 1 |
| | if should_stop_early.num_runs >= args.patience: |
| | logger.info( |
| | "early stop since valid performance hasn't improved for last {} runs".format( |
| | args.patience |
| | ) |
| | ) |
| | return True |
| | else: |
| | return False |
| |
|
| |
|
| | @metrics.aggregate("train") |
| | def train(args, trainer, task, epoch_itr): |
| | """Train the model for one epoch and return validation losses.""" |
| | |
| | itr = epoch_itr.next_epoch_itr( |
| | fix_batches_to_gpus=args.fix_batches_to_gpus, |
| | shuffle=(epoch_itr.next_epoch_idx > args.curriculum), |
| | ) |
| | update_freq = ( |
| | args.update_freq[epoch_itr.epoch - 1] |
| | if epoch_itr.epoch <= len(args.update_freq) |
| | else args.update_freq[-1] |
| | ) |
| | itr = iterators.GroupedIterator(itr, update_freq) |
| | if getattr(args, "tpu", False): |
| | itr = utils.tpu_data_loader(itr) |
| | progress = progress_bar.progress_bar( |
| | itr, |
| | log_format=args.log_format, |
| | log_interval=args.log_interval, |
| | epoch=epoch_itr.epoch, |
| | tensorboard_logdir=( |
| | args.tensorboard_logdir if distributed_utils.is_master(args) else None |
| | ), |
| | default_log_format=("tqdm" if not args.no_progress_bar else "simple"), |
| | ) |
| |
|
| | trainer.begin_epoch(epoch_itr.epoch) |
| |
|
| | valid_losses = [None] |
| | valid_subsets = args.valid_subset.split(",") |
| | should_stop = False |
| | num_updates = trainer.get_num_updates() |
| | for i, samples in enumerate(progress): |
| | with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( |
| | "train_step-%d" % i |
| | ): |
| | log_output = trainer.train_step(samples) |
| |
|
| | if log_output is not None: |
| | |
| | num_updates = trainer.get_num_updates() |
| | if num_updates % args.log_interval == 0: |
| | stats = get_training_stats(metrics.get_smoothed_values("train_inner")) |
| | progress.log(stats, tag="train_inner", step=num_updates) |
| |
|
| | |
| | |
| | metrics.reset_meters("train_inner") |
| |
|
| | end_of_epoch = not itr.has_next() |
| | valid_losses, should_stop = validate_and_save( |
| | args, trainer, task, epoch_itr, valid_subsets, end_of_epoch |
| | ) |
| |
|
| | if should_stop: |
| | break |
| |
|
| | |
| | logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch)) |
| | stats = get_training_stats(metrics.get_smoothed_values("train")) |
| | progress.print(stats, tag="train", step=num_updates) |
| |
|
| | |
| | metrics.reset_meters("train") |
| | return valid_losses, should_stop |
| |
|
| |
|
| | def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): |
| | num_updates = trainer.get_num_updates() |
| | max_update = args.max_update or math.inf |
| | do_save = ( |
| | (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) |
| | or num_updates >= max_update |
| | or ( |
| | args.save_interval_updates > 0 |
| | and num_updates > 0 |
| | and num_updates % args.save_interval_updates == 0 |
| | and num_updates >= args.validate_after_updates |
| | ) |
| | ) |
| | do_validate = ( |
| | (not end_of_epoch and do_save) |
| | or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) |
| | or num_updates >= max_update |
| | or ( |
| | args.validate_interval_updates > 0 |
| | and num_updates > 0 |
| | and num_updates % args.validate_interval_updates == 0 |
| | ) |
| | ) and not args.disable_validation |
| |
|
| | |
| | valid_losses = [None] |
| | if do_validate: |
| | valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) |
| |
|
| | |
| | should_stop = ( |
| | should_stop_early(args, valid_losses[0]) |
| | or num_updates >= max_update |
| | or ( |
| | args.stop_time_hours > 0 |
| | and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours |
| | ) |
| | ) |
| |
|
| | |
| | if do_save or should_stop: |
| | logger.info("begin save checkpoint") |
| | checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) |
| |
|
| | return valid_losses, should_stop |
| |
|
| |
|
| | def get_training_stats(stats): |
| | stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) |
| | return stats |
| |
|
| |
|
| | def validate(args, trainer, task, epoch_itr, subsets): |
| | """Evaluate the model on the validation set(s) and return the losses.""" |
| |
|
| | if args.fixed_validation_seed is not None: |
| | |
| | utils.set_torch_seed(args.fixed_validation_seed) |
| |
|
| | trainer.begin_valid_epoch(epoch_itr.epoch) |
| | valid_losses = [] |
| | for subset in subsets: |
| | logger.info('begin validation on "{}" subset'.format(subset)) |
| |
|
| | |
| | itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) |
| | if getattr(args, "tpu", False): |
| | itr = utils.tpu_data_loader(itr) |
| | progress = progress_bar.progress_bar( |
| | itr, |
| | log_format=args.log_format, |
| | log_interval=args.log_interval, |
| | epoch=epoch_itr.epoch, |
| | prefix=f"valid on '{subset}' subset", |
| | tensorboard_logdir=( |
| | args.tensorboard_logdir if distributed_utils.is_master(args) else None |
| | ), |
| | default_log_format=("tqdm" if not args.no_progress_bar else "simple"), |
| | ) |
| |
|
| | |
| | |
| | with metrics.aggregate(new_root=True) as agg: |
| | for sample in progress: |
| | trainer.valid_step(sample) |
| |
|
| | |
| | stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) |
| | progress.print(stats, tag=subset, step=trainer.get_num_updates()) |
| |
|
| | valid_losses.append(stats[args.best_checkpoint_metric]) |
| | return valid_losses |
| |
|
| |
|
| | def get_valid_stats(args, trainer, stats): |
| | stats["num_updates"] = trainer.get_num_updates() |
| | if hasattr(checkpoint_utils.save_checkpoint, "best"): |
| | key = "best_{0}".format(args.best_checkpoint_metric) |
| | best_function = max if args.maximize_best_checkpoint_metric else min |
| | stats[key] = best_function( |
| | checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric] |
| | ) |
| | return stats |
| |
|
| |
|
| | def cli_main(modify_parser=None): |
| | parser = options.get_training_parser() |
| | args = options.parse_args_and_arch(parser, modify_parser=modify_parser) |
| | if args.profile: |
| | with torch.cuda.profiler.profile(): |
| | with torch.autograd.profiler.emit_nvtx(): |
| | distributed_utils.call_main(args, main) |
| | else: |
| | distributed_utils.call_main(args, main) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | cli_main() |
| |
|