LexaLCM_Pre0 / lcm /train /metrics.py
Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from collections.abc import MutableMapping
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import torch
from fairseq2.gang import Gang
from fairseq2.logging import get_log_writer
from fairseq2.metrics import (
MetricBag,
format_as_float,
format_as_int,
format_as_seconds,
)
from fairseq2.metrics.recorder import (
MetricRecorder,
_metric_formatters,
register_metric_formatter,
)
from fairseq2.typing import override
from torch import Tensor
from torch.cuda import _get_device_index
from torcheval.metrics import Max, Mean, Sum, Throughput
logger = get_log_writer(__name__)
format_as_percent = partial(format_as_int, postfix="%")
def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> Dict:
"""
A helper function to flatten nested dictionaries
Example. With a training config like
config = {
'data': {
'training': {'batch_size': 10},
'validation': {'batch_size': 2}
},
'model': {'model_dim': 1024},
'use_fsdp': True
}
The flat config will be:
{
'data.training.batch_size': 10,
'data.validation.batch_size': 2,
'model.model_dim': 1024,
'use_fsdp': True
}
This helper is used to convert our nested training config into a flat
dictionary for Tensoarboard's HParams conusmption
"""
items: List = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def get_allocated_gpu_memory(device):
"""
Get allocated memory in GiB for GPU devices
"""
if device.type == "cpu":
return 0, 0
device = _get_device_index(device, optional=True)
memory_stats = torch.cuda.memory_stats(device=device)
current_usage = memory_stats["allocated_bytes.all.current"] / (1024**3)
peak_usage = memory_stats["allocated_bytes.all.peak"] / (1024**3)
return current_usage, peak_usage
@dataclass
class LossTerm:
"""Dataclass for a batch loss term"""
value: Tensor
"""The final loss to be optimized"""
batch_size: int
num_target_elements: Union[int, float]
summands: Dict[str, Tuple[Any, Any]] = field(default_factory=lambda: {})
"""A dictionary of loss terms to record. Each term is a tuple of (loss, number of elements)
The second term is optional; if None, we will use `num_target_elements` when aggregating"""
class LCMMetricBag(MetricBag):
"""Holds the common metrics of an LCM."""
loss: Mean
batch_size: Sum
elements_per_batch: Mean
elements_per_second: Throughput
num_target_elements: Sum
total_num_target_elements: Sum
grad_norm: Mean
def __init__(
self, gang: Gang, loss_summands: Sequence[str] = [], reduction: str = "sum"
) -> None:
"""
:param gang:
The gang to sync metrics across all processes.
"""
super().__init__(gang)
# temporary fix:
self.reduction = reduction
d = gang.device
# A temporary solution to track as many loss terms as we explore
self.loss_summands = loss_summands
self.register_metric("loss", Mean(device=d), persistent=False)
# this is the effective batch size
self.register_metric("batch_size", Sum(device=d), persistent=False)
self.register_metric("elements_per_batch", Mean(device=d), persistent=False)
self.register_metric(
"elements_per_second", Throughput(device=d), persistent=False
)
self.register_metric("gpu_memory_usage", Max(device=d), persistent=False)
self.register_metric("gpu_peak_memory_usage", Max(device=d), persistent=False)
# self.register_metric("ram_percentage", Max(device=d), persistent=False)
# self.register_metric("cpu_percentage", Max(device=d), persistent=False)
for summand in self.loss_summands:
self.register_metric(summand, Mean(device=d), persistent=False)
# The number of target tokens in a parallel batch. Used for computing throughput
self.register_metric("num_target_elements", Sum(device=d), persistent=False)
# The total_num_target_elements is persistent and is supposed to track the
# total number of tokens consumed since training started
self.total_num_target_elements = Sum(device=d)
def register_adaln_metric(self, module_name: str):
for block in ["mha", "ffn"]:
for tensor in [
"shift",
"scale",
"gate",
]:
self.register_metric(
f"{module_name}_{block}_{tensor}_mean",
Mean(device=self._gang.device),
persistent=False,
)
self.register_metric(
f"{module_name}_{block}_{tensor}_std",
Mean(device=self._gang.device),
persistent=False,
)
# formatters
register_metric_formatter(
f"{module_name}_{block}_{tensor}_mean",
f"{module_name}_{block}_{tensor}_mean",
1000,
format_as_float,
)
register_metric_formatter(
f"{module_name}_{block}_{tensor}_std",
f"{module_name}_{block}_{tensor}_std",
1000,
format_as_float,
)
def register_module_metric(self, module_name: str):
for tensor in [
"input_gradient",
"output_gradient",
"input_activations",
"output_activations",
]:
self.register_metric(
f"{module_name}_{tensor}_mean",
Mean(device=self._gang.device),
persistent=False,
)
self.register_metric(
f"{module_name}_{tensor}_std",
Mean(device=self._gang.device),
persistent=False,
)
# formatters
register_metric_formatter(
f"{module_name}_{tensor}_mean",
f"{module_name}_{tensor}_mean",
1000,
format_as_float,
)
register_metric_formatter(
f"{module_name}_{tensor}_std",
f"{module_name}_{tensor}_std",
1000,
format_as_float,
)
@torch.inference_mode()
def update(
self,
losses: Sequence[LossTerm],
) -> None:
"""Update the metrics.
:param output:
The losses generated by the model for each batch
:param elapsed_time:
The total elapsed time to read and process batches
"""
loss = torch.zeros((), dtype=torch.float64)
loss_summands = {
s: torch.zeros((), dtype=torch.float64) for s in self.loss_summands
}
# Denominator to normalize the loss summands, if -1,
# we will default to normalizing with `num_target_elements`
loss_summands_numel = {
s: -torch.ones((), dtype=torch.long) for s in self.loss_summands
}
batch_size = torch.zeros((), dtype=torch.int64)
num_target_elements = torch.zeros((), dtype=torch.int64)
# Only in the case of using gradient accumulation that `losses` will be a non-singleton
for batch_loss in losses:
loss += float(batch_loss.value)
for s in self.loss_summands:
loss_term = batch_loss.summands.get(s, (0.0, None))
loss_summands[s] += float(loss_term[0])
if loss_term[1] is not None and not loss_term[1] == -1:
if loss_summands_numel[s] == -1:
loss_summands_numel[s] = torch.zeros((), dtype=torch.int64)
loss_summands_numel[s] += loss_term[1]
batch_size += batch_loss.batch_size
num_target_elements += batch_loss.num_target_elements
# Misleading normalization in the metric bag with reduction == "mean"
# Kept here for backward compatibility
# Any normalization here is only for reporting and doesn't impact optimization
if self.reduction == "sum":
loss /= num_target_elements
keys = list(loss_summands)
for k in keys:
denom = loss_summands_numel[k]
if denom == -1:
denom = num_target_elements
loss_summands[k] /= denom + 1e-6
self.loss.update(loss, weight=num_target_elements)
for s in loss_summands:
weight = loss_summands_numel[s]
if weight == -1:
weight = num_target_elements
getattr(self, s).update(loss_summands[s], weight=weight)
self.batch_size.update(batch_size)
self.elements_per_batch.update(num_target_elements)
self.num_target_elements.update(num_target_elements)
# update the cumulative metric
self.total_num_target_elements.update(num_target_elements)
# Get GPU memory usage
gpu_memory_usage, gpu_peak_memory_usage = get_allocated_gpu_memory(
self._gang.device
)
self.gpu_memory_usage.update(torch.tensor(gpu_memory_usage))
self.gpu_peak_memory_usage.update(torch.tensor(gpu_peak_memory_usage))
def reset_batch_metrics(self) -> None:
"""Reset the batch metrics to their initial state."""
self.loss.reset()
for s in self.loss_summands:
getattr(self, s).reset()
self.batch_size.reset()
self.elements_per_batch.reset()
self.elements_per_second.reset()
self.grad_norm.reset()
self.gpu_memory_usage.reset()
self.gpu_peak_memory_usage.reset()
# self.ram_percentage.reset()
# self.cpu_percentage.reset()
## Weight and Biases recorder
try:
import wandb # type: ignore[import-not-found]
except ImportError:
has_wandb = False
else:
has_wandb = True
class LCMWandBRecorder(MetricRecorder):
"""Records metric values to Weights & Biases."""
defined_runs: Set[str] = set()
def __init__(
self,
project: Optional[str] = None,
name: Optional[str] = None,
output_dir: Optional[Path] = None,
config: Dict[str, Any] = {},
**kwargs,
) -> None:
"""
:param project: A project to organise this run with other experiments, if none, the run will go under `uncategorized`.
:param name: A unique name for your run, if none is given, a random name will be generated
:param output_dir: The base directory under which to store the W&B files. You don't have to provide this.
:param config: A dictionary of key-value pairs to be stored as the experiment's config. (akin to hparams in tb)
:param kwargs: Additional arguments to pass to wandb.init()
In order to use W&B, run `wandb login` from the command line and enter
the API key when prompted.
"""
if not has_wandb:
log = get_log_writer(__name__)
log.warning("wandb not found. Please install it with `pip install wandb`.") # fmt: skip
self._run = None
else:
if output_dir:
output_dir.mkdir(parents=True, exist_ok=True)
self._run = wandb.init( # type: ignore
project=project,
name=name,
dir=output_dir,
resume="allow",
config=config,
**kwargs,
)
def _define_run(self, run: str):
if run in self.defined_runs:
return
# https://docs.wandb.ai/guides/track/log/customize-logging-axes/
wandb.define_metric(f"{run}/step")
wandb.define_metric(f"{run}/*", step_metric=f"{run}/step")
@override
def record_metrics(
self,
run: str,
values: Mapping[str, Any],
step_nr: Optional[int] = None,
*,
flush: bool = True,
) -> None:
if self._run is None:
return
self._define_run(run)
for name, value in values.items():
formatter = _metric_formatters.get(name)
if formatter is None:
display_name = name
else:
display_name = formatter.display_name
self._run.log({f"{run}/{display_name}": value, f"{run}/step": step_nr})
@override
def close(self) -> None:
if self._run is not None:
self._run.finish()
lcm_metric_formatters: Dict[str, Tuple[str, int, Callable[[Any], str]]] = {
# fmt: off
"loss": ("Loss", 100, format_as_float),
"nll_loss": ("NLL Loss", 100, format_as_float),
"mse_loss": ("MSE Loss", 100, format_as_float),
"contrastive_loss": ("Contrastive Loss", 110, format_as_float),
"reconstruction_loss": ("Reconstruction loss", 110, format_as_float),
"unnormalized_reconstruction_loss": (
"Unnormalized Reconstruction Loss",
110,
format_as_float,
),
"kld": ("KLD loss", 110, format_as_float),
"encoder_mse_loss": ("Encoder MSE loss", 110, format_as_float),
"decoder_ce_loss": ("Decoder CE loss", 110, format_as_float),
"elapsed_time": ("Elapsed Time", 500, format_as_seconds),
"wall_time": ("Wall Time", 510, format_as_seconds),
"lr": ("Learning Rate", 800, format_as_float),
"loss_scale": ("Loss Scale", 810, format_as_float),
"grad_norm": ("Grad norm", 810, format_as_float),
"raw_grad_norm": ("Raw Grad norm", 815, format_as_float),
"encoder_mse_scale": ("Encoder MSE loss scale", 850, format_as_float),
"batch_size": ("Batch Size", 900, format_as_int),
"elements_per_batch": ("Elements per Batch", 900, format_as_int),
"elements_per_second": ("Elements per Second", 900, format_as_int),
"num_examples": ("Number of Examples", 900, format_as_int),
"num_source_elements": ("Number of Source Elements", 900, format_as_int),
"num_target_elements": ("Number of Target Elements", 900, format_as_int),
"total_num_target_elements": ("Accumulated Target Elements", 920, format_as_int),
"gpu_memory_usage": ("GPU memory usage (GiB)", 910, format_as_float),
"gpu_peak_memory_usage": ("GPU peak memory usage (GiB)", 910, format_as_float),
"ram_percentage": ("RAM usage", 920, format_as_percent),
"cpu_percentage": ("CPU usage", 920, format_as_percent),
"mean_predicted_embeddings": ("mean_predicted_embeddings", 920, format_as_float),
"std_predicted_embeddings": ("std_predicted_embeddings", 920, format_as_float),
# fmt: on
}
for key in lcm_metric_formatters:
register_metric_formatter(key, *lcm_metric_formatters[key], overwrite=True)