| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Usage: |
| |
| export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" |
| |
| # For hubert model pretraining: |
| ./hubert/pretrain.py \ |
| --world-size 8 \ |
| --num-epochs 400 \ |
| --start-epoch 1 \ |
| --use-fp16 1 \ |
| --exp-dir hubert/exp \ |
| --full-libri 1 \ |
| --max-duration 87.5 \ |
| --accum-grad 4 |
| """ |
|
|
|
|
| import argparse |
| import copy |
| import logging |
| import sys |
| import warnings |
| from pathlib import Path |
| from shutil import copyfile |
| from typing import Any, Dict, Optional, Tuple, Union |
|
|
| import k2 |
| import optim |
| import torch |
| import torch.multiprocessing as mp |
| import torch.nn as nn |
| from hubert_ce import HubertModel, add_hubert_arguments |
| from lhotse.cut import Cut |
| from lhotse.dataset.sampling.base import CutSampler |
| from lhotse.utils import fix_random_seed |
| from optim import Eden, ScaledAdam |
| from ssl_datamodule import LibriSpeechDataModule |
| from torch import Tensor |
| from torch.cuda.amp import GradScaler |
| from torch.nn.functional import pad |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| from icefall import diagnostics |
| from icefall.checkpoint import load_checkpoint, remove_checkpoints |
| from icefall.checkpoint import save_checkpoint as save_checkpoint_impl |
| from icefall.checkpoint import ( |
| save_checkpoint_with_global_batch_idx, |
| update_averaged_model, |
| ) |
| from icefall.dist import cleanup_dist, setup_dist |
| from icefall.env import get_env_info |
| from icefall.hooks import register_inf_check_hooks |
| from icefall.utils import ( |
| AttributeDict, |
| MetricsTracker, |
| get_parameter_groups_with_lrs, |
| setup_logger, |
| str2bool, |
| torch_autocast, |
| ) |
|
|
| LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] |
|
|
|
|
| def get_adjusted_batch_count(params: AttributeDict) -> float: |
| |
| |
| return ( |
| params.batch_idx_train |
| * params.accum_grad |
| * (params.max_duration * params.world_size) |
| / params.ref_duration |
| ) |
|
|
|
|
| def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: |
| if isinstance(model, DDP): |
| |
| model = model.module |
| for name, module in model.named_modules(): |
| if hasattr(module, "batch_count"): |
| module.batch_count = batch_count |
| if hasattr(module, "name"): |
| module.name = name |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "--world-size", |
| type=int, |
| default=1, |
| help="Number of GPUs for DDP training.", |
| ) |
|
|
| parser.add_argument( |
| "--master-port", |
| type=int, |
| default=12354, |
| help="Master port to use for DDP training.", |
| ) |
|
|
| parser.add_argument( |
| "--tensorboard", |
| type=str2bool, |
| default=True, |
| help="Should various information be logged in tensorboard.", |
| ) |
|
|
| parser.add_argument( |
| "--num-epochs", |
| type=int, |
| default=400, |
| help="Number of epochs to train.", |
| ) |
|
|
| parser.add_argument( |
| "--start-epoch", |
| type=int, |
| default=1, |
| help="""Resume training from this epoch. It should be positive. |
| If larger than 1, it will load checkpoint from |
| exp-dir/epoch-{start_epoch-1}.pt |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--start-batch", |
| type=int, |
| default=0, |
| help="""If positive, --start-epoch is ignored and |
| it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--exp-dir", |
| type=str, |
| default="hubert/exp", |
| help="""The experiment dir. |
| It specifies the directory where all training related |
| files, e.g., checkpoints, log, etc, are saved |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--base-lr", type=float, default=0.045, help="The base learning rate." |
| ) |
|
|
| parser.add_argument( |
| "--lr-batches", |
| type=float, |
| default=7500, |
| help="""Number of steps that affects how rapidly the learning rate |
| decreases. We suggest not to change this.""", |
| ) |
|
|
| parser.add_argument( |
| "--lr-epochs", |
| type=float, |
| default=10.5, |
| help="""Number of epochs that affects how rapidly the learning rate decreases. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--warmup-batches", |
| type=float, |
| default=5000, |
| help="Eden warmup steps", |
| ) |
|
|
| parser.add_argument( |
| "--warmup-start", |
| type=float, |
| default=0, |
| help="Eden warmup start learning rate", |
| ) |
|
|
| parser.add_argument( |
| "--ref-duration", |
| type=float, |
| default=80, |
| help="Reference batch duration for purposes of adjusting batch counts for setting various " |
| "schedules inside the model", |
| ) |
|
|
| parser.add_argument( |
| "--seed", |
| type=int, |
| default=42, |
| help="The seed for random generators intended for reproducibility", |
| ) |
|
|
| parser.add_argument( |
| "--print-diagnostics", |
| type=str2bool, |
| default=False, |
| help="Accumulate stats on activations, print them and exit.", |
| ) |
|
|
| parser.add_argument( |
| "--sanity-check", |
| type=str2bool, |
| default=False, |
| help="Check if any of the batches in epoch 1 would cause OOM.", |
| ) |
|
|
| parser.add_argument( |
| "--inf-check", |
| type=str2bool, |
| default=False, |
| help="Add hooks to check for infinite module outputs and gradients.", |
| ) |
|
|
| parser.add_argument( |
| "--save-every-n", |
| type=int, |
| default=100000, |
| help="""Save checkpoint after processing this number of batches" |
| periodically. We save checkpoint to exp-dir/ whenever |
| params.batch_idx_train % save_every_n == 0. The checkpoint filename |
| has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' |
| Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the |
| end of each epoch where `xxx` is the epoch number counting from 1. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--keep-last-k", |
| type=int, |
| default=30, |
| help="""Only keep this number of checkpoints on disk. |
| For instance, if it is 3, there are only 3 checkpoints |
| in the exp-dir with filenames `checkpoint-xxx.pt`. |
| It does not affect checkpoints with name `epoch-xxx.pt`. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--average-period", |
| type=int, |
| default=200, |
| help="""Update the averaged model, namely `model_avg`, after processing |
| this number of batches. `model_avg` is a separate version of model, |
| in which each floating-point parameter is the average of all the |
| parameters from the start of training. Each time we take the average, |
| we do: `model_avg = model * (average_period / batch_idx_train) + |
| model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--accum-grad", |
| type=int, |
| default=4, |
| help="""update gradient when batch_idx_train % accum_grad == 0. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--use-fp16", |
| type=str2bool, |
| default=False, |
| help="Whether to use half precision training.", |
| ) |
|
|
| parser.add_argument( |
| "--max-keep-size", |
| type=int, |
| default=sys.maxsize, |
| help="exclude sample longer than this.", |
| ) |
|
|
| parser.add_argument( |
| "--min-keep-size", |
| type=float, |
| default=32000, |
| help="exclude sample longer less than this.", |
| ) |
|
|
| parser.add_argument( |
| "--max-sample-size", |
| type=float, |
| default=250000, |
| help="max sample size to crop to for batching.", |
| ) |
|
|
| add_hubert_arguments(parser) |
|
|
| return parser |
|
|
|
|
| def get_params() -> AttributeDict: |
| """Return a dict containing training parameters. |
| |
| All training related parameters that are not passed from the commandline |
| are saved in the variable `params`. |
| |
| Commandline options are merged into `params` after they are parsed, so |
| you can also access them via `params`. |
| |
| Explanation of options saved in `params`: |
| |
| - best_train_loss: Best training loss so far. It is used to select |
| the model that has the lowest training loss. It is |
| updated during the training. |
| |
| - best_valid_loss: Best validation loss so far. It is used to select |
| the model that has the lowest validation loss. It is |
| updated during the training. |
| |
| - best_train_epoch: It is the epoch that has the best training loss. |
| |
| - best_valid_epoch: It is the epoch that has the best validation loss. |
| |
| - batch_idx_train: Used to writing statistics to tensorboard. It |
| contains number of updates happen to the model so far across |
| epochs. |
| |
| - sub_batch_idx_train: It contains number of batch trained so far across |
| epochs. |
| |
| - log_interval: Print training loss if batch_idx % log_interval` is 0 |
| |
| - reset_interval: Reset statistics if batch_idx % reset_interval is 0 |
| |
| - valid_interval: Run validation if batch_idx % valid_interval is 0 |
| """ |
| params = AttributeDict( |
| { |
| "best_train_loss": float("inf"), |
| "best_valid_loss": float("inf"), |
| "best_train_epoch": -1, |
| "best_valid_epoch": -1, |
| "batch_idx_train": 0, |
| "sub_batch_idx_train": 0, |
| "log_interval": 50, |
| "reset_interval": 200, |
| "valid_interval": 3000, |
| "env_info": get_env_info(), |
| } |
| ) |
|
|
| return params |
|
|
|
|
| def _to_int_tuple(s: str): |
| return tuple(map(int, s.split(","))) |
|
|
|
|
| def get_model(params: AttributeDict) -> nn.Module: |
| model = HubertModel(params) |
| return model |
|
|
|
|
| def load_checkpoint_if_available( |
| params: AttributeDict, |
| model: nn.Module, |
| model_avg: nn.Module = None, |
| optimizer: Optional[torch.optim.Optimizer] = None, |
| scheduler: Optional[LRSchedulerType] = None, |
| ) -> Optional[Dict[str, Any]]: |
| """Load checkpoint from file. |
| |
| If params.start_batch is positive, it will load the checkpoint from |
| `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if |
| params.start_epoch is larger than 1, it will load the checkpoint from |
| `params.start_epoch - 1`. |
| |
| Apart from loading state dict for `model` and `optimizer` it also updates |
| `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, |
| and `best_valid_loss` in `params`. |
| |
| Args: |
| params: |
| The return value of :func:`get_params`. |
| model: |
| The training model. |
| model_avg: |
| The stored model averaged from the start of training. |
| optimizer: |
| The optimizer that we are using. |
| scheduler: |
| The scheduler that we are using. |
| Returns: |
| Return a dict containing previously saved training info. |
| """ |
| if params.start_batch > 0: |
| filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" |
| elif params.start_epoch > 1: |
| filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" |
| else: |
| return None |
|
|
| assert filename.is_file(), f"{filename} does not exist!" |
|
|
| saved_params = load_checkpoint( |
| filename, |
| model=model, |
| model_avg=model_avg, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| ) |
|
|
| keys = [ |
| "best_train_epoch", |
| "best_valid_epoch", |
| "batch_idx_train", |
| "best_train_loss", |
| "best_valid_loss", |
| ] |
| for k in keys: |
| params[k] = saved_params[k] |
|
|
| if params.start_batch > 0: |
| if "cur_epoch" in saved_params: |
| params["start_epoch"] = saved_params["cur_epoch"] |
|
|
| return saved_params |
|
|
|
|
| def save_checkpoint( |
| params: AttributeDict, |
| model: Union[nn.Module, DDP], |
| model_avg: Optional[nn.Module] = None, |
| optimizer: Optional[torch.optim.Optimizer] = None, |
| scheduler: Optional[LRSchedulerType] = None, |
| sampler: Optional[CutSampler] = None, |
| scaler: Optional[GradScaler] = None, |
| rank: int = 0, |
| ) -> None: |
| """Save model, optimizer, scheduler and training stats to file. |
| |
| Args: |
| params: |
| It is returned by :func:`get_params`. |
| model: |
| The training model. |
| model_avg: |
| The stored model averaged from the start of training. |
| optimizer: |
| The optimizer used in the training. |
| sampler: |
| The sampler for the training dataset. |
| scaler: |
| The scaler used for mix precision training. |
| """ |
| if rank != 0: |
| return |
| filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" |
| save_checkpoint_impl( |
| filename=filename, |
| model=model, |
| model_avg=model_avg, |
| params=params, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| sampler=sampler, |
| scaler=scaler, |
| rank=rank, |
| ) |
|
|
| if params.best_train_epoch == params.cur_epoch: |
| best_train_filename = params.exp_dir / "best-train-loss.pt" |
| copyfile(src=filename, dst=best_train_filename) |
|
|
| if params.best_valid_epoch == params.cur_epoch: |
| best_valid_filename = params.exp_dir / "best-valid-loss.pt" |
| copyfile(src=filename, dst=best_valid_filename) |
|
|
|
|
| def compute_loss( |
| params: AttributeDict, |
| model: Union[nn.Module, DDP], |
| batch: dict, |
| is_training: bool, |
| ) -> Tuple[Tensor, MetricsTracker]: |
| """ |
| Compute loss given the model and its inputs. |
| |
| Args: |
| params: |
| Parameters for training. See :func:`get_params`. |
| model: |
| The model for training. It is an instance of Zipformer in our case. |
| batch: |
| A batch of data. See `dataset.HubertDataset()` |
| for the content in it. |
| is_training: |
| True for training. False for validation. When it is True, this |
| function enables autograd during computation; when it is False, it |
| disables autograd. |
| """ |
| device = model.device if isinstance(model, DDP) else next(model.parameters()).device |
| audio = batch["audio"].to(device) |
| padding_mask = batch["padding_mask"].to(device) |
| kmeans = batch["kmeans"].to(device) |
|
|
| with torch.set_grad_enabled(is_training): |
| loss, num_masked_tokens, logging_output = model( |
| source=audio, target_list=[kmeans], padding_mask=padding_mask |
| ) |
|
|
| assert loss.requires_grad == is_training |
|
|
| info = MetricsTracker() |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| info["frames"] = num_masked_tokens |
| for item in logging_output: |
| info[item] = logging_output[item] |
| return loss, info |
|
|
|
|
| def compute_validation_loss( |
| params: AttributeDict, |
| model: Union[nn.Module, DDP], |
| valid_dl: torch.utils.data.DataLoader, |
| world_size: int = 1, |
| ) -> MetricsTracker: |
| """Run the validation process.""" |
| model.eval() |
|
|
| tot_loss = MetricsTracker() |
|
|
| for batch_idx, batch in enumerate(valid_dl): |
| loss, loss_info = compute_loss( |
| params=params, |
| model=model, |
| batch=batch, |
| is_training=False, |
| ) |
| assert loss.requires_grad is False |
| tot_loss = tot_loss + loss_info |
|
|
| if world_size > 1: |
| tot_loss.reduce(loss.device) |
|
|
| loss_value = tot_loss["loss"] / tot_loss["frames"] |
| if loss_value < params.best_valid_loss: |
| params.best_valid_epoch = params.cur_epoch |
| params.best_valid_loss = loss_value |
|
|
| return tot_loss |
|
|
|
|
| def train_one_epoch( |
| params: AttributeDict, |
| model: Union[nn.Module, DDP], |
| optimizer: torch.optim.Optimizer, |
| scheduler: LRSchedulerType, |
| train_dl: torch.utils.data.DataLoader, |
| valid_dl: torch.utils.data.DataLoader, |
| scaler: GradScaler, |
| model_avg: Optional[nn.Module] = None, |
| tb_writer: Optional[SummaryWriter] = None, |
| world_size: int = 1, |
| rank: int = 0, |
| ) -> None: |
| """Train the model for one epoch. |
| |
| The training loss from the mean of all frames is saved in |
| `params.train_loss`. It runs the validation process every |
| `params.valid_interval` batches. |
| |
| Args: |
| params: |
| It is returned by :func:`get_params`. |
| model: |
| The model for training. |
| optimizer: |
| The optimizer we are using. |
| scheduler: |
| The learning rate scheduler, we call step() every step. |
| train_dl: |
| Dataloader for the training dataset. |
| valid_dl: |
| Dataloader for the validation dataset. |
| scaler: |
| The scaler used for mix precision training. |
| model_avg: |
| The stored model averaged from the start of training. |
| tb_writer: |
| Writer to write log messages to tensorboard. |
| world_size: |
| Number of nodes in DDP training. If it is 1, DDP is disabled. |
| rank: |
| The rank of the node in DDP training. If no DDP is used, it should |
| be set to 0. |
| """ |
| model.train() |
|
|
| tot_loss = MetricsTracker() |
|
|
| saved_bad_model = False |
|
|
| def save_bad_model(suffix: str = ""): |
| save_checkpoint_impl( |
| filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", |
| model=model, |
| model_avg=model_avg, |
| params=params, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| sampler=train_dl.sampler, |
| scaler=scaler, |
| rank=0, |
| ) |
|
|
| for sub_batch_idx, batch in enumerate(train_dl): |
| params.sub_batch_idx_train += 1 |
| batch_idx = sub_batch_idx // params.accum_grad |
|
|
| if batch_idx % 10 == 0: |
| set_batch_count(model, get_adjusted_batch_count(params)) |
|
|
| batch_size = batch["kmeans"].shape[0] |
|
|
| try: |
| with torch_autocast(enabled=params.use_fp16): |
| loss, loss_info = compute_loss( |
| params=params, |
| model=model, |
| batch=batch, |
| is_training=True, |
| ) |
| |
| tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info |
|
|
| |
| |
| scaler.scale(loss / params.accum_grad).backward() |
|
|
| if sub_batch_idx % params.accum_grad == params.accum_grad - 1: |
| params.batch_idx_train += 1 |
| scheduler.step_batch(params.batch_idx_train) |
|
|
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
| else: |
| continue |
|
|
| except: |
| save_bad_model() |
| display_and_save_batch(batch, params=params) |
| raise |
|
|
| if params.print_diagnostics and batch_idx == 5: |
| return |
|
|
| if ( |
| rank == 0 |
| and params.batch_idx_train > 0 |
| and params.batch_idx_train % params.average_period == 0 |
| ): |
| update_averaged_model( |
| params=params, |
| model_cur=model, |
| model_avg=model_avg, |
| ) |
|
|
| if ( |
| params.batch_idx_train > 0 |
| and params.batch_idx_train % params.save_every_n == 0 |
| ): |
| save_checkpoint_with_global_batch_idx( |
| out_dir=params.exp_dir, |
| global_batch_idx=params.batch_idx_train, |
| model=model, |
| model_avg=model_avg, |
| params=params, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| sampler=train_dl.sampler, |
| scaler=scaler, |
| rank=rank, |
| ) |
| remove_checkpoints( |
| out_dir=params.exp_dir, |
| topk=params.keep_last_k, |
| rank=rank, |
| ) |
|
|
| if batch_idx % 100 == 0 and params.use_fp16: |
| |
| |
| |
| cur_grad_scale = scaler._scale.item() |
|
|
| if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): |
| scaler.update(cur_grad_scale * 2.0) |
| if cur_grad_scale < 0.01: |
| if not saved_bad_model: |
| save_bad_model(suffix="-first-warning") |
| saved_bad_model = True |
| logging.warning(f"Grad scale is small: {cur_grad_scale}") |
| if cur_grad_scale < 1.0e-05: |
| save_bad_model() |
| raise RuntimeError( |
| f"grad_scale is too small, exiting: {cur_grad_scale}" |
| ) |
|
|
| if batch_idx % params.log_interval == 0: |
| cur_lr = max(scheduler.get_last_lr()) |
| cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 |
|
|
| logging.info( |
| f"Epoch {params.cur_epoch}, " |
| f"batch {batch_idx}, loss[{loss_info}], " |
| f"tot_loss[{tot_loss}], batch size: {batch_size}, " |
| f"lr: {cur_lr:.2e}, " |
| + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") |
| ) |
|
|
| if tb_writer is not None: |
| tb_writer.add_scalar( |
| "train/learning_rate", cur_lr, params.batch_idx_train |
| ) |
|
|
| loss_info.write_summary( |
| tb_writer, "train/current_", params.batch_idx_train |
| ) |
| tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) |
| if params.use_fp16: |
| tb_writer.add_scalar( |
| "train/grad_scale", |
| cur_grad_scale, |
| params.batch_idx_train, |
| ) |
|
|
| if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: |
| logging.info("Computing validation loss") |
| valid_info = compute_validation_loss( |
| params=params, |
| model=model, |
| valid_dl=valid_dl, |
| world_size=world_size, |
| ) |
| model.train() |
| logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") |
| logging.info( |
| f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" |
| ) |
| if tb_writer is not None: |
| valid_info.write_summary( |
| tb_writer, "train/valid_", params.batch_idx_train |
| ) |
|
|
| if sub_batch_idx % params.accum_grad != params.accum_grad - 1: |
| optimizer.zero_grad() |
| loss_value = tot_loss["loss"] / tot_loss["frames"] |
| params.train_loss = loss_value |
| if params.train_loss < params.best_train_loss: |
| params.best_train_epoch = params.cur_epoch |
| params.best_train_loss = params.train_loss |
|
|
|
|
| def run(rank, world_size, args): |
| """ |
| Args: |
| rank: |
| It is a value between 0 and `world_size-1`, which is |
| passed automatically by `mp.spawn()` in :func:`main`. |
| The node with rank 0 is responsible for saving checkpoint. |
| world_size: |
| Number of GPUs for DDP training. |
| args: |
| The return value of get_parser().parse_args() |
| """ |
| params = get_params() |
| params.update(vars(args)) |
|
|
| fix_random_seed(params.seed) |
| if world_size > 1: |
| setup_dist(rank, world_size, params.master_port) |
|
|
| setup_logger(f"{params.exp_dir}/log/log-train") |
| logging.info("Training started") |
|
|
| if args.tensorboard and rank == 0: |
| tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") |
| else: |
| tb_writer = None |
|
|
| device = torch.device("cpu") |
| if torch.cuda.is_available(): |
| device = torch.device("cuda", rank) |
| logging.info(f"Device: {device}") |
| logging.info(params) |
|
|
| logging.info("About to create model") |
| model = get_model(params) |
|
|
| num_param = sum([p.numel() for p in model.parameters()]) |
| logging.info(f"Number of model parameters: {num_param}") |
|
|
| assert params.save_every_n >= params.average_period |
| model_avg: Optional[nn.Module] = None |
| if rank == 0: |
| |
| model_avg = copy.deepcopy(model).to(torch.float64) |
|
|
| assert params.start_epoch > 0, params.start_epoch |
| checkpoints = load_checkpoint_if_available( |
| params=params, model=model, model_avg=model_avg |
| ) |
|
|
| model.to(device) |
| if world_size > 1: |
| logging.info("Using DDP") |
| model = DDP(model, device_ids=[rank], find_unused_parameters=True) |
|
|
| optimizer = ScaledAdam( |
| get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), |
| lr=params.base_lr, |
| clipping_scale=2.0, |
| ) |
|
|
| scheduler = Eden( |
| optimizer, |
| params.lr_batches, |
| params.lr_epochs, |
| params.warmup_batches, |
| params.warmup_start, |
| ) |
|
|
| if checkpoints and "optimizer" in checkpoints: |
| logging.info("Loading optimizer state dict") |
| optimizer.load_state_dict(checkpoints["optimizer"]) |
|
|
| if ( |
| checkpoints |
| and "scheduler" in checkpoints |
| and checkpoints["scheduler"] is not None |
| ): |
| logging.info("Loading scheduler state dict") |
| scheduler.load_state_dict(checkpoints["scheduler"]) |
|
|
| if params.print_diagnostics: |
| opts = diagnostics.TensorDiagnosticOptions( |
| 512 |
| ) |
| diagnostic = diagnostics.attach_diagnostics(model, opts) |
|
|
| if params.inf_check: |
| register_inf_check_hooks(model) |
|
|
| librispeech = LibriSpeechDataModule(args) |
|
|
| train_cuts = ( |
| librispeech.train_all_shuf_cuts() |
| if params.full_libri |
| else librispeech.train_clean_100_cuts() |
| ) |
|
|
| def remove_short_and_long_utt(c: Cut): |
| |
| |
| |
| |
| |
| |
| |
| |
| if ( |
| c.duration < params.min_keep_size / params.sample_rate |
| or c.duration > params.max_keep_size / params.sample_rate |
| ): |
| logging.warning( |
| f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" |
| ) |
| return False |
|
|
| return True |
|
|
| train_cuts = train_cuts.filter(remove_short_and_long_utt) |
|
|
| if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: |
| |
| |
| sampler_state_dict = checkpoints["sampler"] |
| else: |
| sampler_state_dict = None |
|
|
| train_dl = librispeech.train_dataloaders( |
| train_cuts, |
| max_sample_size=params.max_sample_size, |
| sample_rate=params.sample_rate, |
| label_rate=params.label_rate, |
| random_crop=params.random_crop, |
| pad_audio=False, |
| num_classes=params.num_classes, |
| do_normalize=params.do_normalize, |
| sampler_state_dict=sampler_state_dict, |
| ) |
|
|
| valid_cuts = librispeech.dev_clean_cuts() |
| |
| valid_cuts = valid_cuts.filter(remove_short_and_long_utt) |
|
|
| valid_dl = librispeech.valid_dataloaders( |
| valid_cuts, |
| max_sample_size=params.max_sample_size, |
| sample_rate=params.sample_rate, |
| label_rate=params.label_rate, |
| random_crop=params.random_crop, |
| pad_audio=False, |
| num_classes=params.num_classes, |
| do_normalize=params.do_normalize, |
| ) |
|
|
| if params.sanity_check and not params.print_diagnostics: |
| scan_pessimistic_batches_for_oom( |
| model=model, |
| train_dl=train_dl, |
| optimizer=optimizer, |
| params=params, |
| ) |
|
|
| scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) |
| if checkpoints and "grad_scaler" in checkpoints: |
| logging.info("Loading grad scaler state dict") |
| scaler.load_state_dict(checkpoints["grad_scaler"]) |
|
|
| for epoch in range(params.start_epoch, params.num_epochs + 1): |
| scheduler.step_epoch(epoch - 1) |
| fix_random_seed(params.seed + epoch - 1) |
| train_dl.sampler.set_epoch(epoch - 1) |
|
|
| if tb_writer is not None: |
| tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) |
|
|
| params.cur_epoch = epoch |
|
|
| train_one_epoch( |
| params=params, |
| model=model, |
| model_avg=model_avg, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| train_dl=train_dl, |
| valid_dl=valid_dl, |
| scaler=scaler, |
| tb_writer=tb_writer, |
| world_size=world_size, |
| rank=rank, |
| ) |
|
|
| if params.print_diagnostics: |
| diagnostic.print_diagnostics() |
| break |
|
|
| save_checkpoint( |
| params=params, |
| model=model, |
| model_avg=model_avg, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| sampler=train_dl.sampler, |
| scaler=scaler, |
| rank=rank, |
| ) |
|
|
| logging.info("Done!") |
|
|
| if world_size > 1: |
| torch.distributed.barrier() |
| cleanup_dist() |
|
|
|
|
| def display_and_save_batch( |
| batch: dict, |
| params: AttributeDict, |
| ) -> None: |
| """Display the batch statistics and save the batch into disk. |
| |
| Args: |
| batch: |
| A batch of data. See `dataset.HubertDataset()` |
| for the content in it. |
| params: |
| Parameters for training. See :func:`get_params`. |
| sp: |
| The BPE model. |
| """ |
| from lhotse.utils import uuid4 |
|
|
| filename = f"{params.exp_dir}/batch-{uuid4()}.pt" |
| logging.info(f"Saving batch to {filename}") |
| torch.save(batch, filename) |
|
|
| audio = batch["audio"] |
| logging.info(f"audio shape: {audio.shape}") |
|
|
|
|
| def scan_pessimistic_batches_for_oom( |
| model: Union[nn.Module, DDP], |
| train_dl: torch.utils.data.DataLoader, |
| optimizer: torch.optim.Optimizer, |
| params: AttributeDict, |
| ): |
| from lhotse.dataset import find_pessimistic_batches |
|
|
| logging.info( |
| "Sanity check -- see if any of the batches in epoch 1 would cause OOM." |
| ) |
| batches, crit_values = find_pessimistic_batches(train_dl.sampler) |
| for criterion, cuts in batches.items(): |
| batch = train_dl.dataset[cuts] |
| try: |
| with torch_autocast(enabled=params.use_fp16): |
| loss, _ = compute_loss( |
| params=params, |
| model=model, |
| batch=batch, |
| is_training=True, |
| ) |
| loss.backward() |
| optimizer.zero_grad() |
| except Exception as e: |
| if "CUDA out of memory" in str(e): |
| logging.error( |
| "Your GPU ran out of memory with the current " |
| "max_duration setting. We recommend decreasing " |
| "max_duration and trying again.\n" |
| f"Failing criterion: {criterion} " |
| f"(={crit_values[criterion]}) ..." |
| ) |
| display_and_save_batch(batch, params=params) |
| raise |
| logging.info( |
| f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" |
| ) |
|
|
|
|
| def main(): |
| parser = get_parser() |
| LibriSpeechDataModule.add_arguments(parser) |
| args = parser.parse_args() |
| args.exp_dir = Path(args.exp_dir) |
|
|
| world_size = args.world_size |
| assert world_size >= 1 |
| if world_size > 1: |
| mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) |
| else: |
| run(rank=0, world_size=1, args=args) |
|
|
|
|
| torch.set_num_threads(1) |
| torch.set_num_interop_threads(1) |
|
|
| if __name__ == "__main__": |
| main() |
|
|