| import time |
| from collections import deque |
| from contextlib import nullcontext |
| from typing import Any, Callable, Deque, Dict, Optional |
|
|
| import torch |
| from lightning import Callback, Fabric, LightningModule, Trainer |
| from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 |
| from lightning.fabric.plugins import ( |
| BitsandbytesPrecision, |
| DoublePrecision, |
| FSDPPrecision, |
| HalfPrecision, |
| MixedPrecision, |
| Precision, |
| TransformerEnginePrecision, |
| XLAPrecision, |
| ) |
| from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only |
| from lightning.pytorch.plugins import ( |
| DoublePrecisionPlugin, |
| FSDPPrecisionPlugin, |
| HalfPrecisionPlugin, |
| MixedPrecisionPlugin, |
| XLAPrecisionPlugin, |
| ) |
| from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only |
| from torch.utils.flop_counter import FlopCounterMode |
|
|
| from lit_gpt import GPT |
| from lit_gpt.utils import num_parameters |
|
|
| GPU_AVAILABLE_FLOPS = { |
| |
| |
| "h100-sxm": { |
| torch.float64: 67e12, |
| torch.float32: 67e12, |
| torch.bfloat16: 1.979e15 / 2, |
| torch.float16: 1.979e15 / 2, |
| torch.int8: 3.958e15 / 2, |
| }, |
| "h100-pcie": { |
| torch.float64: 51e12, |
| torch.float32: 51e12, |
| torch.bfloat16: 1.513e15 / 2, |
| torch.float16: 1.513e15 / 2, |
| torch.int8: 3.026e15 / 2, |
| }, |
| |
| |
| "a100": {torch.float64: 19.5e12, torch.float32: 19.5e12, torch.bfloat16: 312e12, torch.float16: 312e12}, |
| |
| "a10g": {torch.float32: 31.2e12, torch.bfloat16: 125e12, torch.float16: 125e12}, |
| |
| "v100-sxm": {torch.float64: 7.8e12, torch.float32: 15.7e12, torch.float16: 125e12}, |
| "v100-pcie": {torch.float64: 7e12, torch.float32: 14e12, torch.float16: 112e12}, |
| "v100s-pcie": {torch.float64: 8.2e12, torch.float32: 16.4e12, torch.float16: 130e12}, |
| |
| |
| "t4": {torch.float32: 8.1e12, torch.float16: 65e12, torch.int8: 130e12}, |
| |
| "quadro rtx 5000": {torch.float32: 11.2e12, torch.float16: 89.2e12}, |
| } |
|
|
| TPU_AVAILABLE_FLOPS = { |
| |
| |
| |
| |
| "v2": 45e12, |
| |
| "v3": 123e12, |
| |
| "v4": 275e12, |
| |
| "v5litepod": 197e12, |
| } |
|
|
|
|
| def get_flops_available(device: torch.device, dtype: torch.dtype) -> Optional[float]: |
| if device.type == "cuda": |
| device_name = torch.cuda.get_device_name(device).lower() |
| if "h100" in device_name and "hbm3" in device_name: |
| device_name = "h100-sxm" |
| elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name): |
| device_name = "h100-pcie" |
| elif "a100" in device_name: |
| device_name = "a100" |
| elif "a10g" in device_name: |
| device_name = "a10g" |
| elif "v100-sxm" in device_name: |
| device_name = "v100-sxm" |
| elif "v100-pcie" in device_name: |
| device_name = "v100-pcie" |
| elif "t4" in device_name: |
| device_name = "t4" |
| elif "quadro rtx 5000" in device_name: |
| device_name = "quadro rtx 5000" |
| else: |
| device_name = None |
|
|
| if device_name is not None: |
| try: |
| return int(GPU_AVAILABLE_FLOPS[device_name][dtype]) |
| except KeyError: |
| raise KeyError( |
| f"flop count not found for {device_name} with dtype: {dtype}; " |
| "MFU cannot be calculated and reported." |
| ) |
| elif device.type == "xla": |
| if _XLA_GREATER_EQUAL_2_1: |
| from torch_xla._internal import tpu |
| else: |
| from torch_xla.experimental import tpu |
|
|
| device_name = tpu.get_tpu_env()["TYPE"].lower() |
| try: |
| return int(TPU_AVAILABLE_FLOPS[device_name]) |
| except KeyError: |
| raise KeyError( |
| f"flop count not found for {device_name} with dtype: {dtype}; MFU cannot be calculated and reported." |
| ) |
|
|
| return None |
|
|
|
|
| |
|
|
|
|
| class SpeedMonitorBase: |
| """Logs the training throughput and utilization. |
| |
| +-------------------------------------+-----------------------------------------------------------+ |
| | Key | Logged data | |
| +=====================================+===========================================================+ |
| | | Rolling average (over `window_size` most recent | |
| | `throughput/batches_per_sec` | batches) of the number of batches processed per second | |
| | | | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | | Rolling average (over `window_size` most recent | |
| | `throughput/samples_per_sec` | batches) of the number of samples processed per second | |
| | | | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | | Rolling average (over `window_size` most recent | |
| | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. | |
| | | This may include padding depending on dataset | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | | Estimates flops by `flops_per_batch * batches_per_sec` | |
| | `throughput/flops_per_sec` | | |
| | | | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | | `throughput/tokens_per_sec` divided by world size. This | |
| | `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset | |
| | | | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | | `throughput/flops_per_sec` divided by world size. Only | |
| | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | |
| | | | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | | `throughput/device/flops_per_sec` divided by world size. | |
| | `throughput/device/mfu` | | |
| | | | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | `time/train` | Total elapsed training time | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | `time/val` | Total elapsed validation time | |
| +-------------------------------------+-----------------------------------------------------------+ |
| | `time/total` | Total elapsed time (time/train + time/val) | |
| +-------------------------------------+-----------------------------------------------------------+ |
| |
| Notes: |
| - The implementation assumes that devices are homogeneous as it normalizes by the world size. |
| - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or |
| batches/sec to measure throughput under this circumstance. |
| - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``. |
| There is no widespread, realistic, and reliable implementation to compute them. |
| We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which |
| will almost always be an overestimate when compared to the true value. |
| |
| Args: |
| window_size (int, optional): Number of batches to use for a rolling average of throughput. |
| Defaults to 100. |
| time_unit (str, optional): Time unit to use for `time` logging. Can be one of |
| 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'. |
| """ |
|
|
| def __init__( |
| self, |
| flops_available: float, |
| log_dict: Callable[[Dict, int], None], |
| window_size: int = 100, |
| time_unit: str = "hours", |
| ): |
| self.flops_available = flops_available |
| self.log_dict = log_dict |
|
|
| |
| self.history_samples: Deque[int] = deque(maxlen=window_size + 1) |
| self.history_wct: Deque[float] = deque(maxlen=window_size + 1) |
| self.history_lengths: Deque[int] = deque(maxlen=window_size + 1) |
| self.history_flops: Deque[int] = deque(maxlen=window_size + 1) |
|
|
| self.divider = 1 |
| if time_unit == "seconds": |
| self.divider = 1 |
| elif time_unit == "minutes": |
| self.divider = 60 |
| elif time_unit == "hours": |
| self.divider = 60 * 60 |
| elif time_unit == "days": |
| self.divider = 60 * 60 * 24 |
| else: |
| raise ValueError( |
| f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".' |
| ) |
|
|
| |
| self.total_eval_wct = 0.0 |
| self.step = -1 |
|
|
| def on_train_batch_end( |
| self, |
| samples: int, |
| train_elapsed: float, |
| world_size: int, |
| flops_per_batch: Optional[int] = None, |
| lengths: Optional[int] = None, |
| ) -> None: |
| self.step += 1 |
| step = self.step |
| metrics = {} |
|
|
| self.history_samples.append(samples) |
| if lengths is not None: |
| self.history_lengths.append(lengths) |
| |
| assert len(self.history_samples) == len(self.history_lengths) |
| self.history_wct.append(train_elapsed) |
| if len(self.history_wct) == self.history_wct.maxlen: |
| elapsed_batches = len(self.history_samples) - 1 |
| elapsed_samples = self.history_samples[-1] - self.history_samples[0] |
| elapsed_wct = self.history_wct[-1] - self.history_wct[0] |
| samples_per_sec = elapsed_samples * world_size / elapsed_wct |
| dev_samples_per_sec = elapsed_samples / elapsed_wct |
| metrics.update( |
| { |
| "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct, |
| "throughput/samples_per_sec": samples_per_sec, |
| "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct, |
| "throughput/device/samples_per_sec": dev_samples_per_sec, |
| } |
| ) |
| if lengths is not None: |
| elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0]) |
| avg_length = elapsed_lengths / elapsed_batches |
| metrics.update( |
| { |
| "throughput/tokens_per_sec": samples_per_sec * avg_length, |
| "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length, |
| } |
| ) |
|
|
| if flops_per_batch is not None: |
| |
| self.history_flops.append(flops_per_batch * world_size) |
| if len(self.history_flops) == self.history_flops.maxlen: |
| elapsed_flops = sum(self.history_flops) - self.history_flops[0] |
| elapsed_wct = self.history_wct[-1] - self.history_wct[0] |
| flops_per_sec = elapsed_flops / elapsed_wct |
| device_flops_per_sec = flops_per_sec / world_size |
| metrics.update( |
| {"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec} |
| ) |
| if self.flops_available: |
| metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available |
|
|
| metrics.update( |
| { |
| "time/train": train_elapsed / self.divider, |
| "time/val": self.total_eval_wct / self.divider, |
| "time/total": (train_elapsed + self.total_eval_wct) / self.divider, |
| "samples": samples, |
| } |
| ) |
|
|
| self.log_dict(metrics, step) |
|
|
| def eval_end(self, eval_elapsed: float) -> None: |
| self.total_eval_wct += eval_elapsed |
|
|
|
|
| def plugin_to_compute_dtype(plugin: Precision) -> torch.dtype: |
| if isinstance(plugin, BitsandbytesPrecision): |
| return plugin.dtype |
| if isinstance(plugin, (HalfPrecision, MixedPrecision, HalfPrecisionPlugin)): |
| return plugin._desired_input_dtype |
| if isinstance(plugin, MixedPrecisionPlugin): |
| return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half |
| if isinstance(plugin, (DoublePrecision, DoublePrecisionPlugin)): |
| return torch.double |
| if isinstance(plugin, (XLAPrecision, XLAPrecisionPlugin)): |
| return plugin._desired_dtype |
| if isinstance(plugin, TransformerEnginePrecision): |
| return torch.int8 |
| if isinstance(plugin, (FSDPPrecision, FSDPPrecisionPlugin)): |
| return plugin.mixed_precision_config.reduce_dtype |
| if isinstance(plugin, Precision): |
| return torch.float32 |
| raise NotImplementedError(plugin) |
|
|
|
|
| class SpeedMonitorFabric(SpeedMonitorBase): |
| def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None: |
| dtype = plugin_to_compute_dtype(fabric.strategy.precision) |
| flops_available = get_flops_available(fabric.device, dtype) |
| super().__init__(flops_available, fabric.log_dict, *args, **kwargs) |
|
|
| @fabric_rank_zero_only |
| def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None: |
| super().on_train_batch_end(*args, **kwargs) |
|
|
|
|
| class SpeedMonitorCallback(Callback): |
| def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None: |
| super().__init__() |
| self.speed_monitor: Optional[SpeedMonitorBase] = None |
| self.speed_monitor_kwargs = kwargs |
| self.length_fn = length_fn |
| self.batch_size = batch_size |
| self.eval_t0: int = 0 |
| self.train_t0: int = 0 |
| self.total_lengths: int = 0 |
|
|
| def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: |
| if self.speed_monitor is not None: |
| return |
| dtype = plugin_to_compute_dtype(trainer.precision_plugin) |
| flops_available = get_flops_available(trainer.strategy.root_device, dtype) |
| self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs) |
|
|
| @trainer_rank_zero_only |
| def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: |
| if trainer.fit_loop._should_accumulate(): |
| return |
|
|
| self.train_t0 = time.perf_counter() |
|
|
| @trainer_rank_zero_only |
| def on_train_batch_end( |
| self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int |
| ) -> None: |
| self.total_lengths += self.length_fn(batch) |
| if trainer.fit_loop._should_accumulate(): |
| return |
| train_elapsed = time.perf_counter() - self.train_t0 |
| assert self.speed_monitor is not None |
| iter_num = trainer.fit_loop.total_batch_idx |
| assert (measured_flops := pl_module.measured_flops) is not None |
| self.speed_monitor.on_train_batch_end( |
| (iter_num + 1) * self.batch_size, |
| train_elapsed, |
| |
| trainer.world_size, |
| flops_per_batch=measured_flops, |
| lengths=self.total_lengths, |
| ) |
|
|
| @trainer_rank_zero_only |
| def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: |
| self.eval_t0 = time.perf_counter() |
|
|
| @trainer_rank_zero_only |
| def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: |
| eval_elapsed = time.perf_counter() - self.eval_t0 |
| assert self.speed_monitor is not None |
| self.speed_monitor.eval_end(eval_elapsed) |
|
|
|
|
| def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int: |
| flops_per_token = 2 * n_params |
| |
| |
| flops_per_seq = flops_per_token * max_seq_length |
| attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2)) |
| return flops_per_seq + attn_flops_per_seq |
|
|
|
|
| def estimate_flops(model: GPT) -> int: |
| """Measures estimated FLOPs for MFU. |
| |
| Refs: |
| * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 |
| * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 |
| """ |
| |
| |
| |
| |
| n_trainable_params = num_parameters(model, requires_grad=True) |
| trainable_flops = flops_per_param( |
| model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params |
| ) |
| |
| ops_per_step = 3 if model.training else 1 |
| n_frozen_params = num_parameters(model, requires_grad=False) |
| frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params) |
| |
| frozen_ops_per_step = 2 if model.training else 1 |
| return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops |
|
|
|
|
| def measure_flops(model: GPT, x: torch.Tensor) -> int: |
| """Measures real FLOPs for HFU""" |
| flop_counter = FlopCounterMode(model, display=False) |
| ctx = nullcontext() if model.training else torch.no_grad() |
| with ctx, flop_counter: |
| y = model(x) |
| if model.training: |
| y.sum().backward() |
| return flop_counter.get_total_flops() |
|
|