|
|
""" |
|
|
Utilities for initializing components of the training process. |
|
|
|
|
|
Here, we initialize all of the components that are part of the learning process. From logging, |
|
|
and checkpointing to the optimizer to the dataset and the dataloader, this file contains the |
|
|
logic for setting up the classes and functions that are used in the training loop. |
|
|
|
|
|
As always, this code is meant to be basic. We hard-code the obvious defaults, and leave the |
|
|
more experimental stuff to you. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import warnings |
|
|
from dataclasses import fields, is_dataclass |
|
|
from datetime import datetime |
|
|
from typing import Dict, Optional, Union |
|
|
|
|
|
import lightning as L |
|
|
import torch |
|
|
import yaml |
|
|
from datasets import Dataset, DownloadConfig, load_dataset |
|
|
from datasets import config as datasets_config |
|
|
from huggingface_hub import add_collection_item, create_branch, create_repo |
|
|
from lightning.fabric.loggers import Logger as FabricLogger |
|
|
from lightning.fabric.utilities.rank_zero import rank_zero_only |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
import wandb |
|
|
from src.config import ( |
|
|
CheckpointingConfig, |
|
|
DataConfig, |
|
|
EvaluationConfig, |
|
|
ModelConfig, |
|
|
MonitoringConfig, |
|
|
TrainingConfig, |
|
|
) |
|
|
from src.model import PicoDecoder |
|
|
from src.training.utils.io import use_backoff |
|
|
from wandb.integration.lightning.fabric import WandbLogger |
|
|
|
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
message=".*This integration is tested and supported for lightning Fabric.*", |
|
|
) |
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
message=".*Please report any issues to.*", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _apply_config_overrides(config, overrides: dict): |
|
|
"""Recursively apply configuration overrides to a dataclass config object. |
|
|
|
|
|
Args: |
|
|
config: Base configuration object (must be a dataclass) |
|
|
overrides: Dictionary of override values matching config structure |
|
|
|
|
|
Returns: |
|
|
Modified config object with overrides to the config. |
|
|
""" |
|
|
for field in fields(config): |
|
|
field_value = getattr(config, field.name) |
|
|
if is_dataclass(field_value): |
|
|
_apply_config_overrides(field_value, overrides.get(field.name, {})) |
|
|
else: |
|
|
if field.name in overrides: |
|
|
setattr(config, field.name, overrides[field.name]) |
|
|
return config |
|
|
|
|
|
|
|
|
def initialize_configuration( |
|
|
config_path: Optional[str] = None, |
|
|
) -> Dict[ |
|
|
str, |
|
|
Union[ |
|
|
DataConfig, |
|
|
ModelConfig, |
|
|
TrainingConfig, |
|
|
EvaluationConfig, |
|
|
MonitoringConfig, |
|
|
CheckpointingConfig, |
|
|
], |
|
|
]: |
|
|
"""Initialize configuration objects with optional overrides from a YAML file. |
|
|
|
|
|
This function initializes all of the configuration objects, and then applies |
|
|
any overrides from the config_path file. If no config_path is provided, |
|
|
the function will use the default configuration objects. |
|
|
|
|
|
Args: |
|
|
config_path: Path to a YAML file containing configuration overrides. |
|
|
|
|
|
Returns: |
|
|
A dictionary containing the initialized configuration objects. |
|
|
""" |
|
|
data_config = DataConfig() |
|
|
model_config = ModelConfig() |
|
|
training_config = TrainingConfig() |
|
|
evaluation_config = EvaluationConfig() |
|
|
monitoring_config = MonitoringConfig() |
|
|
checkpointing_config = CheckpointingConfig() |
|
|
|
|
|
if config_path: |
|
|
overrides = yaml.safe_load(open(config_path, "r")) |
|
|
data_config = _apply_config_overrides(data_config, overrides.get("data", {})) |
|
|
model_config = _apply_config_overrides(model_config, overrides.get("model", {})) |
|
|
training_config = _apply_config_overrides( |
|
|
training_config, overrides.get("training", {}) |
|
|
) |
|
|
evaluation_config = _apply_config_overrides( |
|
|
evaluation_config, overrides.get("evaluation", {}) |
|
|
) |
|
|
monitoring_config = _apply_config_overrides( |
|
|
monitoring_config, overrides.get("monitoring", {}) |
|
|
) |
|
|
checkpointing_config = _apply_config_overrides( |
|
|
checkpointing_config, overrides.get("checkpointing", {}) |
|
|
) |
|
|
|
|
|
configs = { |
|
|
"data": data_config, |
|
|
"model": model_config, |
|
|
"training": training_config, |
|
|
"evaluation": evaluation_config, |
|
|
"monitoring": monitoring_config, |
|
|
"checkpointing": checkpointing_config, |
|
|
} |
|
|
|
|
|
return configs |
|
|
|
|
|
|
|
|
def initialize_run_dir(checkpointing_config: CheckpointingConfig) -> str: |
|
|
"""Initialize a directory for the current training run. |
|
|
|
|
|
Creates a unique directory for storing training, evaluation, and logging artifacts. |
|
|
If no run name is specified in the config, generates a timestamp-based name. |
|
|
|
|
|
Args: |
|
|
checkpointing_config: Configuration object containing run settings. |
|
|
NOTE: Must have a 'run_name' attribute that can be None, in which case |
|
|
a timestamp-based name will be generated. |
|
|
|
|
|
Returns: |
|
|
str: The path to the run directory. |
|
|
""" |
|
|
run_name = checkpointing_config.run_name |
|
|
if run_name is None: |
|
|
run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
checkpointing_config.run_name = run_name |
|
|
|
|
|
run_dir = os.path.join(checkpointing_config.runs_dir, run_name) |
|
|
|
|
|
os.makedirs(run_dir, exist_ok=True) |
|
|
return run_dir |
|
|
|
|
|
|
|
|
def initialize_fabric( |
|
|
training_config: TrainingConfig, wandb_logger: Optional[FabricLogger] = None |
|
|
): |
|
|
"""Initialize Lightning Fabric for distributed training. |
|
|
|
|
|
Sets up a Lightning Fabric instance with the specified configuration for |
|
|
handling distributed training, mixed precision, and logging. |
|
|
|
|
|
Args: |
|
|
training_config: Configuration object containing fabric settings |
|
|
(accelerator, precision, devices, etc.). |
|
|
wandb_logger: Optional weights and biases logger instance for experiment tracking |
|
|
|
|
|
Returns: |
|
|
L.Fabric: Initialized Lightning Fabric instance. |
|
|
|
|
|
Example: |
|
|
>>> fabric = initialize_fabric(training_config, wandb_logger) |
|
|
""" |
|
|
|
|
|
total_devices = ( |
|
|
training_config.fabric.num_devices * training_config.fabric.num_nodes |
|
|
) |
|
|
|
|
|
if total_devices > 1: |
|
|
strategy = "deepspeed_stage_2" |
|
|
else: |
|
|
strategy = "auto" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fabric = L.Fabric( |
|
|
accelerator=training_config.fabric.accelerator, |
|
|
precision=training_config.fabric.precision, |
|
|
devices=training_config.fabric.num_devices, |
|
|
num_nodes=training_config.fabric.num_nodes, |
|
|
loggers=[wandb_logger] if wandb_logger is not None else None, |
|
|
strategy=strategy, |
|
|
) |
|
|
|
|
|
fabric.launch() |
|
|
|
|
|
return fabric |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@use_backoff(max_retries=20) |
|
|
def initialize_dataset( |
|
|
data_config: DataConfig, |
|
|
fabric: L.Fabric, |
|
|
initial_batch_step: Optional[int] = 0, |
|
|
return_fast_forward_steps: bool = False, |
|
|
): |
|
|
"""Initialize dataset based on the given config. |
|
|
|
|
|
This function will return a dataset object, and optionally a fast_forward_steps value. |
|
|
|
|
|
The fast_forward_steps value is the number of steps that we need to fast-forward an iterator by, |
|
|
so that we can continue from a ertain batch of data we would have seen had training not previously |
|
|
stopped. Depending on how the dataset is loaded, the amount of steps to fast-forward may be |
|
|
different from the initial_batch_step value. |
|
|
|
|
|
NOTE: This functionality is primarily useful for streaming datasets (which for large |
|
|
datasets is most of the time). |
|
|
|
|
|
Args: |
|
|
data_config: Configuration object containing dataset settings. |
|
|
fabric: A Lightning Fabric instance. |
|
|
initial_batch_step: The initial batch step to fast-forward to. |
|
|
return_fast_forward_steps: Whether to return the fast-forward steps value. |
|
|
|
|
|
Returns: |
|
|
Dataset: Initialized dataset object. |
|
|
Optional[int]: Number of steps to fast-forward the iterator by, if return_fast_forward_steps is True. |
|
|
""" |
|
|
|
|
|
datasets_config.STREAMING_READ_MAX_RETRIES = 40 |
|
|
datasets_config.STREAMING_READ_RETRY_INTERVAL = 10 |
|
|
download_config = DownloadConfig( |
|
|
max_retries=20, |
|
|
) |
|
|
|
|
|
fast_forward_steps = 0 |
|
|
|
|
|
if data_config.dataset.name == "pico-lm/pretokenized-dolma": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if initial_batch_step is not None: |
|
|
examples_per_shard = 20_480 |
|
|
total_shards = 10_000 |
|
|
batches_per_shard = examples_per_shard // data_config.dataloader.batch_size |
|
|
shard_idx = initial_batch_step // batches_per_shard |
|
|
|
|
|
data_files = [ |
|
|
f"data/train-{str(_shard_idx).zfill(5)}-of-{total_shards}.parquet" |
|
|
for _shard_idx in range(shard_idx, total_shards) |
|
|
] |
|
|
|
|
|
fast_forward_steps = initial_batch_step % batches_per_shard |
|
|
else: |
|
|
data_files = None |
|
|
|
|
|
base_dataset = load_dataset( |
|
|
data_config.dataset.name, |
|
|
split="train", |
|
|
streaming=True, |
|
|
data_files=data_files, |
|
|
download_config=download_config, |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
base_dataset = load_dataset( |
|
|
data_config.dataset.name, |
|
|
split="train", |
|
|
streaming=True, |
|
|
download_config=download_config, |
|
|
) |
|
|
|
|
|
if data_config.dataset.name == "pico-lm/pretokenized-dolma": |
|
|
from .data import ShardedIterableDataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = ShardedIterableDataset( |
|
|
base_dataset, fabric.global_rank, fabric.world_size |
|
|
) |
|
|
else: |
|
|
dataset = base_dataset |
|
|
|
|
|
if return_fast_forward_steps: |
|
|
return dataset, fast_forward_steps |
|
|
else: |
|
|
return dataset |
|
|
|
|
|
|
|
|
def initialize_tokenizer(data_config: DataConfig): |
|
|
"""Initialize the tokenizer for text processing. |
|
|
|
|
|
This function can be extended to include custom tokenization logic. |
|
|
|
|
|
Args: |
|
|
data_config: Configuration object containing tokenizer settings. |
|
|
|
|
|
Returns: |
|
|
AutoTokenizer: A HuggingFace tokenizer instance. |
|
|
""" |
|
|
|
|
|
return AutoTokenizer.from_pretrained(data_config.tokenizer.name) |
|
|
|
|
|
|
|
|
def initialize_dataloader( |
|
|
data_config: DataConfig, |
|
|
training_config: TrainingConfig, |
|
|
fabric: L.Fabric, |
|
|
dataset: Dataset, |
|
|
): |
|
|
"""Initialize the DataLoader for efficient batch processing. |
|
|
|
|
|
Creates a PyTorch DataLoader that handles batching and data loading for training. |
|
|
Configured specifically for streaming tokenized text datasets. |
|
|
|
|
|
You might also want to extend this function to add a sampler, or some sort of custom |
|
|
collate function. For the default dataset, we don't need any of this, because the data are |
|
|
pre-shuffled, and pre-tokenized. |
|
|
|
|
|
Args: |
|
|
data_config: Configuration object containing dataloader settings. |
|
|
training_config: Configuration object containing training settings. |
|
|
fabric: A Lightning Fabric instance. |
|
|
dataset: A HuggingFace Dataset object containing tokenized text data. |
|
|
Expected to have 'input_ids' field in its items. |
|
|
|
|
|
Returns: |
|
|
DataLoader: PyTorch DataLoader instance configured for the dataset. |
|
|
""" |
|
|
|
|
|
def _collate_fn(batch): |
|
|
return {"input_ids": [entry["input_ids"] for entry in batch]} |
|
|
|
|
|
sub_batch_size = data_config.dataloader.batch_size // ( |
|
|
fabric.world_size * training_config.optimization.gradient_accumulation_steps |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=sub_batch_size, |
|
|
shuffle=False, |
|
|
pin_memory=True, |
|
|
collate_fn=_collate_fn, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_model(model_config: ModelConfig): |
|
|
"""Initialize the model for training. |
|
|
|
|
|
Loads in a given model implemented in the `src.model` package and returns it. |
|
|
|
|
|
NOTE: out of the box we currently only support the PicoDecoder model (a causal transformer |
|
|
language model). If you'd like to implement your own model, you can do so by adding a new |
|
|
model class in the `src.model` package, and then adding a new entry here. |
|
|
|
|
|
Args: |
|
|
model_config: Configuration object containing model settings. |
|
|
|
|
|
Returns: |
|
|
PyTorch model instance. |
|
|
|
|
|
""" |
|
|
if model_config.model_type == "pico_decoder": |
|
|
return PicoDecoder(model_config) |
|
|
else: |
|
|
raise ValueError(f"Invalid model type: {model_config.model_type}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_optimizer(training_config: TrainingConfig, model: torch.nn.Module): |
|
|
"""Initialize the optimizer for model training. |
|
|
|
|
|
Creates an optimizer instance based on the configuration settings. |
|
|
|
|
|
Add whatever other optimizers you want here. |
|
|
|
|
|
Args: |
|
|
training_config: Configuration object containing optimizer settings. |
|
|
Must have: |
|
|
- optimization.optimizer (str): Name of the optimizer ("adamw") |
|
|
- optimization.lr (float): Learning rate for the optimizer |
|
|
model: PyTorch model whose parameters will be optimized. |
|
|
|
|
|
Returns: |
|
|
torch.optim.Optimizer: Configured optimizer instance. |
|
|
|
|
|
""" |
|
|
|
|
|
if training_config.optimization.optimizer == "adamw": |
|
|
optimizer = torch.optim.AdamW( |
|
|
model.parameters(), lr=training_config.optimization.lr |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Invalid optimizer: {training_config.optimization.optimizer}") |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
def initialize_lr_scheduler( |
|
|
training_config: TrainingConfig, optimizer: torch.optim.Optimizer |
|
|
): |
|
|
"""Initialize a learning rate scheduler with warmup and decay. |
|
|
|
|
|
The default is a learning rate scheduler that implements a linear warmup followed by |
|
|
linear decay. The learning rate increases linearly from 0 to the initial lr |
|
|
during warmup, then decreases linearly to 0 during the remaining steps. |
|
|
|
|
|
Add other types of learning rate schedulers here. |
|
|
|
|
|
Args: |
|
|
training_config: Configuration object containing optimizer and scheduler settings. |
|
|
optimizer: PyTorch optimizer whose learning rate will be scheduled. |
|
|
|
|
|
Returns: |
|
|
torch.optim.lr_scheduler.LambdaLR: Learning rate scheduler instance. |
|
|
""" |
|
|
|
|
|
if training_config.optimization.lr_scheduler == "linear_with_warmup": |
|
|
|
|
|
|
|
|
def _lr_lambda(curr_step, num_warmup_steps, max_steps): |
|
|
if curr_step < num_warmup_steps: |
|
|
return float(curr_step) / float(max(1, num_warmup_steps)) |
|
|
else: |
|
|
return max( |
|
|
0.0, |
|
|
float(max_steps - curr_step) |
|
|
/ float(max(1, max_steps - num_warmup_steps)), |
|
|
) |
|
|
|
|
|
lr_lambda = lambda step: _lr_lambda( |
|
|
step, |
|
|
training_config.optimization.lr_warmup_steps, |
|
|
training_config.max_steps, |
|
|
) |
|
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR( |
|
|
optimizer, |
|
|
lr_lambda, |
|
|
) |
|
|
elif training_config.optimization.lr_scheduler == "cosine": |
|
|
|
|
|
|
|
|
def _cosine_lr_lambda(curr_step, num_warmup_steps, max_steps): |
|
|
if curr_step < num_warmup_steps: |
|
|
|
|
|
return float(curr_step) / float(max(1, num_warmup_steps)) |
|
|
else: |
|
|
|
|
|
progress = float(curr_step - num_warmup_steps) / float( |
|
|
max(1, max_steps - num_warmup_steps) |
|
|
) |
|
|
return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) |
|
|
|
|
|
lr_lambda = lambda step: _cosine_lr_lambda( |
|
|
step, |
|
|
training_config.optimization.lr_warmup_steps, |
|
|
training_config.max_steps, |
|
|
) |
|
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR( |
|
|
optimizer, |
|
|
lr_lambda, |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Invalid learning rate scheduler: {training_config.optimization.lr_scheduler}" |
|
|
) |
|
|
|
|
|
return lr_scheduler |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _initialize_log_file(checkpointing_config: CheckpointingConfig) -> str: |
|
|
"""Create and initialize a timestamped log file in the run's log directory. |
|
|
|
|
|
Sets up a log file with a unique timestamp in the run's logging directory. |
|
|
Creates the necessary directory structure if it doesn't exist. |
|
|
|
|
|
Directory Structure: |
|
|
{checkpointing_config.runs_dir}/ |
|
|
└── {checkpointing_config.run_name}/ |
|
|
└── {checkpointing_config.logs_dir}/ |
|
|
└── log_YYYYMMDD_HHMMSS.txt |
|
|
|
|
|
Args: |
|
|
checkpointing_config: Configuration object containing checkpointing settings. |
|
|
|
|
|
Returns: |
|
|
str: Absolute path to the created log file. |
|
|
|
|
|
""" |
|
|
|
|
|
run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name) |
|
|
logs_dir = os.path.join(run_dir, checkpointing_config.logs_dir) |
|
|
os.makedirs(logs_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
log_file_name = f"log_{timestamp}.log" |
|
|
log_file_path = os.path.join(logs_dir, log_file_name) |
|
|
|
|
|
open(log_file_path, "w").close() |
|
|
|
|
|
return log_file_path |
|
|
|
|
|
|
|
|
@use_backoff() |
|
|
def initialize_wandb( |
|
|
monitoring_config: MonitoringConfig, checkpointing_config: CheckpointingConfig |
|
|
): |
|
|
"""Initialize Weights and Biases. |
|
|
|
|
|
This function initializes Weights and Biases based on the configuration settings. |
|
|
|
|
|
Args: |
|
|
monitoring_config: Configuration object containing monitoring settings. |
|
|
checkpointing_config: Configuration object containing checkpointing settings. |
|
|
|
|
|
Returns: |
|
|
Optional[WandbLogger]: An experiment tracker instance. |
|
|
""" |
|
|
|
|
|
assert ( |
|
|
monitoring_config.wandb.project is not None |
|
|
and monitoring_config.wandb.project != "" |
|
|
), "Wandb project must be provided if wandb is to be used." |
|
|
assert ( |
|
|
monitoring_config.wandb.entity is not None |
|
|
and monitoring_config.wandb.entity != "" |
|
|
), "Wandb entity must be provided if wandb is to be used." |
|
|
|
|
|
_run_id = None |
|
|
if checkpointing_config.training.auto_resume: |
|
|
|
|
|
previous_runs = wandb.Api().runs( |
|
|
path=f"{monitoring_config.wandb.entity}/{monitoring_config.wandb.project}", |
|
|
filters={"display_name": checkpointing_config.run_name}, |
|
|
) |
|
|
try: |
|
|
if len(previous_runs) == 1: |
|
|
_run_id = previous_runs[0].id |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
wandb_logger = WandbLogger( |
|
|
project=monitoring_config.wandb.project, |
|
|
entity=monitoring_config.wandb.entity, |
|
|
id=_run_id, |
|
|
name=checkpointing_config.run_name, |
|
|
) |
|
|
|
|
|
return wandb_logger |
|
|
|
|
|
|
|
|
@rank_zero_only |
|
|
def initialize_logging( |
|
|
monitoring_config: MonitoringConfig, |
|
|
checkpointing_config: CheckpointingConfig, |
|
|
fabric: L.Fabric, |
|
|
): |
|
|
"""Initialize logging system with default logging, to file and console. |
|
|
|
|
|
The default logging system uses a file handler and a stream handler. |
|
|
|
|
|
NOTE: this function is only called on rank 0. |
|
|
|
|
|
Args: |
|
|
monitoring_config: Configuration object containing monitoring settings. |
|
|
checkpointing_config: Configuration object containing checkpointing settings. |
|
|
|
|
|
Returns: |
|
|
logger: Standard Python logger configured for file and console output |
|
|
""" |
|
|
|
|
|
|
|
|
logger = logging.getLogger("pico-train") |
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
log_file_path = _initialize_log_file(checkpointing_config) |
|
|
file_handler = logging.FileHandler(log_file_path, encoding="utf-8") |
|
|
file_handler.setLevel(monitoring_config.logging.log_level) |
|
|
|
|
|
|
|
|
formatter = logging.Formatter( |
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
|
) |
|
|
file_handler.setFormatter(formatter) |
|
|
|
|
|
|
|
|
logger.addHandler(file_handler) |
|
|
|
|
|
|
|
|
stream_handler = logging.StreamHandler() |
|
|
stream_handler.setLevel(monitoring_config.logging.log_level) |
|
|
stream_handler.setFormatter(formatter) |
|
|
logger.addHandler(stream_handler) |
|
|
|
|
|
return logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@rank_zero_only |
|
|
@use_backoff() |
|
|
def initialize_hf_checkpointing( |
|
|
checkpointing_config: CheckpointingConfig, fabric: L.Fabric |
|
|
): |
|
|
"""Initialize HuggingFace Checkpointing. |
|
|
|
|
|
Creates a HuggingFace repository if it doesn't exist, and creates a branch named after the run. |
|
|
|
|
|
NOTE: this function is only called on rank 0. |
|
|
|
|
|
Args: |
|
|
checkpointing_config: Configuration object containing checkpointing settings; must have |
|
|
a 'hf_checkpoint' attribute that specifies the HuggingFace repository id and |
|
|
collection slug (if applicable) to save the checkpoint to. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If unable to create HuggingFace repository after multiple attempts. |
|
|
""" |
|
|
|
|
|
huggingface_repo_id = checkpointing_config.hf_checkpoint.repo_id |
|
|
assert ( |
|
|
huggingface_repo_id is not None and huggingface_repo_id != "" |
|
|
), "hf_checkpoint.repo_id must be provided." |
|
|
|
|
|
repo = create_repo(huggingface_repo_id, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpointing_config.hf_checkpoint.repo_id = repo.repo_id |
|
|
huggingface_repo_id = repo.repo_id |
|
|
|
|
|
if checkpointing_config.hf_checkpoint.collection_slug: |
|
|
add_collection_item( |
|
|
checkpointing_config.hf_checkpoint.collection_slug, |
|
|
huggingface_repo_id, |
|
|
repo.repo_type, |
|
|
exists_ok=True, |
|
|
) |
|
|
|
|
|
create_branch( |
|
|
repo_id=huggingface_repo_id, |
|
|
branch=checkpointing_config.run_name, |
|
|
exist_ok=True, |
|
|
) |
|
|
|