| |
| |
| |
| |
| |
|
|
| import os |
| import time |
| from collections import namedtuple |
| from datetime import datetime |
| from typing import Any |
|
|
| import torch |
| from torch.utils.tensorboard import SummaryWriter |
| from torchtitan.components.lr_scheduler import LRSchedulersContainer |
| from torchtitan.components.optimizer import OptimizersContainer |
| from torchtitan.config_manager import JobConfig |
| from torchtitan.distributed import ParallelDims |
| from torchtitan.tools import utils |
| from torchtitan.tools.logging import logger |
| from torchtitan.tools.utils import Color, device_module, device_type |
|
|
| |
| DeviceMemStats = namedtuple( |
| "DeviceMemStats", |
| [ |
| "max_active_gib", |
| "max_active_pct", |
| "max_reserved_gib", |
| "max_reserved_pct", |
| "num_alloc_retries", |
| "num_ooms", |
| ], |
| ) |
|
|
|
|
| class DeviceMemoryMonitor: |
| def __init__(self, device: str = f"{device_type}:0"): |
| self.device = torch.device(device) |
| self.device_name = device_module.get_device_name(self.device) |
| self.device_index = device_module.current_device() |
| self.device_capacity = device_module.get_device_properties( |
| self.device |
| ).total_memory |
| self.device_capacity_gib = self._to_gib(self.device_capacity) |
|
|
| device_module.reset_peak_memory_stats() |
| device_module.empty_cache() |
|
|
| def _to_gib(self, memory_in_bytes): |
| |
| _gib_in_bytes = 1024 * 1024 * 1024 |
| memory_in_gib = memory_in_bytes / _gib_in_bytes |
| return memory_in_gib |
|
|
| def _to_pct(self, memory): |
| return 100 * memory / self.device_capacity |
|
|
| def get_peak_stats(self): |
| device_info = device_module.memory_stats(self.device) |
|
|
| max_active = device_info.get("active_bytes.all.peak", -1) |
| max_active_gib = self._to_gib(max_active) |
| max_active_pct = self._to_pct(max_active) |
|
|
| max_reserved = device_info.get("reserved_bytes.all.peak", -1) |
| max_reserved_gib = self._to_gib(max_reserved) |
| max_reserved_pct = self._to_pct(max_reserved) |
|
|
| num_retries = device_info.get("num_alloc_retries", -1) |
| num_ooms = device_info.get("num_ooms", -1) |
|
|
| if num_retries > 0: |
| logger.warning( |
| f"{num_retries} {device_type.upper()} memory allocation retries." |
| ) |
| if num_ooms > 0: |
| logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.") |
|
|
| return DeviceMemStats( |
| max_active_gib, |
| max_active_pct, |
| max_reserved_gib, |
| max_reserved_pct, |
| num_retries, |
| num_ooms, |
| ) |
|
|
| def reset_peak_stats(self): |
| device_module.reset_peak_memory_stats() |
|
|
|
|
| def build_device_memory_monitor(): |
| device_memory_monitor = DeviceMemoryMonitor(device_type) |
| logger.info( |
| f"{device_type.upper()} capacity: {device_memory_monitor.device_name} " |
| f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory" |
| ) |
| return device_memory_monitor |
|
|
|
|
| class BaseLogger: |
| """Logger that does nothing, used when logging is disabled.""" |
|
|
| def log(self, metrics: dict[str, Any], step: int) -> None: |
| pass |
|
|
| def close(self) -> None: |
| pass |
|
|
|
|
| class TensorBoardLogger(BaseLogger): |
| """Logger implementation for TensorBoard.""" |
|
|
| def __init__(self, log_dir: str, tag: str | None = None): |
| self.tag = tag |
| self.writer = SummaryWriter(log_dir, max_queue=1000) |
| logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}") |
|
|
| def log(self, metrics: dict[str, Any], step: int) -> None: |
| for k, v in metrics.items(): |
| tag = k if self.tag is None else f"{self.tag}/{k}" |
| self.writer.add_scalar(tag, v, step) |
|
|
| def close(self) -> None: |
| self.writer.close() |
|
|
|
|
| class WandBLogger(BaseLogger): |
| """Logger implementation for Weights & Biases.""" |
|
|
| def __init__(self, log_dir: str, tag: str | None = None): |
| |
| import wandb |
|
|
| self.wandb = wandb |
| self.tag = tag |
|
|
| |
| os.makedirs(log_dir, exist_ok=True) |
|
|
| self.wandb.init( |
| project=os.getenv("WANDB_PROJECT", "torchtitan"), |
| dir=log_dir, |
| ) |
| logger.info("WandB logging enabled") |
|
|
| def log(self, metrics: dict[str, Any], step: int) -> None: |
| wandb_metrics = { |
| (k if self.tag is None else f"{self.tag}/{k}"): v |
| for k, v in metrics.items() |
| } |
| self.wandb.log(wandb_metrics, step=step) |
|
|
| def close(self) -> None: |
| if self.wandb.run is not None: |
| self.wandb.finish() |
|
|
|
|
| def ensure_pp_loss_visible( |
| parallel_dims: ParallelDims, job_config: JobConfig, color: Color |
| ) -> None: |
| """ |
| Ensures that the loss is visible on the console for pipeline-parallel training. |
| |
| For pipeline-parallel training, the loss is only visible on the last pipeline stage. |
| This function checks if the appropriate rank is included in the LOG_RANK environment |
| variable and warns if it's not. |
| """ |
|
|
| |
| if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble": |
| return |
|
|
| |
| world_size = parallel_dims.world_size |
| pp_size = parallel_dims.pp |
| loss_visible_rank = (world_size // pp_size) * (pp_size - 1) |
|
|
| |
| env_logged_ranks = os.environ.get("LOG_RANK", "").split(",") |
| if env_logged_ranks == [""]: |
| env_logged_ranks = [] |
|
|
| if str(loss_visible_rank) not in env_logged_ranks: |
| logger.warning( |
| f"{color.red}Pipeline Parallel loss is not visible. " |
| f"Please add {color.yellow}rank {loss_visible_rank}{color.red} " |
| f"to LOG_RANK environment variable in run_train.sh.{color.reset}" |
| ) |
|
|
|
|
| def _get_metrics_rank( |
| parallel_dims: ParallelDims, |
| job_config: JobConfig, |
| ) -> int: |
| """ |
| Determines which rank should log metrics. |
| |
| Returns: |
| int: The rank responsible for logging metrics: |
| - Rank 0 for non-pipeline-parallel configs |
| - Rank 0 for pipeline-parallel 'ZBVZeroBubble' schedule |
| - The first rank of the last pipeline stage for other pipeline-parallel schedules |
| """ |
| |
| if not parallel_dims.pp_enabled: |
| return 0 |
|
|
| |
| if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble": |
| return 0 |
|
|
| |
| world_size = parallel_dims.world_size |
| pp_size = parallel_dims.pp |
| return (world_size // pp_size) * (pp_size - 1) |
|
|
|
|
| def _build_metric_logger( |
| job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None |
| ) -> BaseLogger: |
| """ |
| Build an appropriate metric logger based on configuration. |
| """ |
| metrics_config = job_config.metrics |
|
|
| |
| logger.debug( |
| f"Building logger with config: wandb={metrics_config.enable_wandb}, " |
| f"tensorboard={metrics_config.enable_tensorboard}" |
| ) |
|
|
| |
| has_logging_enabled = ( |
| metrics_config.enable_tensorboard or metrics_config.enable_wandb |
| ) |
|
|
| |
| should_log = has_logging_enabled |
| if (not metrics_config.save_for_all_ranks) and should_log: |
| metrics_rank = _get_metrics_rank(parallel_dims, job_config) |
| should_log = torch.distributed.get_rank() == metrics_rank |
|
|
| logger.debug( |
| f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}" |
| ) |
|
|
| if not should_log: |
| logger.debug("Returning BaseLogger due to should_log=False") |
| return BaseLogger() |
|
|
| |
| dump_dir = job_config.job.dump_folder |
| base_log_dir = os.path.join( |
| dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M") |
| ) |
|
|
| if metrics_config.save_for_all_ranks: |
| base_log_dir = os.path.join( |
| base_log_dir, f"rank_{torch.distributed.get_rank()}" |
| ) |
|
|
| |
| if metrics_config.enable_wandb: |
| logger.debug("Attempting to create WandB logger") |
| try: |
| return WandBLogger(base_log_dir, tag) |
| except Exception as e: |
| if "No module named 'wandb'" in str(e): |
| logger.error( |
| "Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'." |
| ) |
| else: |
| logger.error(f"Failed to create WandB logger: {e}") |
|
|
| if metrics_config.enable_tensorboard: |
| logger.debug("Creating TensorBoard logger") |
| return TensorBoardLogger(base_log_dir, tag) |
|
|
| logger.debug("No loggers enabled, returning BaseLogger") |
| return BaseLogger() |
|
|
|
|
| class MetricsProcessor: |
| """Metrics processor to processes the metrics and log metrics. |
| |
| The current MetricsProcessor log some metrics to STDOUT and some metrics to |
| TensorBoard or WandB. |
| |
| Args: |
| job_config (JobConfig): Job configuration. |
| parallel_dims (ParallelDims): Parallel dimensions. |
| tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None. |
| """ |
|
|
| logger: BaseLogger |
| parallel_dims: ParallelDims |
| job_config: JobConfig |
| device_memory_monitor: DeviceMemoryMonitor |
| color: utils.NoColor | utils.Color |
|
|
| gpu_peak_flops: int |
| ntokens_since_last_log: int |
| data_loading_times: list[float] |
| time_last_log: float |
|
|
| num_flops_per_token: int |
| optimizers: OptimizersContainer | None |
| lr_schedulers: LRSchedulersContainer | None |
|
|
| def __init__( |
| self, |
| job_config: JobConfig, |
| parallel_dims: ParallelDims, |
| tag: str | None = None, |
| ): |
| self.logger = _build_metric_logger(job_config, parallel_dims, tag) |
| self.parallel_dims = parallel_dims |
| self.job_config = job_config |
| self.device_memory_monitor = build_device_memory_monitor() |
| |
| self.color = ( |
| utils.NoColor() |
| if job_config.metrics.disable_color_printing |
| else utils.Color() |
| ) |
|
|
| self.gpu_peak_flops = utils.get_peak_flops( |
| self.device_memory_monitor.device_name |
| ) |
| self.ntokens_since_last_log = 0 |
| self.data_loading_times = [] |
| self.time_last_log = time.perf_counter() |
| self.device_memory_monitor.reset_peak_stats() |
|
|
| |
| self.num_flops_per_token = -1 |
| self.optimizers = None |
| self.lr_schedulers = None |
|
|
| def should_log(self, step: int) -> bool: |
| return step == 1 or step % self.job_config.metrics.log_freq == 0 |
|
|
| def log( |
| self, |
| step: int, |
| global_avg_loss: float, |
| global_max_loss: float, |
| extra_metrics: dict[str, Any] | None = None, |
| ): |
| assert self.num_flops_per_token > 0, "num_flops_per_token must be set" |
|
|
| time_delta = time.perf_counter() - self.time_last_log |
|
|
| |
| tps = self.ntokens_since_last_log / ( |
| time_delta * self.parallel_dims.non_data_parallel_size |
| ) |
| |
| |
| |
| mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops |
| tflops = self.num_flops_per_token * tps / 1e12 |
|
|
| time_end_to_end = time_delta / self.job_config.metrics.log_freq |
| time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times) |
| time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta |
|
|
| device_mem_stats = self.device_memory_monitor.get_peak_stats() |
|
|
| metrics = { |
| "loss_metrics/global_avg_loss": global_avg_loss, |
| "loss_metrics/global_max_loss": global_max_loss, |
| "throughput(tps)": tps, |
| "tflops": tflops, |
| "mfu(%)": mfu, |
| "time_metrics/end_to_end(s)": time_end_to_end, |
| "time_metrics/data_loading(s)": time_data_loading, |
| "time_metrics/data_loading(%)": time_data_loading_pct, |
| "memory/max_active(GiB)": device_mem_stats.max_active_gib, |
| "memory/max_active(%)": device_mem_stats.max_active_pct, |
| "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, |
| "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, |
| "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, |
| "memory/num_ooms": device_mem_stats.num_ooms, |
| } |
|
|
| if extra_metrics: |
| metrics.update(extra_metrics) |
|
|
| self.logger.log(metrics, step) |
|
|
| color = self.color |
| construct_string = str( |
| f"{color.red}step: {step:2} " |
| f"{color.green}loss: {global_avg_loss:7.4f} " |
| f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" |
| f"({device_mem_stats.max_reserved_pct:.2f}%) " |
| f"{color.blue}tps: {round(tps):,} " |
| f"{color.cyan}tflops: {tflops:,.2f} " |
| f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" |
| ) |
|
|
| if extra_metrics: |
| for k, v in extra_metrics.items(): |
| if "loss" in k: |
| construct_string += f" {color.white}{k.lstrip('loss_metrics/')}: {v:7.4f}" |
| logger.info( |
| construct_string |
| ) |
|
|
| self.ntokens_since_last_log = 0 |
| self.data_loading_times.clear() |
| self.time_last_log = time.perf_counter() |
| self.device_memory_monitor.reset_peak_stats() |
|
|
| def close(self): |
| self.logger.close() |
|
|
|
|
| def build_metrics_processor( |
| job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None |
| ) -> MetricsProcessor: |
| """Create a metrics processor. |
| |
| Args: |
| job_config (JobConfig): Job configuration. |
| parallel_dims (ParallelDims): Parallel dimensions. |
| tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None. |
| |
| Returns: |
| MetricsProcessor: A metrics processor. |
| """ |
| return MetricsProcessor(job_config, parallel_dims, tag) |
|
|