ChenZuoLM / config.py
Chennnnn's picture
Initial model upload
2f59567
from dataclasses import dataclass
@dataclass
class TrainConfig:
# Model core
vocab_size: int = 128010 # Llama3 tokenizer vocabulary size (including all special tokens)
embedding_dim: int = 1024 # factorized embedding dim (can be < d_model)
d_model: int = 1024
num_heads: int = 32
num_layers: int = 24 # Changed from 12 to 24 layers
max_seq_len: int = 1024
# Parameter sharing (ALBERT-style)
share_transformer_block: bool = True # share attention/FFN across layers
share_layernorms: bool = False # each layer has its own RMSNorm (as requested)
# Positional encoding (RoPE)
rope_base: float = 10000.0
rope_scale: float = 1.0
# Attention specifics
use_flash_attention: bool = True
qk_norm: bool = True # L2 normalize q and k before dot-product
# FFN / MoE
ffn_intermediate_dim: int = 4096 # shared expert intermediate dim
moe_num_experts: int = 64 # reduced from 128 to 64
moe_top_k: int = 2 # Increased from 2 to 3 for better expert utilization
moe_shared_expert: bool = True # always include shared expert
moe_expert_hidden_dim: int = 1024 # expert intermediate dim
moe_router_dropout: float = 0.1 # added dropout
moe_capacity_factor: float = 1.0
moe_router_temperature: float = 1.0
# Aux-free load balancing
aux_free_balance: bool = True # use auxiliary-loss-free router (Expert-Choice style)
moe_capacity_factor: float = 1.25 # capacity factor for aux-free routing
# Set aux coefficients to 0 when aux_free_balance is enabled
moe_router_zloss_coef: float = 0.0
moe_load_balance_coef: float = 0.0
# MoQ for query projection (mirrors MoE settings)
moq_num_experts: int = 64 # reduced from 128 to 64
moq_top_k: int = 2 # Increased from 2 to 3 for better query projection efficiency
moq_shared_expert: bool = True
moq_expert_hidden_dim: int = 1024 # MoQ hidden dimension
moq_router_temperature: float = 1.0
# Training
learning_rate: float = 3e-4
weight_decay: float = 0.01
betas: tuple = (0.9, 0.95)
eps: float = 1e-8
batch_size: int = 4 # reduced for max_seq_len=1024 to prevent OOM
grad_accum_steps: int = 8 # keep effective batch ~32 tokens per step
max_steps: int = 50000
warmup_steps: int = 10
clip_grad_norm: float = 1.0
mixed_precision: bool = True # enable for GPU training
# Checkpointing / logging
log_interval: int = 50
eval_interval: int = 5000
save_interval: int = 10000 # Save model every 10,000 steps
output_dir: str = "./outputs"
# Memory optimization
gradient_checkpointing: bool = True
cpu_offload: bool = False
flash_attention: bool = True
# Expert loading optimization for RTX 4090
max_loaded_experts: int = 40 # Optimized for RTX 4090 memory
expert_cache_strategy: str = "lru" # LRU cache strategy
expert_preload_threshold: int = 5 # Preload threshold for smart caching
# Early stopping
early_stopping_patience: int = 100
early_stopping_min_delta: float = 0.0001
early_stopping_monitor: str = "loss"
# Misc
seed: int = 42
device: str = "cuda"
dtype: str = "bfloat16" # autocast target (choices: float16|bfloat16|float32)