| | """ |
| | Checkpointing Config |
| | |
| | Specifies the hyperparameters for the checkpointing process; checkpointing is used to save |
| | the model and optimizer states, as well as the learning dynamics metrics. |
| | """ |
| |
|
| | from dataclasses import dataclass, field |
| | from typing import List, Optional |
| |
|
| | from ._constants import ( |
| | CHECKPOINTS_DIR, |
| | EVAL_RESULTS_DIR, |
| | FABRIC_CHECKPOINT_DIR, |
| | FABRIC_CHECKPOINT_FILENAME, |
| | LEARNING_DYNAMICS_DIR, |
| | LOGS_DIR, |
| | RUNS_DIR, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class TrainingCheckpointingConfig: |
| | |
| | auto_resume: bool = True |
| |
|
| |
|
| | @dataclass |
| | class EvaluationCheckpointingConfig: |
| | |
| | eval_results_dir: str = EVAL_RESULTS_DIR |
| |
|
| |
|
| | @dataclass |
| | class LearningDynamicsCheckpointingConfig: |
| | |
| | layer_suffixes: List[str] = field( |
| | default_factory=lambda: [ |
| | "attention.v_proj", |
| | "attention.o_proj", |
| | "swiglu.w_2", |
| | ] |
| | ) |
| |
|
| | |
| | |
| | sequence_idx: int = -1 |
| |
|
| | |
| | batch_size: int = 8 |
| |
|
| | |
| | |
| | |
| | eval_data: Optional[str] = "pico-lm/pretokenized-paloma-tinsy" |
| |
|
| |
|
| | @dataclass |
| | class HuggingFaceCheckpointingConfig: |
| | |
| | repo_id: str = "" |
| |
|
| | |
| | collection_slug: Optional[str] = None |
| |
|
| |
|
| | @dataclass |
| | class CheckpointingConfig: |
| | |
| | run_name: Optional[str] = None |
| |
|
| | |
| | runs_dir: str = RUNS_DIR |
| | checkpoints_dir: str = CHECKPOINTS_DIR |
| | logs_dir: str = LOGS_DIR |
| | fabric_checkpoint_dir: str = FABRIC_CHECKPOINT_DIR |
| | fabric_checkpoint_filename: str = FABRIC_CHECKPOINT_FILENAME |
| | learning_dynamics_dir: str = LEARNING_DYNAMICS_DIR |
| |
|
| | |
| | save_every_n_steps: int = 1000 |
| |
|
| | |
| | save_to_hf: Optional[bool] = False |
| | hf_checkpoint: HuggingFaceCheckpointingConfig = field( |
| | default_factory=HuggingFaceCheckpointingConfig |
| | ) |
| |
|
| | training: TrainingCheckpointingConfig = field( |
| | default_factory=TrainingCheckpointingConfig |
| | ) |
| | evaluation: EvaluationCheckpointingConfig = field( |
| | default_factory=EvaluationCheckpointingConfig |
| | ) |
| | learning_dynamics: LearningDynamicsCheckpointingConfig = field( |
| | default_factory=LearningDynamicsCheckpointingConfig |
| | ) |
| |
|