Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from torch.cuda.amp import GradScaler | |
| from torch.utils.data import Sampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.optim import Optimizer | |
| from torch import distributed as dist | |
| from torch.utils.tensorboard import SummaryWriter | |
| import logging | |
| import os | |
| import re | |
| import glob | |
| import collections | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Union, Tuple | |
| from datetime import datetime | |
| from pathlib import Path | |
| import argparse | |
| # use duck typing for LRScheduler since we have different possibilities, see | |
| # our class LRScheduler. | |
| LRSchedulerType = object | |
| Pathlike = Union[str, Path] | |
| def average_state_dict( | |
| state_dict_1: Dict[str, Tensor], | |
| state_dict_2: Dict[str, Tensor], | |
| weight_1: float, | |
| weight_2: float, | |
| scaling_factor: float = 1.0, | |
| ) -> Dict[str, Tensor]: | |
| """Average two state_dict with given weights: | |
| state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2) | |
| * scaling_factor | |
| It is an in-place operation on state_dict_1 itself. | |
| """ | |
| # Identify shared parameters. Two parameters are said to be shared | |
| # if they have the same data_ptr | |
| uniqued: Dict[int, str] = dict() | |
| for k, v in state_dict_1.items(): | |
| v_data_ptr = v.data_ptr() | |
| if v_data_ptr in uniqued: | |
| continue | |
| uniqued[v_data_ptr] = k | |
| uniqued_names = list(uniqued.values()) | |
| for k in uniqued_names: | |
| v = state_dict_1[k] | |
| if torch.is_floating_point(v): | |
| v *= weight_1 | |
| v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 | |
| v *= scaling_factor | |
| def load_checkpoint( | |
| filename: Path, | |
| model: nn.Module, | |
| model_avg: Optional[nn.Module] = None, | |
| optimizer: Optional[Optimizer] = None, | |
| scheduler: Optional[LRSchedulerType] = None, | |
| scaler: Optional[GradScaler] = None, | |
| sampler: Optional[Sampler] = None, | |
| strict: bool = False, | |
| ) -> Dict[str, Any]: | |
| """ | |
| TODO: document it | |
| """ | |
| logging.info(f"Loading checkpoint from {filename}") | |
| checkpoint = torch.load(filename, map_location="cpu") | |
| if next(iter(checkpoint["model"])).startswith("module."): | |
| logging.info("Loading checkpoint saved by DDP") | |
| dst_state_dict = model.state_dict() | |
| src_state_dict = checkpoint["model"] | |
| for key in dst_state_dict.keys(): | |
| src_key = "{}.{}".format("module", key) | |
| dst_state_dict[key] = src_state_dict.pop(src_key) | |
| assert len(src_state_dict) == 0 | |
| model.load_state_dict(dst_state_dict, strict=strict) | |
| else: | |
| model.load_state_dict(checkpoint["model"], strict=strict) | |
| checkpoint.pop("model") | |
| if model_avg is not None and "model_avg" in checkpoint: | |
| logging.info("Loading averaged model") | |
| model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) | |
| checkpoint.pop("model_avg") | |
| def load(name, obj): | |
| s = checkpoint.get(name, None) | |
| if obj and s: | |
| obj.load_state_dict(s) | |
| checkpoint.pop(name) | |
| load("optimizer", optimizer) | |
| load("scheduler", scheduler) | |
| load("grad_scaler", scaler) | |
| load("sampler", sampler) | |
| return checkpoint | |
| def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: | |
| """Find all available checkpoints in a directory. | |
| The checkpoint filenames have the form: `checkpoint-xxx.pt` | |
| where xxx is a numerical value. | |
| Assume you have the following checkpoints in the folder `foo`: | |
| - checkpoint-1.pt | |
| - checkpoint-20.pt | |
| - checkpoint-300.pt | |
| - checkpoint-4000.pt | |
| Case 1 (Return all checkpoints):: | |
| find_checkpoints(out_dir='foo') | |
| Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e., | |
| checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt) | |
| find_checkpoints(out_dir='foo', iteration=20) | |
| Case 3 (Return checkpoints older than checkpoint-20.pt, i.e., | |
| checkpoint-20.pt, checkpoint-1.pt):: | |
| find_checkpoints(out_dir='foo', iteration=-20) | |
| Args: | |
| out_dir: | |
| The directory where to search for checkpoints. | |
| iteration: | |
| If it is 0, return all available checkpoints. | |
| If it is positive, return the checkpoints whose iteration number is | |
| greater than or equal to `iteration`. | |
| If it is negative, return the checkpoints whose iteration number is | |
| less than or equal to `-iteration`. | |
| Returns: | |
| Return a list of checkpoint filenames, sorted in descending | |
| order by the numerical value in the filename. | |
| """ | |
| checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) | |
| pattern = re.compile(r"checkpoint-([0-9]+).pt") | |
| iter_checkpoints = [] | |
| for c in checkpoints: | |
| result = pattern.search(c) | |
| if not result: | |
| logging.warn(f"Invalid checkpoint filename {c}") | |
| continue | |
| iter_checkpoints.append((int(result.group(1)), c)) | |
| # iter_checkpoints is a list of tuples. Each tuple contains | |
| # two elements: (iteration_number, checkpoint-iteration_number.pt) | |
| iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0]) | |
| if iteration >= 0: | |
| ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration] | |
| else: | |
| ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration] | |
| return ans | |
| def remove_checkpoints( | |
| out_dir: Path, | |
| topk: int, | |
| ): | |
| """Remove checkpoints from the given directory. | |
| We assume that checkpoint filename has the form `checkpoint-xxx.pt` | |
| where xxx is a number, representing the number of processed batches | |
| when saving that checkpoint. We sort checkpoints by filename and keep | |
| only the `topk` checkpoints with the highest `xxx`. | |
| Args: | |
| out_dir: | |
| The directory containing checkpoints to be removed. | |
| topk: | |
| Number of checkpoints to keep. | |
| rank: | |
| If using DDP for training, it is the rank of the current node. | |
| Use 0 if no DDP is used for training. | |
| """ | |
| assert topk >= 1, topk | |
| checkpoints = find_checkpoints(out_dir) | |
| if len(checkpoints) == 0: | |
| logging.warn(f"No checkpoints found in {out_dir}") | |
| return | |
| if len(checkpoints) <= topk: | |
| return | |
| to_remove = checkpoints[topk:] | |
| for c in to_remove: | |
| os.remove(c) | |
| def save_checkpoint_impl( | |
| filename: Path, | |
| model: Union[nn.Module, DDP], | |
| model_avg: Optional[nn.Module] = None, | |
| params: Optional[Dict[str, Any]] = None, | |
| optimizer: Optional[Optimizer] = None, | |
| scheduler: Optional[LRSchedulerType] = None, | |
| scaler: Optional[GradScaler] = None, | |
| sampler = None, | |
| rank: int = 0, | |
| ) -> None: | |
| """Save training information to a file. | |
| Args: | |
| filename: | |
| The checkpoint filename. | |
| model: | |
| The model to be saved. We only save its `state_dict()`. | |
| model_avg: | |
| The stored model averaged from the start of training. | |
| params: | |
| User defined parameters, e.g., epoch, loss. | |
| optimizer: | |
| The optimizer to be saved. We only save its `state_dict()`. | |
| scheduler: | |
| The scheduler to be saved. We only save its `state_dict()`. | |
| scalar: | |
| The GradScaler to be saved. We only save its `state_dict()`. | |
| rank: | |
| Used in DDP. We save checkpoint only for the node whose rank is 0. | |
| Returns: | |
| Return None. | |
| """ | |
| if rank != 0: | |
| return | |
| logging.info(f"Saving checkpoint to {filename}") | |
| if isinstance(model, DDP): | |
| model = model.module | |
| checkpoint = { | |
| "model": model.state_dict(), | |
| "optimizer": optimizer.state_dict() if optimizer is not None else None, | |
| "scheduler": scheduler.state_dict() if scheduler is not None else None, | |
| "grad_scaler": scaler.state_dict() if scaler is not None else None, | |
| "sampler": sampler.state_dict() if sampler is not None else None, | |
| } | |
| if model_avg is not None: | |
| checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() | |
| if params: | |
| for k, v in params.items(): | |
| assert k not in checkpoint | |
| checkpoint[k] = v | |
| torch.save(checkpoint, filename) | |
| def save_checkpoint_with_global_batch_idx( | |
| out_dir: Path, | |
| global_batch_idx: int, | |
| model: nn.Module, | |
| model_avg: Optional[nn.Module] = None, | |
| params: Optional[Dict[str, Any]] = None, | |
| optimizer: Optional[Optimizer] = None, | |
| scheduler: Optional[LRSchedulerType] = None, | |
| scaler: Optional[GradScaler] = None, | |
| sampler: Optional[Sampler] = None, | |
| rank: int = 0, | |
| ): | |
| """Save training info after processing given number of batches. | |
| Args: | |
| out_dir: | |
| The directory to save the checkpoint. | |
| global_batch_idx: | |
| The number of batches processed so far from the very start of the | |
| training. The saved checkpoint will have the following filename: | |
| f'out_dir / checkpoint-{global_batch_idx}.pt' | |
| model: | |
| The neural network model whose `state_dict` will be saved in the | |
| checkpoint. | |
| model_avg: | |
| The stored model averaged from the start of training. | |
| params: | |
| A dict of training configurations to be saved. | |
| optimizer: | |
| The optimizer used in the training. Its `state_dict` will be saved. | |
| scheduler: | |
| The learning rate scheduler used in the training. Its `state_dict` will | |
| be saved. | |
| scaler: | |
| The scaler used for mix precision training. Its `state_dict` will | |
| be saved. | |
| sampler: | |
| The sampler used in the training dataset. | |
| rank: | |
| The rank ID used in DDP training of the current node. Set it to 0 | |
| if DDP is not used. | |
| """ | |
| out_dir = Path(out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| print(f"Saving model checkpoint with global batch IDX is {global_batch_idx}") | |
| filename = out_dir / "checkpoint-global-batch.pt" | |
| save_checkpoint_impl( | |
| filename=filename, | |
| model=model, | |
| model_avg=model_avg, | |
| params=params, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| scaler=scaler, | |
| sampler=sampler, | |
| rank=rank, | |
| ) | |
| def update_averaged_model( | |
| params: Dict[str, Tensor], | |
| model_cur: nn.Module, | |
| model_avg: nn.Module, | |
| ) -> None: | |
| """Update the averaged model: | |
| model_avg = model_cur * (average_period / batch_idx_train) | |
| + model_avg * ((batch_idx_train - average_period) / batch_idx_train) | |
| Args: | |
| params: | |
| User defined parameters, e.g., epoch, loss. | |
| model_cur: | |
| The current model. | |
| model_avg: | |
| The averaged model to be updated. | |
| """ | |
| weight_cur = params.average_period / params.batch_idx_train | |
| weight_avg = 1 - weight_cur | |
| cur = model_cur.state_dict() | |
| avg = model_avg.state_dict() | |
| average_state_dict( | |
| state_dict_1=avg, | |
| state_dict_2=cur, | |
| weight_1=weight_avg, | |
| weight_2=weight_cur, | |
| ) | |
| def cleanup_dist(): | |
| dist.destroy_process_group() | |
| def setup_dist( | |
| rank, world_size, master_port=None, use_ddp_launch=False, master_addr=None | |
| ): | |
| """ | |
| rank and world_size are used only if use_ddp_launch is False. | |
| """ | |
| if "MASTER_ADDR" not in os.environ: | |
| os.environ["MASTER_ADDR"] = ( | |
| "localhost" if master_addr is None else str(master_addr) | |
| ) | |
| if "MASTER_PORT" not in os.environ: | |
| os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port) | |
| if use_ddp_launch is False: | |
| dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
| torch.cuda.set_device(rank) | |
| else: | |
| dist.init_process_group("nccl") | |
| def register_inf_check_hooks(model: nn.Module) -> None: | |
| """Registering forward hook on each module, to check | |
| whether its output tensors is not finite. | |
| Args: | |
| model: | |
| the model to be analyzed. | |
| """ | |
| for name, module in model.named_modules(): | |
| if name == "": | |
| name = "<top-level>" | |
| # default param _name is a way to capture the current value of the variable "name". | |
| def forward_hook(_module, _input, _output, _name=name): | |
| if isinstance(_output, Tensor): | |
| if not torch.isfinite(_output.to(torch.float32).sum()): | |
| raise ValueError( | |
| f"The sum of {_name}.output is not finite: {_output}" | |
| ) | |
| elif isinstance(_output, tuple): | |
| for i, o in enumerate(_output): | |
| if isinstance(o, tuple): | |
| o = o[0] | |
| if not isinstance(o, Tensor): | |
| continue | |
| if not torch.isfinite(o.to(torch.float32).sum()): | |
| raise ValueError( | |
| f"The sum of {_name}.output[{i}] is not finite: {_output}" | |
| ) | |
| # default param _name is a way to capture the current value of the variable "name". | |
| def backward_hook(_module, _input, _output, _name=name): | |
| if isinstance(_output, Tensor): | |
| if not torch.isfinite(_output.to(torch.float32).sum()): | |
| logging.warning( | |
| f"The sum of {_name}.grad is not finite" # ": {_output}" | |
| ) | |
| elif isinstance(_output, tuple): | |
| for i, o in enumerate(_output): | |
| if isinstance(o, tuple): | |
| o = o[0] | |
| if not isinstance(o, Tensor): | |
| continue | |
| if not torch.isfinite(o.to(torch.float32).sum()): | |
| logging.warning(f"The sum of {_name}.grad[{i}] is not finite") | |
| module.register_forward_hook(forward_hook) | |
| module.register_backward_hook(backward_hook) | |
| for name, parameter in model.named_parameters(): | |
| def param_backward_hook(grad, _name=name): | |
| if not torch.isfinite(grad.to(torch.float32).sum()): | |
| logging.warning(f"The sum of {_name}.param_grad is not finite") | |
| parameter.register_hook(param_backward_hook) | |
| class AttributeDict(dict): | |
| def __getattr__(self, key): | |
| if key in self: | |
| return self[key] | |
| raise AttributeError(f"No such attribute '{key}'") | |
| def __setattr__(self, key, value): | |
| self[key] = value | |
| def __delattr__(self, key): | |
| if key in self: | |
| del self[key] | |
| return | |
| raise AttributeError(f"No such attribute '{key}'") | |
| class MetricsTracker(collections.defaultdict): | |
| def __init__(self): | |
| # Passing the type 'int' to the base-class constructor | |
| # makes undefined items default to int() which is zero. | |
| # This class will play a role as metrics tracker. | |
| # It can record many metrics, including but not limited to loss. | |
| super(MetricsTracker, self).__init__(int) | |
| def __add__(self, other: "MetricsTracker") -> "MetricsTracker": | |
| ans = MetricsTracker() | |
| for k, v in self.items(): | |
| ans[k] = v | |
| for k, v in other.items(): | |
| ans[k] = ans[k] + v | |
| return ans | |
| def __mul__(self, alpha: float) -> "MetricsTracker": | |
| ans = MetricsTracker() | |
| for k, v in self.items(): | |
| ans[k] = v * alpha | |
| return ans | |
| def __str__(self) -> str: | |
| ans_frames = "" | |
| ans_utterances = "" | |
| for k, v in self.norm_items(): | |
| norm_value = "%.4g" % v | |
| if "utt_" not in k: | |
| ans_frames += str(k) + "=" + str(norm_value) + ", " | |
| else: | |
| ans_utterances += str(k) + "=" + str(norm_value) | |
| if k == "utt_duration": | |
| ans_utterances += " frames, " | |
| elif k == "utt_pad_proportion": | |
| ans_utterances += ", " | |
| else: | |
| raise ValueError(f"Unexpected key: {k}") | |
| frames = "%.2f" % self["frames"] | |
| ans_frames += "over " + str(frames) + " frames. " | |
| if ans_utterances != "": | |
| utterances = "%.2f" % self["utterances"] | |
| ans_utterances += "over " + str(utterances) + " utterances." | |
| return ans_frames + ans_utterances | |
| def norm_items(self) -> List[Tuple[str, float]]: | |
| """ | |
| Returns a list of pairs, like: | |
| [('ctc_loss', 0.1), ('att_loss', 0.07)] | |
| """ | |
| num_frames = self["frames"] if "frames" in self else 1 | |
| num_utterances = self["utterances"] if "utterances" in self else 1 | |
| ans = [] | |
| for k, v in self.items(): | |
| if k == "frames" or k == "utterances": | |
| continue | |
| norm_value = ( | |
| float(v) / num_frames if "utt_" not in k else float(v) / num_utterances | |
| ) | |
| ans.append((k, norm_value)) | |
| return ans | |
| def write_summary( | |
| self, | |
| tb_writer: SummaryWriter, | |
| prefix: str, | |
| batch_idx: int, | |
| ) -> None: | |
| """Add logging information to a TensorBoard writer. | |
| Args: | |
| tb_writer: a TensorBoard writer | |
| prefix: a prefix for the name of the loss, e.g. "train/valid_", | |
| or "train/current_" | |
| batch_idx: The current batch index, used as the x-axis of the plot. | |
| """ | |
| for k, v in self.norm_items(): | |
| tb_writer.add_scalar(prefix + k, v, batch_idx) | |
| def setup_logger( | |
| log_filename: Pathlike, | |
| log_level: str = "info", | |
| use_console: bool = True, | |
| ) -> None: | |
| """Setup log level. | |
| Args: | |
| log_filename: | |
| The filename to save the log. | |
| log_level: | |
| The log level to use, e.g., "debug", "info", "warning", "error", | |
| "critical" | |
| use_console: | |
| True to also print logs to console. | |
| """ | |
| now = datetime.now() | |
| date_time = now.strftime("%Y-%m-%d-%H-%M-%S") | |
| if dist.is_available() and dist.is_initialized(): | |
| world_size = dist.get_world_size() | |
| rank = dist.get_rank() | |
| formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa | |
| log_filename = f"{log_filename}-{date_time}-{rank}" | |
| else: | |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | |
| log_filename = f"{log_filename}-{date_time}" | |
| os.makedirs(os.path.dirname(log_filename), exist_ok=True) | |
| level = logging.ERROR | |
| if log_level == "debug": | |
| level = logging.DEBUG | |
| elif log_level == "info": | |
| level = logging.INFO | |
| elif log_level == "warning": | |
| level = logging.WARNING | |
| elif log_level == "critical": | |
| level = logging.CRITICAL | |
| logging.basicConfig( | |
| filename=log_filename, | |
| format=formatter, | |
| level=level, | |
| filemode="w", | |
| ) | |
| if use_console: | |
| console = logging.StreamHandler() | |
| console.setLevel(level) | |
| console.setFormatter(logging.Formatter(formatter)) | |
| logging.getLogger("").addHandler(console) | |
| def str2bool(v): | |
| """Used in argparse.ArgumentParser.add_argument to indicate | |
| that a type is a bool type and user can enter | |
| - yes, true, t, y, 1, to represent True | |
| - no, false, f, n, 0, to represent False | |
| See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa | |
| """ | |
| if isinstance(v, bool): | |
| return v | |
| if v.lower() in ("yes", "true", "t", "y", "1"): | |
| return True | |
| elif v.lower() in ("no", "false", "f", "n", "0"): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError("Boolean value expected.") |