from dataclasses import dataclass @dataclass class TrainingConfig: """Configuration for training (as a dataclass).""" # Model parameters d_model: int = 512 n_heads: int = 8 n_encoder_layers: int = 6 n_decoder_layers: int = 6 vocab_in: int = 12 # digits 0-9 + padding token + start vocab_out: int = 11 # digits 0-9 + padding token + start block_size_in: int = 128 # max length of position n block_size_out: int = 129 # context_length + 1 (inclusive) pad_token: int = -1 start_token: int = 10 # Training parameters batch_size: int = 64 learning_rate: float = 1e-4 weight_decay: float = 0.01 num_epochs: int = 100 gradient_clip: float = 1.0 warmup_steps: int = 1000 # Dataset parameters context_length: int = 128 train_split: float = 0.9 num_workers: int = 4 device: str = "cpu" # Logging and checkpointing log_interval: int = 100 eval_interval: int = 1000 save_interval: int = 5000 checkpoint_dir: str = "checkpoints" log_dir: str = "runs"