File size: 11,671 Bytes
feba2ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
"""
Utilities for checkpointing training-related states (i.e. model, optimizer, lr_scheduler, etc.)
We save both a HuggingFace model and a Fabric-specific checkpoint. The HuggingFace model is
saved at the step-specific checkpoint directory, while the Fabric-specific checkpoint is saved
in a subdirectory. This is done to facilitate easier versioning of the HuggingFace model files
(which are what gets uploaded to the Hub).
"""
import os
from dataclasses import asdict
from typing import Any, Dict, Tuple, Union
import yaml
from huggingface_hub import upload_file, upload_folder
from lightning.fabric import Fabric
from lightning.fabric.strategies import DeepSpeedStrategy
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedTokenizerBase
from src.config import CheckpointingConfig
from src.training.utils.io import use_backoff
@use_backoff()
def load_checkpoint(
checkpointing_config: CheckpointingConfig,
checkpoint_step: Union[str, int],
fabric: Fabric,
model: nn.Module,
optimizer: Optimizer,
lr_scheduler: LRScheduler,
) -> Tuple[nn.Module, Optimizer, LRScheduler, int]:
"""Load model checkpoint and associated states from a given step.
Args:
checkpointing_config: Configuration object containing checkpoint settings
checkpoint_step: The step at which to load the checkpoint
fabric: Lightning Fabric instance for distributed training support
model: The model instance to load weights into
optimizer: The optimizer instance to load states into
lr_scheduler: The learning rate scheduler to load states into
Returns:
Tuple containing the model, optimizer, lr_scheduler, and checkpoint step.
Returns None if no checkpoint is found.
"""
if isinstance(checkpoint_step, int):
checkpoint_step = f"step_{checkpoint_step}"
checkpoint_path = os.path.join(
checkpointing_config.runs_dir,
checkpointing_config.run_name,
checkpointing_config.checkpoints_dir,
checkpoint_step,
)
if not os.path.exists(checkpoint_path):
return None
# Load from specified fabric checkpoint subdirectory
fabric_checkpoint_path = os.path.join(
checkpoint_path, checkpointing_config.fabric_checkpoint_dir
)
checkpoint_state = {
"_model": model,
"_optimizer": optimizer,
"_lr_scheduler": lr_scheduler,
}
if not isinstance(fabric.strategy, DeepSpeedStrategy):
fabric_load_file = os.path.join(
fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
)
else:
# Deepspeed checkpoints create sub-directory with distributed checkpoint file
fabric_load_file = fabric_checkpoint_path
extra_state = fabric.load(os.path.join(fabric_load_file), state=checkpoint_state)
# NOTE: extra_state will contain any additional states that were saved in the checkpoint
checkpoint_step = extra_state["_checkpoint_step"]
if "_rng_states" in extra_state:
_rng_states = extra_state["_rng_states"]
_set_rng_states(_rng_states)
return model, optimizer, lr_scheduler, checkpoint_step
@use_backoff()
def save_checkpoint(
configs: Dict[str, Any],
checkpoint_step: int,
fabric: Fabric,
model: nn.Module,
optimizer: Optimizer,
lr_scheduler: LRScheduler,
tokenizer: PreTrainedTokenizerBase,
upload_logs: bool = False,
) -> None:
"""Save training checkpoint and associated states to disk and optionally to HuggingFace Hub.
We save the following files:
- HuggingFace model files (config.json, pytorch_model.bin)
- Tokenizer files (vocab.json, merges.txt)
- Fabric-specific files - fabric state of the model, optimizer, and lr_scheduler. If using
DeepSpeed, the checkpoint is saved in a subdirectory, otherwise it is saved in a single file.
Note that the HuggingFace model files are saved at the step-specific checkpoint directory, while the
Fabric-specific files are saved in a subdirectory. This is done to facilitate easier
versioning of the HuggingFace model files (which are what gets uploaded to the Hub).
NOTE: Why do we save a HF model at all? We do this because it makes it easier to load the model
in a separate script for evaluation and to play nicely with the HuggingFace Hub.
Creates a versioned checkpoint directory with the following structure:
{checkpointing_config.runs_dir}/
βββ {checkpointing_config.run_name}/
βββ training_config.yaml # Training config
βββ {checkpointing_config.checkpoints_dir}/
βββ step_{checkpoint_step}/
β βββ config.json # HuggingFace model config
β βββ model.safetensors # HuggingFace model weights
β βββ pico_{model_type}.py # HuggingFace custom model class
β βββ tokenizer.json # Tokenizer vocab
β βββ tokenizer_config.json # Tokenizer config
β βββ {checkpointing_config.fabric_checkpoint_dir}/ # Fabric-specific files
β βββ checkpoint/ # Distributed model checkpoint files (if using DeepSpeed)
β OR
β βββ checkpoint.pt # Single checkpoint file (if using other strategies)
βββ latest -> step_{checkpoint_step}/
Args:
configs: A dictionary containing the initialized configuration objects.
checkpoint_step: The current training checkpoint step (i.e. number of learning steps taken)
fabric: Lightning Fabric instance for distributed training support
model: The model instance to save
optimizer: The optimizer instance to save
lr_scheduler: The learning rate scheduler to save
tokenizer: The tokenizer to save
upload_logs: Whether to upload training logs to HF Hub (default: False)
"""
checkpointing_config = configs["checkpointing"]
# Get the directories from the training config
runs_dir = checkpointing_config.runs_dir
checkpoints_dir = checkpointing_config.checkpoints_dir
fabric_checkpoint_dir = checkpointing_config.fabric_checkpoint_dir
logs_dir = checkpointing_config.logs_dir
run_path = os.path.join(runs_dir, checkpointing_config.run_name)
root_checkpoint_path = os.path.join(run_path, checkpoints_dir)
checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}")
# Create directories
os.makedirs(checkpoint_path, exist_ok=True)
########################################################
#
# Save HuggingFace files
#
########################################################
# NOTE: we convert the Pico model to a HuggingFace model before saving it. See `model.py`
# for more details.
if fabric.global_rank == 0:
hf_model = model.convert_to_hf_model()
hf_model.save_pretrained(checkpoint_path)
tokenizer.save_pretrained(checkpoint_path)
########################################################
#
# Save Fabric-specific files
#
########################################################
# Create fabric-specific subdirectory
fabric_checkpoint_path = os.path.join(checkpoint_path, fabric_checkpoint_dir)
os.makedirs(fabric_checkpoint_path, exist_ok=True)
# Save model states (use underscore to avoid conflicts with third-party libraries)
checkpoint_state = {
"_model": model,
"_optimizer": optimizer,
"_lr_scheduler": lr_scheduler,
"_checkpoint_step": checkpoint_step,
}
if not isinstance(fabric.strategy, DeepSpeedStrategy):
checkpoint_state["_rng_states"] = _collect_rng_states()
fabric_save_file = os.path.join(
fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
)
else:
# Deepspeed checkpoints create sub-directory with distributed checkpoint file
fabric_save_file = fabric_checkpoint_path
fabric.save(fabric_save_file, checkpoint_state)
if fabric.global_rank == 0:
# Save config in fabric directory
config_path = os.path.join(run_path, "training_config.yaml")
if not os.path.exists(config_path):
# Converting dataclasses to joined dicts and saving to file
_training_config = {}
for config_name, config in configs.items():
_training_config[config_name] = asdict(config)
with open(config_path, "w") as f:
yaml.dump(_training_config, f)
# Update latest symlink
latest_symlink_path = os.path.join(root_checkpoint_path, "latest")
if os.path.lexists(latest_symlink_path):
os.remove(latest_symlink_path)
os.symlink(
f"step_{checkpoint_step}", latest_symlink_path, target_is_directory=True
)
########################################################
#
# Push to HuggingFace Hub (if configured)
#
########################################################
if fabric.global_rank == 0:
# Push only on rank zero thread
if checkpointing_config.save_to_hf:
repo_id = checkpointing_config.hf_checkpoint.repo_id
# Upload the HF model
hf_model.push_to_hub(
repo_id=repo_id,
commit_message=f"Saving HF Model -- Step {checkpoint_step}",
revision=checkpointing_config.run_name,
token=os.getenv("HF_TOKEN"),
)
if checkpoint_step == 0:
# Uploading Tokenizer during first step since it never changes
tokenizer.push_to_hub(
repo_id=repo_id,
commit_message=f"Saving Tokenizer -- Step {checkpoint_step}",
revision=checkpointing_config.run_name,
token=os.getenv("HF_TOKEN"),
)
# Upload training config, also only in first step
upload_file(
path_or_fileobj=config_path,
path_in_repo="training_config.yaml",
repo_id=repo_id,
commit_message=f"Saving Training Config -- Step {checkpoint_step}",
revision=checkpointing_config.run_name,
token=os.getenv("HF_TOKEN"),
)
# Upload the fabric checkpoint directory
upload_folder(
folder_path=fabric_checkpoint_path,
path_in_repo=fabric_checkpoint_dir,
repo_id=repo_id,
commit_message=f"Saving Fabric Checkpoint -- Step {checkpoint_step}",
revision=checkpointing_config.run_name,
token=os.getenv("HF_TOKEN"),
)
# Upload logs if requested
if upload_logs:
logs_path = os.path.join(run_path, logs_dir)
upload_folder(
folder_path=logs_path,
path_in_repo=logs_dir,
repo_id=repo_id,
commit_message=f"Saving Logs -- Step {checkpoint_step}",
revision=checkpointing_config.run_name,
token=os.getenv("HF_TOKEN"),
)
|