|
|
""" |
|
|
Utilities for checkpointing training-related states (i.e. model, optimizer, lr_scheduler, etc.) |
|
|
|
|
|
We save both a HuggingFace model and a Fabric-specific checkpoint. The HuggingFace model is |
|
|
saved at the step-specific checkpoint directory, while the Fabric-specific checkpoint is saved |
|
|
in a subdirectory. This is done to facilitate easier versioning of the HuggingFace model files |
|
|
(which are what gets uploaded to the Hub). |
|
|
""" |
|
|
|
|
|
import os |
|
|
from dataclasses import asdict |
|
|
from typing import Any, Dict, Tuple, Union |
|
|
|
|
|
import yaml |
|
|
from huggingface_hub import upload_file, upload_folder |
|
|
from lightning.fabric import Fabric |
|
|
from lightning.fabric.strategies import DeepSpeedStrategy |
|
|
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states |
|
|
from torch import nn |
|
|
from torch.optim import Optimizer |
|
|
from torch.optim.lr_scheduler import LRScheduler |
|
|
from transformers import PreTrainedTokenizerBase |
|
|
|
|
|
from src.config import CheckpointingConfig |
|
|
from src.training.utils.io import use_backoff |
|
|
|
|
|
|
|
|
@use_backoff() |
|
|
def load_checkpoint( |
|
|
checkpointing_config: CheckpointingConfig, |
|
|
checkpoint_step: Union[str, int], |
|
|
fabric: Fabric, |
|
|
model: nn.Module, |
|
|
optimizer: Optimizer, |
|
|
lr_scheduler: LRScheduler, |
|
|
) -> Tuple[nn.Module, Optimizer, LRScheduler, int]: |
|
|
"""Load model checkpoint and associated states from a given step. |
|
|
|
|
|
Args: |
|
|
checkpointing_config: Configuration object containing checkpoint settings |
|
|
checkpoint_step: The step at which to load the checkpoint |
|
|
fabric: Lightning Fabric instance for distributed training support |
|
|
model: The model instance to load weights into |
|
|
optimizer: The optimizer instance to load states into |
|
|
lr_scheduler: The learning rate scheduler to load states into |
|
|
|
|
|
Returns: |
|
|
Tuple containing the model, optimizer, lr_scheduler, and checkpoint step. |
|
|
Returns None if no checkpoint is found. |
|
|
""" |
|
|
|
|
|
if isinstance(checkpoint_step, int): |
|
|
checkpoint_step = f"step_{checkpoint_step}" |
|
|
|
|
|
checkpoint_path = os.path.join( |
|
|
checkpointing_config.runs_dir, |
|
|
checkpointing_config.run_name, |
|
|
checkpointing_config.checkpoints_dir, |
|
|
checkpoint_step, |
|
|
) |
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
return None |
|
|
|
|
|
|
|
|
fabric_checkpoint_path = os.path.join( |
|
|
checkpoint_path, checkpointing_config.fabric_checkpoint_dir |
|
|
) |
|
|
|
|
|
checkpoint_state = { |
|
|
"_model": model, |
|
|
"_optimizer": optimizer, |
|
|
"_lr_scheduler": lr_scheduler, |
|
|
} |
|
|
|
|
|
if not isinstance(fabric.strategy, DeepSpeedStrategy): |
|
|
fabric_load_file = os.path.join( |
|
|
fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename |
|
|
) |
|
|
else: |
|
|
|
|
|
fabric_load_file = fabric_checkpoint_path |
|
|
|
|
|
extra_state = fabric.load(os.path.join(fabric_load_file), state=checkpoint_state) |
|
|
|
|
|
|
|
|
checkpoint_step = extra_state["_checkpoint_step"] |
|
|
|
|
|
if "_rng_states" in extra_state: |
|
|
_rng_states = extra_state["_rng_states"] |
|
|
_set_rng_states(_rng_states) |
|
|
|
|
|
return model, optimizer, lr_scheduler, checkpoint_step |
|
|
|
|
|
|
|
|
@use_backoff() |
|
|
def save_checkpoint( |
|
|
configs: Dict[str, Any], |
|
|
checkpoint_step: int, |
|
|
fabric: Fabric, |
|
|
model: nn.Module, |
|
|
optimizer: Optimizer, |
|
|
lr_scheduler: LRScheduler, |
|
|
tokenizer: PreTrainedTokenizerBase, |
|
|
upload_logs: bool = False, |
|
|
) -> None: |
|
|
"""Save training checkpoint and associated states to disk and optionally to HuggingFace Hub. |
|
|
|
|
|
We save the following files: |
|
|
- HuggingFace model files (config.json, pytorch_model.bin) |
|
|
- Tokenizer files (vocab.json, merges.txt) |
|
|
- Fabric-specific files - fabric state of the model, optimizer, and lr_scheduler. If using |
|
|
DeepSpeed, the checkpoint is saved in a subdirectory, otherwise it is saved in a single file. |
|
|
|
|
|
Note that the HuggingFace model files are saved at the step-specific checkpoint directory, while the |
|
|
Fabric-specific files are saved in a subdirectory. This is done to facilitate easier |
|
|
versioning of the HuggingFace model files (which are what gets uploaded to the Hub). |
|
|
|
|
|
NOTE: Why do we save a HF model at all? We do this because it makes it easier to load the model |
|
|
in a separate script for evaluation and to play nicely with the HuggingFace Hub. |
|
|
|
|
|
Creates a versioned checkpoint directory with the following structure: |
|
|
|
|
|
{checkpointing_config.runs_dir}/ |
|
|
βββ {checkpointing_config.run_name}/ |
|
|
βββ training_config.yaml # Training config |
|
|
βββ {checkpointing_config.checkpoints_dir}/ |
|
|
βββ step_{checkpoint_step}/ |
|
|
β βββ config.json # HuggingFace model config |
|
|
β βββ model.safetensors # HuggingFace model weights |
|
|
β βββ pico_{model_type}.py # HuggingFace custom model class |
|
|
β βββ tokenizer.json # Tokenizer vocab |
|
|
β βββ tokenizer_config.json # Tokenizer config |
|
|
β βββ {checkpointing_config.fabric_checkpoint_dir}/ # Fabric-specific files |
|
|
β βββ checkpoint/ # Distributed model checkpoint files (if using DeepSpeed) |
|
|
β OR |
|
|
β βββ checkpoint.pt # Single checkpoint file (if using other strategies) |
|
|
βββ latest -> step_{checkpoint_step}/ |
|
|
|
|
|
Args: |
|
|
configs: A dictionary containing the initialized configuration objects. |
|
|
checkpoint_step: The current training checkpoint step (i.e. number of learning steps taken) |
|
|
fabric: Lightning Fabric instance for distributed training support |
|
|
model: The model instance to save |
|
|
optimizer: The optimizer instance to save |
|
|
lr_scheduler: The learning rate scheduler to save |
|
|
tokenizer: The tokenizer to save |
|
|
upload_logs: Whether to upload training logs to HF Hub (default: False) |
|
|
|
|
|
""" |
|
|
|
|
|
checkpointing_config = configs["checkpointing"] |
|
|
|
|
|
|
|
|
runs_dir = checkpointing_config.runs_dir |
|
|
checkpoints_dir = checkpointing_config.checkpoints_dir |
|
|
fabric_checkpoint_dir = checkpointing_config.fabric_checkpoint_dir |
|
|
logs_dir = checkpointing_config.logs_dir |
|
|
|
|
|
run_path = os.path.join(runs_dir, checkpointing_config.run_name) |
|
|
root_checkpoint_path = os.path.join(run_path, checkpoints_dir) |
|
|
checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}") |
|
|
|
|
|
|
|
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if fabric.global_rank == 0: |
|
|
hf_model = model.convert_to_hf_model() |
|
|
hf_model.save_pretrained(checkpoint_path) |
|
|
tokenizer.save_pretrained(checkpoint_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fabric_checkpoint_path = os.path.join(checkpoint_path, fabric_checkpoint_dir) |
|
|
os.makedirs(fabric_checkpoint_path, exist_ok=True) |
|
|
|
|
|
|
|
|
checkpoint_state = { |
|
|
"_model": model, |
|
|
"_optimizer": optimizer, |
|
|
"_lr_scheduler": lr_scheduler, |
|
|
"_checkpoint_step": checkpoint_step, |
|
|
} |
|
|
|
|
|
if not isinstance(fabric.strategy, DeepSpeedStrategy): |
|
|
checkpoint_state["_rng_states"] = _collect_rng_states() |
|
|
fabric_save_file = os.path.join( |
|
|
fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename |
|
|
) |
|
|
else: |
|
|
|
|
|
fabric_save_file = fabric_checkpoint_path |
|
|
|
|
|
fabric.save(fabric_save_file, checkpoint_state) |
|
|
|
|
|
if fabric.global_rank == 0: |
|
|
|
|
|
config_path = os.path.join(run_path, "training_config.yaml") |
|
|
if not os.path.exists(config_path): |
|
|
|
|
|
_training_config = {} |
|
|
for config_name, config in configs.items(): |
|
|
_training_config[config_name] = asdict(config) |
|
|
with open(config_path, "w") as f: |
|
|
yaml.dump(_training_config, f) |
|
|
|
|
|
|
|
|
latest_symlink_path = os.path.join(root_checkpoint_path, "latest") |
|
|
if os.path.lexists(latest_symlink_path): |
|
|
os.remove(latest_symlink_path) |
|
|
os.symlink( |
|
|
f"step_{checkpoint_step}", latest_symlink_path, target_is_directory=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if fabric.global_rank == 0: |
|
|
|
|
|
|
|
|
if checkpointing_config.save_to_hf: |
|
|
repo_id = checkpointing_config.hf_checkpoint.repo_id |
|
|
|
|
|
|
|
|
hf_model.push_to_hub( |
|
|
repo_id=repo_id, |
|
|
commit_message=f"Saving HF Model -- Step {checkpoint_step}", |
|
|
revision=checkpointing_config.run_name, |
|
|
token=os.getenv("HF_TOKEN"), |
|
|
) |
|
|
|
|
|
if checkpoint_step == 0: |
|
|
|
|
|
tokenizer.push_to_hub( |
|
|
repo_id=repo_id, |
|
|
commit_message=f"Saving Tokenizer -- Step {checkpoint_step}", |
|
|
revision=checkpointing_config.run_name, |
|
|
token=os.getenv("HF_TOKEN"), |
|
|
) |
|
|
|
|
|
|
|
|
upload_file( |
|
|
path_or_fileobj=config_path, |
|
|
path_in_repo="training_config.yaml", |
|
|
repo_id=repo_id, |
|
|
commit_message=f"Saving Training Config -- Step {checkpoint_step}", |
|
|
revision=checkpointing_config.run_name, |
|
|
token=os.getenv("HF_TOKEN"), |
|
|
) |
|
|
|
|
|
|
|
|
upload_folder( |
|
|
folder_path=fabric_checkpoint_path, |
|
|
path_in_repo=fabric_checkpoint_dir, |
|
|
repo_id=repo_id, |
|
|
commit_message=f"Saving Fabric Checkpoint -- Step {checkpoint_step}", |
|
|
revision=checkpointing_config.run_name, |
|
|
token=os.getenv("HF_TOKEN"), |
|
|
) |
|
|
|
|
|
|
|
|
if upload_logs: |
|
|
logs_path = os.path.join(run_path, logs_dir) |
|
|
upload_folder( |
|
|
folder_path=logs_path, |
|
|
path_in_repo=logs_dir, |
|
|
repo_id=repo_id, |
|
|
commit_message=f"Saving Logs -- Step {checkpoint_step}", |
|
|
revision=checkpointing_config.run_name, |
|
|
token=os.getenv("HF_TOKEN"), |
|
|
) |
|
|
|