|
|
|
|
|
""" |
|
|
Module NN: Neural Network Training Wrapper |
|
|
========================================== |
|
|
|
|
|
A flexible wrapper for training neural networks on modular arithmetic tasks. |
|
|
Supports command-line parameter overrides for easy batch experimentation. |
|
|
|
|
|
Usage Examples: |
|
|
# Use default config |
|
|
python module_nn.py |
|
|
|
|
|
# Override specific parameters |
|
|
python module_nn.py --p 17 --lr 0.01 --num_epochs 10000 |
|
|
|
|
|
# Run batch experiments on init_type, optimizer, and act_type (16 total combinations) |
|
|
python module_nn.py --experiments |
|
|
|
|
|
# Run batch experiments with custom parameters |
|
|
python module_nn.py --experiments --p 17 --num_epochs 3000 |
|
|
|
|
|
# Dry run to see configuration |
|
|
python module_nn.py --dry_run --p 23 --lr 0.001 |
|
|
|
|
|
# Multiple parameters for single experiment |
|
|
python module_nn.py --p 23 --lr 0.001 --d_mlp 256 --act_type ReLU --seed 42 |
|
|
|
|
|
Bash Script Example: |
|
|
# Run experiments for different primes |
|
|
for p in 17 23 31; do |
|
|
python module_nn.py --experiments --p $p --num_epochs 3000 |
|
|
done |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
from collections import deque |
|
|
|
|
|
from utils import * |
|
|
from nnTrainer import Trainer |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
"""Parse command line arguments with support for config overrides""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description='Neural Network Training for Modular Arithmetic', |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=__doc__ |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument('--p', type=int, help='Prime number for modular arithmetic') |
|
|
parser.add_argument('--d_vocab', type=int, help='Vocabulary size (defaults to p)') |
|
|
parser.add_argument('--fn_name', type=str, choices=['add', 'subtract', 'x2xyy2'], help='Function to learn') |
|
|
parser.add_argument('--frac_train', type=float, help='Fraction of data for training') |
|
|
parser.add_argument('--batch_style', type=str, help='Batch processing style') |
|
|
|
|
|
|
|
|
parser.add_argument('--d_model', type=int, help='Model embedding dimensionality') |
|
|
parser.add_argument('--d_mlp', type=int, help='MLP layer dimensionality') |
|
|
parser.add_argument('--act_type', type=str, choices=['ReLU', 'GeLU', 'Quad', 'Id'], help='Activation function') |
|
|
parser.add_argument('--embed_type', type=str, choices=['one_hot', 'learned'], help='Embedding type') |
|
|
parser.add_argument('--init_type', type=str, choices=['random', 'single-freq'], help='Weight initialization') |
|
|
parser.add_argument('--init_scale', type=float, help='Scale factor for weight initialization') |
|
|
parser.add_argument('--freq_num', type=int, help='Number of frequencies for single-freq init') |
|
|
|
|
|
|
|
|
parser.add_argument('--num_epochs', type=int, help='Number of training epochs') |
|
|
parser.add_argument('--lr', type=float, help='Learning rate') |
|
|
parser.add_argument('--weight_decay', type=float, help='Weight decay') |
|
|
parser.add_argument('--optimizer', type=str, choices=['AdamW', 'SGD'], help='Optimizer') |
|
|
parser.add_argument('--stopping_thresh', type=float, help='Early stopping threshold') |
|
|
parser.add_argument('--save_models', type=bool, help='Whether to save models') |
|
|
parser.add_argument('--save_every', type=int, help='Save frequency (epochs)') |
|
|
parser.add_argument('--seed', type=int, help='Random seed') |
|
|
|
|
|
|
|
|
parser.add_argument('--config', type=str, help='Path to custom config file') |
|
|
parser.add_argument('--dry_run', action='store_true', help='Print config and exit without training') |
|
|
parser.add_argument('--no_wandb', action='store_true', help='Disable wandb logging') |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def override_config(config_dict, args): |
|
|
"""Override config values with command line arguments""" |
|
|
|
|
|
flat_config = {} |
|
|
|
|
|
def flatten_dict(d, parent_key=''): |
|
|
for k, v in d.items(): |
|
|
if isinstance(v, dict): |
|
|
flatten_dict(v, parent_key) |
|
|
else: |
|
|
flat_config[k] = v |
|
|
|
|
|
flatten_dict(config_dict) |
|
|
|
|
|
|
|
|
for arg_name, arg_value in vars(args).items(): |
|
|
if arg_value is not None and arg_name in flat_config: |
|
|
flat_config[arg_name] = arg_value |
|
|
print(f"Override: {arg_name} = {arg_value}") |
|
|
|
|
|
|
|
|
result = {'data': {}, 'model': {}, 'training': {}} |
|
|
|
|
|
|
|
|
data_params = ['p', 'd_vocab', 'fn_name', 'frac_train', 'batch_style'] |
|
|
for param in data_params: |
|
|
if param in flat_config: |
|
|
result['data'][param] = flat_config[param] |
|
|
|
|
|
|
|
|
model_params = ['d_model', 'd_mlp', 'act_type', 'embed_type', 'init_type', 'init_scale', 'freq_num'] |
|
|
for param in model_params: |
|
|
if param in flat_config: |
|
|
result['model'][param] = flat_config[param] |
|
|
|
|
|
|
|
|
training_params = ['num_epochs', 'lr', 'weight_decay', 'optimizer', 'stopping_thresh', |
|
|
'save_models', 'save_every', 'seed', 'no_wandb'] |
|
|
for param in training_params: |
|
|
if param in flat_config: |
|
|
result['training'][param] = flat_config[param] |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def run_experiment(config_dict): |
|
|
"""Run the training experiment with given configuration""" |
|
|
print("="*80) |
|
|
print("MODULAR ARITHMETIC NEURAL NETWORK TRAINING") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
pipeline_config = Config(config_dict) |
|
|
print(f"Configuration loaded successfully") |
|
|
print(f"Device: {pipeline_config.device}") |
|
|
print(f"Prime p: {pipeline_config.p}") |
|
|
print(f"Vocabulary size: {pipeline_config.d_vocab}") |
|
|
print(f"Model dimensions: d_model={pipeline_config.d_model}, d_mlp={pipeline_config.d_mlp}") |
|
|
print(f"Function: {pipeline_config.fn_name}") |
|
|
print(f"Activation: {pipeline_config.act_type}") |
|
|
print(f"Seed: {pipeline_config.seed}") |
|
|
print(f"Init scale: {pipeline_config.init_scale}") |
|
|
print(f"Learning rate: {pipeline_config.lr}") |
|
|
print("-" * 80) |
|
|
|
|
|
|
|
|
use_wandb = not getattr(pipeline_config, 'no_wandb', False) |
|
|
world = Trainer(config=pipeline_config, use_wandb=use_wandb) |
|
|
print(f'Run name: {world.run_name}') |
|
|
world.initial_save_if_appropriate() |
|
|
|
|
|
|
|
|
recent_test_loss = deque(maxlen=2) |
|
|
save_point = 0 |
|
|
|
|
|
print(f"Starting training for {pipeline_config.num_epochs} epochs...") |
|
|
print("-" * 80) |
|
|
|
|
|
|
|
|
for epoch in range(pipeline_config.num_epochs): |
|
|
|
|
|
train_loss, test_loss = world.do_a_training_step(epoch) |
|
|
|
|
|
|
|
|
if test_loss.item() < pipeline_config.stopping_thresh: |
|
|
print(f"Early stopping at epoch {epoch}: test loss {test_loss.item():.6f} < {pipeline_config.stopping_thresh}") |
|
|
break |
|
|
|
|
|
|
|
|
if pipeline_config.is_it_time_to_save(epoch=epoch): |
|
|
world.save_epoch(epoch=epoch, local_save=True) |
|
|
|
|
|
|
|
|
print("-" * 80) |
|
|
print("Training completed! Saving final model...") |
|
|
world.post_training_save(save_optimizer_and_scheduler=True) |
|
|
print(f"Final train loss: {world.train_losses[-1]:.6f}") |
|
|
print(f"Final test loss: {world.test_losses[-1]:.6f}") |
|
|
print(f"Final train accuracy: {world.train_accs[-1]:.4f}") |
|
|
print(f"Final test accuracy: {world.test_accs[-1]:.4f}") |
|
|
print("="*80) |
|
|
|
|
|
return world |
|
|
|
|
|
|
|
|
def run_batch_experiments(base_config): |
|
|
"""Run batch experiments on init_type, optimizer, and act_type""" |
|
|
print("="*80) |
|
|
print("BATCH EXPERIMENTS: init_type, optimizer, act_type") |
|
|
print("="*80) |
|
|
|
|
|
results = [] |
|
|
experiment_count = 0 |
|
|
|
|
|
|
|
|
init_types = ['random', 'single-freq'] |
|
|
optimizers = ['AdamW', 'SGD'] |
|
|
act_types = ['ReLU', 'GeLU', 'Quad', 'Id'] |
|
|
|
|
|
total_experiments = len(init_types) * len(optimizers) * len(act_types) |
|
|
print(f"Running {total_experiments} experiments...") |
|
|
print("-" * 80) |
|
|
|
|
|
for init_type in init_types: |
|
|
for optimizer in optimizers: |
|
|
for act_type in act_types: |
|
|
experiment_count += 1 |
|
|
print(f"\nExperiment {experiment_count}/{total_experiments}") |
|
|
print(f"Configuration: init_type={init_type}, optimizer={optimizer}, act_type={act_type}") |
|
|
print("-" * 50) |
|
|
|
|
|
|
|
|
exp_config = base_config.copy() |
|
|
exp_config['model']['init_type'] = init_type |
|
|
exp_config['training']['optimizer'] = optimizer |
|
|
exp_config['model']['act_type'] = act_type |
|
|
|
|
|
|
|
|
exp_config['training']['seed'] = 1024 + experiment_count |
|
|
|
|
|
|
|
|
exp_config['training']['num_epochs'] = min(exp_config['training']['num_epochs'], 5000) |
|
|
|
|
|
try: |
|
|
|
|
|
trainer = run_experiment(exp_config) |
|
|
|
|
|
|
|
|
result = { |
|
|
'experiment': experiment_count, |
|
|
'init_type': init_type, |
|
|
'optimizer': optimizer, |
|
|
'act_type': act_type, |
|
|
'seed': exp_config['training']['seed'], |
|
|
'final_train_loss': trainer.train_losses[-1], |
|
|
'final_test_loss': trainer.test_losses[-1], |
|
|
'final_train_acc': trainer.train_accs[-1], |
|
|
'final_test_acc': trainer.test_accs[-1], |
|
|
'run_name': trainer.run_name |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
print(f"✓ Experiment {experiment_count} completed successfully") |
|
|
print(f" Final test accuracy: {result['final_test_acc']:.4f}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"✗ Experiment {experiment_count} failed: {str(e)}") |
|
|
results.append({ |
|
|
'experiment': experiment_count, |
|
|
'init_type': init_type, |
|
|
'optimizer': optimizer, |
|
|
'act_type': act_type, |
|
|
'seed': exp_config['training']['seed'], |
|
|
'error': str(e) |
|
|
}) |
|
|
|
|
|
print("-" * 50) |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("BATCH EXPERIMENTS SUMMARY") |
|
|
print("="*80) |
|
|
|
|
|
successful_results = [r for r in results if 'error' not in r] |
|
|
failed_results = [r for r in results if 'error' in r] |
|
|
|
|
|
print(f"Total experiments: {total_experiments}") |
|
|
print(f"Successful: {len(successful_results)}") |
|
|
print(f"Failed: {len(failed_results)}") |
|
|
|
|
|
if successful_results: |
|
|
print("\nTop 5 Results by Test Accuracy:") |
|
|
print("-" * 50) |
|
|
sorted_results = sorted(successful_results, key=lambda x: x['final_test_acc'], reverse=True) |
|
|
for i, result in enumerate(sorted_results[:5]): |
|
|
print(f"{i+1}. Test Acc: {result['final_test_acc']:.4f} | " |
|
|
f"init_type={result['init_type']}, optimizer={result['optimizer']}, " |
|
|
f"act_type={result['act_type']}") |
|
|
|
|
|
print("\nDetailed Results:") |
|
|
print("-" * 80) |
|
|
print(f"{'Exp':<3} {'Init':<11} {'Opt':<5} {'Act':<4} {'Train Acc':<9} {'Test Acc':<8} {'Train Loss':<10} {'Test Loss':<9}") |
|
|
print("-" * 80) |
|
|
for result in sorted_results: |
|
|
print(f"{result['experiment']:<3} " |
|
|
f"{result['init_type']:<11} " |
|
|
f"{result['optimizer']:<5} " |
|
|
f"{result['act_type']:<4} " |
|
|
f"{result['final_train_acc']:<9.4f} " |
|
|
f"{result['final_test_acc']:<8.4f} " |
|
|
f"{result['final_train_loss']:<10.6f} " |
|
|
f"{result['final_test_loss']:<9.6f}") |
|
|
|
|
|
if failed_results: |
|
|
print(f"\nFailed Experiments:") |
|
|
for result in failed_results: |
|
|
print(f"Exp {result['experiment']}: {result['init_type']}, {result['optimizer']}, {result['act_type']} - {result['error']}") |
|
|
|
|
|
print("="*80) |
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main entry point""" |
|
|
args = parse_arguments() |
|
|
|
|
|
|
|
|
if args.config: |
|
|
|
|
|
import yaml |
|
|
with open(args.config, 'r') as f: |
|
|
configs = yaml.safe_load(f) |
|
|
else: |
|
|
|
|
|
configs = read_config() |
|
|
|
|
|
|
|
|
final_config = override_config(configs, args) |
|
|
|
|
|
if args.dry_run: |
|
|
print("DRY RUN - Configuration that would be used:") |
|
|
print("-" * 50) |
|
|
import yaml |
|
|
print(yaml.dump(final_config, default_flow_style=False, indent=2)) |
|
|
return |
|
|
|
|
|
|
|
|
trainer = run_experiment(final_config) |
|
|
print(f"Experiment completed successfully!") |
|
|
print(f"Results saved to: {trainer.save_dir}/{trainer.run_name}") |
|
|
return trainer |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |