| | from dataclasses import dataclass |
| |
|
| |
|
| | @dataclass |
| | class TrainConfig: |
| | |
| | vocab_size: int = 128010 |
| | embedding_dim: int = 1024 |
| | d_model: int = 1024 |
| | num_heads: int = 32 |
| | num_layers: int = 24 |
| | max_seq_len: int = 1024 |
| |
|
| | |
| | share_transformer_block: bool = True |
| | share_layernorms: bool = False |
| |
|
| | |
| | rope_base: float = 10000.0 |
| | rope_scale: float = 1.0 |
| |
|
| | |
| | use_flash_attention: bool = True |
| | qk_norm: bool = True |
| |
|
| | |
| | ffn_intermediate_dim: int = 4096 |
| | moe_num_experts: int = 64 |
| | moe_top_k: int = 2 |
| | moe_shared_expert: bool = True |
| | moe_expert_hidden_dim: int = 1024 |
| | moe_router_dropout: float = 0.1 |
| | moe_capacity_factor: float = 1.0 |
| | moe_router_temperature: float = 1.0 |
| | |
| | aux_free_balance: bool = True |
| | moe_capacity_factor: float = 1.25 |
| | |
| | moe_router_zloss_coef: float = 0.0 |
| | moe_load_balance_coef: float = 0.0 |
| |
|
| | |
| | moq_num_experts: int = 64 |
| | moq_top_k: int = 2 |
| | moq_shared_expert: bool = True |
| | moq_expert_hidden_dim: int = 1024 |
| | moq_router_temperature: float = 1.0 |
| |
|
| | |
| | learning_rate: float = 3e-4 |
| | weight_decay: float = 0.01 |
| | betas: tuple = (0.9, 0.95) |
| | eps: float = 1e-8 |
| | batch_size: int = 4 |
| | grad_accum_steps: int = 8 |
| | max_steps: int = 50000 |
| | warmup_steps: int = 10 |
| | clip_grad_norm: float = 1.0 |
| | mixed_precision: bool = True |
| |
|
| | |
| | log_interval: int = 50 |
| | eval_interval: int = 5000 |
| | save_interval: int = 10000 |
| | output_dir: str = "./outputs" |
| |
|
| | |
| | gradient_checkpointing: bool = True |
| | cpu_offload: bool = False |
| | flash_attention: bool = True |
| | |
| | |
| | max_loaded_experts: int = 40 |
| | expert_cache_strategy: str = "lru" |
| | expert_preload_threshold: int = 5 |
| | |
| | |
| | early_stopping_patience: int = 100 |
| | early_stopping_min_delta: float = 0.0001 |
| | early_stopping_monitor: str = "loss" |
| | |
| | |
| | seed: int = 42 |
| | device: str = "cuda" |
| | dtype: str = "bfloat16" |
| |
|
| |
|
| |
|