| | """ |
| | Pico Language Model Trainer |
| | |
| | This Trainer implements a minimalistic end-to-end training pipeline of the Pico language model with |
| | distributed training support via Lightning Fabric. It provides a modular and configurable training |
| | pipeline with the features: |
| | |
| | - Configuration Management: YAML-based configuration for all aspects of training |
| | - Distributed Training: Multi-GPU support via Lightning Fabric |
| | - Checkpointing: Regular model saving and training state recovery |
| | - Evaluation: Periodic model evaluation on validation datasets |
| | - Logging: Comprehensive metric tracking and experiment monitoring |
| | - Optimization: Support for gradient accumulation, clipping, and LR scheduling |
| | """ |
| |
|
| | import logging |
| | import os |
| | import platform |
| | from typing import Any, Dict |
| |
|
| | import lightning as L |
| | import psutil |
| | import torch |
| | import torch.nn.functional as F |
| | import yaml |
| | from datasets import Dataset, load_dataset |
| | from lightning.fabric.utilities.rank_zero import rank_zero_only |
| |
|
| | from src.checkpointing import ( |
| | compute_learning_dynamics_states, |
| | load_checkpoint, |
| | save_checkpoint, |
| | save_evaluation_results, |
| | save_learning_dynamics_states, |
| | ) |
| | from src.evaluation import run_evaluation |
| | from src.training.utils import ( |
| | initialize_configuration, |
| | initialize_dataloader, |
| | initialize_dataset, |
| | initialize_fabric, |
| | initialize_hf_checkpointing, |
| | initialize_logging, |
| | initialize_lr_scheduler, |
| | initialize_model, |
| | initialize_optimizer, |
| | initialize_run_dir, |
| | initialize_tokenizer, |
| | initialize_wandb, |
| | ) |
| | from src.training.utils.logging import pretty_print_yaml_config |
| |
|
| |
|
| | class Trainer: |
| | def __init__(self, config_path: str): |
| | """ |
| | Initializes the Trainer class. This Trainer class implements a `train` method, which is the |
| | main entry point for training the Pico model. Before calling `train`, the Trainer class |
| | initializes the following: |
| | |
| | - Configuration loading and validation |
| | - Model, optimizer, and dataset setup |
| | - Logging and experiment tracking setup |
| | - Checkpoint management |
| | |
| | Args: |
| | config_path (str): Path to the YAML configuration file containing any overrides. |
| | """ |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | self.configs = initialize_configuration(config_path) |
| |
|
| | |
| | initialize_run_dir(checkpointing_config=self.configs["checkpointing"]) |
| |
|
| | |
| | if self.configs["monitoring"].save_to_wandb: |
| | wandb_logger = initialize_wandb( |
| | monitoring_config=self.configs["monitoring"], |
| | checkpointing_config=self.configs["checkpointing"], |
| | ) |
| | else: |
| | wandb_logger = None |
| |
|
| | |
| | self.fabric = initialize_fabric( |
| | training_config=self.configs["training"], |
| | wandb_logger=wandb_logger, |
| | ) |
| | L.seed_everything(42, verbose=False) |
| |
|
| | |
| | if self.fabric.device.type == "cuda": |
| | torch.set_float32_matmul_precision( |
| | "high" |
| | ) |
| | print( |
| | "Enabled Tensor Core optimization: torch.set_float32_matmul_precision('high')" |
| | ) |
| |
|
| | |
| | self.logger = initialize_logging( |
| | monitoring_config=self.configs["monitoring"], |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | ) |
| |
|
| | |
| | self.model = initialize_model(model_config=self.configs["model"]) |
| | self.optimizer = initialize_optimizer( |
| | training_config=self.configs["training"], model=self.model |
| | ) |
| | self.lr_scheduler = initialize_lr_scheduler( |
| | training_config=self.configs["training"], optimizer=self.optimizer |
| | ) |
| |
|
| | |
| | self.model, self.optimizer = self.fabric.setup(self.model, self.optimizer) |
| |
|
| | |
| | if self.configs["checkpointing"].save_to_hf: |
| | initialize_hf_checkpointing( |
| | checkpointing_config=self.configs["checkpointing"], fabric=self.fabric |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.should_load_checkpoint = self.configs["checkpointing"].training.auto_resume |
| |
|
| | |
| | if self.should_load_checkpoint: |
| | resume_checkpoint = load_checkpoint( |
| | checkpointing_config=self.configs["checkpointing"], |
| | checkpoint_step="latest", |
| | fabric=self.fabric, |
| | model=self.model, |
| | optimizer=self.optimizer, |
| | lr_scheduler=self.lr_scheduler, |
| | ) |
| |
|
| | if resume_checkpoint: |
| | ( |
| | self.model, |
| | self.optimizer, |
| | self.lr_scheduler, |
| | self.initial_batch_step, |
| | ) = resume_checkpoint |
| | else: |
| | self.initial_batch_step = 0 |
| | else: |
| | self.initial_batch_step = 0 |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.train_dataset, fast_forward_steps = initialize_dataset( |
| | data_config=self.configs["data"], |
| | fabric=self.fabric, |
| | initial_batch_step=self.initial_batch_step, |
| | return_fast_forward_steps=True, |
| | ) |
| |
|
| | self.train_dataloader = initialize_dataloader( |
| | data_config=self.configs["data"], |
| | training_config=self.configs["training"], |
| | fabric=self.fabric, |
| | dataset=self.train_dataset, |
| | ) |
| | self.train_dataloader = self.fabric.setup_dataloaders( |
| | self.train_dataloader, use_distributed_sampler=False |
| | ) |
| |
|
| | self.tokenizer = initialize_tokenizer(data_config=self.configs["data"]) |
| |
|
| | |
| | |
| | |
| | train_iterator = iter(self.train_dataloader) |
| | if fast_forward_steps > 0: |
| | fast_forward_sub_steps = ( |
| | fast_forward_steps |
| | * self.configs["training"].optimization.gradient_accumulation_steps |
| | ) |
| | for _ in range(fast_forward_sub_steps): |
| | next(train_iterator) |
| |
|
| | self.train_iterator = train_iterator |
| |
|
| | |
| | self.fabric.barrier() |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | self.should_evaluate = ( |
| | self.configs["evaluation"].metrics is not None |
| | and len(self.configs["evaluation"].metrics) > 0 |
| | ) |
| |
|
| | self.should_compute_learning_dynamics = ( |
| | self.configs["checkpointing"].learning_dynamics.layer_suffixes is not None |
| | and len(self.configs["checkpointing"].learning_dynamics.layer_suffixes) > 0 |
| | ) |
| |
|
| | if self.should_compute_learning_dynamics: |
| | if self.configs["checkpointing"].learning_dynamics.eval_data is not None: |
| | self.learning_dynamics_eval_dataset = load_dataset( |
| | self.configs["checkpointing"].learning_dynamics.eval_data, |
| | split="val", |
| | ) |
| | else: |
| | self.learning_dynamics_eval_dataset = None |
| |
|
| | def train(self) -> None: |
| | """Execute the main training pipeline. |
| | |
| | This method orchestrates the complete training process by: |
| | 1. Creating an initial checkpoint to save the starting state and evaluate the model as a |
| | baseline |
| | 2. Running the main training loop via `_training_loop` |
| | 3. Handling final checkpointing and evaluation |
| | |
| | The training progress is tracked through checkpoints and evaluations |
| | at intervals specified in the configuration. |
| | """ |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | save_checkpoint( |
| | configs=self.configs, |
| | checkpoint_step=self.initial_batch_step, |
| | fabric=self.fabric, |
| | model=self.model, |
| | optimizer=self.optimizer, |
| | lr_scheduler=self.lr_scheduler, |
| | tokenizer=self.tokenizer, |
| | ) |
| |
|
| | |
| | if self.should_evaluate: |
| | if self.initial_batch_step == 0: |
| | evaluation_results = run_evaluation( |
| | evaluation_config=self.configs["evaluation"], |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | model=self.model, |
| | ) |
| | self._log_evaluation_results( |
| | evaluation_results, self.initial_batch_step |
| | ) |
| | save_evaluation_results( |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | evaluation_results=evaluation_results, |
| | checkpoint_step=self.initial_batch_step, |
| | ) |
| | else: |
| | |
| | eval_results_path = os.path.join( |
| | self.configs["checkpointing"].evaluation.eval_results_dir, |
| | f"step_{self.initial_batch_step}.json", |
| | ) |
| | if not os.path.exists(eval_results_path): |
| | evaluation_results = run_evaluation( |
| | evaluation_config=self.configs["evaluation"], |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | model=self.model, |
| | ) |
| | self._log_evaluation_results( |
| | evaluation_results, self.initial_batch_step |
| | ) |
| | save_evaluation_results( |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | evaluation_results=evaluation_results, |
| | checkpoint_step=self.initial_batch_step, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | if self.initial_batch_step < self.configs["training"].max_steps: |
| | self._log_training_configuration() |
| | final_step = self._training_loop() |
| | else: |
| | final_step = self.initial_batch_step |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if self.should_compute_learning_dynamics: |
| | if self.learning_dynamics_eval_dataset is not None: |
| | self.log(f"Step {final_step} -- π Saving Learning Dynamics") |
| | learning_dynamics_val_states = compute_learning_dynamics_states( |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | model=self.model, |
| | dataset=self.learning_dynamics_eval_dataset, |
| | compute_gradients=True, |
| | ) |
| | save_learning_dynamics_states( |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | learning_dynamics_states=learning_dynamics_val_states, |
| | checkpoint_step=final_step, |
| | prefix="val", |
| | ) |
| |
|
| | |
| | if final_step % self.configs["checkpointing"].save_every_n_steps != 0: |
| | self.log(f"Step {final_step} -- πΎ Saving Final Checkpoint") |
| | save_checkpoint( |
| | configs=self.configs, |
| | checkpoint_step=final_step, |
| | fabric=self.fabric, |
| | model=self.model, |
| | optimizer=self.optimizer, |
| | lr_scheduler=self.lr_scheduler, |
| | tokenizer=self.tokenizer, |
| | ) |
| |
|
| | |
| | if self.should_evaluate: |
| | evaluation_results = run_evaluation( |
| | evaluation_config=self.configs["evaluation"], |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | model=self.model, |
| | ) |
| | self._log_evaluation_results(evaluation_results, final_step) |
| | save_evaluation_results( |
| | checkpointing_config=self.configs["checkpointing"], |
| | checkpoint_step=final_step, |
| | fabric=self.fabric, |
| | evaluation_results=evaluation_results, |
| | ) |
| |
|
| | self.log(f"π Training complete! Final step: {final_step}") |
| |
|
| | if final_step < self.configs["training"].max_steps: |
| | self.log( |
| | f"\t Note: Training stopped before max steps ({self.configs['training'].max_steps})", |
| | level=logging.WARNING, |
| | ) |
| |
|
| | |
| | self.fabric.barrier() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.destroy_process_group() |
| |
|
| | del self.train_dataloader |
| |
|
| | self.fabric.barrier() |
| |
|
| | def _training_loop(self) -> int: |
| | """Execute the main training loop. |
| | |
| | This method orchestrates the core training loop and includes the following features: |
| | - Gradient accumulation |
| | - Gradient clipping |
| | - Periodic model evaluation and checkpointing |
| | - Learning Dynamics Checkpointing |
| | - Learning rate scheduling |
| | - Logging of training metrics including loss and learning rate |
| | - Handling of infinite/NaN losses |
| | |
| | Returns: |
| | int: The final step count reached during training. |
| | NOTE: A complete training run should match the configured max_steps. |
| | """ |
| | |
| | batch_step = self.initial_batch_step |
| |
|
| | |
| | |
| | interval_loss = torch.tensor(0.0, device=self.fabric.device) |
| | interval_steps = torch.tensor(0, device=self.fabric.device) |
| | interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device) |
| |
|
| | if self.should_compute_learning_dynamics: |
| | |
| | training_batch = {"input_ids": []} |
| |
|
| | |
| | initial_sub_batch_step = ( |
| | batch_step |
| | * self.configs["training"].optimization.gradient_accumulation_steps |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | for sub_batch_step, sub_batch in enumerate( |
| | self.train_iterator, start=initial_sub_batch_step |
| | ): |
| | |
| | |
| | should_store_training_batch = self.should_compute_learning_dynamics and ( |
| | batch_step % self.configs["checkpointing"].save_every_n_steps == 0 |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | _input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device) |
| | input_ids = _input_ids[:, :-1] |
| | labels = _input_ids[:, 1:] |
| |
|
| | if should_store_training_batch: |
| | gathered_input_ids = self.fabric.all_gather(_input_ids) |
| |
|
| | |
| | |
| | if self.fabric.world_size > 1: |
| | gathered_input_ids = gathered_input_ids.reshape( |
| | -1, *gathered_input_ids.shape[2:] |
| | ) |
| |
|
| | training_batch["input_ids"].extend(gathered_input_ids.tolist()) |
| |
|
| | |
| | model_output, _ = self.model(input_ids) |
| | model_output = model_output.transpose(1, 2) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | should_accumulate_gradients = (sub_batch_step + 1) % self.configs[ |
| | "training" |
| | ].optimization.gradient_accumulation_steps != 0 |
| |
|
| | with self.fabric.no_backward_sync( |
| | self.model, enabled=should_accumulate_gradients |
| | ): |
| | loss = F.cross_entropy(model_output, labels) |
| | self.fabric.backward( |
| | loss |
| | / self.configs["training"].optimization.gradient_accumulation_steps, |
| | model=self.model, |
| | ) |
| |
|
| | if torch.isnan(loss) or torch.isinf(loss): |
| | interval_inf_or_nan_count += 1 |
| | else: |
| | interval_loss += loss.item() |
| | interval_steps += 1 |
| |
|
| | |
| | if should_accumulate_gradients: |
| | continue |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | if batch_step % self.configs["monitoring"].logging.log_every_n_steps == 0: |
| | self._log_training_metrics( |
| | interval_loss=interval_loss, |
| | interval_steps=interval_steps, |
| | interval_inf_or_nan_count=interval_inf_or_nan_count, |
| | batch_step=batch_step, |
| | ) |
| | interval_loss = torch.tensor(0.0, device=self.fabric.device) |
| | interval_steps = torch.tensor(0, device=self.fabric.device) |
| | interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | if batch_step % self.configs["checkpointing"].save_every_n_steps == 0: |
| | if self.should_compute_learning_dynamics: |
| | self.log(f"Step {batch_step} -- π Saving Learning Dynamics") |
| |
|
| | |
| | training_batch_dataset = Dataset.from_dict(training_batch) |
| |
|
| | learning_dynamics_train_states = compute_learning_dynamics_states( |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | model=self.model, |
| | dataset=training_batch_dataset, |
| | compute_gradients=True, |
| | ) |
| |
|
| | save_learning_dynamics_states( |
| | checkpointing_config=self.configs["checkpointing"], |
| | checkpoint_step=batch_step, |
| | prefix="train", |
| | fabric=self.fabric, |
| | learning_dynamics_states=learning_dynamics_train_states, |
| | learning_dynamics_dataset=training_batch_dataset, |
| | tokenizer=self.tokenizer, |
| | ) |
| | training_batch = { |
| | "input_ids": [] |
| | } |
| |
|
| | |
| | if self.learning_dynamics_eval_dataset is not None: |
| | learning_dynamics_val_states = compute_learning_dynamics_states( |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | model=self.model, |
| | dataset=self.learning_dynamics_eval_dataset, |
| | compute_gradients=True, |
| | ) |
| | save_learning_dynamics_states( |
| | checkpointing_config=self.configs["checkpointing"], |
| | checkpoint_step=batch_step, |
| | prefix="val", |
| | fabric=self.fabric, |
| | learning_dynamics_states=learning_dynamics_val_states, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.optimizer.step() |
| | self.optimizer.zero_grad() |
| | self.lr_scheduler.step() |
| |
|
| | batch_step += 1 |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | if batch_step % self.configs["checkpointing"].save_every_n_steps == 0: |
| | self.log(f"Step {batch_step} -- πΎ Saving Checkpoint") |
| | save_checkpoint( |
| | configs=self.configs, |
| | checkpoint_step=batch_step, |
| | fabric=self.fabric, |
| | model=self.model, |
| | optimizer=self.optimizer, |
| | lr_scheduler=self.lr_scheduler, |
| | tokenizer=self.tokenizer, |
| | ) |
| |
|
| | if self.should_evaluate: |
| | evaluation_results = run_evaluation( |
| | evaluation_config=self.configs["evaluation"], |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | model=self.model, |
| | ) |
| | if evaluation_results is not None: |
| | self._log_evaluation_results(evaluation_results, batch_step) |
| | save_evaluation_results( |
| | checkpointing_config=self.configs["checkpointing"], |
| | fabric=self.fabric, |
| | evaluation_results=evaluation_results, |
| | checkpoint_step=batch_step, |
| | ) |
| |
|
| | |
| | if batch_step >= self.configs["training"].max_steps: |
| | break |
| |
|
| | return batch_step |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | def _log_training_metrics( |
| | self, |
| | interval_loss: torch.Tensor, |
| | interval_steps: torch.Tensor, |
| | interval_inf_or_nan_count: torch.Tensor, |
| | batch_step: int, |
| | ): |
| | """ |
| | Gathers together the training metrics computed across all processes in distributed training |
| | and logs them in a tree-style format. |
| | """ |
| | gathered_interval_loss = self.fabric.all_reduce( |
| | interval_loss, reduce_op="sum" |
| | ).item() |
| | gathered_interval_inf_or_nan_count = self.fabric.all_reduce( |
| | interval_inf_or_nan_count, reduce_op="sum" |
| | ).item() |
| | gathered_interval_steps = self.fabric.all_reduce( |
| | interval_steps, reduce_op="sum" |
| | ).item() |
| |
|
| | avg_loss = ( |
| | gathered_interval_loss / gathered_interval_steps |
| | if gathered_interval_steps > 0 |
| | else float("inf") |
| | ) |
| |
|
| | self.fabric.log("train/loss", avg_loss, step=batch_step) |
| | self.fabric.log( |
| | "trainer/inf_or_nan_count", |
| | gathered_interval_inf_or_nan_count, |
| | step=batch_step, |
| | ) |
| | self.fabric.log( |
| | "trainer/learning_rate", |
| | self.lr_scheduler.get_last_lr()[0], |
| | step=batch_step, |
| | ) |
| |
|
| | |
| | self.log(f"Step {batch_step} -- π Training Metrics") |
| | self.log(f"βββ Loss: {avg_loss:.4f}") |
| | self.log(f"βββ Learning Rate: {self.lr_scheduler.get_last_lr()[0]:.2e}") |
| | self.log(f"βββ Inf/NaN count: {gathered_interval_inf_or_nan_count}") |
| |
|
| | def _log_evaluation_results( |
| | self, evaluation_results: Dict[str, Any], batch_step: int |
| | ): |
| | """Log model evaluation metrics to experiment tracking system and console.""" |
| | self.log(f"Step {batch_step} -- π Evaluation Results") |
| | for i, (metric, result) in enumerate(evaluation_results.items()): |
| | prefix = "βββ" if i == len(evaluation_results) - 1 else "βββ" |
| | self.log(f"{prefix} {metric}: {result}") |
| | self.fabric.log(f"eval/{metric}", result, step=batch_step) |
| |
|
| | def _log_training_configuration(self): |
| | """ |
| | Log training configuration details as well as runtime information about the hardware, |
| | software, and batch settings. |
| | |
| | This function is called at the beginning of the training loop to provide a summary of the |
| | training configuration. |
| | """ |
| |
|
| | total_params = sum(p.numel() for p in self.model.parameters()) |
| | trainable_params = sum( |
| | p.numel() for p in self.model.parameters() if p.requires_grad |
| | ) |
| | global_batch_size = self.configs["data"].dataloader.batch_size |
| | per_device_batch_size = self.train_dataloader.batch_size |
| | gradient_accumulation_steps = self.configs[ |
| | "training" |
| | ].optimization.gradient_accumulation_steps |
| |
|
| | device_type = "" |
| | fabric_device = str(self.fabric.device) |
| | if torch.cuda.is_available() and "cuda" in fabric_device: |
| | device_type = torch.cuda.get_device_name(self.fabric.device) |
| | elif torch.backends.mps.is_available() and "mps" in fabric_device: |
| | device_type = "MPS (Apple Silicon)" |
| | else: |
| | device_type = "CPU" |
| |
|
| | training_config_path = os.path.join( |
| | self.configs["checkpointing"].runs_dir, |
| | self.configs["checkpointing"].run_name, |
| | "training_config.yaml", |
| | ) |
| | if os.path.exists(training_config_path): |
| | self.log("=" * 50) |
| | self.log("β¨ Training Configuration") |
| | self.log("=" * 50) |
| | training_config = yaml.safe_load(open(training_config_path, "r")) |
| | pretty_print_yaml_config(self.logger, training_config) |
| |
|
| | self.log("=" * 50) |
| | self.log("β Runtime Summary:") |
| | self.log("=" * 50) |
| | self.log(f"Starting from step: {self.initial_batch_step}") |
| |
|
| | self.log("Model Setup:") |
| | self.log(f"ββ Total Parameters: {total_params:,}") |
| | self.log(f"ββ Trainable Parameters: {trainable_params:,}") |
| |
|
| | self.log("Distributed Setup:") |
| | self.log(f"ββ Number of Devices: {self.fabric.world_size}") |
| | self.log(f"ββ Device Type: {device_type}") |
| | self.log( |
| | f"ββ Available Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB" |
| | if torch.cuda.is_available() |
| | else f"ββ Available Memory: {psutil.virtual_memory().total / 1e9:.2f} GB" |
| | ) |
| |
|
| | self.log("Software Setup:") |
| | self.log(f"ββ Python Version: {platform.python_version()}") |
| | self.log(f"ββ PyTorch Version: {torch.__version__}") |
| | self.log( |
| | f"ββ CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}" |
| | ) |
| | self.log(f"ββ Operating System: {platform.system()} {platform.release()}") |
| |
|
| | self.log("Batch Size Configuration:") |
| | self.log(f"ββ Global Batch Size: {global_batch_size}") |
| | self.log(f"ββ Per Device Batch Size: {per_device_batch_size}") |
| | self.log(f"ββ Gradient Accumulation Steps: {gradient_accumulation_steps}") |
| | self.log("=" * 50) |
| |
|
| | @rank_zero_only |
| | def log(self, msg: str, level: int = logging.INFO) -> None: |
| | """NOTE: Log messages only from rank zero process.""" |
| | self.logger.log(level, msg) |
| |
|