llm / config /baseline_config.py
eyad-silx's picture
Update config/baseline_config.py
8c261f1 verified
"""
Configuration for Baseline Transformer on enwik8.
Matches DTAT's training setup for fair comparison.
"""
class BaselineConfig:
def __init__(self):
# Model architecture (exactly matching DTAT)
self.n_layer = 12
self.n_head = 8 # Same as DTAT
self.n_embd = 512 # Same as DTAT
self.dropout = 0.1
self.bias = True
# Sequence parameters
self.block_size = 1024 # Same as DTAT
self.vocab_size = 256 # For character-level model
# Training parameters (matched with DTAT)
self.learning_rate = 6e-4
self.min_lr = 1e-5 # Lower minimum to allow fine-tuning
self.warmup_iters = 367 # 5% of total iterations
self.max_iters = 7334 # Exactly 4 epochs with batch_size=24
self.weight_decay = 0.1 # Same as DTAT
self.beta1 = 0.9
self.beta2 = 0.95
self.grad_clip = 1.0
# Learning rate schedule
self.decay_lr = True
self.lr_decay_iters = 5000 # Same as DTAT
# Early stopping
self.patience = 15 # Same as DTAT
self.min_delta = 0.005 # Same as DTAT
self.eval_interval = 250 # Same as DTAT
self.eval_iters = 200 # Same as DTAT
# Logging
self.log_interval = 10
# Mixed precision training
self.mixed_precision = True
self.dtype = 'bfloat16'
# Memory optimization
self.gradient_checkpointing = True
self.batch_size = 24 # Same as DTAT
# System
self.device = 'cuda'
self.compile = True
# Performance optimization
self.compile_model = True
self.cudnn_benchmark = True
# Git config for model versioning
self.git_name = "Your Name"
self.git_email = "your.email@example.com"
def get_config(self):
return self
def get_config():
"""Helper function to get config instance."""
return BaselineConfig()