Commit
·
bc8288b
1
Parent(s):
959cbe5
MLA | KDA | TPA | GDA | ResFormer | Mamba3 | DragonMimo (WIP) | tokenshift | SeeDNorm | shrink DA/GDN | gate shared across all block types |
Browse files- configuration_dragon.py +55 -14
- inspecting_dragon.py +302 -0
- modeling_dragon.py +0 -0
- optimizers/Ademamix.py +165 -0
- optimizers/Snoo.py +67 -0
- optimizers/__init__.py +2 -0
- training_dragon.py +114 -36
configuration_dragon.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
# TODO : TP (cf qwen)
|
| 4 |
# TODO : init
|
| 5 |
|
|
|
|
| 6 |
import re
|
| 7 |
|
| 8 |
from transformers.configuration_utils import PretrainedConfig
|
|
@@ -89,29 +90,40 @@ class DragonConfig(PretrainedConfig):
|
|
| 89 |
model_type = "dragon"
|
| 90 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 91 |
|
| 92 |
-
"""
|
| 93 |
-
config.num_attention_heads_indexer
|
| 94 |
-
self.indexer_head_dim = config.head_dim_indexer
|
| 95 |
-
self.q_lora_rank = config.dsa_q_lora_rank
|
| 96 |
-
self.topk = config.dsa_topk
|
| 97 |
-
"""
|
| 98 |
-
|
| 99 |
def __init__(
|
| 100 |
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
patch_level_training: bool = False,
|
| 102 |
patch_level_training_size: int = 4,
|
| 103 |
-
nsa_head_dim: int = 128,
|
| 104 |
nsa_topk: int = 16,
|
| 105 |
nsa_block_size: int = 64,
|
| 106 |
nsa_window_size: int = 512,
|
| 107 |
-
cca_head_dim: int = 128,
|
| 108 |
cca_seq_kernel_size: int = 4,
|
| 109 |
rope_gdn: str = None,
|
| 110 |
zero_centered_gate: bool = False,
|
| 111 |
zero_centered_gate_type: int = 1,
|
| 112 |
scalable_softmax: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
gate_attn: bool = False,
|
| 114 |
gate_gdn: bool = True,
|
|
|
|
| 115 |
num_attention_heads_gdn: int = 32,
|
| 116 |
num_key_value_heads_gdn: int = None,
|
| 117 |
fused_loss_computation=False,
|
|
@@ -129,6 +141,7 @@ class DragonConfig(PretrainedConfig):
|
|
| 129 |
intermediate_size=8192,
|
| 130 |
expand_factor=2,
|
| 131 |
layers_config=4*"lrdlr",
|
|
|
|
| 132 |
num_attention_heads=32,
|
| 133 |
num_key_value_heads=8,
|
| 134 |
mlp_hidden_act="relu2",
|
|
@@ -147,7 +160,10 @@ class DragonConfig(PretrainedConfig):
|
|
| 147 |
eos_token_id=2,
|
| 148 |
sliding_window_size=1024,
|
| 149 |
slw_wsize=-1,
|
|
|
|
|
|
|
| 150 |
rope_theta_local=163.,
|
|
|
|
| 151 |
uscaling_tau=0.2,
|
| 152 |
attention_dropout=0.,
|
| 153 |
hidden_dropout=0.,
|
|
@@ -157,21 +173,39 @@ class DragonConfig(PretrainedConfig):
|
|
| 157 |
gdn_dt_init_floor=1e-4,
|
| 158 |
gdn_A_init_range=(1, 16),
|
| 159 |
old_lns=False,
|
|
|
|
| 160 |
**kwargs,
|
| 161 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
self.patch_level_training = patch_level_training
|
| 163 |
self.patch_level_training_size = patch_level_training_size
|
| 164 |
-
self.nsa_head_dim = nsa_head_dim
|
| 165 |
self.nsa_topk = nsa_topk
|
| 166 |
self.nsa_block_size = nsa_block_size
|
| 167 |
self.nsa_window_size = nsa_window_size
|
| 168 |
-
self.cca_head_dim = cca_head_dim
|
| 169 |
self.cca_seq_kernel_size = cca_seq_kernel_size
|
| 170 |
self.rope_gdn = rope_gdn
|
| 171 |
self.zero_centered_gate = zero_centered_gate
|
| 172 |
self.zero_centered_gate_type = zero_centered_gate_type
|
|
|
|
|
|
|
| 173 |
self.gate_attn = gate_attn
|
| 174 |
self.gate_gdn = gate_gdn
|
|
|
|
|
|
|
| 175 |
self.num_attention_heads_gdn = num_attention_heads_gdn
|
| 176 |
if num_key_value_heads_gdn is None:
|
| 177 |
num_key_value_heads_gdn = num_attention_heads_gdn
|
|
@@ -182,13 +216,18 @@ class DragonConfig(PretrainedConfig):
|
|
| 182 |
self.dsa_q_lora_rank = dsa_q_lora_rank
|
| 183 |
self.dsa_topk = dsa_topk
|
| 184 |
self.zero_centered_gamma = zero_centered_gamma
|
| 185 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 186 |
self.qk_norm = qk_norm
|
| 187 |
self.softcap_local_attn=softcap_local_attn
|
| 188 |
self.softcap_global_attn=softcap_global_attn
|
| 189 |
self.use_uscaling = use_uscaling
|
| 190 |
self.uscaling_tau = uscaling_tau
|
| 191 |
self.scalable_softmax = scalable_softmax
|
|
|
|
|
|
|
| 192 |
|
| 193 |
self.vocab_size = vocab_size
|
| 194 |
self.tie_word_embeddings = tie_word_embeddings
|
|
@@ -226,9 +265,11 @@ class DragonConfig(PretrainedConfig):
|
|
| 226 |
self.A_init_range = gdn_A_init_range
|
| 227 |
|
| 228 |
self.old_lns = old_lns
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
assert self.hidden_size % self.num_attention_heads == 0
|
| 231 |
-
assert self.num_attention_heads % self.num_key_value_heads == 0
|
| 232 |
#assert self.num_attention_heads % 2 == 0, "Number of attention heads must be even for differential attention."
|
| 233 |
#assert self.num_key_value_heads % 2 == 0, "Number of kv heads must be even for differential attention."
|
| 234 |
|
|
|
|
| 3 |
# TODO : TP (cf qwen)
|
| 4 |
# TODO : init
|
| 5 |
|
| 6 |
+
from typing import Optional
|
| 7 |
import re
|
| 8 |
|
| 9 |
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
| 90 |
model_type = "dragon"
|
| 91 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
def __init__(
|
| 94 |
self,
|
| 95 |
+
mla_kv_rank: int = 128,
|
| 96 |
+
shrink_qk_da: int = 2,
|
| 97 |
+
shrink_qk_gdn: int = 2,
|
| 98 |
+
mixer_gn: bool = True,
|
| 99 |
+
kda_allow_neg_eigval: bool = False,
|
| 100 |
+
kda_num_v_heads: Optional[int] = None,
|
| 101 |
+
seednorm_wd: bool = True,
|
| 102 |
+
normalization_type: str = "rmsnorm",
|
| 103 |
+
tpa_rank: int = 2,
|
| 104 |
+
num_signal_heads_diff: Optional[int] = None,
|
| 105 |
+
scalar_proj_as_hidden_matrix: bool = True,
|
| 106 |
+
token_shift_attn: bool = False,
|
| 107 |
+
token_shift_gdn: bool = False,
|
| 108 |
+
token_conv1d_attn: bool = False,
|
| 109 |
+
token_conv1d_gdn: bool = True,
|
| 110 |
patch_level_training: bool = False,
|
| 111 |
patch_level_training_size: int = 4,
|
|
|
|
| 112 |
nsa_topk: int = 16,
|
| 113 |
nsa_block_size: int = 64,
|
| 114 |
nsa_window_size: int = 512,
|
|
|
|
| 115 |
cca_seq_kernel_size: int = 4,
|
| 116 |
rope_gdn: str = None,
|
| 117 |
zero_centered_gate: bool = False,
|
| 118 |
zero_centered_gate_type: int = 1,
|
| 119 |
scalable_softmax: bool = True,
|
| 120 |
+
resformer: bool = False,
|
| 121 |
+
mamba_mimo_dim : int = 4,
|
| 122 |
+
gate_type: str = "elementwise",
|
| 123 |
+
gate_act: str = "silu",
|
| 124 |
gate_attn: bool = False,
|
| 125 |
gate_gdn: bool = True,
|
| 126 |
+
head_dim_gdn: Optional[int] = None,
|
| 127 |
num_attention_heads_gdn: int = 32,
|
| 128 |
num_key_value_heads_gdn: int = None,
|
| 129 |
fused_loss_computation=False,
|
|
|
|
| 141 |
intermediate_size=8192,
|
| 142 |
expand_factor=2,
|
| 143 |
layers_config=4*"lrdlr",
|
| 144 |
+
head_dim=128,
|
| 145 |
num_attention_heads=32,
|
| 146 |
num_key_value_heads=8,
|
| 147 |
mlp_hidden_act="relu2",
|
|
|
|
| 160 |
eos_token_id=2,
|
| 161 |
sliding_window_size=1024,
|
| 162 |
slw_wsize=-1,
|
| 163 |
+
rope_type_local="rope",
|
| 164 |
+
rope_type_global="",
|
| 165 |
rope_theta_local=163.,
|
| 166 |
+
rope_theta_global=10000.,
|
| 167 |
uscaling_tau=0.2,
|
| 168 |
attention_dropout=0.,
|
| 169 |
hidden_dropout=0.,
|
|
|
|
| 173 |
gdn_dt_init_floor=1e-4,
|
| 174 |
gdn_A_init_range=(1, 16),
|
| 175 |
old_lns=False,
|
| 176 |
+
mlp_linking=False,
|
| 177 |
**kwargs,
|
| 178 |
):
|
| 179 |
+
self.mla_kv_rank = mla_kv_rank
|
| 180 |
+
self.shrink_qk_da = shrink_qk_da
|
| 181 |
+
self.shrink_qk_gdn = shrink_qk_gdn
|
| 182 |
+
self.mixer_gn = mixer_gn
|
| 183 |
+
self.kda_allow_neg_eigval = kda_allow_neg_eigval
|
| 184 |
+
self.kda_num_v_heads = kda_num_v_heads
|
| 185 |
+
self.seednorm_wd = seednorm_wd
|
| 186 |
+
self.normalization_type = normalization_type
|
| 187 |
+
self.tpa_rank = tpa_rank
|
| 188 |
+
self.num_signal_heads_diff = num_signal_heads_diff
|
| 189 |
+
self.scalar_proj_as_hidden_matrix = scalar_proj_as_hidden_matrix
|
| 190 |
+
self.token_shift_attn = token_shift_attn
|
| 191 |
+
self.token_shift_gdn = token_shift_gdn
|
| 192 |
+
self.token_conv1d_attn = token_conv1d_attn
|
| 193 |
+
self.token_conv1d_gdn = token_conv1d_gdn
|
| 194 |
self.patch_level_training = patch_level_training
|
| 195 |
self.patch_level_training_size = patch_level_training_size
|
|
|
|
| 196 |
self.nsa_topk = nsa_topk
|
| 197 |
self.nsa_block_size = nsa_block_size
|
| 198 |
self.nsa_window_size = nsa_window_size
|
|
|
|
| 199 |
self.cca_seq_kernel_size = cca_seq_kernel_size
|
| 200 |
self.rope_gdn = rope_gdn
|
| 201 |
self.zero_centered_gate = zero_centered_gate
|
| 202 |
self.zero_centered_gate_type = zero_centered_gate_type
|
| 203 |
+
self.gate_type = gate_type
|
| 204 |
+
self.gate_act = gate_act
|
| 205 |
self.gate_attn = gate_attn
|
| 206 |
self.gate_gdn = gate_gdn
|
| 207 |
+
self.head_dim = head_dim
|
| 208 |
+
self.head_dim_gdn = head_dim_gdn
|
| 209 |
self.num_attention_heads_gdn = num_attention_heads_gdn
|
| 210 |
if num_key_value_heads_gdn is None:
|
| 211 |
num_key_value_heads_gdn = num_attention_heads_gdn
|
|
|
|
| 216 |
self.dsa_q_lora_rank = dsa_q_lora_rank
|
| 217 |
self.dsa_topk = dsa_topk
|
| 218 |
self.zero_centered_gamma = zero_centered_gamma
|
| 219 |
+
self.rope_type_local = rope_type_local
|
| 220 |
+
self.rope_type_global = rope_type_global
|
| 221 |
+
self.rope_theta_local = rope_theta_local
|
| 222 |
+
self.rope_theta_global = rope_theta_global
|
| 223 |
self.qk_norm = qk_norm
|
| 224 |
self.softcap_local_attn=softcap_local_attn
|
| 225 |
self.softcap_global_attn=softcap_global_attn
|
| 226 |
self.use_uscaling = use_uscaling
|
| 227 |
self.uscaling_tau = uscaling_tau
|
| 228 |
self.scalable_softmax = scalable_softmax
|
| 229 |
+
self.resformer = resformer
|
| 230 |
+
self.mamba_mimo_dim = mamba_mimo_dim
|
| 231 |
|
| 232 |
self.vocab_size = vocab_size
|
| 233 |
self.tie_word_embeddings = tie_word_embeddings
|
|
|
|
| 265 |
self.A_init_range = gdn_A_init_range
|
| 266 |
|
| 267 |
self.old_lns = old_lns
|
| 268 |
+
|
| 269 |
+
self.mlp_linking = mlp_linking
|
| 270 |
|
| 271 |
+
#assert self.hidden_size % self.num_attention_heads == 0
|
| 272 |
+
#assert self.num_attention_heads % self.num_key_value_heads == 0
|
| 273 |
#assert self.num_attention_heads % 2 == 0, "Number of attention heads must be even for differential attention."
|
| 274 |
#assert self.num_key_value_heads % 2 == 0, "Number of kv heads must be even for differential attention."
|
| 275 |
|
inspecting_dragon.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from functools import partial
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import tyro
|
| 10 |
+
|
| 11 |
+
from .configuration_dragon import DragonConfig
|
| 12 |
+
from .modeling_dragon import DragonForCausalLM
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class NanoArgs:
|
| 16 |
+
resume_from: Optional[str] = None
|
| 17 |
+
run_name : str = ""
|
| 18 |
+
|
| 19 |
+
# arch - general
|
| 20 |
+
d_model : int = 768
|
| 21 |
+
n_heads : int = 6 # head dim 128 suggested by @Grad62304977
|
| 22 |
+
layers_config : str = 4*"lrdlr"
|
| 23 |
+
expand_factor : int = 1 # expand factor for Mamba/Dragon
|
| 24 |
+
rope_theta_local: float = 10000.0
|
| 25 |
+
eps_rmsnorm: float = 1e-6
|
| 26 |
+
mlp_expand: int = 4 # expand factor for MLP
|
| 27 |
+
fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
|
| 28 |
+
use_uscaling: bool = False
|
| 29 |
+
uscaling_tau: float = 0.2
|
| 30 |
+
zero_centered_gamma: bool = False
|
| 31 |
+
zero_centered_gate: bool = False
|
| 32 |
+
zero_centered_gate_type: int = 1 # 1, 2, 3, 4
|
| 33 |
+
gate_attn: bool = False
|
| 34 |
+
gate_gdn: bool = True
|
| 35 |
+
gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head)
|
| 36 |
+
gate_act: str = "silu" # silu, sigmoid
|
| 37 |
+
scalar_proj_as_hidden_matrix: bool = True
|
| 38 |
+
|
| 39 |
+
# attention related
|
| 40 |
+
n_kv_heads : int = 0
|
| 41 |
+
swa_window_size : int = 1024
|
| 42 |
+
slw_warmup_iters: float = 0
|
| 43 |
+
slw_start: int = 8 # window size at the start of training
|
| 44 |
+
slw_increment: int = 64 # window size increment at each step
|
| 45 |
+
softcap_local_attn: float = 0.0 # logit soft-capping for local attn logits, as per Gemma2 (0.0 = no soft-capping)
|
| 46 |
+
softcap_global_attn: float = 0.0
|
| 47 |
+
qk_norm: bool = True
|
| 48 |
+
scalable_softmax: bool = True
|
| 49 |
+
token_shift: bool = False
|
| 50 |
+
num_attention_heads_indexer: int = 8
|
| 51 |
+
head_dim_indexer: int = 32
|
| 52 |
+
dsa_q_lora_rank: int = 128
|
| 53 |
+
dsa_topk: int = 512
|
| 54 |
+
cca_head_dim: int = 128
|
| 55 |
+
cca_seq_kernel_size: int = 4
|
| 56 |
+
nsa_head_dim: int = 128
|
| 57 |
+
nsa_topk: int = 16
|
| 58 |
+
nsa_block_size: int = 64
|
| 59 |
+
nsa_window_size: int = 512
|
| 60 |
+
|
| 61 |
+
# GDN related
|
| 62 |
+
rope_gdn: Optional[str] = None # None, rope, (srope)
|
| 63 |
+
n_heads_gdn: int = 0
|
| 64 |
+
n_kv_heads_gdn: int = 0
|
| 65 |
+
|
| 66 |
+
# optim
|
| 67 |
+
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
| 68 |
+
second_order_optim : Optional[str] = None #Snoo
|
| 69 |
+
batch_size: int = 8*64 # batch size, in sequences, across all devices
|
| 70 |
+
device_batch_size: int = 64 # batch size, in sequences, per device
|
| 71 |
+
total_iterations: int = 1000 # number of iterations to run
|
| 72 |
+
learning_rate: float = 1e-4
|
| 73 |
+
weight_decay: float = 0.
|
| 74 |
+
adam_beta1: float = 0.9
|
| 75 |
+
adam_beta2: float = 0.95
|
| 76 |
+
adam_eps: float = 1e-8
|
| 77 |
+
warmup_iters: int = 200
|
| 78 |
+
warmdown_iters: int = 3000
|
| 79 |
+
grad_norm_clip: float = 1.0
|
| 80 |
+
uscaling_mult_embed: float = 0
|
| 81 |
+
uscaling_mult_scalar: float = 0
|
| 82 |
+
uscaling_mult_head: float = 0
|
| 83 |
+
init_std: float = 0.006
|
| 84 |
+
patch_level_training: bool = False
|
| 85 |
+
patch_level_training_size: int = 4
|
| 86 |
+
patch_level_training_mode: str = "reduced" # reduced = ask L tokens, treat L//K. full = ask K*L tokens, treat L.
|
| 87 |
+
|
| 88 |
+
# data
|
| 89 |
+
vocab_size: int = 50304
|
| 90 |
+
sequence_length: int = 1024
|
| 91 |
+
use_patch_level_training: bool = False
|
| 92 |
+
patch_size: int = 4
|
| 93 |
+
patch_training_fraction: float = 0.67
|
| 94 |
+
input_bin: Optional[str] = None
|
| 95 |
+
input_val_bin: Optional[str] = None
|
| 96 |
+
|
| 97 |
+
# evaluation and logging
|
| 98 |
+
val_loss_every: int = 125
|
| 99 |
+
val_iterations: int = 50 # 1 step = global bs * T tokens
|
| 100 |
+
inspect_every: int = 0
|
| 101 |
+
save_every: int = 1000
|
| 102 |
+
log_dir: str = "logs/"
|
| 103 |
+
wandb_project: str = "dragon_v1.5"
|
| 104 |
+
wandb_name: Optional[str] = None
|
| 105 |
+
log_wandb: bool = False
|
| 106 |
+
|
| 107 |
+
load_arg_from_config: bool = True
|
| 108 |
+
load_optim: bool = True
|
| 109 |
+
load_sched: bool = True
|
| 110 |
+
compile: bool = True
|
| 111 |
+
|
| 112 |
+
# used during training
|
| 113 |
+
slw_window: int = 0
|
| 114 |
+
|
| 115 |
+
args = tyro.cli(NanoArgs)
|
| 116 |
+
|
| 117 |
+
# load model.
|
| 118 |
+
config_hf = DragonConfig(
|
| 119 |
+
scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
|
| 120 |
+
token_shift=args.token_shift,
|
| 121 |
+
patch_level_training=args.patch_level_training,
|
| 122 |
+
patch_level_training_size=args.patch_level_training_size,
|
| 123 |
+
nsa_head_dim=args.nsa_head_dim,
|
| 124 |
+
nsa_topk=args.nsa_topk,
|
| 125 |
+
nsa_block_size=args.nsa_block_size,
|
| 126 |
+
nsa_window_size=args.nsa_window_size,
|
| 127 |
+
cca_head_dim=args.cca_head_dim,
|
| 128 |
+
cca_seq_kernel_size=args.cca_seq_kernel_size,
|
| 129 |
+
num_attention_heads_gdn=args.n_heads_gdn,
|
| 130 |
+
num_key_value_heads_gdn=args.n_kv_heads_gdn,
|
| 131 |
+
zero_centered_gate=args.zero_centered_gate,
|
| 132 |
+
zero_centered_gate_type=args.zero_centered_gate_type,
|
| 133 |
+
scalable_softmax=args.scalable_softmax,
|
| 134 |
+
gate_type=args.gate_type,
|
| 135 |
+
gate_act=args.gate_act,
|
| 136 |
+
gate_attn=args.gate_attn,
|
| 137 |
+
gate_gdn=args.gate_gdn,
|
| 138 |
+
fused_loss_computation=args.fused_loss_computation,
|
| 139 |
+
qk_norm=args.qk_norm,
|
| 140 |
+
num_attention_heads_indexer=args.num_attention_heads_indexer,
|
| 141 |
+
head_dim_indexer=args.head_dim_indexer,
|
| 142 |
+
dsa_q_lora_rank=args.dsa_q_lora_rank,
|
| 143 |
+
dsa_topk=args.dsa_topk,
|
| 144 |
+
zero_centered_gamma=args.zero_centered_gamma,
|
| 145 |
+
vocab_size=args.vocab_size,
|
| 146 |
+
max_position_embeddings=args.sequence_length,
|
| 147 |
+
use_uscaling=args.use_uscaling,
|
| 148 |
+
hidden_size=args.d_model,
|
| 149 |
+
intermediate_size=args.d_model * args.mlp_expand,
|
| 150 |
+
expand_factor=args.expand_factor,
|
| 151 |
+
layers_config=args.layers_config,
|
| 152 |
+
num_attention_heads=args.n_heads,
|
| 153 |
+
num_key_value_heads=args.n_kv_heads if args.n_kv_heads > 0 else args.n_heads,
|
| 154 |
+
initializer_range=args.init_std,
|
| 155 |
+
softcap_local_attn=args.softcap_local_attn,
|
| 156 |
+
softcap_global_attn=args.softcap_global_attn,
|
| 157 |
+
norm_epsilon=args.eps_rmsnorm,
|
| 158 |
+
use_cache=False,
|
| 159 |
+
sliding_window_size=args.swa_window_size,
|
| 160 |
+
rope_theta_local=args.rope_theta_local,
|
| 161 |
+
uscaling_tau=args.uscaling_tau,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
model = DragonForCausalLM(config_hf)
|
| 165 |
+
model = model.cuda()
|
| 166 |
+
|
| 167 |
+
B, L = 2, 2048
|
| 168 |
+
|
| 169 |
+
# ---------- helpers ---------- #
|
| 170 |
+
def l1(x: torch.Tensor) -> float:
|
| 171 |
+
return x.abs().mean().item()
|
| 172 |
+
|
| 173 |
+
def _capture(name: str, store: Dict[str, torch.Tensor], _m, _inp, out):
|
| 174 |
+
"""Save every tensor produced by a module so that we can measure activations."""
|
| 175 |
+
def walk(x, suf=""):
|
| 176 |
+
if torch.is_tensor(x):
|
| 177 |
+
store[f"{name}{suf}"] = x.detach()
|
| 178 |
+
elif isinstance(x, (list, tuple)):
|
| 179 |
+
for i, xi in enumerate(x):
|
| 180 |
+
walk(xi, suf + f"[{i}]")
|
| 181 |
+
walk(out)
|
| 182 |
+
|
| 183 |
+
_stat_pat = re.compile(r"(\.grad\.(?:std|mean|l1)|\.act\.(?:std|mean|l1)|\.(?:std|mean|l1))$")
|
| 184 |
+
|
| 185 |
+
# Support multiple model naming schemes
|
| 186 |
+
_LAYER_PATTERNS = [
|
| 187 |
+
re.compile(r"\.h\.(\d+)\."), # transformer.h.<i>.
|
| 188 |
+
re.compile(r"\.layers\.(\d+)\."), # model.layers.<i>.
|
| 189 |
+
re.compile(r"\.decoder\.layers\.(\d+)\."), # decoder.layers.<i>.
|
| 190 |
+
re.compile(r"\.block\.(\d+)\."), # ...block.<i>.
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
def _find_layer_span_and_idx(key: str):
|
| 194 |
+
for pat in _LAYer_PATTERNS if False else _LAYER_PATTERNS: # keep exact name
|
| 195 |
+
m = pat.search(key)
|
| 196 |
+
if m:
|
| 197 |
+
return m.span(0), int(m.group(1)) # span of ".layers.<i>." and the idx
|
| 198 |
+
return None, -1
|
| 199 |
+
|
| 200 |
+
def _layer_idx(key: str) -> int:
|
| 201 |
+
_, idx = _find_layer_span_and_idx(key)
|
| 202 |
+
return idx
|
| 203 |
+
|
| 204 |
+
def _base_key(key: str) -> str:
|
| 205 |
+
"""Return <parameter-suffix>.<stat> without the layer index, e.g. mixer.linear_qkv.weight.std"""
|
| 206 |
+
span, _ = _find_layer_span_and_idx(key)
|
| 207 |
+
pre_cut = key
|
| 208 |
+
if span:
|
| 209 |
+
s, e = span
|
| 210 |
+
pre_cut = pre_cut[:s] + "." + pre_cut[e:] # collapse the layer segment to a single dot
|
| 211 |
+
# Drop common top-level prefixes
|
| 212 |
+
for prefix in ("transformer.", "model.", "module."):
|
| 213 |
+
if pre_cut.startswith(prefix):
|
| 214 |
+
pre_cut = pre_cut[len(prefix):]
|
| 215 |
+
stat_match = _stat_pat.search(pre_cut)
|
| 216 |
+
assert stat_match, f"No stat suffix in key {key}"
|
| 217 |
+
stat_suffix = stat_match.group(1)
|
| 218 |
+
base_no_stat = pre_cut[: -len(stat_suffix)]
|
| 219 |
+
return f"{base_no_stat}{stat_suffix}"
|
| 220 |
+
|
| 221 |
+
# ---------- main routine ---------- #
|
| 222 |
+
|
| 223 |
+
def show_layer_stats(model: nn.Module) -> str:
|
| 224 |
+
"""Run a forward/backward pass and return aggregated stats in JSON.
|
| 225 |
+
|
| 226 |
+
The JSON schema is:
|
| 227 |
+
{
|
| 228 |
+
"attn.linear_qkv.weight.std": [layer0, layer1, ..., layerN],
|
| 229 |
+
"attn.linear_qkv.grad.std" : [...],
|
| 230 |
+
"attn.linear_qkv.act.std" : [...],
|
| 231 |
+
...
|
| 232 |
+
}
|
| 233 |
+
Layers that do not have a value for a given statistic are represented with null.
|
| 234 |
+
Non‑layer parameters (e.g., embeddings) are kept flat as a single key‑value pair.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
PAD = len(str(len(config_hf.layers_config) - 1))
|
| 238 |
+
|
| 239 |
+
# ----- collect activations ----- #
|
| 240 |
+
acts, hooks = {}, []
|
| 241 |
+
for n, m in model.named_modules():
|
| 242 |
+
if m is model:
|
| 243 |
+
continue # skip root
|
| 244 |
+
hooks.append(m.register_forward_hook(partial(_capture, n, acts)))
|
| 245 |
+
|
| 246 |
+
x = torch.randint(0, config_hf.vocab_size, (B, L), device="cuda")
|
| 247 |
+
y = torch.randint(0, config_hf.vocab_size, (B, L), device="cuda")
|
| 248 |
+
loss = model(input_ids=x, labels=y).loss
|
| 249 |
+
loss.backward()
|
| 250 |
+
|
| 251 |
+
# ----- collect stats (weight / grad / act) ----- #
|
| 252 |
+
raw_stats = {}
|
| 253 |
+
for n, p in model.named_parameters():
|
| 254 |
+
raw_stats[f"{n}.std"] = p.std().item()
|
| 255 |
+
#raw_stats[f"{n}.mean"] = p.mean().item()
|
| 256 |
+
raw_stats[f"{n}.l1"] = l1(p)
|
| 257 |
+
if p.grad is not None:
|
| 258 |
+
raw_stats[f"{n}.grad.std"] = p.grad.std().item()
|
| 259 |
+
#raw_stats[f"{n}.grad.mean"] = p.grad.mean().item()
|
| 260 |
+
raw_stats[f"{n}.grad.l1"] = l1(p.grad)
|
| 261 |
+
for n, a in acts.items():
|
| 262 |
+
raw_stats[f"{n}.act.std"] = a.std().item()
|
| 263 |
+
#raw_stats[f"{n}.act.mean"] = a.mean().item()
|
| 264 |
+
raw_stats[f"{n}.act.l1"] = l1(a)
|
| 265 |
+
|
| 266 |
+
# ----- aggregate across layers ----- #
|
| 267 |
+
agg: Dict[str, List] = defaultdict(lambda: [None] * len(config_hf.layers_config))
|
| 268 |
+
flat: Dict[str, float] = {}
|
| 269 |
+
|
| 270 |
+
for key, val in raw_stats.items():
|
| 271 |
+
layer = _layer_idx(key)
|
| 272 |
+
if layer == -1:
|
| 273 |
+
# params without layer index stay flat
|
| 274 |
+
flat[key] = val
|
| 275 |
+
continue
|
| 276 |
+
base = _base_key(key)
|
| 277 |
+
if layer < len(config_hf.layers_config):
|
| 278 |
+
agg[base][layer] = val
|
| 279 |
+
else:
|
| 280 |
+
# unexpected layer index; fall back to flat
|
| 281 |
+
flat[key] = val
|
| 282 |
+
|
| 283 |
+
# ----- merge flat & aggregated with custom sorting ----- #
|
| 284 |
+
stats = {}
|
| 285 |
+
|
| 286 |
+
# First: per-quantity arrays over layers
|
| 287 |
+
for base_key in sorted(agg.keys()):
|
| 288 |
+
stats[f"inspect/{base_key}"] = agg[base_key] # list of length = #layers (None where absent)
|
| 289 |
+
|
| 290 |
+
# Then: non-layer (“flat”) stats
|
| 291 |
+
for k, v in sorted(flat.items()):
|
| 292 |
+
stats[f"inspect/{k}"] = v
|
| 293 |
+
|
| 294 |
+
return stats
|
| 295 |
+
|
| 296 |
+
filename = "layer_stats.json"
|
| 297 |
+
|
| 298 |
+
json_blob = show_layer_stats(model)
|
| 299 |
+
with open(args.log_dir + filename, "w") as f:
|
| 300 |
+
if json_blob:
|
| 301 |
+
json.dump(json_blob, f, indent=2) # Use json.dump() instead of f.write()
|
| 302 |
+
print(f"✅ Saved layer stats to {args.log_dir + filename} ✅")
|
modeling_dragon.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
optimizers/Ademamix.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from: https://pytorch.org/docs/1.6.0/_modules/torch/optim/adam.html
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
from torch.optim import Optimizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1):
|
| 10 |
+
if step < warmup:
|
| 11 |
+
a = step / float(warmup)
|
| 12 |
+
return (1.0-a) * alpha_start + a * alpha_end
|
| 13 |
+
return alpha_end
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
|
| 17 |
+
|
| 18 |
+
def f(beta, eps=1e-8):
|
| 19 |
+
return math.log(0.5)/math.log(beta+eps)-1
|
| 20 |
+
|
| 21 |
+
def f_inv(t):
|
| 22 |
+
return math.pow(0.5, 1/(t+1))
|
| 23 |
+
|
| 24 |
+
if step < warmup:
|
| 25 |
+
a = step / float(warmup)
|
| 26 |
+
return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
|
| 27 |
+
return beta_end
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AdEMAMix(Optimizer):
|
| 31 |
+
r"""Implements the AdEMAMix algorithm.
|
| 32 |
+
|
| 33 |
+
Arguments:
|
| 34 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 35 |
+
parameter groups
|
| 36 |
+
lr (float, optional): learning rate (default: 1e-3)
|
| 37 |
+
betas (Tuple[float, float, float], optional): coefficients used for computing
|
| 38 |
+
running averages of gradient and its square (default: (0.9, 0.999, 0.9999))
|
| 39 |
+
corresponding to beta_1, beta_2, beta_3 in AdEMAMix
|
| 40 |
+
alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2)
|
| 41 |
+
beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None)
|
| 42 |
+
alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None)
|
| 43 |
+
eps (float, optional): term added to the denominator to improve
|
| 44 |
+
numerical stability (default: 1e-8)
|
| 45 |
+
weight_decay (float, optional): weight decay as in AdamW (default: 0)
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.999), alpha=8.0,
|
| 49 |
+
beta3_warmup=None, alpha_warmup=None, eps=1e-8,
|
| 50 |
+
weight_decay=0):
|
| 51 |
+
if not 0.0 <= lr:
|
| 52 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 53 |
+
if not 0.0 <= eps:
|
| 54 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
| 55 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 56 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
| 57 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 58 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
| 59 |
+
if not 0.0 <= betas[2] < 1.0:
|
| 60 |
+
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
|
| 61 |
+
if not 0.0 <= weight_decay:
|
| 62 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
| 63 |
+
if not 0.0 <= alpha:
|
| 64 |
+
raise ValueError("Invalid alpha value: {}".format(alpha))
|
| 65 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup,
|
| 66 |
+
alpha_warmup=alpha_warmup, weight_decay=weight_decay)
|
| 67 |
+
super(AdEMAMix, self).__init__(params, defaults)
|
| 68 |
+
|
| 69 |
+
def __setstate__(self, state):
|
| 70 |
+
super(AdEMAMix, self).__setstate__(state)
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def step(self, closure=None):
|
| 74 |
+
"""Performs a single optimization step.
|
| 75 |
+
|
| 76 |
+
Arguments:
|
| 77 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 78 |
+
and returns the loss.
|
| 79 |
+
"""
|
| 80 |
+
loss = None
|
| 81 |
+
if closure is not None:
|
| 82 |
+
with torch.enable_grad():
|
| 83 |
+
loss = closure()
|
| 84 |
+
|
| 85 |
+
for group in self.param_groups:
|
| 86 |
+
|
| 87 |
+
lr = group["lr"]
|
| 88 |
+
lmbda = group["weight_decay"]
|
| 89 |
+
eps = group["eps"]
|
| 90 |
+
beta1, beta2, beta3_final = group["betas"]
|
| 91 |
+
beta3_warmup = group["beta3_warmup"]
|
| 92 |
+
alpha_final = group["alpha"]
|
| 93 |
+
alpha_warmup = group["alpha_warmup"]
|
| 94 |
+
|
| 95 |
+
for p in group['params']:
|
| 96 |
+
if p.grad is None:
|
| 97 |
+
continue
|
| 98 |
+
grad = p.grad
|
| 99 |
+
if grad.is_sparse:
|
| 100 |
+
raise RuntimeError('AdEMAMix does not support sparse gradients.')
|
| 101 |
+
|
| 102 |
+
state = self.state[p]
|
| 103 |
+
|
| 104 |
+
# State initialization
|
| 105 |
+
if len(state) == 0:
|
| 106 |
+
state['step'] = 0
|
| 107 |
+
# Exponential moving average of gradient values
|
| 108 |
+
if beta1 != 0.0: # save memory in case beta1 is 0.0
|
| 109 |
+
state['exp_avg_fast'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 110 |
+
else:
|
| 111 |
+
state['exp_avg_fast'] = None
|
| 112 |
+
state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 113 |
+
# Exponential moving average of squared gradient values
|
| 114 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 115 |
+
|
| 116 |
+
exp_avg_fast, exp_avg_slow, exp_avg_sq = state['exp_avg_fast'], state['exp_avg_slow'], state['exp_avg_sq']
|
| 117 |
+
|
| 118 |
+
state['step'] += 1
|
| 119 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
| 120 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
| 121 |
+
|
| 122 |
+
# Compute the effective alpha and beta3 in case warmup is used
|
| 123 |
+
if alpha_warmup is not None:
|
| 124 |
+
alpha = linear_warmup_scheduler(state["step"], alpha_end=alpha_final, alpha_start=0, warmup=alpha_warmup)
|
| 125 |
+
else:
|
| 126 |
+
alpha = alpha_final
|
| 127 |
+
|
| 128 |
+
if beta3_warmup is not None:
|
| 129 |
+
beta3 = linear_hl_warmup_scheduler(state["step"], beta_end=beta3_final, beta_start=beta1, warmup=beta3_warmup)
|
| 130 |
+
else:
|
| 131 |
+
beta3 = beta3_final
|
| 132 |
+
|
| 133 |
+
# Decay the first and second moment running average coefficient
|
| 134 |
+
if beta1 != 0.0:
|
| 135 |
+
exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1)
|
| 136 |
+
else:
|
| 137 |
+
exp_avg_fast = grad
|
| 138 |
+
exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3)
|
| 139 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 140 |
+
|
| 141 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
| 142 |
+
|
| 143 |
+
update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom
|
| 144 |
+
|
| 145 |
+
# decay
|
| 146 |
+
update.add_(p, alpha=lmbda)
|
| 147 |
+
|
| 148 |
+
p.add_(-lr * update)
|
| 149 |
+
|
| 150 |
+
return loss
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__": # small dummy test
|
| 154 |
+
|
| 155 |
+
x = torch.randn((10,7))
|
| 156 |
+
model = torch.nn.Linear(7, 1, bias=False)
|
| 157 |
+
opt = AdEMAMix(params=model.parameters(), lr=1e-2, betas=(0.9, 0.999, 0.9999), alpha=2.0, beta3_warmup=45, alpha_warmup=45, weight_decay=0.1)
|
| 158 |
+
print(model.weight)
|
| 159 |
+
for itr in range(50):
|
| 160 |
+
y = model(x).mean()
|
| 161 |
+
opt.zero_grad()
|
| 162 |
+
y.backward()
|
| 163 |
+
opt.step()
|
| 164 |
+
|
| 165 |
+
print(model.weight)
|
optimizers/Snoo.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Snoo:
|
| 7 |
+
"""
|
| 8 |
+
@DominikKallusky, @vishal9-team, @vinaysrao
|
| 9 |
+
Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can
|
| 10 |
+
improve the stability and smoothness of the optimization process and thus the quality
|
| 11 |
+
of large language models (LLM) and other models. Snoo implicitly adds temporal regularization
|
| 12 |
+
to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter
|
| 13 |
+
minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead
|
| 14 |
+
in compute and moderate memory usage.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def __init__(self, model: nn.Module, lr: float, momentum: float, k: int) -> None:
|
| 19 |
+
self.model = model
|
| 20 |
+
self.lr = lr
|
| 21 |
+
self.momentum = momentum
|
| 22 |
+
self.k = k
|
| 23 |
+
self.current_step = 0
|
| 24 |
+
self.outer_buf = [p.clone() for p in model.parameters()]
|
| 25 |
+
self.model_params = list(self.model.parameters())
|
| 26 |
+
self.optimizer = torch.optim.SGD(
|
| 27 |
+
self.model.parameters(),
|
| 28 |
+
lr=lr,
|
| 29 |
+
momentum=momentum,
|
| 30 |
+
nesterov=True,
|
| 31 |
+
fused=True,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def step(
|
| 36 |
+
self,
|
| 37 |
+
) -> None:
|
| 38 |
+
if self.current_step % self.k == 0:
|
| 39 |
+
for p_new, p_old in zip(self.model_params, self.outer_buf):
|
| 40 |
+
p_new.grad = p_old.data - p_new.data
|
| 41 |
+
p_new.copy_(p_old, non_blocking=True)
|
| 42 |
+
|
| 43 |
+
self.optimizer.step()
|
| 44 |
+
|
| 45 |
+
for p_new, p_old in zip(self.model_params, self.outer_buf):
|
| 46 |
+
p_old.copy_(p_new, non_blocking=True)
|
| 47 |
+
self.current_step += 1
|
| 48 |
+
|
| 49 |
+
def state_dict(self):
|
| 50 |
+
state_dict = {
|
| 51 |
+
"current_step": self.current_step,
|
| 52 |
+
"lr": self.lr,
|
| 53 |
+
"momentum": self.momentum,
|
| 54 |
+
"k": self.k,
|
| 55 |
+
"outer_buf": [p.clone() for p in self.outer_buf],
|
| 56 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 57 |
+
}
|
| 58 |
+
return state_dict
|
| 59 |
+
|
| 60 |
+
def load_state_dict(self, state_dict):
|
| 61 |
+
self.current_step = state_dict["current_step"]
|
| 62 |
+
self.lr = state_dict["lr"]
|
| 63 |
+
self.momentum = state_dict["momentum"]
|
| 64 |
+
self.k = state_dict["k"]
|
| 65 |
+
for p_src, p_dst in zip(state_dict["outer_buf"], self.outer_buf):
|
| 66 |
+
p_dst.copy_(p_src)
|
| 67 |
+
self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
|
optimizers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .Ademamix import AdEMAMix
|
| 2 |
+
from .Snoo import Snoo
|
training_dragon.py
CHANGED
|
@@ -32,9 +32,13 @@ class NanoArgs:
|
|
| 32 |
# arch - general
|
| 33 |
d_model : int = 768
|
| 34 |
n_heads : int = 6 # head dim 128 suggested by @Grad62304977
|
|
|
|
| 35 |
layers_config : str = 4*"lrdlr"
|
| 36 |
-
expand_factor : int =
|
|
|
|
|
|
|
| 37 |
rope_theta_local: float = 10000.0
|
|
|
|
| 38 |
eps_rmsnorm: float = 1e-6
|
| 39 |
mlp_expand: int = 4 # expand factor for MLP
|
| 40 |
fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
|
|
@@ -42,9 +46,16 @@ class NanoArgs:
|
|
| 42 |
uscaling_tau: float = 0.2
|
| 43 |
zero_centered_gamma: bool = False
|
| 44 |
zero_centered_gate: bool = False
|
| 45 |
-
zero_centered_gate_type: int = 1 # 1, 2, 3
|
| 46 |
gate_attn: bool = False
|
| 47 |
gate_gdn: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# attention related
|
| 50 |
n_kv_heads : int = 0
|
|
@@ -56,24 +67,36 @@ class NanoArgs:
|
|
| 56 |
softcap_global_attn: float = 0.0
|
| 57 |
qk_norm: bool = True
|
| 58 |
scalable_softmax: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
num_attention_heads_indexer: int = 8
|
| 60 |
head_dim_indexer: int = 32
|
| 61 |
dsa_q_lora_rank: int = 128
|
| 62 |
dsa_topk: int = 512
|
| 63 |
-
cca_head_dim: int = 128
|
| 64 |
cca_seq_kernel_size: int = 4
|
| 65 |
-
nsa_head_dim: int = 128
|
| 66 |
nsa_topk: int = 16
|
| 67 |
nsa_block_size: int = 64
|
| 68 |
nsa_window_size: int = 512
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
# GDN related
|
| 71 |
rope_gdn: Optional[str] = None # None, rope, (srope)
|
|
|
|
| 72 |
n_heads_gdn: int = 0
|
| 73 |
n_kv_heads_gdn: int = 0
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# optim
|
| 76 |
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
|
|
|
| 77 |
batch_size: int = 8*64 # batch size, in sequences, across all devices
|
| 78 |
device_batch_size: int = 64 # batch size, in sequences, per device
|
| 79 |
total_iterations: int = 1000 # number of iterations to run
|
|
@@ -91,14 +114,13 @@ class NanoArgs:
|
|
| 91 |
init_std: float = 0.006
|
| 92 |
patch_level_training: bool = False
|
| 93 |
patch_level_training_size: int = 4
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
|
| 96 |
# data
|
| 97 |
vocab_size: int = 50304
|
| 98 |
sequence_length: int = 1024
|
| 99 |
-
use_patch_level_training: bool = False
|
| 100 |
-
patch_size: int = 4
|
| 101 |
-
patch_training_fraction: float = 0.67
|
| 102 |
input_bin: Optional[str] = None
|
| 103 |
input_val_bin: Optional[str] = None
|
| 104 |
|
|
@@ -213,11 +235,15 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
|
|
| 213 |
for mod in model.modules():
|
| 214 |
if isinstance(mod, nn.Linear):
|
| 215 |
pname = id2name.get(id(mod.weight), "")
|
|
|
|
| 216 |
fan_in = mod.weight.shape[1]
|
| 217 |
scale = 1 / math.sqrt(fan_in)
|
| 218 |
if "lm_head" in pname:
|
| 219 |
lr_scaled = base_lr_head
|
| 220 |
wd_scaled = 0.0
|
|
|
|
|
|
|
|
|
|
| 221 |
else:
|
| 222 |
lr_scaled = base_lr_hidden * scale
|
| 223 |
wd_scaled = wd / lr_scaled
|
|
@@ -226,7 +252,7 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
|
|
| 226 |
seen.add(mod.weight)
|
| 227 |
|
| 228 |
if mod.bias is not None:
|
| 229 |
-
groups.append({"params": [mod.bias], "lr":
|
| 230 |
seen.add(mod.bias)
|
| 231 |
|
| 232 |
for p in model.parameters():
|
|
@@ -235,13 +261,17 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
|
|
| 235 |
pname = id2name.get(id(p), "<unnamed>")
|
| 236 |
|
| 237 |
if "embedding" in pname:
|
| 238 |
-
fan_out = p.shape[1] # nn.Embedding is transposed
|
| 239 |
#lr_scaled = base_lr / math.sqrt(fan_out) # u-muP
|
| 240 |
lr_scaled = base_lr_embed
|
| 241 |
else:
|
| 242 |
lr_scaled = base_lr_scalar
|
| 243 |
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
return groups
|
| 247 |
|
|
@@ -299,11 +329,13 @@ if master_process:
|
|
| 299 |
with open(f'{logdir}/args.json', 'w') as f: json.dump(vars(args), f)
|
| 300 |
with open(f'{logdir}/args.pkl', 'wb') as f: pickle.dump(args, f)
|
| 301 |
def print0(s, console=True):
|
| 302 |
-
if master_process:
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
| 307 |
if resume_dir is not None and args.load_arg_from_config:
|
| 308 |
saved_args_path = os.path.join(os.path.dirname(resume_dir), "args.pkl")
|
| 309 |
print0(f"Loading args from {saved_args_path}")
|
|
@@ -326,16 +358,14 @@ np.random.seed(seed)
|
|
| 326 |
|
| 327 |
# define convenience variables.
|
| 328 |
B, T = args.device_batch_size, args.sequence_length
|
|
|
|
|
|
|
| 329 |
assert args.batch_size % (B * ddp_world_size) == 0
|
| 330 |
accumulation_steps = args.batch_size // (B * ddp_world_size)
|
| 331 |
|
| 332 |
# load dataloaders.
|
| 333 |
-
if args.patch_level_training:
|
| 334 |
-
|
| 335 |
-
assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
|
| 336 |
-
T = T
|
| 337 |
-
elif args.patch_level_training_mode == "full":
|
| 338 |
-
T = T * args.patch_level_training_size
|
| 339 |
train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
|
| 340 |
val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
|
| 341 |
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
|
|
@@ -343,19 +373,38 @@ print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total}
|
|
| 343 |
|
| 344 |
# load model.
|
| 345 |
config_hf = DragonConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
patch_level_training=args.patch_level_training,
|
| 347 |
patch_level_training_size=args.patch_level_training_size,
|
| 348 |
-
nsa_head_dim=args.nsa_head_dim,
|
| 349 |
nsa_topk=args.nsa_topk,
|
| 350 |
nsa_block_size=args.nsa_block_size,
|
| 351 |
nsa_window_size=args.nsa_window_size,
|
| 352 |
-
cca_head_dim=args.cca_head_dim,
|
| 353 |
cca_seq_kernel_size=args.cca_seq_kernel_size,
|
|
|
|
|
|
|
| 354 |
num_attention_heads_gdn=args.n_heads_gdn,
|
| 355 |
num_key_value_heads_gdn=args.n_kv_heads_gdn,
|
| 356 |
zero_centered_gate=args.zero_centered_gate,
|
| 357 |
zero_centered_gate_type=args.zero_centered_gate_type,
|
| 358 |
scalable_softmax=args.scalable_softmax,
|
|
|
|
|
|
|
|
|
|
| 359 |
gate_attn=args.gate_attn,
|
| 360 |
gate_gdn=args.gate_gdn,
|
| 361 |
fused_loss_computation=args.fused_loss_computation,
|
|
@@ -380,15 +429,19 @@ config_hf = DragonConfig(
|
|
| 380 |
norm_epsilon=args.eps_rmsnorm,
|
| 381 |
use_cache=False,
|
| 382 |
sliding_window_size=args.swa_window_size,
|
|
|
|
|
|
|
|
|
|
| 383 |
rope_theta_local=args.rope_theta_local,
|
| 384 |
uscaling_tau=args.uscaling_tau,
|
|
|
|
| 385 |
)
|
| 386 |
|
| 387 |
if resume_dir is None:
|
| 388 |
model = DragonForCausalLM(config_hf)
|
| 389 |
model = model.cuda()
|
| 390 |
else:
|
| 391 |
-
model = DragonForCausalLM.from_pretrained(resume_dir, torch_dtype=torch.bfloat16)
|
| 392 |
model = model.cuda()
|
| 393 |
print0(model)
|
| 394 |
|
|
@@ -421,12 +474,13 @@ print0(f"number of total parameters: {num_params}")
|
|
| 421 |
uncompiled_model = model
|
| 422 |
model = torch.compile(model, dynamic=True) if args.compile else model
|
| 423 |
model.train()
|
| 424 |
-
model = DDP(model, device_ids=[ddp_local_rank])
|
| 425 |
raw_model = model.module
|
| 426 |
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
|
| 427 |
|
| 428 |
# load optimizers & schedulers.
|
| 429 |
if args.use_uscaling:
|
|
|
|
| 430 |
param_list = param_groups_mup(
|
| 431 |
raw_model,
|
| 432 |
base_lr_hidden=args.learning_rate,
|
|
@@ -435,9 +489,30 @@ if args.use_uscaling:
|
|
| 435 |
base_lr_head=args.uscaling_mult_head*args.learning_rate if args.uscaling_mult_head > 0 else args.learning_rate,
|
| 436 |
wd=args.weight_decay,
|
| 437 |
)
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
else:
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
optimizers = [optimizer]
|
| 442 |
|
| 443 |
def get_lr_wsd(num_iterations, warmup_iters, warmdown_iters, it):
|
|
@@ -478,12 +553,13 @@ WARMUP_SKIP = 10
|
|
| 478 |
|
| 479 |
# begin training.
|
| 480 |
train_loader.reset()
|
| 481 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained("openai-community/gpt2", use_fast=True) # for saving
|
|
|
|
| 482 |
x, y = train_loader.next_batch()
|
| 483 |
|
| 484 |
-
for iter_ in range(start_iter, args.total_iterations+1):
|
| 485 |
-
last_iter = (iter_ == args.total_iterations)
|
| 486 |
-
if iter_ == WARMUP_SKIP:
|
| 487 |
training_time_ms = 0
|
| 488 |
t0 = time.perf_counter()
|
| 489 |
to_log = {}
|
|
@@ -521,7 +597,7 @@ for iter_ in range(start_iter, args.total_iterations+1):
|
|
| 521 |
model.train()
|
| 522 |
|
| 523 |
# log.
|
| 524 |
-
print0(f'iteration:{iter_:0{len(str(args.total_iterations))}d}/{args.total_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms')
|
| 525 |
if master_process:
|
| 526 |
wandb.log({"val_loss": val_loss}, step=iter_)
|
| 527 |
|
|
@@ -530,7 +606,7 @@ for iter_ in range(start_iter, args.total_iterations+1):
|
|
| 530 |
t0 = time.perf_counter()
|
| 531 |
|
| 532 |
# ----------- SAVING SECTION -----------
|
| 533 |
-
if master_process and
|
| 534 |
# stop the clock.
|
| 535 |
torch.cuda.synchronize()
|
| 536 |
training_time_ms += 1000 * (time.perf_counter() - t0)
|
|
@@ -584,14 +660,16 @@ for iter_ in range(start_iter, args.total_iterations+1):
|
|
| 584 |
for opt, sched in zip(optimizers, schedulers):
|
| 585 |
opt.step()
|
| 586 |
sched.step()
|
|
|
|
|
|
|
| 587 |
# null those gradients.
|
| 588 |
model.zero_grad(set_to_none=True)
|
| 589 |
|
| 590 |
# ----------- LOGGING SECTION -----------
|
| 591 |
approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
|
| 592 |
-
avg_step_time = approx_training_time_ms / (iter_ + 1 - WARMUP_SKIP) if iter_ >= WARMUP_SKIP else 0
|
| 593 |
extra = " ".join(f"{k}:{v}" for k, v in (to_log or {}).items())
|
| 594 |
-
print0(f"iteration:{iter_+1:0{len(str(args.total_iterations))}d}/{args.total_iterations} train_loss:{train_loss.item():.4f} lr: {schedulers[0].get_last_lr()[0]:.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms {extra}")
|
| 595 |
if master_process:
|
| 596 |
wandb.log({'train_loss': train_loss.item(), 'step_avg_time': avg_step_time, **{f'lr_{i}': sched.get_last_lr()[0] for i, sched in enumerate(schedulers)}, 'grad_norm': grad_norm.item(), **to_log}, step=iter_)
|
| 597 |
|
|
|
|
| 32 |
# arch - general
|
| 33 |
d_model : int = 768
|
| 34 |
n_heads : int = 6 # head dim 128 suggested by @Grad62304977
|
| 35 |
+
head_dim: Optional[int] = None
|
| 36 |
layers_config : str = 4*"lrdlr"
|
| 37 |
+
expand_factor : int = 2 # expand factor for Mamba/Dragon
|
| 38 |
+
rope_type_local: str = "rope" #p-rope
|
| 39 |
+
rope_type_global: str = "rope" #p-rope
|
| 40 |
rope_theta_local: float = 10000.0
|
| 41 |
+
rope_theta_global: float = 0.0
|
| 42 |
eps_rmsnorm: float = 1e-6
|
| 43 |
mlp_expand: int = 4 # expand factor for MLP
|
| 44 |
fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
|
|
|
|
| 46 |
uscaling_tau: float = 0.2
|
| 47 |
zero_centered_gamma: bool = False
|
| 48 |
zero_centered_gate: bool = False
|
| 49 |
+
zero_centered_gate_type: int = 1 # 1, 2, 3, 4
|
| 50 |
gate_attn: bool = False
|
| 51 |
gate_gdn: bool = True
|
| 52 |
+
gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head), kimi (lora)
|
| 53 |
+
gate_act: str = "silu" # silu, sigmoid
|
| 54 |
+
scalar_proj_as_hidden_matrix: bool = True
|
| 55 |
+
normalization_type: str = "rmsnorm" # rmsnorm, seednorm
|
| 56 |
+
seednorm_wd: bool = True
|
| 57 |
+
mixer_gn: bool = True
|
| 58 |
+
mlp_linking : bool = False
|
| 59 |
|
| 60 |
# attention related
|
| 61 |
n_kv_heads : int = 0
|
|
|
|
| 67 |
softcap_global_attn: float = 0.0
|
| 68 |
qk_norm: bool = True
|
| 69 |
scalable_softmax: bool = True
|
| 70 |
+
resformer : bool = False # Works only on f layers (DiffAttention)
|
| 71 |
+
token_shift_attn: bool = False
|
| 72 |
+
token_shift_gdn: bool = False
|
| 73 |
+
token_conv1d_attn: bool = False
|
| 74 |
+
token_conv1d_gdn: bool = True
|
| 75 |
num_attention_heads_indexer: int = 8
|
| 76 |
head_dim_indexer: int = 32
|
| 77 |
dsa_q_lora_rank: int = 128
|
| 78 |
dsa_topk: int = 512
|
|
|
|
| 79 |
cca_seq_kernel_size: int = 4
|
|
|
|
| 80 |
nsa_topk: int = 16
|
| 81 |
nsa_block_size: int = 64
|
| 82 |
nsa_window_size: int = 512
|
| 83 |
+
num_signal_heads_diff: Optional[int] = None
|
| 84 |
+
tpa_rank: int = 2
|
| 85 |
+
shrink_qk_da: int = 2
|
| 86 |
+
mla_kv_rank: int = 128
|
| 87 |
|
| 88 |
# GDN related
|
| 89 |
rope_gdn: Optional[str] = None # None, rope, (srope)
|
| 90 |
+
head_dim_gdn: Optional[int] = None
|
| 91 |
n_heads_gdn: int = 0
|
| 92 |
n_kv_heads_gdn: int = 0
|
| 93 |
+
shrink_qk_gdn: int = 2
|
| 94 |
+
kda_allow_neg_eigval: bool = False
|
| 95 |
+
kda_num_v_heads: Optional[int] = None
|
| 96 |
|
| 97 |
# optim
|
| 98 |
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
| 99 |
+
second_order_optim : Optional[str] = None # snoo
|
| 100 |
batch_size: int = 8*64 # batch size, in sequences, across all devices
|
| 101 |
device_batch_size: int = 64 # batch size, in sequences, per device
|
| 102 |
total_iterations: int = 1000 # number of iterations to run
|
|
|
|
| 114 |
init_std: float = 0.006
|
| 115 |
patch_level_training: bool = False
|
| 116 |
patch_level_training_size: int = 4
|
| 117 |
+
second_order_lr: float = 0.68
|
| 118 |
+
second_order_momentum: float = 0.37
|
| 119 |
+
second_order_interval: int = 25
|
| 120 |
|
| 121 |
# data
|
| 122 |
vocab_size: int = 50304
|
| 123 |
sequence_length: int = 1024
|
|
|
|
|
|
|
|
|
|
| 124 |
input_bin: Optional[str] = None
|
| 125 |
input_val_bin: Optional[str] = None
|
| 126 |
|
|
|
|
| 235 |
for mod in model.modules():
|
| 236 |
if isinstance(mod, nn.Linear):
|
| 237 |
pname = id2name.get(id(mod.weight), "")
|
| 238 |
+
is_scalar = getattr(mod, "is_scalar_weight", False)
|
| 239 |
fan_in = mod.weight.shape[1]
|
| 240 |
scale = 1 / math.sqrt(fan_in)
|
| 241 |
if "lm_head" in pname:
|
| 242 |
lr_scaled = base_lr_head
|
| 243 |
wd_scaled = 0.0
|
| 244 |
+
elif is_scalar:
|
| 245 |
+
lr_scaled = base_lr_scalar
|
| 246 |
+
wd_scaled = 0.0
|
| 247 |
else:
|
| 248 |
lr_scaled = base_lr_hidden * scale
|
| 249 |
wd_scaled = wd / lr_scaled
|
|
|
|
| 252 |
seen.add(mod.weight)
|
| 253 |
|
| 254 |
if mod.bias is not None:
|
| 255 |
+
groups.append({"params": [mod.bias], "lr": base_lr_scalar, "weight_decay": 0.0})
|
| 256 |
seen.add(mod.bias)
|
| 257 |
|
| 258 |
for p in model.parameters():
|
|
|
|
| 261 |
pname = id2name.get(id(p), "<unnamed>")
|
| 262 |
|
| 263 |
if "embedding" in pname:
|
| 264 |
+
#fan_out = p.shape[1] # nn.Embedding is transposed
|
| 265 |
#lr_scaled = base_lr / math.sqrt(fan_out) # u-muP
|
| 266 |
lr_scaled = base_lr_embed
|
| 267 |
else:
|
| 268 |
lr_scaled = base_lr_scalar
|
| 269 |
|
| 270 |
+
wd_scaled = 0.
|
| 271 |
+
if getattr(p, "requires_weight_decay", False):
|
| 272 |
+
wd_scaled = wd / lr_scaled
|
| 273 |
+
|
| 274 |
+
groups.append({"params": [p], "lr": lr_scaled, "weight_decay": wd_scaled})
|
| 275 |
|
| 276 |
return groups
|
| 277 |
|
|
|
|
| 329 |
with open(f'{logdir}/args.json', 'w') as f: json.dump(vars(args), f)
|
| 330 |
with open(f'{logdir}/args.pkl', 'wb') as f: pickle.dump(args, f)
|
| 331 |
def print0(s, console=True):
|
| 332 |
+
if not master_process: return
|
| 333 |
+
if console:
|
| 334 |
+
print(s)
|
| 335 |
+
try:
|
| 336 |
+
d=os.path.dirname(logfile); d and os.makedirs(d, exist_ok=True)
|
| 337 |
+
with open(logfile, "a", encoding="utf-8") as f: print(s, file=f)
|
| 338 |
+
except: pass
|
| 339 |
if resume_dir is not None and args.load_arg_from_config:
|
| 340 |
saved_args_path = os.path.join(os.path.dirname(resume_dir), "args.pkl")
|
| 341 |
print0(f"Loading args from {saved_args_path}")
|
|
|
|
| 358 |
|
| 359 |
# define convenience variables.
|
| 360 |
B, T = args.device_batch_size, args.sequence_length
|
| 361 |
+
if args.patch_level_training:
|
| 362 |
+
T = args.patch_level_training_size * T
|
| 363 |
assert args.batch_size % (B * ddp_world_size) == 0
|
| 364 |
accumulation_steps = args.batch_size // (B * ddp_world_size)
|
| 365 |
|
| 366 |
# load dataloaders.
|
| 367 |
+
#if args.patch_level_training:
|
| 368 |
+
# assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
|
| 370 |
val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
|
| 371 |
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
|
|
|
|
| 373 |
|
| 374 |
# load model.
|
| 375 |
config_hf = DragonConfig(
|
| 376 |
+
mla_kv_rank=args.mla_kv_rank,
|
| 377 |
+
rope_gdn=args.rope_gdn,
|
| 378 |
+
shrink_qk_da=args.shrink_qk_da,
|
| 379 |
+
shrink_qk_gdn=args.shrink_qk_gdn,
|
| 380 |
+
mixer_gn=args.mixer_gn,
|
| 381 |
+
kda_allow_neg_eigval=args.kda_allow_neg_eigval,
|
| 382 |
+
kda_num_v_heads=args.kda_num_v_heads,
|
| 383 |
+
seednorm_wd=args.seednorm_wd,
|
| 384 |
+
normalization_type=args.normalization_type,
|
| 385 |
+
tpa_rank=args.tpa_rank,
|
| 386 |
+
num_signal_heads_diff=args.num_signal_heads_diff,
|
| 387 |
+
scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
|
| 388 |
+
token_shift_attn=args.token_shift_attn,
|
| 389 |
+
token_shift_gdn=args.token_shift_gdn,
|
| 390 |
+
token_conv1d_attn=args.token_conv1d_attn,
|
| 391 |
+
token_conv1d_gdn=args.token_conv1d_gdn,
|
| 392 |
patch_level_training=args.patch_level_training,
|
| 393 |
patch_level_training_size=args.patch_level_training_size,
|
|
|
|
| 394 |
nsa_topk=args.nsa_topk,
|
| 395 |
nsa_block_size=args.nsa_block_size,
|
| 396 |
nsa_window_size=args.nsa_window_size,
|
|
|
|
| 397 |
cca_seq_kernel_size=args.cca_seq_kernel_size,
|
| 398 |
+
head_dim=args.head_dim,
|
| 399 |
+
head_dim_gdn=args.head_dim_gdn,
|
| 400 |
num_attention_heads_gdn=args.n_heads_gdn,
|
| 401 |
num_key_value_heads_gdn=args.n_kv_heads_gdn,
|
| 402 |
zero_centered_gate=args.zero_centered_gate,
|
| 403 |
zero_centered_gate_type=args.zero_centered_gate_type,
|
| 404 |
scalable_softmax=args.scalable_softmax,
|
| 405 |
+
resformer=args.resformer,
|
| 406 |
+
gate_type=args.gate_type,
|
| 407 |
+
gate_act=args.gate_act,
|
| 408 |
gate_attn=args.gate_attn,
|
| 409 |
gate_gdn=args.gate_gdn,
|
| 410 |
fused_loss_computation=args.fused_loss_computation,
|
|
|
|
| 429 |
norm_epsilon=args.eps_rmsnorm,
|
| 430 |
use_cache=False,
|
| 431 |
sliding_window_size=args.swa_window_size,
|
| 432 |
+
rope_type_global=args.rope_type_global,
|
| 433 |
+
rope_type_local=args.rope_type_local,
|
| 434 |
+
rope_theta_global=args.rope_theta_global,
|
| 435 |
rope_theta_local=args.rope_theta_local,
|
| 436 |
uscaling_tau=args.uscaling_tau,
|
| 437 |
+
mlp_linking=args.mlp_linking
|
| 438 |
)
|
| 439 |
|
| 440 |
if resume_dir is None:
|
| 441 |
model = DragonForCausalLM(config_hf)
|
| 442 |
model = model.cuda()
|
| 443 |
else:
|
| 444 |
+
model = DragonForCausalLM.from_pretrained(resume_dir, config=config_hf, torch_dtype=torch.bfloat16)
|
| 445 |
model = model.cuda()
|
| 446 |
print0(model)
|
| 447 |
|
|
|
|
| 474 |
uncompiled_model = model
|
| 475 |
model = torch.compile(model, dynamic=True) if args.compile else model
|
| 476 |
model.train()
|
| 477 |
+
model = DDP(model, device_ids=[ddp_local_rank], find_unused_parameters=args.resformer)
|
| 478 |
raw_model = model.module
|
| 479 |
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
|
| 480 |
|
| 481 |
# load optimizers & schedulers.
|
| 482 |
if args.use_uscaling:
|
| 483 |
+
#assert args.optim == "adamw", "uscaling is only supported with AdamW optimizer currently"
|
| 484 |
param_list = param_groups_mup(
|
| 485 |
raw_model,
|
| 486 |
base_lr_hidden=args.learning_rate,
|
|
|
|
| 489 |
base_lr_head=args.uscaling_mult_head*args.learning_rate if args.uscaling_mult_head > 0 else args.learning_rate,
|
| 490 |
wd=args.weight_decay,
|
| 491 |
)
|
| 492 |
+
if args.optim == "adamw":
|
| 493 |
+
optimizer = torch.optim.AdamW(param_list, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
|
| 494 |
+
elif args.optim == "ademamix":
|
| 495 |
+
from .optimizers.Ademamix import AdEMAMix
|
| 496 |
+
beta3_warmup = alpha_warmup = args.total_iterations
|
| 497 |
+
optimizer = AdEMAMix(param_list, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, weight_decay=args.weight_decay)
|
| 498 |
+
else:
|
| 499 |
+
raise ValueError(f"Unknown optimizer for unit scaling: {args.optim}")
|
| 500 |
else:
|
| 501 |
+
if args.optim == "adamw":
|
| 502 |
+
optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
|
| 503 |
+
elif args.optim == "ademamix":
|
| 504 |
+
from .optimizers.Ademamix import AdEMAMix
|
| 505 |
+
|
| 506 |
+
beta3_warmup = alpha_warmup = args.total_iterations
|
| 507 |
+
optimizer = AdEMAMix(raw_model.parameters(), lr=args.learning_rate, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, weight_decay=args.weight_decay)
|
| 508 |
+
else:
|
| 509 |
+
raise ValueError(f"Unknown Optimizer: {args.optim}")
|
| 510 |
+
if args.second_order_optim == "snoo":
|
| 511 |
+
from .optimizers.Snoo import Snoo
|
| 512 |
+
second_order_optim = Snoo(raw_model, lr=args.second_order_lr, momentum=args.second_order_momentum, k=args.second_order_interval)
|
| 513 |
+
else:
|
| 514 |
+
second_order_optim = None
|
| 515 |
+
|
| 516 |
optimizers = [optimizer]
|
| 517 |
|
| 518 |
def get_lr_wsd(num_iterations, warmup_iters, warmdown_iters, it):
|
|
|
|
| 553 |
|
| 554 |
# begin training.
|
| 555 |
train_loader.reset()
|
| 556 |
+
#tokenizer = transformers.AutoTokenizer.from_pretrained("openai-community/gpt2", use_fast=True) # for saving
|
| 557 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("/leonardo_work/BOOST_LCustodi/script/training/temp/hf_models/gpt2", use_fast=True)
|
| 558 |
x, y = train_loader.next_batch()
|
| 559 |
|
| 560 |
+
for iter_ in range(start_iter, start_iter+args.total_iterations+1):
|
| 561 |
+
last_iter = (iter_ == start_iter+args.total_iterations)
|
| 562 |
+
if iter_ == start_iter+WARMUP_SKIP:
|
| 563 |
training_time_ms = 0
|
| 564 |
t0 = time.perf_counter()
|
| 565 |
to_log = {}
|
|
|
|
| 597 |
model.train()
|
| 598 |
|
| 599 |
# log.
|
| 600 |
+
print0(f'iteration:{iter_:0{len(str(start_iter+args.total_iterations))}d}/{args.total_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms')
|
| 601 |
if master_process:
|
| 602 |
wandb.log({"val_loss": val_loss}, step=iter_)
|
| 603 |
|
|
|
|
| 606 |
t0 = time.perf_counter()
|
| 607 |
|
| 608 |
# ----------- SAVING SECTION -----------
|
| 609 |
+
if master_process and (last_iter or (args.save_every > 0 and iter_ % args.save_every == 0)):
|
| 610 |
# stop the clock.
|
| 611 |
torch.cuda.synchronize()
|
| 612 |
training_time_ms += 1000 * (time.perf_counter() - t0)
|
|
|
|
| 660 |
for opt, sched in zip(optimizers, schedulers):
|
| 661 |
opt.step()
|
| 662 |
sched.step()
|
| 663 |
+
if second_order_optim:
|
| 664 |
+
second_order_optim.step()
|
| 665 |
# null those gradients.
|
| 666 |
model.zero_grad(set_to_none=True)
|
| 667 |
|
| 668 |
# ----------- LOGGING SECTION -----------
|
| 669 |
approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
|
| 670 |
+
avg_step_time = approx_training_time_ms / (iter_ + 1 - WARMUP_SKIP) if iter_ >= start_iter+WARMUP_SKIP else 0
|
| 671 |
extra = " ".join(f"{k}:{v}" for k, v in (to_log or {}).items())
|
| 672 |
+
print0(f"iteration:{iter_+1:0{len(str(start_iter+args.total_iterations))}d}/{args.total_iterations} train_loss:{train_loss.item():.4f} lr: {schedulers[0].get_last_lr()[0]:.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms {extra}")
|
| 673 |
if master_process:
|
| 674 |
wandb.log({'train_loss': train_loss.item(), 'step_avg_time': avg_step_time, **{f'lr_{i}': sched.get_last_lr()[0] for i, sched in enumerate(schedulers)}, 'grad_norm': grad_norm.item(), **to_log}, step=iter_)
|
| 675 |
|