| # ============================================================================= | |
| # CONFIGURATION FILE FOR MODULAR ARITHMETIC LEARNING | |
| # ============================================================================= | |
| # ----------------------------------------------------------------------------- | |
| # DATA CONFIGURATION | |
| # ----------------------------------------------------------------------------- | |
| data: | |
| p: 97 # Prime number used for modular arithmetic operations | |
| d_vocab: null # Vocabulary size, automatically set to p if null | |
| fn_name: 'add' # Function to learn ('add', 'subtract', 'x2xyy2', etc.) | |
| frac_train: 1 # Fraction of data used for training (rest for testing) | |
| batch_style: 'full' # Batch processing style ('full' or mini-batch) | |
| # ----------------------------------------------------------------------------- | |
| # MODEL ARCHITECTURE | |
| # ----------------------------------------------------------------------------- | |
| model: | |
| d_model: null # Dimensionality of model embeddings | |
| d_mlp: 1024 # Dimensionality of the MLP (feedforward) layers | |
| act_type: 'ReLU' # Activation function ('ReLU', 'GeLU', 'Quad', 'Id') | |
| embed_type: 'one_hot' # Embedding type ('one_hot' or 'learned') | |
| # Weight Initialization | |
| init_type: 'random' # Initialization type ('random' or 'single-freq') | |
| init_scale: 0.1 # Scale factor for weight initialization | |
| freq_num: null # Number of frequencies for single-freq init (defaults to (d_vocab-1)//2 if null) | |
| # ----------------------------------------------------------------------------- | |
| # TRAINING CONFIGURATION | |
| # ----------------------------------------------------------------------------- | |
| training: | |
| # Basic Training Parameters | |
| num_epochs: 5000 # Number of training epochs | |
| lr: 5e-5 # Learning rate for the optimizer | |
| weight_decay: 0 # Weight decay for regularization | |
| optimizer: 'AdamW' # Optimizer ('AdamW' or 'SGD') | |
| # Early Stopping | |
| stopping_thresh: -1 # Training stops if test loss falls below this value (-1 to disable) | |
| # Checkpointing and Logging | |
| save_models: false # Whether to save model checkpoints | |
| save_every: 200 # Frequency (in epochs) at which to save models | |
| # Reproducibility | |
| seed: 42 # Random seed for reproducibility |