Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # | |
| from collections.abc import MutableMapping | |
| from dataclasses import dataclass, field | |
| from functools import partial | |
| from pathlib import Path | |
| from typing import ( | |
| Any, | |
| Callable, | |
| Dict, | |
| List, | |
| Mapping, | |
| Optional, | |
| Sequence, | |
| Set, | |
| Tuple, | |
| Union, | |
| ) | |
| import torch | |
| from fairseq2.gang import Gang | |
| from fairseq2.logging import get_log_writer | |
| from fairseq2.metrics import ( | |
| MetricBag, | |
| format_as_float, | |
| format_as_int, | |
| format_as_seconds, | |
| ) | |
| from fairseq2.metrics.recorder import ( | |
| MetricRecorder, | |
| _metric_formatters, | |
| register_metric_formatter, | |
| ) | |
| from fairseq2.typing import override | |
| from torch import Tensor | |
| from torch.cuda import _get_device_index | |
| from torcheval.metrics import Max, Mean, Sum, Throughput | |
| logger = get_log_writer(__name__) | |
| format_as_percent = partial(format_as_int, postfix="%") | |
| def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> Dict: | |
| """ | |
| A helper function to flatten nested dictionaries | |
| Example. With a training config like | |
| config = { | |
| 'data': { | |
| 'training': {'batch_size': 10}, | |
| 'validation': {'batch_size': 2} | |
| }, | |
| 'model': {'model_dim': 1024}, | |
| 'use_fsdp': True | |
| } | |
| The flat config will be: | |
| { | |
| 'data.training.batch_size': 10, | |
| 'data.validation.batch_size': 2, | |
| 'model.model_dim': 1024, | |
| 'use_fsdp': True | |
| } | |
| This helper is used to convert our nested training config into a flat | |
| dictionary for Tensoarboard's HParams conusmption | |
| """ | |
| items: List = [] | |
| for k, v in d.items(): | |
| new_key = parent_key + sep + k if parent_key else k | |
| if isinstance(v, MutableMapping): | |
| items.extend(flatten_dict(v, new_key, sep=sep).items()) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| def get_allocated_gpu_memory(device): | |
| """ | |
| Get allocated memory in GiB for GPU devices | |
| """ | |
| if device.type == "cpu": | |
| return 0, 0 | |
| device = _get_device_index(device, optional=True) | |
| memory_stats = torch.cuda.memory_stats(device=device) | |
| current_usage = memory_stats["allocated_bytes.all.current"] / (1024**3) | |
| peak_usage = memory_stats["allocated_bytes.all.peak"] / (1024**3) | |
| return current_usage, peak_usage | |
| class LossTerm: | |
| """Dataclass for a batch loss term""" | |
| value: Tensor | |
| """The final loss to be optimized""" | |
| batch_size: int | |
| num_target_elements: Union[int, float] | |
| summands: Dict[str, Tuple[Any, Any]] = field(default_factory=lambda: {}) | |
| """A dictionary of loss terms to record. Each term is a tuple of (loss, number of elements) | |
| The second term is optional; if None, we will use `num_target_elements` when aggregating""" | |
| class LCMMetricBag(MetricBag): | |
| """Holds the common metrics of an LCM.""" | |
| loss: Mean | |
| batch_size: Sum | |
| elements_per_batch: Mean | |
| elements_per_second: Throughput | |
| num_target_elements: Sum | |
| total_num_target_elements: Sum | |
| grad_norm: Mean | |
| def __init__( | |
| self, gang: Gang, loss_summands: Sequence[str] = [], reduction: str = "sum" | |
| ) -> None: | |
| """ | |
| :param gang: | |
| The gang to sync metrics across all processes. | |
| """ | |
| super().__init__(gang) | |
| # temporary fix: | |
| self.reduction = reduction | |
| d = gang.device | |
| # A temporary solution to track as many loss terms as we explore | |
| self.loss_summands = loss_summands | |
| self.register_metric("loss", Mean(device=d), persistent=False) | |
| # this is the effective batch size | |
| self.register_metric("batch_size", Sum(device=d), persistent=False) | |
| self.register_metric("elements_per_batch", Mean(device=d), persistent=False) | |
| self.register_metric( | |
| "elements_per_second", Throughput(device=d), persistent=False | |
| ) | |
| self.register_metric("gpu_memory_usage", Max(device=d), persistent=False) | |
| self.register_metric("gpu_peak_memory_usage", Max(device=d), persistent=False) | |
| # self.register_metric("ram_percentage", Max(device=d), persistent=False) | |
| # self.register_metric("cpu_percentage", Max(device=d), persistent=False) | |
| for summand in self.loss_summands: | |
| self.register_metric(summand, Mean(device=d), persistent=False) | |
| # The number of target tokens in a parallel batch. Used for computing throughput | |
| self.register_metric("num_target_elements", Sum(device=d), persistent=False) | |
| # The total_num_target_elements is persistent and is supposed to track the | |
| # total number of tokens consumed since training started | |
| self.total_num_target_elements = Sum(device=d) | |
| def register_adaln_metric(self, module_name: str): | |
| for block in ["mha", "ffn"]: | |
| for tensor in [ | |
| "shift", | |
| "scale", | |
| "gate", | |
| ]: | |
| self.register_metric( | |
| f"{module_name}_{block}_{tensor}_mean", | |
| Mean(device=self._gang.device), | |
| persistent=False, | |
| ) | |
| self.register_metric( | |
| f"{module_name}_{block}_{tensor}_std", | |
| Mean(device=self._gang.device), | |
| persistent=False, | |
| ) | |
| # formatters | |
| register_metric_formatter( | |
| f"{module_name}_{block}_{tensor}_mean", | |
| f"{module_name}_{block}_{tensor}_mean", | |
| 1000, | |
| format_as_float, | |
| ) | |
| register_metric_formatter( | |
| f"{module_name}_{block}_{tensor}_std", | |
| f"{module_name}_{block}_{tensor}_std", | |
| 1000, | |
| format_as_float, | |
| ) | |
| def register_module_metric(self, module_name: str): | |
| for tensor in [ | |
| "input_gradient", | |
| "output_gradient", | |
| "input_activations", | |
| "output_activations", | |
| ]: | |
| self.register_metric( | |
| f"{module_name}_{tensor}_mean", | |
| Mean(device=self._gang.device), | |
| persistent=False, | |
| ) | |
| self.register_metric( | |
| f"{module_name}_{tensor}_std", | |
| Mean(device=self._gang.device), | |
| persistent=False, | |
| ) | |
| # formatters | |
| register_metric_formatter( | |
| f"{module_name}_{tensor}_mean", | |
| f"{module_name}_{tensor}_mean", | |
| 1000, | |
| format_as_float, | |
| ) | |
| register_metric_formatter( | |
| f"{module_name}_{tensor}_std", | |
| f"{module_name}_{tensor}_std", | |
| 1000, | |
| format_as_float, | |
| ) | |
| def update( | |
| self, | |
| losses: Sequence[LossTerm], | |
| ) -> None: | |
| """Update the metrics. | |
| :param output: | |
| The losses generated by the model for each batch | |
| :param elapsed_time: | |
| The total elapsed time to read and process batches | |
| """ | |
| loss = torch.zeros((), dtype=torch.float64) | |
| loss_summands = { | |
| s: torch.zeros((), dtype=torch.float64) for s in self.loss_summands | |
| } | |
| # Denominator to normalize the loss summands, if -1, | |
| # we will default to normalizing with `num_target_elements` | |
| loss_summands_numel = { | |
| s: -torch.ones((), dtype=torch.long) for s in self.loss_summands | |
| } | |
| batch_size = torch.zeros((), dtype=torch.int64) | |
| num_target_elements = torch.zeros((), dtype=torch.int64) | |
| # Only in the case of using gradient accumulation that `losses` will be a non-singleton | |
| for batch_loss in losses: | |
| loss += float(batch_loss.value) | |
| for s in self.loss_summands: | |
| loss_term = batch_loss.summands.get(s, (0.0, None)) | |
| loss_summands[s] += float(loss_term[0]) | |
| if loss_term[1] is not None and not loss_term[1] == -1: | |
| if loss_summands_numel[s] == -1: | |
| loss_summands_numel[s] = torch.zeros((), dtype=torch.int64) | |
| loss_summands_numel[s] += loss_term[1] | |
| batch_size += batch_loss.batch_size | |
| num_target_elements += batch_loss.num_target_elements | |
| # Misleading normalization in the metric bag with reduction == "mean" | |
| # Kept here for backward compatibility | |
| # Any normalization here is only for reporting and doesn't impact optimization | |
| if self.reduction == "sum": | |
| loss /= num_target_elements | |
| keys = list(loss_summands) | |
| for k in keys: | |
| denom = loss_summands_numel[k] | |
| if denom == -1: | |
| denom = num_target_elements | |
| loss_summands[k] /= denom + 1e-6 | |
| self.loss.update(loss, weight=num_target_elements) | |
| for s in loss_summands: | |
| weight = loss_summands_numel[s] | |
| if weight == -1: | |
| weight = num_target_elements | |
| getattr(self, s).update(loss_summands[s], weight=weight) | |
| self.batch_size.update(batch_size) | |
| self.elements_per_batch.update(num_target_elements) | |
| self.num_target_elements.update(num_target_elements) | |
| # update the cumulative metric | |
| self.total_num_target_elements.update(num_target_elements) | |
| # Get GPU memory usage | |
| gpu_memory_usage, gpu_peak_memory_usage = get_allocated_gpu_memory( | |
| self._gang.device | |
| ) | |
| self.gpu_memory_usage.update(torch.tensor(gpu_memory_usage)) | |
| self.gpu_peak_memory_usage.update(torch.tensor(gpu_peak_memory_usage)) | |
| def reset_batch_metrics(self) -> None: | |
| """Reset the batch metrics to their initial state.""" | |
| self.loss.reset() | |
| for s in self.loss_summands: | |
| getattr(self, s).reset() | |
| self.batch_size.reset() | |
| self.elements_per_batch.reset() | |
| self.elements_per_second.reset() | |
| self.grad_norm.reset() | |
| self.gpu_memory_usage.reset() | |
| self.gpu_peak_memory_usage.reset() | |
| # self.ram_percentage.reset() | |
| # self.cpu_percentage.reset() | |
| ## Weight and Biases recorder | |
| try: | |
| import wandb # type: ignore[import-not-found] | |
| except ImportError: | |
| has_wandb = False | |
| else: | |
| has_wandb = True | |
| class LCMWandBRecorder(MetricRecorder): | |
| """Records metric values to Weights & Biases.""" | |
| defined_runs: Set[str] = set() | |
| def __init__( | |
| self, | |
| project: Optional[str] = None, | |
| name: Optional[str] = None, | |
| output_dir: Optional[Path] = None, | |
| config: Dict[str, Any] = {}, | |
| **kwargs, | |
| ) -> None: | |
| """ | |
| :param project: A project to organise this run with other experiments, if none, the run will go under `uncategorized`. | |
| :param name: A unique name for your run, if none is given, a random name will be generated | |
| :param output_dir: The base directory under which to store the W&B files. You don't have to provide this. | |
| :param config: A dictionary of key-value pairs to be stored as the experiment's config. (akin to hparams in tb) | |
| :param kwargs: Additional arguments to pass to wandb.init() | |
| In order to use W&B, run `wandb login` from the command line and enter | |
| the API key when prompted. | |
| """ | |
| if not has_wandb: | |
| log = get_log_writer(__name__) | |
| log.warning("wandb not found. Please install it with `pip install wandb`.") # fmt: skip | |
| self._run = None | |
| else: | |
| if output_dir: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| self._run = wandb.init( # type: ignore | |
| project=project, | |
| name=name, | |
| dir=output_dir, | |
| resume="allow", | |
| config=config, | |
| **kwargs, | |
| ) | |
| def _define_run(self, run: str): | |
| if run in self.defined_runs: | |
| return | |
| # https://docs.wandb.ai/guides/track/log/customize-logging-axes/ | |
| wandb.define_metric(f"{run}/step") | |
| wandb.define_metric(f"{run}/*", step_metric=f"{run}/step") | |
| def record_metrics( | |
| self, | |
| run: str, | |
| values: Mapping[str, Any], | |
| step_nr: Optional[int] = None, | |
| *, | |
| flush: bool = True, | |
| ) -> None: | |
| if self._run is None: | |
| return | |
| self._define_run(run) | |
| for name, value in values.items(): | |
| formatter = _metric_formatters.get(name) | |
| if formatter is None: | |
| display_name = name | |
| else: | |
| display_name = formatter.display_name | |
| self._run.log({f"{run}/{display_name}": value, f"{run}/step": step_nr}) | |
| def close(self) -> None: | |
| if self._run is not None: | |
| self._run.finish() | |
| lcm_metric_formatters: Dict[str, Tuple[str, int, Callable[[Any], str]]] = { | |
| # fmt: off | |
| "loss": ("Loss", 100, format_as_float), | |
| "nll_loss": ("NLL Loss", 100, format_as_float), | |
| "mse_loss": ("MSE Loss", 100, format_as_float), | |
| "contrastive_loss": ("Contrastive Loss", 110, format_as_float), | |
| "reconstruction_loss": ("Reconstruction loss", 110, format_as_float), | |
| "unnormalized_reconstruction_loss": ( | |
| "Unnormalized Reconstruction Loss", | |
| 110, | |
| format_as_float, | |
| ), | |
| "kld": ("KLD loss", 110, format_as_float), | |
| "encoder_mse_loss": ("Encoder MSE loss", 110, format_as_float), | |
| "decoder_ce_loss": ("Decoder CE loss", 110, format_as_float), | |
| "elapsed_time": ("Elapsed Time", 500, format_as_seconds), | |
| "wall_time": ("Wall Time", 510, format_as_seconds), | |
| "lr": ("Learning Rate", 800, format_as_float), | |
| "loss_scale": ("Loss Scale", 810, format_as_float), | |
| "grad_norm": ("Grad norm", 810, format_as_float), | |
| "raw_grad_norm": ("Raw Grad norm", 815, format_as_float), | |
| "encoder_mse_scale": ("Encoder MSE loss scale", 850, format_as_float), | |
| "batch_size": ("Batch Size", 900, format_as_int), | |
| "elements_per_batch": ("Elements per Batch", 900, format_as_int), | |
| "elements_per_second": ("Elements per Second", 900, format_as_int), | |
| "num_examples": ("Number of Examples", 900, format_as_int), | |
| "num_source_elements": ("Number of Source Elements", 900, format_as_int), | |
| "num_target_elements": ("Number of Target Elements", 900, format_as_int), | |
| "total_num_target_elements": ("Accumulated Target Elements", 920, format_as_int), | |
| "gpu_memory_usage": ("GPU memory usage (GiB)", 910, format_as_float), | |
| "gpu_peak_memory_usage": ("GPU peak memory usage (GiB)", 910, format_as_float), | |
| "ram_percentage": ("RAM usage", 920, format_as_percent), | |
| "cpu_percentage": ("CPU usage", 920, format_as_percent), | |
| "mean_predicted_embeddings": ("mean_predicted_embeddings", 920, format_as_float), | |
| "std_predicted_embeddings": ("std_predicted_embeddings", 920, format_as_float), | |
| # fmt: on | |
| } | |
| for key in lcm_metric_formatters: | |
| register_metric_formatter(key, *lcm_metric_formatters[key], overwrite=True) | |