from dataclasses import dataclass @dataclass class TrainConfig: # Model core vocab_size: int = 128010 # Llama3 tokenizer vocabulary size (including all special tokens) embedding_dim: int = 1024 # factorized embedding dim (can be < d_model) d_model: int = 1024 num_heads: int = 32 num_layers: int = 24 # Changed from 12 to 24 layers max_seq_len: int = 1024 # Parameter sharing (ALBERT-style) share_transformer_block: bool = True # share attention/FFN across layers share_layernorms: bool = False # each layer has its own RMSNorm (as requested) # Positional encoding (RoPE) rope_base: float = 10000.0 rope_scale: float = 1.0 # Attention specifics use_flash_attention: bool = True qk_norm: bool = True # L2 normalize q and k before dot-product # FFN / MoE ffn_intermediate_dim: int = 4096 # shared expert intermediate dim moe_num_experts: int = 64 # reduced from 128 to 64 moe_top_k: int = 2 # Increased from 2 to 3 for better expert utilization moe_shared_expert: bool = True # always include shared expert moe_expert_hidden_dim: int = 1024 # expert intermediate dim moe_router_dropout: float = 0.1 # added dropout moe_capacity_factor: float = 1.0 moe_router_temperature: float = 1.0 # Aux-free load balancing aux_free_balance: bool = True # use auxiliary-loss-free router (Expert-Choice style) moe_capacity_factor: float = 1.25 # capacity factor for aux-free routing # Set aux coefficients to 0 when aux_free_balance is enabled moe_router_zloss_coef: float = 0.0 moe_load_balance_coef: float = 0.0 # MoQ for query projection (mirrors MoE settings) moq_num_experts: int = 64 # reduced from 128 to 64 moq_top_k: int = 2 # Increased from 2 to 3 for better query projection efficiency moq_shared_expert: bool = True moq_expert_hidden_dim: int = 1024 # MoQ hidden dimension moq_router_temperature: float = 1.0 # Training learning_rate: float = 3e-4 weight_decay: float = 0.01 betas: tuple = (0.9, 0.95) eps: float = 1e-8 batch_size: int = 4 # reduced for max_seq_len=1024 to prevent OOM grad_accum_steps: int = 8 # keep effective batch ~32 tokens per step max_steps: int = 50000 warmup_steps: int = 10 clip_grad_norm: float = 1.0 mixed_precision: bool = True # enable for GPU training # Checkpointing / logging log_interval: int = 50 eval_interval: int = 5000 save_interval: int = 10000 # Save model every 10,000 steps output_dir: str = "./outputs" # Memory optimization gradient_checkpointing: bool = True cpu_offload: bool = False flash_attention: bool = True # Expert loading optimization for RTX 4090 max_loaded_experts: int = 40 # Optimized for RTX 4090 memory expert_cache_strategy: str = "lru" # LRU cache strategy expert_preload_threshold: int = 5 # Preload threshold for smart caching # Early stopping early_stopping_patience: int = 100 early_stopping_min_delta: float = 0.0001 early_stopping_monitor: str = "loss" # Misc seed: int = 42 device: str = "cuda" dtype: str = "bfloat16" # autocast target (choices: float16|bfloat16|float32)