zhuoranyang's picture
Deploy app with precomputed results for p=15,23,29,31
b753304 verified
# =============================================================================
# CONFIGURATION FILE FOR MODULAR ARITHMETIC LEARNING
# =============================================================================
# -----------------------------------------------------------------------------
# DATA CONFIGURATION
# -----------------------------------------------------------------------------
data:
p: 97 # Prime number used for modular arithmetic operations
d_vocab: null # Vocabulary size, automatically set to p if null
fn_name: 'add' # Function to learn ('add', 'subtract', 'x2xyy2', etc.)
frac_train: 1 # Fraction of data used for training (rest for testing)
batch_style: 'full' # Batch processing style ('full' or mini-batch)
# -----------------------------------------------------------------------------
# MODEL ARCHITECTURE
# -----------------------------------------------------------------------------
model:
d_model: null # Dimensionality of model embeddings
d_mlp: 1024 # Dimensionality of the MLP (feedforward) layers
act_type: 'ReLU' # Activation function ('ReLU', 'GeLU', 'Quad', 'Id')
embed_type: 'one_hot' # Embedding type ('one_hot' or 'learned')
# Weight Initialization
init_type: 'random' # Initialization type ('random' or 'single-freq')
init_scale: 0.1 # Scale factor for weight initialization
freq_num: null # Number of frequencies for single-freq init (defaults to (d_vocab-1)//2 if null)
# -----------------------------------------------------------------------------
# TRAINING CONFIGURATION
# -----------------------------------------------------------------------------
training:
# Basic Training Parameters
num_epochs: 5000 # Number of training epochs
lr: 5e-5 # Learning rate for the optimizer
weight_decay: 0 # Weight decay for regularization
optimizer: 'AdamW' # Optimizer ('AdamW' or 'SGD')
# Early Stopping
stopping_thresh: -1 # Training stops if test loss falls below this value (-1 to disable)
# Checkpointing and Logging
save_models: false # Whether to save model checkpoints
save_every: 200 # Frequency (in epochs) at which to save models
# Reproducibility
seed: 42 # Random seed for reproducibility