File size: 2,992 Bytes
feba2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Checkpointing Config

Specifies the hyperparameters for the checkpointing process; checkpointing is used to save
the model and optimizer states, as well as the learning dynamics metrics.
"""

from dataclasses import dataclass, field
from typing import List, Optional

from ._constants import (
    CHECKPOINTS_DIR,
    EVAL_RESULTS_DIR,
    FABRIC_CHECKPOINT_DIR,
    FABRIC_CHECKPOINT_FILENAME,
    LEARNING_DYNAMICS_DIR,
    LOGS_DIR,
    RUNS_DIR,
)


@dataclass
class TrainingCheckpointingConfig:
    # Automatically resume training from the most recent checkpoint
    auto_resume: bool = True


@dataclass
class EvaluationCheckpointingConfig:
    # Directory in which evaluation results are saved
    eval_results_dir: str = EVAL_RESULTS_DIR


@dataclass
class LearningDynamicsCheckpointingConfig:
    # Suffixes of the layers to compute learning dynamics for
    layer_suffixes: List[str] = field(
        default_factory=lambda: [
            "attention.v_proj",
            "attention.o_proj",
            "swiglu.w_2",
        ]
    )

    # Sequence index at which to extract hidden states; by default, we extract the hidden states
    # at the last token of the sequence (-1)
    sequence_idx: int = -1

    # size of the sub-batch used for extracting learning dynamics states
    batch_size: int = 8

    # Path to evaluation dataset - used across learning dynamics checkpointing for consistency
    # NOTE: set to None to disable extracting learning dynamics states for an eval_batch
    # NOTE: this dataset should be small, ideally just a batch of additional data
    eval_data: Optional[str] = "pico-lm/pretokenized-paloma-tinsy"


@dataclass
class HuggingFaceCheckpointingConfig:
    # Should be in the format of <(username or organization name)>/<repo_name>, e.g. pico-lm/demo
    repo_id: str = ""

    # HuggingFace Collection Slug (specifies a tag for the run)
    collection_slug: Optional[str] = None


@dataclass
class CheckpointingConfig:
    # Assign a name to the run
    run_name: Optional[str] = None

    # Defining checkpointing directories
    runs_dir: str = RUNS_DIR
    checkpoints_dir: str = CHECKPOINTS_DIR
    logs_dir: str = LOGS_DIR
    fabric_checkpoint_dir: str = FABRIC_CHECKPOINT_DIR
    fabric_checkpoint_filename: str = FABRIC_CHECKPOINT_FILENAME
    learning_dynamics_dir: str = LEARNING_DYNAMICS_DIR

    # How often to save checkpoints
    save_every_n_steps: int = 1000

    # Whether to save checkpoints to HuggingFace
    save_to_hf: Optional[bool] = False
    hf_checkpoint: HuggingFaceCheckpointingConfig = field(
        default_factory=HuggingFaceCheckpointingConfig
    )

    training: TrainingCheckpointingConfig = field(
        default_factory=TrainingCheckpointingConfig
    )
    evaluation: EvaluationCheckpointingConfig = field(
        default_factory=EvaluationCheckpointingConfig
    )
    learning_dynamics: LearningDynamicsCheckpointingConfig = field(
        default_factory=LearningDynamicsCheckpointingConfig
    )