| from dataclasses import dataclass | |
| class TrainingConfig: | |
| """Configuration for training (as a dataclass).""" | |
| # Model parameters | |
| d_model: int = 512 | |
| n_heads: int = 8 | |
| n_encoder_layers: int = 6 | |
| n_decoder_layers: int = 6 | |
| vocab_in: int = 12 # digits 0-9 + padding token + start | |
| vocab_out: int = 11 # digits 0-9 + padding token + start | |
| block_size_in: int = 128 # max length of position n | |
| block_size_out: int = 129 # context_length + 1 (inclusive) | |
| pad_token: int = -1 | |
| start_token: int = 10 | |
| # Training parameters | |
| batch_size: int = 64 | |
| learning_rate: float = 1e-4 | |
| weight_decay: float = 0.01 | |
| num_epochs: int = 100 | |
| gradient_clip: float = 1.0 | |
| warmup_steps: int = 1000 | |
| # Dataset parameters | |
| context_length: int = 128 | |
| train_split: float = 0.9 | |
| num_workers: int = 4 | |
| device: str = "cpu" | |
| # Logging and checkpointing | |
| log_interval: int = 100 | |
| eval_interval: int = 1000 | |
| save_interval: int = 5000 | |
| checkpoint_dir: str = "checkpoints" | |
| log_dir: str = "runs" | |