|
|
""" |
|
|
Pico Language Model Trainer |
|
|
|
|
|
This Trainer implements a minimalistic end-to-end training pipeline of the Pico language model with |
|
|
distributed training support via Lightning Fabric. It provides a modular and configurable training |
|
|
pipeline with the features: |
|
|
|
|
|
- Configuration Management: YAML-based configuration for all aspects of training |
|
|
- Distributed Training: Multi-GPU support via Lightning Fabric |
|
|
- Checkpointing: Regular model saving and training state recovery |
|
|
- Evaluation: Periodic model evaluation on validation datasets |
|
|
- Logging: Comprehensive metric tracking and experiment monitoring |
|
|
- Optimization: Support for gradient accumulation, clipping, and LR scheduling |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import platform |
|
|
from typing import Any, Dict |
|
|
|
|
|
import lightning as L |
|
|
import psutil |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import yaml |
|
|
from datasets import Dataset, load_dataset |
|
|
from lightning.fabric.utilities.rank_zero import rank_zero_only |
|
|
|
|
|
from src.checkpointing import ( |
|
|
compute_learning_dynamics_states, |
|
|
load_checkpoint, |
|
|
save_checkpoint, |
|
|
save_evaluation_results, |
|
|
save_learning_dynamics_states, |
|
|
) |
|
|
from src.evaluation import run_evaluation |
|
|
from src.training.utils import ( |
|
|
initialize_configuration, |
|
|
initialize_dataloader, |
|
|
initialize_dataset, |
|
|
initialize_fabric, |
|
|
initialize_hf_checkpointing, |
|
|
initialize_logging, |
|
|
initialize_lr_scheduler, |
|
|
initialize_model, |
|
|
initialize_optimizer, |
|
|
initialize_run_dir, |
|
|
initialize_tokenizer, |
|
|
initialize_wandb, |
|
|
) |
|
|
from src.training.utils.logging import pretty_print_yaml_config |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
def __init__(self, config_path: str): |
|
|
""" |
|
|
Initializes the Trainer class. This Trainer class implements a `train` method, which is the |
|
|
main entry point for training the Pico model. Before calling `train`, the Trainer class |
|
|
initializes the following: |
|
|
|
|
|
- Configuration loading and validation |
|
|
- Model, optimizer, and dataset setup |
|
|
- Logging and experiment tracking setup |
|
|
- Checkpoint management |
|
|
|
|
|
Args: |
|
|
config_path (str): Path to the YAML configuration file containing any overrides. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.configs = initialize_configuration(config_path) |
|
|
|
|
|
|
|
|
initialize_run_dir(checkpointing_config=self.configs["checkpointing"]) |
|
|
|
|
|
|
|
|
if self.configs["monitoring"].save_to_wandb: |
|
|
wandb_logger = initialize_wandb( |
|
|
monitoring_config=self.configs["monitoring"], |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
) |
|
|
else: |
|
|
wandb_logger = None |
|
|
|
|
|
|
|
|
self.fabric = initialize_fabric( |
|
|
training_config=self.configs["training"], |
|
|
wandb_logger=wandb_logger, |
|
|
) |
|
|
L.seed_everything(42, verbose=False) |
|
|
|
|
|
|
|
|
if self.fabric.device.type == "cuda": |
|
|
torch.set_float32_matmul_precision( |
|
|
"high" |
|
|
) |
|
|
print( |
|
|
"Enabled Tensor Core optimization: torch.set_float32_matmul_precision('high')" |
|
|
) |
|
|
|
|
|
|
|
|
self.logger = initialize_logging( |
|
|
monitoring_config=self.configs["monitoring"], |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
) |
|
|
|
|
|
|
|
|
self.model = initialize_model(model_config=self.configs["model"]) |
|
|
self.optimizer = initialize_optimizer( |
|
|
training_config=self.configs["training"], model=self.model |
|
|
) |
|
|
self.lr_scheduler = initialize_lr_scheduler( |
|
|
training_config=self.configs["training"], optimizer=self.optimizer |
|
|
) |
|
|
|
|
|
|
|
|
self.model, self.optimizer = self.fabric.setup(self.model, self.optimizer) |
|
|
|
|
|
|
|
|
if self.configs["checkpointing"].save_to_hf: |
|
|
initialize_hf_checkpointing( |
|
|
checkpointing_config=self.configs["checkpointing"], fabric=self.fabric |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.should_load_checkpoint = self.configs["checkpointing"].training.auto_resume |
|
|
|
|
|
|
|
|
if self.should_load_checkpoint: |
|
|
resume_checkpoint = load_checkpoint( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
checkpoint_step="latest", |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
optimizer=self.optimizer, |
|
|
lr_scheduler=self.lr_scheduler, |
|
|
) |
|
|
|
|
|
if resume_checkpoint: |
|
|
( |
|
|
self.model, |
|
|
self.optimizer, |
|
|
self.lr_scheduler, |
|
|
self.initial_batch_step, |
|
|
) = resume_checkpoint |
|
|
else: |
|
|
self.initial_batch_step = 0 |
|
|
else: |
|
|
self.initial_batch_step = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.train_dataset, fast_forward_steps = initialize_dataset( |
|
|
data_config=self.configs["data"], |
|
|
fabric=self.fabric, |
|
|
initial_batch_step=self.initial_batch_step, |
|
|
return_fast_forward_steps=True, |
|
|
) |
|
|
|
|
|
self.train_dataloader = initialize_dataloader( |
|
|
data_config=self.configs["data"], |
|
|
training_config=self.configs["training"], |
|
|
fabric=self.fabric, |
|
|
dataset=self.train_dataset, |
|
|
) |
|
|
self.train_dataloader = self.fabric.setup_dataloaders( |
|
|
self.train_dataloader, use_distributed_sampler=False |
|
|
) |
|
|
|
|
|
self.tokenizer = initialize_tokenizer(data_config=self.configs["data"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_iterator = iter(self.train_dataloader) |
|
|
if fast_forward_steps > 0: |
|
|
fast_forward_sub_steps = ( |
|
|
fast_forward_steps |
|
|
* self.configs["training"].optimization.gradient_accumulation_steps |
|
|
) |
|
|
for _ in range(fast_forward_sub_steps): |
|
|
next(train_iterator) |
|
|
|
|
|
self.train_iterator = train_iterator |
|
|
|
|
|
|
|
|
self.fabric.barrier() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.should_evaluate = ( |
|
|
self.configs["evaluation"].metrics is not None |
|
|
and len(self.configs["evaluation"].metrics) > 0 |
|
|
) |
|
|
|
|
|
self.should_compute_learning_dynamics = ( |
|
|
self.configs["checkpointing"].learning_dynamics.layer_suffixes is not None |
|
|
and len(self.configs["checkpointing"].learning_dynamics.layer_suffixes) > 0 |
|
|
) |
|
|
|
|
|
if self.should_compute_learning_dynamics: |
|
|
if self.configs["checkpointing"].learning_dynamics.eval_data is not None: |
|
|
self.learning_dynamics_eval_dataset = load_dataset( |
|
|
self.configs["checkpointing"].learning_dynamics.eval_data, |
|
|
split="val", |
|
|
) |
|
|
else: |
|
|
self.learning_dynamics_eval_dataset = None |
|
|
|
|
|
def train(self) -> None: |
|
|
"""Execute the main training pipeline. |
|
|
|
|
|
This method orchestrates the complete training process by: |
|
|
1. Creating an initial checkpoint to save the starting state and evaluate the model as a |
|
|
baseline |
|
|
2. Running the main training loop via `_training_loop` |
|
|
3. Handling final checkpointing and evaluation |
|
|
|
|
|
The training progress is tracked through checkpoints and evaluations |
|
|
at intervals specified in the configuration. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_checkpoint( |
|
|
configs=self.configs, |
|
|
checkpoint_step=self.initial_batch_step, |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
optimizer=self.optimizer, |
|
|
lr_scheduler=self.lr_scheduler, |
|
|
tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
|
|
|
if self.should_evaluate: |
|
|
if self.initial_batch_step == 0: |
|
|
evaluation_results = run_evaluation( |
|
|
evaluation_config=self.configs["evaluation"], |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
) |
|
|
self._log_evaluation_results( |
|
|
evaluation_results, self.initial_batch_step |
|
|
) |
|
|
save_evaluation_results( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
evaluation_results=evaluation_results, |
|
|
checkpoint_step=self.initial_batch_step, |
|
|
) |
|
|
else: |
|
|
|
|
|
eval_results_path = os.path.join( |
|
|
self.configs["checkpointing"].evaluation.eval_results_dir, |
|
|
f"step_{self.initial_batch_step}.json", |
|
|
) |
|
|
if not os.path.exists(eval_results_path): |
|
|
evaluation_results = run_evaluation( |
|
|
evaluation_config=self.configs["evaluation"], |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
) |
|
|
self._log_evaluation_results( |
|
|
evaluation_results, self.initial_batch_step |
|
|
) |
|
|
save_evaluation_results( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
evaluation_results=evaluation_results, |
|
|
checkpoint_step=self.initial_batch_step, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.initial_batch_step < self.configs["training"].max_steps: |
|
|
self._log_training_configuration() |
|
|
final_step = self._training_loop() |
|
|
else: |
|
|
final_step = self.initial_batch_step |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.should_compute_learning_dynamics: |
|
|
if self.learning_dynamics_eval_dataset is not None: |
|
|
self.log(f"Step {final_step} -- π Saving Learning Dynamics") |
|
|
learning_dynamics_val_states = compute_learning_dynamics_states( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
dataset=self.learning_dynamics_eval_dataset, |
|
|
compute_gradients=True, |
|
|
) |
|
|
save_learning_dynamics_states( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
learning_dynamics_states=learning_dynamics_val_states, |
|
|
checkpoint_step=final_step, |
|
|
prefix="val", |
|
|
) |
|
|
|
|
|
|
|
|
if final_step % self.configs["checkpointing"].save_every_n_steps != 0: |
|
|
self.log(f"Step {final_step} -- πΎ Saving Final Checkpoint") |
|
|
save_checkpoint( |
|
|
configs=self.configs, |
|
|
checkpoint_step=final_step, |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
optimizer=self.optimizer, |
|
|
lr_scheduler=self.lr_scheduler, |
|
|
tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
|
|
|
if self.should_evaluate: |
|
|
evaluation_results = run_evaluation( |
|
|
evaluation_config=self.configs["evaluation"], |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
) |
|
|
self._log_evaluation_results(evaluation_results, final_step) |
|
|
save_evaluation_results( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
checkpoint_step=final_step, |
|
|
fabric=self.fabric, |
|
|
evaluation_results=evaluation_results, |
|
|
) |
|
|
|
|
|
self.log(f"π Training complete! Final step: {final_step}") |
|
|
|
|
|
if final_step < self.configs["training"].max_steps: |
|
|
self.log( |
|
|
f"\t Note: Training stopped before max steps ({self.configs['training'].max_steps})", |
|
|
level=logging.WARNING, |
|
|
) |
|
|
|
|
|
|
|
|
self.fabric.barrier() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
if torch.distributed.is_initialized(): |
|
|
torch.distributed.destroy_process_group() |
|
|
|
|
|
del self.train_dataloader |
|
|
|
|
|
self.fabric.barrier() |
|
|
|
|
|
def _training_loop(self) -> int: |
|
|
"""Execute the main training loop. |
|
|
|
|
|
This method orchestrates the core training loop and includes the following features: |
|
|
- Gradient accumulation |
|
|
- Gradient clipping |
|
|
- Periodic model evaluation and checkpointing |
|
|
- Learning Dynamics Checkpointing |
|
|
- Learning rate scheduling |
|
|
- Logging of training metrics including loss and learning rate |
|
|
- Handling of infinite/NaN losses |
|
|
|
|
|
Returns: |
|
|
int: The final step count reached during training. |
|
|
NOTE: A complete training run should match the configured max_steps. |
|
|
""" |
|
|
|
|
|
batch_step = self.initial_batch_step |
|
|
|
|
|
|
|
|
|
|
|
interval_loss = torch.tensor(0.0, device=self.fabric.device) |
|
|
interval_steps = torch.tensor(0, device=self.fabric.device) |
|
|
interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device) |
|
|
|
|
|
if self.should_compute_learning_dynamics: |
|
|
|
|
|
training_batch = {"input_ids": []} |
|
|
|
|
|
|
|
|
initial_sub_batch_step = ( |
|
|
batch_step |
|
|
* self.configs["training"].optimization.gradient_accumulation_steps |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for sub_batch_step, sub_batch in enumerate( |
|
|
self.train_iterator, start=initial_sub_batch_step |
|
|
): |
|
|
|
|
|
|
|
|
should_store_training_batch = self.should_compute_learning_dynamics and ( |
|
|
batch_step % self.configs["checkpointing"].save_every_n_steps == 0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device) |
|
|
input_ids = _input_ids[:, :-1] |
|
|
labels = _input_ids[:, 1:] |
|
|
|
|
|
if should_store_training_batch: |
|
|
gathered_input_ids = self.fabric.all_gather(_input_ids) |
|
|
|
|
|
|
|
|
|
|
|
if self.fabric.world_size > 1: |
|
|
gathered_input_ids = gathered_input_ids.reshape( |
|
|
-1, *gathered_input_ids.shape[2:] |
|
|
) |
|
|
|
|
|
training_batch["input_ids"].extend(gathered_input_ids.tolist()) |
|
|
|
|
|
|
|
|
model_output, _ = self.model(input_ids) |
|
|
model_output = model_output.transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
should_accumulate_gradients = (sub_batch_step + 1) % self.configs[ |
|
|
"training" |
|
|
].optimization.gradient_accumulation_steps != 0 |
|
|
|
|
|
with self.fabric.no_backward_sync( |
|
|
self.model, enabled=should_accumulate_gradients |
|
|
): |
|
|
loss = F.cross_entropy(model_output, labels) |
|
|
self.fabric.backward( |
|
|
loss |
|
|
/ self.configs["training"].optimization.gradient_accumulation_steps, |
|
|
model=self.model, |
|
|
) |
|
|
|
|
|
if torch.isnan(loss) or torch.isinf(loss): |
|
|
interval_inf_or_nan_count += 1 |
|
|
else: |
|
|
interval_loss += loss.item() |
|
|
interval_steps += 1 |
|
|
|
|
|
|
|
|
if should_accumulate_gradients: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if batch_step % self.configs["monitoring"].logging.log_every_n_steps == 0: |
|
|
self._log_training_metrics( |
|
|
interval_loss=interval_loss, |
|
|
interval_steps=interval_steps, |
|
|
interval_inf_or_nan_count=interval_inf_or_nan_count, |
|
|
batch_step=batch_step, |
|
|
) |
|
|
interval_loss = torch.tensor(0.0, device=self.fabric.device) |
|
|
interval_steps = torch.tensor(0, device=self.fabric.device) |
|
|
interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if batch_step % self.configs["checkpointing"].save_every_n_steps == 0: |
|
|
if self.should_compute_learning_dynamics: |
|
|
self.log(f"Step {batch_step} -- π Saving Learning Dynamics") |
|
|
|
|
|
|
|
|
training_batch_dataset = Dataset.from_dict(training_batch) |
|
|
|
|
|
learning_dynamics_train_states = compute_learning_dynamics_states( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
dataset=training_batch_dataset, |
|
|
compute_gradients=True, |
|
|
) |
|
|
|
|
|
save_learning_dynamics_states( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
checkpoint_step=batch_step, |
|
|
prefix="train", |
|
|
fabric=self.fabric, |
|
|
learning_dynamics_states=learning_dynamics_train_states, |
|
|
learning_dynamics_dataset=training_batch_dataset, |
|
|
tokenizer=self.tokenizer, |
|
|
) |
|
|
training_batch = { |
|
|
"input_ids": [] |
|
|
} |
|
|
|
|
|
|
|
|
if self.learning_dynamics_eval_dataset is not None: |
|
|
learning_dynamics_val_states = compute_learning_dynamics_states( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
dataset=self.learning_dynamics_eval_dataset, |
|
|
compute_gradients=True, |
|
|
) |
|
|
save_learning_dynamics_states( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
checkpoint_step=batch_step, |
|
|
prefix="val", |
|
|
fabric=self.fabric, |
|
|
learning_dynamics_states=learning_dynamics_val_states, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.optimizer.step() |
|
|
self.optimizer.zero_grad() |
|
|
self.lr_scheduler.step() |
|
|
|
|
|
batch_step += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if batch_step % self.configs["checkpointing"].save_every_n_steps == 0: |
|
|
self.log(f"Step {batch_step} -- πΎ Saving Checkpoint") |
|
|
save_checkpoint( |
|
|
configs=self.configs, |
|
|
checkpoint_step=batch_step, |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
optimizer=self.optimizer, |
|
|
lr_scheduler=self.lr_scheduler, |
|
|
tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
if self.should_evaluate: |
|
|
evaluation_results = run_evaluation( |
|
|
evaluation_config=self.configs["evaluation"], |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
model=self.model, |
|
|
) |
|
|
if evaluation_results is not None: |
|
|
self._log_evaluation_results(evaluation_results, batch_step) |
|
|
save_evaluation_results( |
|
|
checkpointing_config=self.configs["checkpointing"], |
|
|
fabric=self.fabric, |
|
|
evaluation_results=evaluation_results, |
|
|
checkpoint_step=batch_step, |
|
|
) |
|
|
|
|
|
|
|
|
if batch_step >= self.configs["training"].max_steps: |
|
|
break |
|
|
|
|
|
return batch_step |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _log_training_metrics( |
|
|
self, |
|
|
interval_loss: torch.Tensor, |
|
|
interval_steps: torch.Tensor, |
|
|
interval_inf_or_nan_count: torch.Tensor, |
|
|
batch_step: int, |
|
|
): |
|
|
""" |
|
|
Gathers together the training metrics computed across all processes in distributed training |
|
|
and logs them in a tree-style format. |
|
|
""" |
|
|
gathered_interval_loss = self.fabric.all_reduce( |
|
|
interval_loss, reduce_op="sum" |
|
|
).item() |
|
|
gathered_interval_inf_or_nan_count = self.fabric.all_reduce( |
|
|
interval_inf_or_nan_count, reduce_op="sum" |
|
|
).item() |
|
|
gathered_interval_steps = self.fabric.all_reduce( |
|
|
interval_steps, reduce_op="sum" |
|
|
).item() |
|
|
|
|
|
avg_loss = ( |
|
|
gathered_interval_loss / gathered_interval_steps |
|
|
if gathered_interval_steps > 0 |
|
|
else float("inf") |
|
|
) |
|
|
|
|
|
self.fabric.log("train/loss", avg_loss, step=batch_step) |
|
|
self.fabric.log( |
|
|
"trainer/inf_or_nan_count", |
|
|
gathered_interval_inf_or_nan_count, |
|
|
step=batch_step, |
|
|
) |
|
|
self.fabric.log( |
|
|
"trainer/learning_rate", |
|
|
self.lr_scheduler.get_last_lr()[0], |
|
|
step=batch_step, |
|
|
) |
|
|
|
|
|
|
|
|
self.log(f"Step {batch_step} -- π Training Metrics") |
|
|
self.log(f"βββ Loss: {avg_loss:.4f}") |
|
|
self.log(f"βββ Learning Rate: {self.lr_scheduler.get_last_lr()[0]:.2e}") |
|
|
self.log(f"βββ Inf/NaN count: {gathered_interval_inf_or_nan_count}") |
|
|
|
|
|
def _log_evaluation_results( |
|
|
self, evaluation_results: Dict[str, Any], batch_step: int |
|
|
): |
|
|
"""Log model evaluation metrics to experiment tracking system and console.""" |
|
|
self.log(f"Step {batch_step} -- π Evaluation Results") |
|
|
for i, (metric, result) in enumerate(evaluation_results.items()): |
|
|
prefix = "βββ" if i == len(evaluation_results) - 1 else "βββ" |
|
|
self.log(f"{prefix} {metric}: {result}") |
|
|
self.fabric.log(f"eval/{metric}", result, step=batch_step) |
|
|
|
|
|
def _log_training_configuration(self): |
|
|
""" |
|
|
Log training configuration details as well as runtime information about the hardware, |
|
|
software, and batch settings. |
|
|
|
|
|
This function is called at the beginning of the training loop to provide a summary of the |
|
|
training configuration. |
|
|
""" |
|
|
|
|
|
total_params = sum(p.numel() for p in self.model.parameters()) |
|
|
trainable_params = sum( |
|
|
p.numel() for p in self.model.parameters() if p.requires_grad |
|
|
) |
|
|
global_batch_size = self.configs["data"].dataloader.batch_size |
|
|
per_device_batch_size = self.train_dataloader.batch_size |
|
|
gradient_accumulation_steps = self.configs[ |
|
|
"training" |
|
|
].optimization.gradient_accumulation_steps |
|
|
|
|
|
device_type = "" |
|
|
fabric_device = str(self.fabric.device) |
|
|
if torch.cuda.is_available() and "cuda" in fabric_device: |
|
|
device_type = torch.cuda.get_device_name(self.fabric.device) |
|
|
elif torch.backends.mps.is_available() and "mps" in fabric_device: |
|
|
device_type = "MPS (Apple Silicon)" |
|
|
else: |
|
|
device_type = "CPU" |
|
|
|
|
|
training_config_path = os.path.join( |
|
|
self.configs["checkpointing"].runs_dir, |
|
|
self.configs["checkpointing"].run_name, |
|
|
"training_config.yaml", |
|
|
) |
|
|
if os.path.exists(training_config_path): |
|
|
self.log("=" * 50) |
|
|
self.log("β¨ Training Configuration") |
|
|
self.log("=" * 50) |
|
|
training_config = yaml.safe_load(open(training_config_path, "r")) |
|
|
pretty_print_yaml_config(self.logger, training_config) |
|
|
|
|
|
self.log("=" * 50) |
|
|
self.log("β Runtime Summary:") |
|
|
self.log("=" * 50) |
|
|
self.log(f"Starting from step: {self.initial_batch_step}") |
|
|
|
|
|
self.log("Model Setup:") |
|
|
self.log(f"ββ Total Parameters: {total_params:,}") |
|
|
self.log(f"ββ Trainable Parameters: {trainable_params:,}") |
|
|
|
|
|
self.log("Distributed Setup:") |
|
|
self.log(f"ββ Number of Devices: {self.fabric.world_size}") |
|
|
self.log(f"ββ Device Type: {device_type}") |
|
|
self.log( |
|
|
f"ββ Available Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB" |
|
|
if torch.cuda.is_available() |
|
|
else f"ββ Available Memory: {psutil.virtual_memory().total / 1e9:.2f} GB" |
|
|
) |
|
|
|
|
|
self.log("Software Setup:") |
|
|
self.log(f"ββ Python Version: {platform.python_version()}") |
|
|
self.log(f"ββ PyTorch Version: {torch.__version__}") |
|
|
self.log( |
|
|
f"ββ CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}" |
|
|
) |
|
|
self.log(f"ββ Operating System: {platform.system()} {platform.release()}") |
|
|
|
|
|
self.log("Batch Size Configuration:") |
|
|
self.log(f"ββ Global Batch Size: {global_batch_size}") |
|
|
self.log(f"ββ Per Device Batch Size: {per_device_batch_size}") |
|
|
self.log(f"ββ Gradient Accumulation Steps: {gradient_accumulation_steps}") |
|
|
self.log("=" * 50) |
|
|
|
|
|
@rank_zero_only |
|
|
def log(self, msg: str, level: int = logging.INFO) -> None: |
|
|
"""NOTE: Log messages only from rank zero process.""" |
|
|
self.logger.log(level, msg) |
|
|
|