pi-digits / config.py
h2di's picture
Upload config.py with huggingface_hub
cd88447 verified
from dataclasses import dataclass
@dataclass
class TrainingConfig:
"""Configuration for training (as a dataclass)."""
# Model parameters
d_model: int = 512
n_heads: int = 8
n_encoder_layers: int = 6
n_decoder_layers: int = 6
vocab_in: int = 12 # digits 0-9 + padding token + start
vocab_out: int = 11 # digits 0-9 + padding token + start
block_size_in: int = 128 # max length of position n
block_size_out: int = 129 # context_length + 1 (inclusive)
pad_token: int = -1
start_token: int = 10
# Training parameters
batch_size: int = 64
learning_rate: float = 1e-4
weight_decay: float = 0.01
num_epochs: int = 100
gradient_clip: float = 1.0
warmup_steps: int = 1000
# Dataset parameters
context_length: int = 128
train_split: float = 0.9
num_workers: int = 4
device: str = "cpu"
# Logging and checkpointing
log_interval: int = 100
eval_interval: int = 1000
save_interval: int = 5000
checkpoint_dir: str = "checkpoints"
log_dir: str = "runs"