# ============================================================================= # CONFIGURATION FILE FOR MODULAR ARITHMETIC LEARNING # ============================================================================= # ----------------------------------------------------------------------------- # DATA CONFIGURATION # ----------------------------------------------------------------------------- data: p: 97 # Prime number used for modular arithmetic operations d_vocab: null # Vocabulary size, automatically set to p if null fn_name: 'add' # Function to learn ('add', 'subtract', 'x2xyy2', etc.) frac_train: 1 # Fraction of data used for training (rest for testing) batch_style: 'full' # Batch processing style ('full' or mini-batch) # ----------------------------------------------------------------------------- # MODEL ARCHITECTURE # ----------------------------------------------------------------------------- model: d_model: null # Dimensionality of model embeddings d_mlp: 1024 # Dimensionality of the MLP (feedforward) layers act_type: 'ReLU' # Activation function ('ReLU', 'GeLU', 'Quad', 'Id') embed_type: 'one_hot' # Embedding type ('one_hot' or 'learned') # Weight Initialization init_type: 'random' # Initialization type ('random' or 'single-freq') init_scale: 0.1 # Scale factor for weight initialization freq_num: null # Number of frequencies for single-freq init (defaults to (d_vocab-1)//2 if null) # ----------------------------------------------------------------------------- # TRAINING CONFIGURATION # ----------------------------------------------------------------------------- training: # Basic Training Parameters num_epochs: 5000 # Number of training epochs lr: 5e-5 # Learning rate for the optimizer weight_decay: 0 # Weight decay for regularization optimizer: 'AdamW' # Optimizer ('AdamW' or 'SGD') # Early Stopping stopping_thresh: -1 # Training stops if test loss falls below this value (-1 to disable) # Checkpointing and Logging save_models: false # Whether to save model checkpoints save_every: 200 # Frequency (in epochs) at which to save models # Reproducibility seed: 42 # Random seed for reproducibility