Commit
·
d79da9a
1
Parent(s):
3b164a1
alpha normalize ademamix | mamba norms and gate | VWN | wnorm (nemotron-flash) | MG equivalence | fix IDM config saving | CCAv2 | MoBA | reduce lm head
Browse files- compute_loss.py +422 -0
- configuration_dragon.py +22 -0
- modeling_dragon.py +934 -42
- optimizers/Ademamix.py +4 -1
- training_dragon.py +151 -42
compute_loss.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import pickle
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import tyro
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from .configuration_dragon import DragonConfig
|
| 13 |
+
from .modeling_dragon import DragonForCausalLM
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class Args:
|
| 17 |
+
load_dir: str
|
| 18 |
+
val_bin: str
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class NanoArgs:
|
| 22 |
+
resume_from: Optional[str] = None
|
| 23 |
+
run_name : str = ""
|
| 24 |
+
|
| 25 |
+
# arch - general
|
| 26 |
+
d_model : int = 768
|
| 27 |
+
n_heads : int = 6 # head dim 128 suggested by @Grad62304977
|
| 28 |
+
head_dim: Optional[int] = None
|
| 29 |
+
layers_config : str = 4*"lrdlr"
|
| 30 |
+
expand_factor : int = 2 # expand factor for Mamba/Dragon
|
| 31 |
+
rope_type_local: str = "" #p-rope
|
| 32 |
+
rope_type_global: str = "" #p-rope
|
| 33 |
+
rope_theta_local: float = 10000.0
|
| 34 |
+
rope_theta_global: float = 0.0
|
| 35 |
+
eps_rmsnorm: float = 1e-6
|
| 36 |
+
mlp_expand: int = 4 # expand factor for MLP
|
| 37 |
+
fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
|
| 38 |
+
use_uscaling: bool = False
|
| 39 |
+
uscaling_tau: float = 0.2
|
| 40 |
+
zero_centered_gamma: bool = False
|
| 41 |
+
zero_centered_gate: bool = False
|
| 42 |
+
zero_centered_gate_type: int = 1 # 1, 2, 3, 4
|
| 43 |
+
gate_attn: bool = False
|
| 44 |
+
gate_gdn: bool = True
|
| 45 |
+
gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head), kimi (lora)
|
| 46 |
+
gate_act: str = "silu" # silu, sigmoid
|
| 47 |
+
scalar_proj_as_hidden_matrix: bool = True
|
| 48 |
+
normalization_type: str = "rmsnorm" # rmsnorm, seednorm
|
| 49 |
+
seednorm_wd: bool = True
|
| 50 |
+
seednorm_type: int = 1
|
| 51 |
+
seednorm_rank: int = 1
|
| 52 |
+
mixer_gn: bool = True
|
| 53 |
+
mlp_linking : bool = False
|
| 54 |
+
final_norm: bool = True
|
| 55 |
+
layer_norm_scaling: bool = False # not read when using muP
|
| 56 |
+
mlp_type: str = "simple" # simple, gated
|
| 57 |
+
tie_lm_head: bool = False
|
| 58 |
+
|
| 59 |
+
# MoE
|
| 60 |
+
moe: bool = False
|
| 61 |
+
moe_num_routed_experts: int = 2
|
| 62 |
+
moe_routed_scaling_factor: float = 2.5
|
| 63 |
+
moe_routed_intermediate_size: int = 768
|
| 64 |
+
moe_shared_intermediate_size: int = 768
|
| 65 |
+
|
| 66 |
+
# attention related
|
| 67 |
+
n_kv_heads : int = 0
|
| 68 |
+
swa_window_size : int = 1024
|
| 69 |
+
slw_warmup_iters: float = 0
|
| 70 |
+
slw_start: int = 8 # window size at the start of training
|
| 71 |
+
slw_increment: int = 64 # window size increment at each step
|
| 72 |
+
softcap_local_attn: float = 0.0 # logit soft-capping for local attn logits, as per Gemma2 (0.0 = no soft-capping)
|
| 73 |
+
softcap_global_attn: float = 0.0
|
| 74 |
+
qk_norm: bool = True
|
| 75 |
+
scalable_softmax: bool = True
|
| 76 |
+
resformer : bool = False # Works only on f layers (DiffAttention)
|
| 77 |
+
token_shift_attn: bool = False
|
| 78 |
+
token_shift_gdn: bool = False
|
| 79 |
+
token_conv1d_attn: bool = False
|
| 80 |
+
token_conv1d_gdn: bool = True
|
| 81 |
+
num_attention_heads_indexer: int = 8
|
| 82 |
+
head_dim_indexer: int = 32
|
| 83 |
+
dsa_q_lora_rank: int = 128
|
| 84 |
+
dsa_topk: int = 512
|
| 85 |
+
cca_seq_kernel_size: int = 4
|
| 86 |
+
nsa_topk: int = 16
|
| 87 |
+
nsa_block_size: int = 64
|
| 88 |
+
nsa_window_size: int = 512
|
| 89 |
+
num_signal_heads_diff: Optional[int] = None
|
| 90 |
+
tpa_rank: int = 2
|
| 91 |
+
shrink_qk_da: int = 2
|
| 92 |
+
mla_kv_rank: int = 128
|
| 93 |
+
|
| 94 |
+
# GDN related
|
| 95 |
+
rope_gdn: Optional[str] = None # None, rope, (srope)
|
| 96 |
+
head_dim_gdn: Optional[int] = None
|
| 97 |
+
n_heads_gdn: int = 0
|
| 98 |
+
n_kv_heads_gdn: int = 0
|
| 99 |
+
shrink_qk_gdn: int = 2
|
| 100 |
+
kda_allow_neg_eigval: bool = False
|
| 101 |
+
kda_num_v_heads: Optional[int] = None
|
| 102 |
+
mamba_mimo_dim: Optional[int] = 2
|
| 103 |
+
mamba_ngroups: Optional[int] = 1
|
| 104 |
+
mamba_d_state: int = 128
|
| 105 |
+
mamba_headdim: int = 64
|
| 106 |
+
mamba3_rope: bool = True
|
| 107 |
+
mamba3_remove_BC_bias: bool = False
|
| 108 |
+
mamba3_is_id_rms: bool = True
|
| 109 |
+
mamba3_remove_conv: bool = True
|
| 110 |
+
mamba3_is_A_dd: bool = True
|
| 111 |
+
mamba3_add_trapezoid: bool = True
|
| 112 |
+
|
| 113 |
+
# optim
|
| 114 |
+
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
| 115 |
+
second_order_optim : Optional[str] = None # snoo
|
| 116 |
+
batch_size: int = 8*64 # batch size, in sequences, across all devices
|
| 117 |
+
device_batch_size: int = 64 # batch size, in sequences, per device
|
| 118 |
+
total_iterations: int = 1000 # number of iterations to run
|
| 119 |
+
learning_rate: float = 1e-4
|
| 120 |
+
weight_decay: float = 0.
|
| 121 |
+
adam_beta1: float = 0.9
|
| 122 |
+
adam_beta2: float = 0.95
|
| 123 |
+
adam_eps: float = 1e-8
|
| 124 |
+
warmup_iters: int = 200
|
| 125 |
+
warmdown_iters: int = 3000
|
| 126 |
+
warmdown_type: str = "linear" # linear, cosine
|
| 127 |
+
grad_norm_clip: float = 1.0
|
| 128 |
+
uscaling_mult_embed: float = 0
|
| 129 |
+
uscaling_mult_scalar: float = 0
|
| 130 |
+
uscaling_mult_head: float = 0
|
| 131 |
+
init_std: float = 0.006
|
| 132 |
+
patch_level_training: bool = False
|
| 133 |
+
patch_level_training_size: int = 4
|
| 134 |
+
second_order_lr: float = 0.68
|
| 135 |
+
second_order_momentum: float = 0.37
|
| 136 |
+
second_order_interval: int = 25
|
| 137 |
+
|
| 138 |
+
# data
|
| 139 |
+
vocab_size: int = 50304
|
| 140 |
+
bos_id: int = 50256
|
| 141 |
+
sequence_length: int = 1024
|
| 142 |
+
intra_doc_masking: bool = False
|
| 143 |
+
input_bin: Optional[str] = None
|
| 144 |
+
input_val_bin: Optional[str] = None
|
| 145 |
+
|
| 146 |
+
# evaluation and logging
|
| 147 |
+
val_loss_every: int = 125
|
| 148 |
+
val_iterations: int = 50 # 1 step = global bs * T tokens
|
| 149 |
+
inspect_every: int = 0
|
| 150 |
+
save_every: int = 1000
|
| 151 |
+
log_dir: str = "logs/"
|
| 152 |
+
wandb_project: str = "dragon_v1.5"
|
| 153 |
+
wandb_name: Optional[str] = None
|
| 154 |
+
log_wandb: bool = False
|
| 155 |
+
|
| 156 |
+
load_arg_from_config: bool = True
|
| 157 |
+
load_optim: bool = True
|
| 158 |
+
load_sched: bool = True
|
| 159 |
+
compile: bool = True
|
| 160 |
+
compile_dynamic: bool = False
|
| 161 |
+
|
| 162 |
+
# used during training
|
| 163 |
+
slw_window: int = 0
|
| 164 |
+
|
| 165 |
+
def _peek_data_shard(filename):
|
| 166 |
+
with open(filename, "rb") as f:
|
| 167 |
+
header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
|
| 168 |
+
if header[0] != 20240520:
|
| 169 |
+
print("ERROR: magic number mismatch in the data .bin file!")
|
| 170 |
+
print("---> HINT: Are you passing in a correct file with --input_bin?")
|
| 171 |
+
print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
|
| 172 |
+
exit(1)
|
| 173 |
+
assert header[1] == 1, "unsupported version"
|
| 174 |
+
ntok = int(header[2])
|
| 175 |
+
return ntok
|
| 176 |
+
|
| 177 |
+
def _load_data_shard(filename):
|
| 178 |
+
with open(filename, "rb") as f:
|
| 179 |
+
header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
|
| 180 |
+
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
|
| 181 |
+
assert header[1] == 1, "unsupported version"
|
| 182 |
+
ntok = int(header[2])
|
| 183 |
+
# memmap the token payload directly (uint16) after the 256*4B header
|
| 184 |
+
tokens = np.memmap(filename, dtype=np.uint16, mode="r", offset=256 * 4, shape=(ntok,))
|
| 185 |
+
assert tokens.size == ntok, "number of tokens read does not match header?"
|
| 186 |
+
return tokens
|
| 187 |
+
|
| 188 |
+
class DistributedDataLoader:
|
| 189 |
+
def __init__(self, filename_pattern, intra_doc_masking,B, T, process_rank, num_processes, bos_id, stop_on_end=False):
|
| 190 |
+
self.process_rank = process_rank
|
| 191 |
+
self.num_processes = num_processes
|
| 192 |
+
self.intra_doc_masking = intra_doc_masking
|
| 193 |
+
self.bos_id = bos_id
|
| 194 |
+
self.B = B # micro batch size
|
| 195 |
+
self.T = T
|
| 196 |
+
self.stop_on_end = stop_on_end
|
| 197 |
+
|
| 198 |
+
# glob files that match the pattern
|
| 199 |
+
self.files = sorted(glob.glob(filename_pattern))
|
| 200 |
+
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
|
| 201 |
+
if self.stop_on_end:
|
| 202 |
+
assert len(self.files) == 1, "Pass a single .bin path (not a pattern) when stop_on_end=True."
|
| 203 |
+
|
| 204 |
+
# load and validate all data shards, count number of tokens in total
|
| 205 |
+
ntok_total = 0
|
| 206 |
+
self.shard_ntoks = []
|
| 207 |
+
for fname in self.files:
|
| 208 |
+
shard_ntok = _peek_data_shard(fname)
|
| 209 |
+
#print(f"shard {fname} has {shard_ntok} tokens")
|
| 210 |
+
assert shard_ntok >= num_processes * B * T + 1
|
| 211 |
+
self.shard_ntoks.append(shard_ntok)
|
| 212 |
+
ntok_total += int(shard_ntok)
|
| 213 |
+
self.ntok_total = ntok_total
|
| 214 |
+
|
| 215 |
+
# kick things off
|
| 216 |
+
self.reset()
|
| 217 |
+
|
| 218 |
+
def reset(self, shard=0):
|
| 219 |
+
self.current_shard = shard
|
| 220 |
+
self.current_position = self.process_rank * self.B * self.T
|
| 221 |
+
self.tokens = _load_data_shard(self.files[self.current_shard])
|
| 222 |
+
|
| 223 |
+
def advance(self): # advance to next data shard
|
| 224 |
+
self.current_shard = (self.current_shard + 1) % len(self.files)
|
| 225 |
+
self.current_position = self.process_rank * self.B * self.T
|
| 226 |
+
self.tokens = _load_data_shard(self.files[self.current_shard])
|
| 227 |
+
|
| 228 |
+
if self.process_rank == 0:
|
| 229 |
+
shard_tokens = self.shard_ntoks[self.current_shard]
|
| 230 |
+
cum_tokens = sum(self.shard_ntoks[: self.current_shard + 1])
|
| 231 |
+
|
| 232 |
+
def _fmt(n):
|
| 233 |
+
return f"{n/1e9:.2f}B" if n >= 1_000_000_000 else (
|
| 234 |
+
f"{n/1e6:.2f}M" if n >= 1_000_000 else str(n))
|
| 235 |
+
|
| 236 |
+
print(
|
| 237 |
+
f"Advancing to shard {self.current_shard}/{len(self.files)-1} "
|
| 238 |
+
f"(this={_fmt(shard_tokens)} tok, cum={_fmt(cum_tokens)}/{_fmt(self.ntok_total)})"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def next_batch(self):
|
| 242 |
+
B = self.B
|
| 243 |
+
T = self.T
|
| 244 |
+
buf = self.tokens[self.current_position : self.current_position+B*T]
|
| 245 |
+
buf = np.asarray(buf, dtype=np.int64)
|
| 246 |
+
x = torch.from_numpy(buf.reshape(B, T)) # inputs
|
| 247 |
+
y = torch.from_numpy(buf.reshape(B, T)) # targets
|
| 248 |
+
|
| 249 |
+
# compute cumulative document positions for intra-document masking
|
| 250 |
+
cu = None
|
| 251 |
+
maxlen = None
|
| 252 |
+
position_ids = None
|
| 253 |
+
if self.intra_doc_masking:
|
| 254 |
+
assert self.B == 1
|
| 255 |
+
starts = (x == self.bos_id).nonzero(as_tuple=True)[1].to(torch.long)
|
| 256 |
+
if starts.numel() == 0 or starts[0] != 0:
|
| 257 |
+
starts = torch.cat([torch.zeros(1, dtype=torch.long), starts])
|
| 258 |
+
ends = torch.cat([starts[1:], torch.tensor([x.numel()])])
|
| 259 |
+
seqlens = (ends - starts).to(torch.int32)
|
| 260 |
+
# cu_seqlens, max_seqlen.
|
| 261 |
+
cu = torch.cat([torch.zeros(1, dtype=torch.int32), seqlens.cumsum(0)]).cuda().to(torch.int32)
|
| 262 |
+
maxlen = int(seqlens.max())
|
| 263 |
+
# position_ids.
|
| 264 |
+
lengths = seqlens.to(torch.long)
|
| 265 |
+
starts_per_token = torch.repeat_interleave(starts.to(torch.long), lengths)
|
| 266 |
+
idx = torch.arange(T, device=x.device, dtype=torch.long)
|
| 267 |
+
position_ids = (idx - starts_per_token).unsqueeze(0)
|
| 268 |
+
|
| 269 |
+
# advance current position and load next shard if necessary
|
| 270 |
+
self.current_position += B * T * self.num_processes
|
| 271 |
+
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
|
| 272 |
+
if self.stop_on_end:
|
| 273 |
+
raise StopIteration
|
| 274 |
+
else:
|
| 275 |
+
self.advance()
|
| 276 |
+
|
| 277 |
+
return x.cuda(), y.cuda(), cu, maxlen, position_ids
|
| 278 |
+
|
| 279 |
+
run_args = tyro.cli(Args)
|
| 280 |
+
|
| 281 |
+
saved_args_path = os.path.join(os.path.dirname(run_args.load_dir), "args.pkl")
|
| 282 |
+
print(f"Loading args from {saved_args_path}")
|
| 283 |
+
if os.path.exists(saved_args_path):
|
| 284 |
+
with open(saved_args_path, "rb") as f:
|
| 285 |
+
saved_args = pickle.load(f)
|
| 286 |
+
args: NanoArgs = saved_args
|
| 287 |
+
|
| 288 |
+
print(args)
|
| 289 |
+
|
| 290 |
+
B, T = args.device_batch_size, args.sequence_length
|
| 291 |
+
accumulation_steps = args.batch_size // (B * 1)
|
| 292 |
+
|
| 293 |
+
val_loader = DistributedDataLoader(run_args.val_bin, False, B, T, 0, 1, args.bos_id, stop_on_end=True)
|
| 294 |
+
print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
|
| 295 |
+
|
| 296 |
+
# load model.
|
| 297 |
+
config_hf = DragonConfig(
|
| 298 |
+
tie_lm_head=args.tie_lm_head,
|
| 299 |
+
mlp_type=args.mlp_type,
|
| 300 |
+
layer_norm_scaling=args.layer_norm_scaling,
|
| 301 |
+
mamba_d_state=args.mamba_d_state,
|
| 302 |
+
mamba_headdim=args.mamba_headdim,
|
| 303 |
+
mamba3_rope=args.mamba3_rope,
|
| 304 |
+
mamba3_remove_BC_bias=args.mamba3_remove_BC_bias,
|
| 305 |
+
mamba3_is_id_rms=args.mamba3_is_id_rms,
|
| 306 |
+
mamba3_remove_conv=args.mamba3_remove_conv,
|
| 307 |
+
mamba3_is_A_dd=args.mamba3_is_A_dd,
|
| 308 |
+
mamba3_add_trapezoid=args.mamba3_add_trapezoid,
|
| 309 |
+
moe=args.moe,
|
| 310 |
+
moe_num_routed_experts=args.moe_num_routed_experts,
|
| 311 |
+
moe_routed_scaling_factor=args.moe_routed_scaling_factor,
|
| 312 |
+
moe_routed_intermediate_size=args.moe_routed_intermediate_size,
|
| 313 |
+
moe_shared_intermediate_size=args.moe_shared_intermediate_size,
|
| 314 |
+
intra_doc_masking=args.intra_doc_masking,
|
| 315 |
+
seednorm_rank=args.seednorm_rank,
|
| 316 |
+
seednorm_type=args.seednorm_type,
|
| 317 |
+
final_norm=args.final_norm,
|
| 318 |
+
mla_kv_rank=args.mla_kv_rank,
|
| 319 |
+
rope_gdn=args.rope_gdn,
|
| 320 |
+
shrink_qk_da=args.shrink_qk_da,
|
| 321 |
+
shrink_qk_gdn=args.shrink_qk_gdn,
|
| 322 |
+
mixer_gn=args.mixer_gn,
|
| 323 |
+
kda_allow_neg_eigval=args.kda_allow_neg_eigval,
|
| 324 |
+
kda_num_v_heads=args.kda_num_v_heads,
|
| 325 |
+
seednorm_wd=args.seednorm_wd,
|
| 326 |
+
normalization_type=args.normalization_type,
|
| 327 |
+
tpa_rank=args.tpa_rank,
|
| 328 |
+
num_signal_heads_diff=args.num_signal_heads_diff,
|
| 329 |
+
scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
|
| 330 |
+
token_shift_attn=args.token_shift_attn,
|
| 331 |
+
token_shift_gdn=args.token_shift_gdn,
|
| 332 |
+
token_conv1d_attn=args.token_conv1d_attn,
|
| 333 |
+
token_conv1d_gdn=args.token_conv1d_gdn,
|
| 334 |
+
patch_level_training=args.patch_level_training,
|
| 335 |
+
patch_level_training_size=args.patch_level_training_size,
|
| 336 |
+
nsa_topk=args.nsa_topk,
|
| 337 |
+
nsa_block_size=args.nsa_block_size,
|
| 338 |
+
nsa_window_size=args.nsa_window_size,
|
| 339 |
+
cca_seq_kernel_size=args.cca_seq_kernel_size,
|
| 340 |
+
head_dim=args.head_dim,
|
| 341 |
+
head_dim_gdn=args.head_dim_gdn,
|
| 342 |
+
num_attention_heads_gdn=args.n_heads_gdn,
|
| 343 |
+
num_key_value_heads_gdn=args.n_kv_heads_gdn,
|
| 344 |
+
zero_centered_gate=args.zero_centered_gate,
|
| 345 |
+
zero_centered_gate_type=args.zero_centered_gate_type,
|
| 346 |
+
scalable_softmax=args.scalable_softmax,
|
| 347 |
+
mamba_mimo_dim=args.mamba_mimo_dim,
|
| 348 |
+
mamba_ngroups=args.mamba_ngroups,
|
| 349 |
+
resformer=args.resformer,
|
| 350 |
+
gate_type=args.gate_type,
|
| 351 |
+
gate_act=args.gate_act,
|
| 352 |
+
gate_attn=args.gate_attn,
|
| 353 |
+
gate_gdn=args.gate_gdn,
|
| 354 |
+
fused_loss_computation=args.fused_loss_computation,
|
| 355 |
+
qk_norm=args.qk_norm,
|
| 356 |
+
num_attention_heads_indexer=args.num_attention_heads_indexer,
|
| 357 |
+
head_dim_indexer=args.head_dim_indexer,
|
| 358 |
+
dsa_q_lora_rank=args.dsa_q_lora_rank,
|
| 359 |
+
dsa_topk=args.dsa_topk,
|
| 360 |
+
zero_centered_gamma=args.zero_centered_gamma,
|
| 361 |
+
vocab_size=args.vocab_size,
|
| 362 |
+
max_position_embeddings=args.sequence_length,
|
| 363 |
+
use_uscaling=args.use_uscaling,
|
| 364 |
+
hidden_size=args.d_model,
|
| 365 |
+
intermediate_size=args.d_model * args.mlp_expand,
|
| 366 |
+
expand_factor=args.expand_factor,
|
| 367 |
+
layers_config=args.layers_config,
|
| 368 |
+
num_attention_heads=args.n_heads,
|
| 369 |
+
num_key_value_heads=args.n_kv_heads if args.n_kv_heads > 0 else args.n_heads,
|
| 370 |
+
initializer_range=args.init_std,
|
| 371 |
+
softcap_local_attn=args.softcap_local_attn,
|
| 372 |
+
softcap_global_attn=args.softcap_global_attn,
|
| 373 |
+
norm_epsilon=args.eps_rmsnorm,
|
| 374 |
+
use_cache=False,
|
| 375 |
+
sliding_window_size=args.swa_window_size,
|
| 376 |
+
rope_type_global=args.rope_type_global,
|
| 377 |
+
rope_type_local=args.rope_type_local,
|
| 378 |
+
rope_theta_global=args.rope_theta_global,
|
| 379 |
+
rope_theta_local=args.rope_theta_local,
|
| 380 |
+
uscaling_tau=args.uscaling_tau,
|
| 381 |
+
mlp_linking=args.mlp_linking
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
model = DragonForCausalLM.from_pretrained(run_args.load_dir, config=config_hf, torch_dtype=torch.bfloat16)
|
| 385 |
+
model = model.cuda()
|
| 386 |
+
|
| 387 |
+
model = torch.compile(model, dynamic=args.compile_dynamic) if args.compile else model
|
| 388 |
+
model.eval()
|
| 389 |
+
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
|
| 390 |
+
|
| 391 |
+
val_loader.reset()
|
| 392 |
+
total_steps = (val_loader.shard_ntoks[val_loader.current_shard] - 1) // (B * T * val_loader.num_processes)
|
| 393 |
+
pbar = tqdm(total=total_steps, desc="Validating", unit="step")
|
| 394 |
+
val_loss_sum = torch.zeros((), device="cuda", dtype=torch.float32)
|
| 395 |
+
n_steps = 0
|
| 396 |
+
tok_per_step = B * T
|
| 397 |
+
|
| 398 |
+
with torch.no_grad():
|
| 399 |
+
while True:
|
| 400 |
+
try:
|
| 401 |
+
inputs, targets, cu, maxlen, position_ids = val_loader.next_batch()
|
| 402 |
+
except StopIteration:
|
| 403 |
+
break
|
| 404 |
+
with ctx:
|
| 405 |
+
step_loss = model(
|
| 406 |
+
input_ids=inputs,
|
| 407 |
+
labels=targets,
|
| 408 |
+
just_loss=True,
|
| 409 |
+
cu_seqlens=cu,
|
| 410 |
+
max_seqlen=maxlen,
|
| 411 |
+
position_ids=position_ids,
|
| 412 |
+
).loss.detach()
|
| 413 |
+
val_loss_sum += step_loss
|
| 414 |
+
n_steps += 1
|
| 415 |
+
avg = (val_loss_sum / n_steps).item()
|
| 416 |
+
pbar.update(1)
|
| 417 |
+
pbar.set_postfix(avg_loss=f"{avg:.4f}", ppl=f"{np.exp(avg):.2f}")
|
| 418 |
+
pbar.close()
|
| 419 |
+
|
| 420 |
+
assert n_steps > 0, "No batches read from the file; check B/T vs file size."
|
| 421 |
+
val_loss = (val_loss_sum / n_steps).item()
|
| 422 |
+
print(f"Validation Loss: {val_loss:.6f}. Perplexity: {np.exp(val_loss):.6f} (steps={n_steps}, tokens={n_steps*tok_per_step})")
|
configuration_dragon.py
CHANGED
|
@@ -92,6 +92,15 @@ class DragonConfig(PretrainedConfig):
|
|
| 92 |
|
| 93 |
def __init__(
|
| 94 |
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
tie_lm_head: bool = False,
|
| 96 |
mlp_type: str = "simple",
|
| 97 |
layer_norm_scaling: bool = False,
|
|
@@ -103,6 +112,7 @@ class DragonConfig(PretrainedConfig):
|
|
| 103 |
mamba3_remove_conv: bool = True,
|
| 104 |
mamba3_is_A_dd: bool = True,
|
| 105 |
mamba3_add_trapezoid: bool = True,
|
|
|
|
| 106 |
moe: bool = False,
|
| 107 |
moe_num_routed_experts: int = 2,
|
| 108 |
moe_routed_scaling_factor: float = 2.5,
|
|
@@ -116,6 +126,7 @@ class DragonConfig(PretrainedConfig):
|
|
| 116 |
shrink_qk_da: int = 2,
|
| 117 |
shrink_qk_gdn: int = 2,
|
| 118 |
mixer_gn: bool = True,
|
|
|
|
| 119 |
kda_allow_neg_eigval: bool = False,
|
| 120 |
kda_num_v_heads: Optional[int] = None,
|
| 121 |
seednorm_wd: bool = True,
|
|
@@ -197,6 +208,15 @@ class DragonConfig(PretrainedConfig):
|
|
| 197 |
mlp_linking=False,
|
| 198 |
**kwargs,
|
| 199 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
self.tie_lm_head = tie_lm_head
|
| 201 |
self.mlp_type = mlp_type
|
| 202 |
self.layer_norm_scaling = layer_norm_scaling
|
|
@@ -208,6 +228,7 @@ class DragonConfig(PretrainedConfig):
|
|
| 208 |
self.mamba3_remove_conv = mamba3_remove_conv
|
| 209 |
self.mamba3_is_A_dd = mamba3_is_A_dd
|
| 210 |
self.mamba3_add_trapezoid = mamba3_add_trapezoid
|
|
|
|
| 211 |
self.moe = moe
|
| 212 |
self.moe_num_routed_experts = moe_num_routed_experts
|
| 213 |
self.moe_routed_scaling_factor = moe_routed_scaling_factor
|
|
@@ -221,6 +242,7 @@ class DragonConfig(PretrainedConfig):
|
|
| 221 |
self.shrink_qk_da = shrink_qk_da
|
| 222 |
self.shrink_qk_gdn = shrink_qk_gdn
|
| 223 |
self.mixer_gn = mixer_gn
|
|
|
|
| 224 |
self.kda_allow_neg_eigval = kda_allow_neg_eigval
|
| 225 |
self.kda_num_v_heads = kda_num_v_heads
|
| 226 |
self.seednorm_wd = seednorm_wd
|
|
|
|
| 92 |
|
| 93 |
def __init__(
|
| 94 |
self,
|
| 95 |
+
reduce_lm_head: int = 0,
|
| 96 |
+
dataset_type: str = "hf",
|
| 97 |
+
vwn: bool = False,
|
| 98 |
+
vwn_m: int = 2,
|
| 99 |
+
vwn_n: int = 3,
|
| 100 |
+
vwn_wd_alpha_beta: bool = False,
|
| 101 |
+
vwn_dynamic: bool = True,
|
| 102 |
+
legacy_gate: bool = False,
|
| 103 |
+
init_gpt2: bool = False,
|
| 104 |
tie_lm_head: bool = False,
|
| 105 |
mlp_type: str = "simple",
|
| 106 |
layer_norm_scaling: bool = False,
|
|
|
|
| 112 |
mamba3_remove_conv: bool = True,
|
| 113 |
mamba3_is_A_dd: bool = True,
|
| 114 |
mamba3_add_trapezoid: bool = True,
|
| 115 |
+
mamba3_postgate_norm: bool = False,
|
| 116 |
moe: bool = False,
|
| 117 |
moe_num_routed_experts: int = 2,
|
| 118 |
moe_routed_scaling_factor: float = 2.5,
|
|
|
|
| 126 |
shrink_qk_da: int = 2,
|
| 127 |
shrink_qk_gdn: int = 2,
|
| 128 |
mixer_gn: bool = True,
|
| 129 |
+
gate_before_norm: bool = True,
|
| 130 |
kda_allow_neg_eigval: bool = False,
|
| 131 |
kda_num_v_heads: Optional[int] = None,
|
| 132 |
seednorm_wd: bool = True,
|
|
|
|
| 208 |
mlp_linking=False,
|
| 209 |
**kwargs,
|
| 210 |
):
|
| 211 |
+
self.reduce_lm_head = reduce_lm_head
|
| 212 |
+
self.dataset_type = dataset_type
|
| 213 |
+
self.vwn = vwn
|
| 214 |
+
self.vwn_m = vwn_m
|
| 215 |
+
self.vwn_n = vwn_n
|
| 216 |
+
self.vwn_wd_alpha_beta = vwn_wd_alpha_beta
|
| 217 |
+
self.vwn_dynamic = vwn_dynamic
|
| 218 |
+
self.legacy_gate = legacy_gate
|
| 219 |
+
self.init_gpt2 = init_gpt2
|
| 220 |
self.tie_lm_head = tie_lm_head
|
| 221 |
self.mlp_type = mlp_type
|
| 222 |
self.layer_norm_scaling = layer_norm_scaling
|
|
|
|
| 228 |
self.mamba3_remove_conv = mamba3_remove_conv
|
| 229 |
self.mamba3_is_A_dd = mamba3_is_A_dd
|
| 230 |
self.mamba3_add_trapezoid = mamba3_add_trapezoid
|
| 231 |
+
self.mamba3_postgate_norm = mamba3_postgate_norm
|
| 232 |
self.moe = moe
|
| 233 |
self.moe_num_routed_experts = moe_num_routed_experts
|
| 234 |
self.moe_routed_scaling_factor = moe_routed_scaling_factor
|
|
|
|
| 242 |
self.shrink_qk_da = shrink_qk_da
|
| 243 |
self.shrink_qk_gdn = shrink_qk_gdn
|
| 244 |
self.mixer_gn = mixer_gn
|
| 245 |
+
self.gate_before_norm = gate_before_norm
|
| 246 |
self.kda_allow_neg_eigval = kda_allow_neg_eigval
|
| 247 |
self.kda_num_v_heads = kda_num_v_heads
|
| 248 |
self.seednorm_wd = seednorm_wd
|
modeling_dragon.py
CHANGED
|
@@ -21,6 +21,11 @@ from fla.ops.nsa.parallel import parallel_nsa
|
|
| 21 |
|
| 22 |
from flash_attn.modules.mlp import GatedMlp
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
try:
|
| 25 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 26 |
except ImportError:
|
|
@@ -54,6 +59,8 @@ try:
|
|
| 54 |
except ImportError:
|
| 55 |
chunk_kda, fused_recurrent_kda, fused_kda_gate, prepare_sequence_ids = None, None, None, None
|
| 56 |
|
|
|
|
|
|
|
| 57 |
from torch.compiler import disable
|
| 58 |
|
| 59 |
logger = logging.get_logger(__name__)
|
|
@@ -268,6 +275,13 @@ class DragonLinear(nn.Linear):
|
|
| 268 |
out = super().forward(x)
|
| 269 |
return ScaledGrad.apply(out, self.alpha_fwd, self.alpha_bwd)
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
class HybridDragonDynamicCache(DynamicCache):
|
| 272 |
"""
|
| 273 |
A dynamic cache that handle both the attention cache (which has a seq_len dimension) and the GDN cache
|
|
@@ -299,6 +313,10 @@ class HybridDragonDynamicCache(DynamicCache):
|
|
| 299 |
self.q_conv_caches = []
|
| 300 |
self.k_conv_caches = []
|
| 301 |
self.v_conv_caches = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
for idx, layer_type in enumerate(config.layers_config):
|
| 304 |
if not layer_type == "r":
|
|
@@ -313,6 +331,8 @@ class HybridDragonDynamicCache(DynamicCache):
|
|
| 313 |
self.q_conv_caches.append(None)
|
| 314 |
self.k_conv_caches.append(None)
|
| 315 |
self.v_conv_caches.append(None)
|
|
|
|
|
|
|
| 316 |
|
| 317 |
self.window_size = config.sliding_window_size
|
| 318 |
self.layers_config = config.layers_config
|
|
@@ -359,6 +379,15 @@ class HybridDragonDynamicCache(DynamicCache):
|
|
| 359 |
|
| 360 |
def set_prev_hidden(self, layer_idx, h):
|
| 361 |
self.cca_prev_hidden[layer_idx] = h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
# kv shift
|
| 364 |
def get_last_kv(self, layer_idx):
|
|
@@ -568,6 +597,7 @@ class DragonAttention(nn.Module):
|
|
| 568 |
|
| 569 |
projection_dim = self.head_dim * (self.num_attention_heads + 2 * (0 if reuse_kv else self.num_key_value_heads))
|
| 570 |
self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
|
|
|
|
| 571 |
|
| 572 |
if self.config.token_shift_attn:
|
| 573 |
if self.config.scalar_proj_as_hidden_matrix:
|
|
@@ -755,6 +785,187 @@ class DragonAttention(nn.Module):
|
|
| 755 |
|
| 756 |
return attn_output, last_key_states, last_value_states
|
| 757 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
class DragonTensorProductAttention(nn.Module):
|
| 759 |
"""
|
| 760 |
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
@@ -785,6 +996,8 @@ class DragonTensorProductAttention(nn.Module):
|
|
| 785 |
self.W_A_v = DragonLinear(config, self.hidden_size, self.num_attention_heads * self.rank, bias=False)
|
| 786 |
self.W_B_k = DragonLinear(config, self.hidden_size, self.rank * self.head_dim, bias=False)
|
| 787 |
self.W_B_v = DragonLinear(config, self.hidden_size, self.rank * self.head_dim, bias=False)
|
|
|
|
|
|
|
| 788 |
|
| 789 |
if self.config.token_shift_attn:
|
| 790 |
if self.config.scalar_proj_as_hidden_matrix:
|
|
@@ -1156,6 +1369,246 @@ class DragonCompressedConvolutionalAttention(nn.Module):
|
|
| 1156 |
|
| 1157 |
return attn_output, None, None
|
| 1158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1159 |
class DragonNativeSparseAttention(nn.Module):
|
| 1160 |
"""
|
| 1161 |
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
@@ -1696,6 +2149,7 @@ class DragonDifferentialAttention(nn.Module):
|
|
| 1696 |
|
| 1697 |
projection_dim = self.head_qk_dim * self.num_attention_heads + self.head_qk_dim * self.num_key_value_heads + (self.head_v_dim * self.num_noise_heads//2)
|
| 1698 |
self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
|
|
|
|
| 1699 |
|
| 1700 |
if self.config.token_shift_attn:
|
| 1701 |
if self.config.scalar_proj_as_hidden_matrix:
|
|
@@ -2373,6 +2827,8 @@ class DragonDifferentialTensorProductAttention(nn.Module):
|
|
| 2373 |
self.W_A_v = DragonLinear(config, self.hidden_size, self.num_noise_heads * self.rank, bias=False)
|
| 2374 |
self.W_B_k = DragonLinear(config, self.hidden_size, self.rank * self.head_qk_dim, bias=False)
|
| 2375 |
self.W_B_v = DragonLinear(config, self.hidden_size, self.rank * self.head_v_dim, bias=False)
|
|
|
|
|
|
|
| 2376 |
|
| 2377 |
if self.config.token_shift_attn:
|
| 2378 |
if self.config.scalar_proj_as_hidden_matrix:
|
|
@@ -3161,12 +3617,29 @@ class DragonGatedDeltaNet(nn.Module):
|
|
| 3161 |
self.num_attention_heads*self.dk + self.n_kv_heads*self.dk + self.n_kv_heads*self.dv,
|
| 3162 |
bias=False
|
| 3163 |
)
|
|
|
|
| 3164 |
self.linear_ba = DragonLinear(
|
| 3165 |
config, config.hidden_size,
|
| 3166 |
self.num_attention_heads + self.num_attention_heads, #+ self.num_attention_heads*self.dv, # b(H), a(H), g(H*dv)
|
| 3167 |
bias=False
|
| 3168 |
)
|
| 3169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3170 |
dt_min = config.time_step_min
|
| 3171 |
dt_max = config.time_step_max
|
| 3172 |
dt_init_floor = config.time_step_floor
|
|
@@ -3181,11 +3654,13 @@ class DragonGatedDeltaNet(nn.Module):
|
|
| 3181 |
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 3182 |
with torch.no_grad():
|
| 3183 |
self.dt_bias = nn.Parameter(inv_dt)
|
|
|
|
| 3184 |
|
| 3185 |
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
| 3186 |
A = torch.empty(self.n_heads_local, dtype=torch.float32).uniform_(*A_init_range)
|
| 3187 |
A_log = torch.log(A) # Keep A_log in fp32
|
| 3188 |
self.A_log = nn.Parameter(A_log)
|
|
|
|
| 3189 |
|
| 3190 |
if self.config.rope_gdn == "rope":
|
| 3191 |
self.rope_proj = DragonLinear(config, config.hidden_size, self.dk//4, bias=False)
|
|
@@ -3348,6 +3823,11 @@ class DragonGatedDeltaNet(nn.Module):
|
|
| 3348 |
use_qk_l2norm_in_kernel=True
|
| 3349 |
) # (B L H dv)
|
| 3350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3351 |
# update GDN cache
|
| 3352 |
if cache_params is not None:
|
| 3353 |
cache_params.ssm_caches[self.layer_idx] = ssm_cache
|
|
@@ -3381,6 +3861,9 @@ class DragonKimiDeltaAttention(nn.Module):
|
|
| 3381 |
self.q_proj = DragonLinear(config, config.hidden_size, self.key_dim, bias=False)
|
| 3382 |
self.k_proj = DragonLinear(config, config.hidden_size, self.key_dim, bias=False)
|
| 3383 |
self.v_proj = DragonLinear(config, config.hidden_size, self.value_dim, bias=False)
|
|
|
|
|
|
|
|
|
|
| 3384 |
|
| 3385 |
self.q_conv1d = ShortConvolution(
|
| 3386 |
hidden_size=self.key_dim,
|
|
@@ -3413,10 +3896,21 @@ class DragonKimiDeltaAttention(nn.Module):
|
|
| 3413 |
self.A_log = nn.Parameter(torch.log(torch.empty(self.num_q_heads, dtype=torch.float32).uniform_(1, 16)))
|
| 3414 |
self.dt_bias = nn.Parameter(torch.zeros(self.key_dim, dtype=torch.float32))
|
| 3415 |
|
| 3416 |
-
|
| 3417 |
-
|
| 3418 |
-
|
| 3419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3420 |
|
| 3421 |
@disable
|
| 3422 |
def _kda_gate_call(self, g, A_log, head_k_dim, g_bias):
|
|
@@ -3427,6 +3921,7 @@ class DragonKimiDeltaAttention(nn.Module):
|
|
| 3427 |
hidden_states: torch.Tensor,
|
| 3428 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 3429 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
|
|
|
| 3430 |
**kwargs,
|
| 3431 |
):
|
| 3432 |
_, q_len, _ = hidden_states.shape
|
|
@@ -3443,20 +3938,26 @@ class DragonKimiDeltaAttention(nn.Module):
|
|
| 3443 |
conv_state_k = cache_params.k_conv_caches[self.layer_idx]
|
| 3444 |
conv_state_v = cache_params.v_conv_caches[self.layer_idx]
|
| 3445 |
|
|
|
|
|
|
|
|
|
|
| 3446 |
q, conv_state_q = self.q_conv1d(
|
| 3447 |
x=self.q_proj(hidden_states),
|
| 3448 |
cache=conv_state_q,
|
| 3449 |
output_final_state=cache_params is not None,
|
|
|
|
| 3450 |
)
|
| 3451 |
k, conv_state_k = self.k_conv1d(
|
| 3452 |
x=self.k_proj(hidden_states),
|
| 3453 |
cache=conv_state_k,
|
| 3454 |
output_final_state=cache_params is not None,
|
|
|
|
| 3455 |
)
|
| 3456 |
v, conv_state_v = self.v_conv1d(
|
| 3457 |
x=self.v_proj(hidden_states),
|
| 3458 |
cache=conv_state_v,
|
| 3459 |
output_final_state=cache_params is not None,
|
|
|
|
| 3460 |
)
|
| 3461 |
|
| 3462 |
g = self.f_proj(hidden_states)
|
|
@@ -3482,6 +3983,7 @@ class DragonKimiDeltaAttention(nn.Module):
|
|
| 3482 |
initial_state=None,
|
| 3483 |
output_final_state=cache_params is not None,
|
| 3484 |
use_qk_l2norm_in_kernel=True,
|
|
|
|
| 3485 |
)
|
| 3486 |
elif mode == 'fused_recurrent':
|
| 3487 |
o, ssm_cache = fused_recurrent_kda(
|
|
@@ -3500,6 +4002,11 @@ class DragonKimiDeltaAttention(nn.Module):
|
|
| 3500 |
#o = o * F.silu(rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim))
|
| 3501 |
# TODO: other types of gates? as well as ZCG?
|
| 3502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3503 |
if cache_params is not None:
|
| 3504 |
cache_params.ssm_caches[self.layer_idx] = ssm_cache
|
| 3505 |
cache_params.q_conv_caches[self.layer_idx] = conv_state_q
|
|
@@ -3549,8 +4056,8 @@ class DragonMamba3(nn.Module):
|
|
| 3549 |
if config.mamba3_rope:
|
| 3550 |
self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
|
| 3551 |
|
| 3552 |
-
# Order: [
|
| 3553 |
-
d_in_proj =
|
| 3554 |
|
| 3555 |
if self.config.mamba3_is_A_dd:
|
| 3556 |
self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
|
|
@@ -3575,6 +4082,7 @@ class DragonMamba3(nn.Module):
|
|
| 3575 |
self.dt_bias._no_weight_decay = True
|
| 3576 |
|
| 3577 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
|
|
|
|
| 3578 |
|
| 3579 |
self.B_bias, self.C_bias = None, None
|
| 3580 |
if not config.mamba3_remove_BC_bias:
|
|
@@ -3604,18 +4112,36 @@ class DragonMamba3(nn.Module):
|
|
| 3604 |
self.D = nn.Parameter(torch.ones(self.nheads))
|
| 3605 |
self.D._no_weight_decay = True
|
| 3606 |
|
| 3607 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3608 |
self,
|
| 3609 |
hidden_states: torch.Tensor,
|
| 3610 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
|
|
|
| 3611 |
**kwargs
|
| 3612 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3613 |
# Apply in_proj
|
| 3614 |
-
|
| 3615 |
-
|
| 3616 |
-
|
| 3617 |
[
|
| 3618 |
-
self.d_inner,
|
| 3619 |
self.d_inner + 2 * self.d_state * self.ngroups,
|
| 3620 |
self.nheads,
|
| 3621 |
],
|
|
@@ -3628,12 +4154,17 @@ class DragonMamba3(nn.Module):
|
|
| 3628 |
_A = -torch.exp(self.A_log).unsqueeze(0).unsqueeze(0)
|
| 3629 |
dt = F.softplus(dd_dt + self.dt_bias) # (B, L, N)
|
| 3630 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3631 |
if not self.config.mamba3_remove_conv:
|
| 3632 |
xBC = causal_conv1d_fn(
|
| 3633 |
x=xBC.transpose(1, 2),
|
| 3634 |
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 3635 |
bias=self.conv1d.bias,
|
| 3636 |
activation=self.activation,
|
|
|
|
| 3637 |
).transpose(1, 2) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
| 3638 |
|
| 3639 |
x, B, C = torch.split(
|
|
@@ -3699,10 +4230,6 @@ class DragonMamba3(nn.Module):
|
|
| 3699 |
|
| 3700 |
x_scalar = (gamma_arr*_alpha_arr).to(torch.bfloat16)
|
| 3701 |
|
| 3702 |
-
ssm_cache = None
|
| 3703 |
-
if cache_params is not None:
|
| 3704 |
-
ssm_cache = cache_params.ssm_caches[self.layer_idx]
|
| 3705 |
-
|
| 3706 |
out = mamba_chunk_scan_discretized_combined(
|
| 3707 |
x=x.bfloat16(),
|
| 3708 |
A=A,
|
|
@@ -3714,19 +4241,26 @@ class DragonMamba3(nn.Module):
|
|
| 3714 |
CB_sum=CB_sum,
|
| 3715 |
D=self.D,
|
| 3716 |
z=None,
|
| 3717 |
-
initial_states=ssm_cache,
|
| 3718 |
-
return_final_states=cache_params is not None,
|
|
|
|
| 3719 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3720 |
|
| 3721 |
-
if
|
| 3722 |
-
y
|
| 3723 |
-
cache_params.ssm_caches[self.layer_idx] = ssm_cache
|
| 3724 |
-
else:
|
| 3725 |
-
y = out
|
| 3726 |
-
|
| 3727 |
-
y = rearrange(y, "b l h p -> b l (h p)")
|
| 3728 |
-
y = y*self.act(z)
|
| 3729 |
-
y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads).to(x.dtype)
|
| 3730 |
|
| 3731 |
return y, None, None
|
| 3732 |
|
|
@@ -3747,6 +4281,7 @@ class DragonMamba2(nn.Module):
|
|
| 3747 |
# Order: [x, B, C, dt]
|
| 3748 |
d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
| 3749 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
|
|
|
|
| 3750 |
|
| 3751 |
if not self.config.mamba3_remove_conv:
|
| 3752 |
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
|
@@ -3784,6 +4319,15 @@ class DragonMamba2(nn.Module):
|
|
| 3784 |
self.D = nn.Parameter(torch.ones(self.nheads))
|
| 3785 |
self.D._no_weight_decay = True
|
| 3786 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3787 |
def forward(self, hidden_states, **kwargs):
|
| 3788 |
"""
|
| 3789 |
u: (B, L, D)
|
|
@@ -3830,6 +4374,12 @@ class DragonMamba2(nn.Module):
|
|
| 3830 |
initial_states=None,
|
| 3831 |
)
|
| 3832 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3833 |
return y, None, None
|
| 3834 |
|
| 3835 |
class DragonMamba3Mimo(nn.Module):
|
|
@@ -3844,11 +4394,13 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3844 |
"when creating this class."
|
| 3845 |
)
|
| 3846 |
|
|
|
|
|
|
|
| 3847 |
self.d_model = config.hidden_size
|
| 3848 |
-
self.d_state =
|
| 3849 |
self.conv_init = None
|
| 3850 |
self.expand = 2
|
| 3851 |
-
self.headdim =
|
| 3852 |
self.ngroups = config.mamba_ngroups
|
| 3853 |
self.activation = "swish"
|
| 3854 |
self.bias = False
|
|
@@ -3863,14 +4415,12 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3863 |
self.dt_init_floor = 1e-4
|
| 3864 |
self.mimo_dim = config.mamba_mimo_dim
|
| 3865 |
self.mimo_proj_block_order = 1
|
| 3866 |
-
|
| 3867 |
|
| 3868 |
self.d_inner = int(self.expand * self.d_model)
|
| 3869 |
assert self.d_inner % self.headdim == 0
|
| 3870 |
self.nheads = self.d_inner // self.headdim
|
| 3871 |
self.dr_out_dim = self.d_inner // self.mimo_proj_block_order
|
| 3872 |
|
| 3873 |
-
|
| 3874 |
self.split_tensor_size = int(self.d_state * self.rope_fraction)
|
| 3875 |
if self.split_tensor_size % 2 != 0:
|
| 3876 |
self.split_tensor_size -= 1
|
|
@@ -3896,6 +4446,7 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3896 |
self.dt_bias._no_weight_decay = True
|
| 3897 |
|
| 3898 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
|
|
|
|
| 3899 |
|
| 3900 |
self.B_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
|
| 3901 |
self.C_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
|
|
@@ -3927,11 +4478,14 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3927 |
self.in_proj_mimo_z = nn.Parameter(in_proj_mimo_z_init_weights, requires_grad=True)
|
| 3928 |
self.out_proj_mimo = nn.Parameter(out_proj_mimo_init_weights, requires_grad=True)
|
| 3929 |
|
| 3930 |
-
|
| 3931 |
# D "skip" parameter
|
| 3932 |
self.D = nn.Parameter(torch.ones(self.nheads))
|
| 3933 |
self.D._no_weight_decay = True
|
| 3934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3935 |
def forward(self, hidden_states, **kwargs):
|
| 3936 |
# Apply in_proj
|
| 3937 |
zxBCdt = self.in_proj(hidden_states)
|
|
@@ -4024,7 +4578,7 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 4024 |
_beta_arr = torch.roll(beta_arr, shifts=-1, dims=1)
|
| 4025 |
|
| 4026 |
x_scalar = (gamma_arr*_alpha_arr + _beta_arr).to(torch.bfloat16)
|
| 4027 |
-
|
| 4028 |
z = rearrange(z, "b l r (h p) -> b l r h p", p=self.headdim)
|
| 4029 |
|
| 4030 |
y = mamba_mimo_chunk_scan_discretized_fused_combined(
|
|
@@ -4037,10 +4591,15 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 4037 |
gamma=gamma_arr,
|
| 4038 |
CB_sum=CB_sum,
|
| 4039 |
D=self.D,
|
| 4040 |
-
z=z,
|
| 4041 |
)
|
| 4042 |
|
| 4043 |
y = rearrange(y, "b l r h p -> b l r (h p)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4044 |
#if seqlen_og is not None:
|
| 4045 |
# y = rearrange(y, "b l r d -> (b l) r d")
|
| 4046 |
|
|
@@ -4067,7 +4626,9 @@ class DragonMLP(nn.Module):
|
|
| 4067 |
self.lambda1 = nn.Parameter(torch.zeros(self.link_size)) # sigmoid->0.5
|
| 4068 |
else :
|
| 4069 |
self.fc_1 = DragonLinear(config, config.hidden_size, intermediate_size, bias=False)
|
|
|
|
| 4070 |
self.fc_2 = DragonLinear(config, intermediate_size, config.hidden_size, bias=False)
|
|
|
|
| 4071 |
self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
|
| 4072 |
|
| 4073 |
def forward(self, hidden_states):
|
|
@@ -4096,7 +4657,9 @@ class DragonGatedMLP(nn.Module):
|
|
| 4096 |
self.intermediate_size = intermediate_size
|
| 4097 |
|
| 4098 |
self.fc_1 = DragonLinear(config, config.hidden_size, num_active_experts*self.intermediate_size, bias=False)
|
|
|
|
| 4099 |
self.fc_2 = DragonLinear(config, num_active_experts*self.intermediate_size, config.hidden_size, bias=False)
|
|
|
|
| 4100 |
self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
|
| 4101 |
|
| 4102 |
def forward(self, hidden_states, gates):
|
|
@@ -4174,6 +4737,11 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4174 |
head_dim = self.mixer.head_dim
|
| 4175 |
num_attention_heads = self.mixer.num_q_heads
|
| 4176 |
use_gate = config.gate_attn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4177 |
elif layer_type == 'n':
|
| 4178 |
self.mixer = DragonNativeSparseAttention(config, reuse_kv=False, layer_idx=layer_idx)
|
| 4179 |
head_dim = self.mixer.head_dim
|
|
@@ -4203,7 +4771,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4203 |
self.mixer = DragonMamba3(config, layer_idx=layer_idx)
|
| 4204 |
head_dim = self.mixer.headdim
|
| 4205 |
num_attention_heads = self.mixer.nheads
|
| 4206 |
-
use_gate =
|
| 4207 |
elif layer_type == '2':
|
| 4208 |
self.mixer = DragonMamba2(config, layer_idx=layer_idx)
|
| 4209 |
head_dim = self.mixer.headdim
|
|
@@ -4214,6 +4782,11 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4214 |
head_dim = self.mixer.headdim
|
| 4215 |
num_attention_heads = self.mixer.nheads
|
| 4216 |
use_gate = False # inside Mamba3Mimo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4217 |
else:
|
| 4218 |
raise ValueError(f"Unknown layer type: {layer_type}")
|
| 4219 |
|
|
@@ -4233,6 +4806,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4233 |
self.gate_proj.is_scalar_weight = True
|
| 4234 |
else:
|
| 4235 |
raise ValueError(f"Unknown gate_type: {self.config.gate_type}")
|
|
|
|
| 4236 |
if self.config.zero_centered_gate:
|
| 4237 |
val = 1.
|
| 4238 |
if self.config.zero_centered_gate_type==3:
|
|
@@ -4253,6 +4827,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4253 |
self.use_gate = use_gate
|
| 4254 |
|
| 4255 |
self.mixer_proj = DragonLinear(config, head_dim*num_attention_heads, config.hidden_size, bias=False)
|
|
|
|
| 4256 |
if config.mixer_gn:
|
| 4257 |
self.mixer_group_norm = DragonHeadWiseRMSNorm(n_heads=num_attention_heads, d_head=head_dim, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
|
| 4258 |
|
|
@@ -4299,6 +4874,8 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4299 |
cu_seqlens=cu_seqlens,
|
| 4300 |
max_seqlen=max_seqlen,
|
| 4301 |
) # (B, L, E*D)
|
|
|
|
|
|
|
| 4302 |
if self.use_gate:
|
| 4303 |
if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
|
| 4304 |
g_proj = self.gate_proj(hidden_states).view(hidden_states.size(0), hidden_states.size(1), self.num_attention_heads, self.head_dim).to(y_mixer.dtype)
|
|
@@ -4313,7 +4890,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4313 |
y_mixer = y_mixer * (self.gate_act(g_proj) + self.gate_bias)
|
| 4314 |
elif self.config.zero_centered_gate_type == 3 or self.config.zero_centered_gate_type == 4:
|
| 4315 |
y_mixer = y_mixer * self.gate_act(g_proj + self.gate_bias)
|
| 4316 |
-
if self.config.mixer_gn:
|
| 4317 |
y_mixer = self.mixer_group_norm(y_mixer)
|
| 4318 |
y_mixer = y_mixer.view(y_mixer.size(0), y_mixer.size(1), -1)
|
| 4319 |
y_mixer = self.mixer_proj(y_mixer)
|
|
@@ -4327,6 +4904,282 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4327 |
|
| 4328 |
return hidden_states, last_key_states, last_value_states
|
| 4329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4330 |
class DragonBlock(GradientCheckpointingLayer):
|
| 4331 |
def __init__(self, config: DragonConfig, layer_idx: int, layer_type: str):
|
| 4332 |
super().__init__()
|
|
@@ -4412,13 +5265,13 @@ class DragonPreTrainedModel(PreTrainedModel):
|
|
| 4412 |
"attentions": DragonBlock,
|
| 4413 |
}
|
| 4414 |
|
| 4415 |
-
def _init_weights(self, module):
|
| 4416 |
if isinstance(module, (DragonLinear, nn.Conv1d)):
|
| 4417 |
if module.bias is not None:
|
| 4418 |
nn.init.zeros_(module.bias)
|
| 4419 |
-
nn.init.normal_(module.weight, mean=0., std=
|
| 4420 |
elif isinstance(module, nn.Embedding):
|
| 4421 |
-
nn.init.normal_(module.weight, mean=0., std=
|
| 4422 |
|
| 4423 |
@dataclass
|
| 4424 |
class DragonOutput(ModelOutput):
|
|
@@ -4473,19 +5326,31 @@ class DragonModel(DragonPreTrainedModel):
|
|
| 4473 |
self.vocab_size = config.vocab_size
|
| 4474 |
|
| 4475 |
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 4476 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4477 |
|
| 4478 |
if self.config.rope_type_global != '' or self.config.rope_type_local != '':
|
| 4479 |
self.rotary_emb = DragonRotaryEmbedding(config, head_dim=config.head_dim if config.head_dim else (config.expand_factor*config.hidden_size)//config.num_attention_heads, theta=config.rope_theta_local) # only for SWA
|
| 4480 |
else:
|
| 4481 |
self.rotary_emb = None
|
| 4482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4483 |
if self.config.final_norm:
|
| 4484 |
self.final_norm = DragonNorm(config, config.hidden_size)
|
| 4485 |
|
| 4486 |
self.gradient_checkpointing = False
|
| 4487 |
self.post_init()
|
| 4488 |
-
|
| 4489 |
def get_input_embeddings(self):
|
| 4490 |
return self.embedding
|
| 4491 |
|
|
@@ -4514,6 +5379,8 @@ class DragonModel(DragonPreTrainedModel):
|
|
| 4514 |
|
| 4515 |
if inputs_embeds is None:
|
| 4516 |
inputs_embeds = self.embedding(input_ids)
|
|
|
|
|
|
|
| 4517 |
|
| 4518 |
if self.config.patch_level_training:
|
| 4519 |
# (B, KL, D) => (B, L, D) OR (B, L, D) ==> (B, L//K, D)
|
|
@@ -4570,12 +5437,21 @@ class DragonModel(DragonPreTrainedModel):
|
|
| 4570 |
)
|
| 4571 |
shared_kv = (last_k, last_v)
|
| 4572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4573 |
if self.config.final_norm:
|
| 4574 |
hidden_states = self.final_norm(hidden_states)
|
| 4575 |
|
| 4576 |
if output_hidden_states:
|
| 4577 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 4578 |
|
|
|
|
|
|
|
|
|
|
| 4579 |
return DragonOutput(
|
| 4580 |
last_hidden_state=hidden_states,
|
| 4581 |
past_key_values=past_key_values if use_cache else None,
|
|
@@ -4589,11 +5465,23 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
|
|
| 4589 |
self.config = config
|
| 4590 |
self.model = DragonModel(config)
|
| 4591 |
self.vocab_size = config.vocab_size
|
| 4592 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4593 |
self.post_init()
|
| 4594 |
if config.tie_lm_head:
|
| 4595 |
self.lm_head.weight = self.model.embedding.weight
|
| 4596 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4597 |
def forward(
|
| 4598 |
self,
|
| 4599 |
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -4639,7 +5527,10 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
|
|
| 4639 |
labels = labels.to(hidden_states.device)
|
| 4640 |
|
| 4641 |
if linear_cross_entropy is None or not self.config.fused_loss_computation:
|
| 4642 |
-
|
|
|
|
|
|
|
|
|
|
| 4643 |
if not self.config.patch_level_training:
|
| 4644 |
shift_logits = logits[..., :-1, :].contiguous()
|
| 4645 |
shift_labels = labels[..., 1:].contiguous()
|
|
@@ -4653,6 +5544,7 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
|
|
| 4653 |
loss = loss + F.nll_loss(log_probs, shift_labels[:, i])
|
| 4654 |
loss = loss / self.config.patch_level_training_size
|
| 4655 |
else:
|
|
|
|
| 4656 |
assert not self.config.patch_level_training, "Fused loss computation is not supported with patch-level training."
|
| 4657 |
loss = linear_cross_entropy(
|
| 4658 |
hidden_states[:, slice_indices, :].view(-1, hidden_states.size(-1)),
|
|
|
|
| 21 |
|
| 22 |
from flash_attn.modules.mlp import GatedMlp
|
| 23 |
|
| 24 |
+
try:
|
| 25 |
+
import flash_moba
|
| 26 |
+
except ImportError:
|
| 27 |
+
flash_moba = None
|
| 28 |
+
|
| 29 |
try:
|
| 30 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 31 |
except ImportError:
|
|
|
|
| 59 |
except ImportError:
|
| 60 |
chunk_kda, fused_recurrent_kda, fused_kda_gate, prepare_sequence_ids = None, None, None, None
|
| 61 |
|
| 62 |
+
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
|
| 63 |
+
|
| 64 |
from torch.compiler import disable
|
| 65 |
|
| 66 |
logger = logging.get_logger(__name__)
|
|
|
|
| 275 |
out = super().forward(x)
|
| 276 |
return ScaledGrad.apply(out, self.alpha_fwd, self.alpha_bwd)
|
| 277 |
|
| 278 |
+
class DragonScale(nn.Module):
|
| 279 |
+
def __init__(self, s: float):
|
| 280 |
+
super().__init__()
|
| 281 |
+
self.s = s
|
| 282 |
+
def forward(self, x):
|
| 283 |
+
return x * self.s
|
| 284 |
+
|
| 285 |
class HybridDragonDynamicCache(DynamicCache):
|
| 286 |
"""
|
| 287 |
A dynamic cache that handle both the attention cache (which has a seq_len dimension) and the GDN cache
|
|
|
|
| 313 |
self.q_conv_caches = []
|
| 314 |
self.k_conv_caches = []
|
| 315 |
self.v_conv_caches = []
|
| 316 |
+
# cca v2
|
| 317 |
+
self.conv_states = []
|
| 318 |
+
self.prev_hs = []
|
| 319 |
+
self.has_previous_state = False
|
| 320 |
|
| 321 |
for idx, layer_type in enumerate(config.layers_config):
|
| 322 |
if not layer_type == "r":
|
|
|
|
| 331 |
self.q_conv_caches.append(None)
|
| 332 |
self.k_conv_caches.append(None)
|
| 333 |
self.v_conv_caches.append(None)
|
| 334 |
+
self.conv_states.append(None)
|
| 335 |
+
self.prev_hs.append(None)
|
| 336 |
|
| 337 |
self.window_size = config.sliding_window_size
|
| 338 |
self.layers_config = config.layers_config
|
|
|
|
| 379 |
|
| 380 |
def set_prev_hidden(self, layer_idx, h):
|
| 381 |
self.cca_prev_hidden[layer_idx] = h
|
| 382 |
+
|
| 383 |
+
# cca v2
|
| 384 |
+
def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor:
|
| 385 |
+
if not self.has_previous_state:
|
| 386 |
+
self.conv_states[layer_idx] = new_conv_state#.to(self.conv_states.device)
|
| 387 |
+
else:
|
| 388 |
+
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
|
| 389 |
+
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :]#.to(self.conv_states.device)
|
| 390 |
+
return self.conv_states[layer_idx]
|
| 391 |
|
| 392 |
# kv shift
|
| 393 |
def get_last_kv(self, layer_idx):
|
|
|
|
| 597 |
|
| 598 |
projection_dim = self.head_dim * (self.num_attention_heads + 2 * (0 if reuse_kv else self.num_key_value_heads))
|
| 599 |
self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
|
| 600 |
+
self.linear_qkv.norm_case_1 = True
|
| 601 |
|
| 602 |
if self.config.token_shift_attn:
|
| 603 |
if self.config.scalar_proj_as_hidden_matrix:
|
|
|
|
| 785 |
|
| 786 |
return attn_output, last_key_states, last_value_states
|
| 787 |
|
| 788 |
+
class DragonMoBAttention(nn.Module):
|
| 789 |
+
def __init__(self, config: DragonConfig, reuse_kv: bool, layer_idx: Optional[int], **kwargs):
|
| 790 |
+
super().__init__()
|
| 791 |
+
self.config = config
|
| 792 |
+
self.layer_idx = layer_idx
|
| 793 |
+
if layer_idx is None:
|
| 794 |
+
logger.warning_once(
|
| 795 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 796 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 797 |
+
"when creating this class."
|
| 798 |
+
)
|
| 799 |
+
self.num_attention_heads = config.num_attention_heads
|
| 800 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 801 |
+
self.hidden_size = config.hidden_size
|
| 802 |
+
self.head_dim = config.head_dim # if config.head_dim else config.hidden_size * config.expand_factor // self.num_attention_heads
|
| 803 |
+
self.qk_norm = config.qk_norm
|
| 804 |
+
self.window_size = config.sliding_window_size
|
| 805 |
+
self.block_size = config.nsa_block_size
|
| 806 |
+
self.topk = config.nsa_topk
|
| 807 |
+
self.reuse_kv = reuse_kv
|
| 808 |
+
|
| 809 |
+
projection_dim = self.head_dim * (self.num_attention_heads + 2 * (0 if reuse_kv else self.num_key_value_heads))
|
| 810 |
+
self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
|
| 811 |
+
self.linear_qkv.norm_case_1 = True
|
| 812 |
+
|
| 813 |
+
if self.config.token_shift_attn:
|
| 814 |
+
if self.config.scalar_proj_as_hidden_matrix:
|
| 815 |
+
self.shift_proj_k = DragonLinear(config, self.hidden_size, self.num_key_value_heads, bias=False)
|
| 816 |
+
self.shift_proj_v = DragonLinear(config, self.hidden_size, self.num_key_value_heads, bias=False)
|
| 817 |
+
else:
|
| 818 |
+
self.shift_proj_k = DragonLinear(config, self.hidden_size, self.num_key_value_heads, bias=False, alpha_bwd=1., alpha_fwd=1.)
|
| 819 |
+
self.shift_proj_v = DragonLinear(config, self.hidden_size, self.num_key_value_heads, bias=False, alpha_bwd=1., alpha_fwd=1.)
|
| 820 |
+
self.shift_proj_k.is_scalar_weight = True
|
| 821 |
+
self.shift_proj_v.is_scalar_weight = True
|
| 822 |
+
|
| 823 |
+
if self.config.token_conv1d_attn:
|
| 824 |
+
self.conv_size = config.conv_kernel
|
| 825 |
+
self.conv_dim = self.num_attention_heads * self.head_dim + self.num_key_value_heads * self.head_dim + self.num_key_value_heads * self.head_dim
|
| 826 |
+
self.qkv_conv1d = nn.Conv1d(in_channels=self.conv_dim, out_channels=self.conv_dim, bias=False, kernel_size=self.conv_size, groups=self.conv_dim, padding=self.conv_size-1)
|
| 827 |
+
self.causal_conv1d_fn = causal_conv1d_fn
|
| 828 |
+
self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
|
| 829 |
+
|
| 830 |
+
if self.qk_norm:
|
| 831 |
+
self.q_norm = DragonNorm(config, self.head_dim)
|
| 832 |
+
if not reuse_kv:
|
| 833 |
+
self.k_norm = DragonNorm(config, self.head_dim)
|
| 834 |
+
|
| 835 |
+
def forward(
|
| 836 |
+
self,
|
| 837 |
+
hidden_states: torch.Tensor,
|
| 838 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 839 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 840 |
+
cache_params: Optional[HybridDragonDynamicCache] = None,
|
| 841 |
+
key_value_last_layer: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 842 |
+
**kwargs,
|
| 843 |
+
):
|
| 844 |
+
_, q_len, _ = hidden_states.shape
|
| 845 |
+
use_precomputed_states = (cache_params is not None and q_len == 1)
|
| 846 |
+
|
| 847 |
+
# Q, K, V projections.
|
| 848 |
+
if not self.reuse_kv:
|
| 849 |
+
query_states, key_states, value_states = get_query_key_value_tensors(self, hidden_states)
|
| 850 |
+
else:
|
| 851 |
+
query_states = get_query_key_value_tensors(self, hidden_states)
|
| 852 |
+
key_states, value_states = key_value_last_layer
|
| 853 |
+
last_key_states, last_value_states = None, None
|
| 854 |
+
|
| 855 |
+
# token-shift.
|
| 856 |
+
if self.config.token_shift_attn and not self.reuse_kv:
|
| 857 |
+
alpha_k = torch.sigmoid(self.shift_proj_k(hidden_states).float()).float().to(key_states.dtype).unsqueeze(-1) # (B, L, Hkv, 1)
|
| 858 |
+
alpha_v = torch.sigmoid(self.shift_proj_v(hidden_states).float()).float().to(value_states.dtype).unsqueeze(-1) # (B, L, Hkv, 1)
|
| 859 |
+
|
| 860 |
+
if cache_params is not None:
|
| 861 |
+
k_prev, v_prev = cache_params.get_last_kv(self.layer_idx)
|
| 862 |
+
if k_prev is None:
|
| 863 |
+
k_prev, v_prev = torch.zeros_like(key_states[:, :1]), torch.zeros_like(value_states[:, :1])
|
| 864 |
+
cache_params.set_last_kv(self.layer_idx, key_states[:, -1:], value_states[:, -1:])
|
| 865 |
+
else:
|
| 866 |
+
k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 867 |
+
v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 868 |
+
|
| 869 |
+
key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
|
| 870 |
+
value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
|
| 871 |
+
|
| 872 |
+
# conv.
|
| 873 |
+
if self.config.token_conv1d_attn:
|
| 874 |
+
assert not self.reuse_kv, "not supported"
|
| 875 |
+
# --- pack for conv ---
|
| 876 |
+
q_proj = rearrange(query_states, "b l h d -> b l (h d)")
|
| 877 |
+
k_proj = rearrange(key_states, "b l g d -> b l (g d)")
|
| 878 |
+
v_proj = rearrange(value_states, "b l g d -> b l (g d)")
|
| 879 |
+
mixed_qkv = torch.cat([q_proj, k_proj, v_proj], dim=-1).transpose(1, 2) # (B,C,L)
|
| 880 |
+
|
| 881 |
+
if cache_params is not None:
|
| 882 |
+
conv_cache = cache_params.conv_caches[self.layer_idx]
|
| 883 |
+
|
| 884 |
+
if use_precomputed_states:
|
| 885 |
+
mixed_qkv = self.causal_conv1d_update(
|
| 886 |
+
mixed_qkv,
|
| 887 |
+
conv_cache,
|
| 888 |
+
self.qkv_conv1d.weight.squeeze(1),
|
| 889 |
+
self.qkv_conv1d.bias,
|
| 890 |
+
'silu',
|
| 891 |
+
) # conv_cache is updated in-place here
|
| 892 |
+
else:
|
| 893 |
+
if cache_params is not None:
|
| 894 |
+
conv_cache = F.pad(mixed_qkv, (self.conv_size - mixed_qkv.shape[-1], 0))
|
| 895 |
+
cache_params.conv_caches[self.layer_idx] = conv_cache
|
| 896 |
+
if self.causal_conv1d_fn is not None:
|
| 897 |
+
mixed_qkv = self.causal_conv1d_fn(
|
| 898 |
+
x=mixed_qkv,
|
| 899 |
+
weight=self.qkv_conv1d.weight.squeeze(1),
|
| 900 |
+
bias=self.qkv_conv1d.bias,
|
| 901 |
+
activation='silu',
|
| 902 |
+
seq_idx=None,
|
| 903 |
+
)
|
| 904 |
+
else:
|
| 905 |
+
mixed_qkv = F.silu(self.qkv_conv1d(mixed_qkv)[:, :, :q_len])
|
| 906 |
+
|
| 907 |
+
# split back
|
| 908 |
+
mixed_qkv = mixed_qkv.transpose(1, 2)
|
| 909 |
+
q_proj, k_proj, v_proj = torch.split(
|
| 910 |
+
mixed_qkv,
|
| 911 |
+
[self.num_attention_heads*self.head_dim, self.num_key_value_heads*self.head_dim, self.num_key_value_heads*self.head_dim],
|
| 912 |
+
dim=-1,
|
| 913 |
+
)
|
| 914 |
+
query_states = rearrange(q_proj, "b l (h d) -> b l h d", h=self.num_attention_heads)
|
| 915 |
+
key_states = rearrange(k_proj, "b l (g d) -> b l g d", g=self.num_key_value_heads)
|
| 916 |
+
value_states = rearrange(v_proj, "b l (g d) -> b l g d", g=self.num_key_value_heads)
|
| 917 |
+
|
| 918 |
+
# QK-norm.
|
| 919 |
+
if self.qk_norm:
|
| 920 |
+
query_states = self.q_norm(query_states)
|
| 921 |
+
if not self.reuse_kv:
|
| 922 |
+
key_states = self.k_norm(key_states)
|
| 923 |
+
|
| 924 |
+
# RoPE.
|
| 925 |
+
if self.config.rope_theta_local > 0.0:
|
| 926 |
+
cos, sin = position_embeddings
|
| 927 |
+
if self.config.rope_type_local == "rope":
|
| 928 |
+
query_states = apply_rotary_emb(query_states, cos, sin)
|
| 929 |
+
if not self.reuse_kv:
|
| 930 |
+
key_states = apply_rotary_emb(key_states, cos, sin)
|
| 931 |
+
elif self.config.rope_type_local == "p-rope":
|
| 932 |
+
query_states = apply_p_rotary_emb(query_states, cos, sin, p=0.5)
|
| 933 |
+
if not self.reuse_kv:
|
| 934 |
+
key_states = apply_p_rotary_emb(key_states, cos, sin)
|
| 935 |
+
else:
|
| 936 |
+
raise ValueError(f"Unknow rope type : {self.config.rope_type_local}")
|
| 937 |
+
|
| 938 |
+
# KV-cache.
|
| 939 |
+
if not self.reuse_kv and cache_params is not None:
|
| 940 |
+
key_states, value_states = cache_params.update(key_states, value_states, self.layer_idx)
|
| 941 |
+
|
| 942 |
+
# save k,v for next layer (*after* norm and RoPE and kv-cache update)
|
| 943 |
+
if not self.reuse_kv:
|
| 944 |
+
last_key_states, last_value_states = key_states, value_states
|
| 945 |
+
|
| 946 |
+
# attention computation.
|
| 947 |
+
B, L, _, _ = query_states.shape
|
| 948 |
+
cu_seqlens = torch.arange(0, (B + 1) * L, step=L, dtype=torch.int32, device=query_states.device)
|
| 949 |
+
attn_output = flash_moba.flash_moba_varlen_func(
|
| 950 |
+
q=query_states.bfloat16().view(B*L, self.num_attention_heads, self.head_dim),
|
| 951 |
+
k=key_states.bfloat16().view(B*L, self.num_key_value_heads, self.head_dim),
|
| 952 |
+
v=value_states.bfloat16().view(B*L, self.num_key_value_heads, self.head_dim),
|
| 953 |
+
cu_seqlens_q=cu_seqlens,
|
| 954 |
+
cu_seqlens_k=cu_seqlens,
|
| 955 |
+
max_seqlen_q=L,
|
| 956 |
+
max_seqlen_k=L,
|
| 957 |
+
moba_chunk_size=self.block_size,
|
| 958 |
+
moba_topk=self.topk,
|
| 959 |
+
causal=True,
|
| 960 |
+
).view(B, L, self.num_attention_heads, self.head_dim)
|
| 961 |
+
# softmax scale...
|
| 962 |
+
# softcap...
|
| 963 |
+
|
| 964 |
+
#if cache_params is not None and not self.reuse_kv:
|
| 965 |
+
# cache_params.trim(self.layer_idx)
|
| 966 |
+
|
| 967 |
+
return attn_output, last_key_states, last_value_states
|
| 968 |
+
|
| 969 |
class DragonTensorProductAttention(nn.Module):
|
| 970 |
"""
|
| 971 |
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
|
|
| 996 |
self.W_A_v = DragonLinear(config, self.hidden_size, self.num_attention_heads * self.rank, bias=False)
|
| 997 |
self.W_B_k = DragonLinear(config, self.hidden_size, self.rank * self.head_dim, bias=False)
|
| 998 |
self.W_B_v = DragonLinear(config, self.hidden_size, self.rank * self.head_dim, bias=False)
|
| 999 |
+
self.c_q.norm_case_1 = True
|
| 1000 |
+
# todo : norm others?
|
| 1001 |
|
| 1002 |
if self.config.token_shift_attn:
|
| 1003 |
if self.config.scalar_proj_as_hidden_matrix:
|
|
|
|
| 1369 |
|
| 1370 |
return attn_output, None, None
|
| 1371 |
|
| 1372 |
+
class DragonCompressedConvolutionalAttention2(nn.Module):
|
| 1373 |
+
def __init__(self, config: DragonConfig, layer_idx: Optional[int], **kwargs):
|
| 1374 |
+
super().__init__()
|
| 1375 |
+
self.config = config
|
| 1376 |
+
assert layer_idx is not None
|
| 1377 |
+
self.layer_idx = layer_idx
|
| 1378 |
+
if layer_idx is None:
|
| 1379 |
+
logger.warning_once(
|
| 1380 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 1381 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 1382 |
+
"when creating this class."
|
| 1383 |
+
)
|
| 1384 |
+
|
| 1385 |
+
self.hidden_size = config.hidden_size
|
| 1386 |
+
self.window_size = config.sliding_window_size
|
| 1387 |
+
|
| 1388 |
+
self.cca_time0 = 2
|
| 1389 |
+
self.cca_time1 = 2
|
| 1390 |
+
self.padding0 = self.cca_time0 - 1
|
| 1391 |
+
self.padding1 = self.cca_time1 - 1
|
| 1392 |
+
self.total_padding = self.padding0 + self.padding1
|
| 1393 |
+
|
| 1394 |
+
self.num_kv_heads = 5 # config.num_key_value_heads
|
| 1395 |
+
self.num_q_heads = 10 # config.num_attention_heads
|
| 1396 |
+
self.num_heads = config.num_attention_heads
|
| 1397 |
+
|
| 1398 |
+
# Geometry
|
| 1399 |
+
self.head_dim = config.head_dim
|
| 1400 |
+
self.latent_k_dim = self.num_kv_heads * self.head_dim
|
| 1401 |
+
self.latent_q_dim = self.num_q_heads * self.head_dim
|
| 1402 |
+
self.sqrt_head_dim = float(math.sqrt(self.head_dim))
|
| 1403 |
+
self.gqa_groups = self.num_q_heads // self.num_kv_heads
|
| 1404 |
+
assert self.num_q_heads % self.num_kv_heads == 0, "q_heads must be a multiple of k_heads"
|
| 1405 |
+
assert (self.latent_k_dim + self.latent_q_dim) == (self.num_kv_heads + self.num_q_heads) * self.head_dim
|
| 1406 |
+
|
| 1407 |
+
# Projections
|
| 1408 |
+
self.linear_q = nn.Linear(self.hidden_size, self.latent_q_dim, bias=self.config.attention_bias)
|
| 1409 |
+
self.linear_k = nn.Linear(self.hidden_size, self.latent_k_dim, bias=self.config.attention_bias)
|
| 1410 |
+
self.val_proj1 = nn.Linear(self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias)
|
| 1411 |
+
self.val_proj2 = nn.Linear(self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias)
|
| 1412 |
+
|
| 1413 |
+
# Depthwise + grouped conv along sequence
|
| 1414 |
+
in_out_ch = self.latent_k_dim + self.latent_q_dim
|
| 1415 |
+
self.conv_qk = nn.Sequential(
|
| 1416 |
+
nn.Conv1d(
|
| 1417 |
+
in_channels=in_out_ch,
|
| 1418 |
+
out_channels=in_out_ch,
|
| 1419 |
+
kernel_size=self.cca_time0,
|
| 1420 |
+
groups=in_out_ch,
|
| 1421 |
+
padding=0,
|
| 1422 |
+
stride=1,
|
| 1423 |
+
),
|
| 1424 |
+
nn.Conv1d(
|
| 1425 |
+
in_channels=in_out_ch,
|
| 1426 |
+
out_channels=in_out_ch,
|
| 1427 |
+
kernel_size=self.cca_time1,
|
| 1428 |
+
groups=(self.num_kv_heads + self.num_q_heads),
|
| 1429 |
+
padding=0,
|
| 1430 |
+
stride=1,
|
| 1431 |
+
),
|
| 1432 |
+
)
|
| 1433 |
+
|
| 1434 |
+
# Per-k head temperature
|
| 1435 |
+
self.temp = nn.Parameter(torch.zeros(self.num_kv_heads))
|
| 1436 |
+
|
| 1437 |
+
def forward(
|
| 1438 |
+
self,
|
| 1439 |
+
hidden_states: torch.Tensor,
|
| 1440 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 1441 |
+
cache_params: Optional[HybridDragonDynamicCache],
|
| 1442 |
+
**kwargs,
|
| 1443 |
+
):
|
| 1444 |
+
"""
|
| 1445 |
+
hidden_states: [B, S, E] (HF layout)
|
| 1446 |
+
returns:
|
| 1447 |
+
query: [B, S, num_q_heads*head_dim]
|
| 1448 |
+
key : [B, S, num_k_heads*head_dim]
|
| 1449 |
+
value: [B, S, num_k_heads*head_dim]
|
| 1450 |
+
"""
|
| 1451 |
+
|
| 1452 |
+
past_key_values = cache_params
|
| 1453 |
+
batch_size, seq_length, _ = hidden_states.shape
|
| 1454 |
+
|
| 1455 |
+
# ---- Switch to [S, B, H] ----
|
| 1456 |
+
hs = hidden_states.transpose(0, 1).contiguous() # [S, B, H]
|
| 1457 |
+
# Time-shifted stream for v2 (pad one at the front along sequence)
|
| 1458 |
+
hs_d = F.pad(hs[:-1], pad=(0, 0, 0, 0, 1, 0)) # [S, B, H]
|
| 1459 |
+
|
| 1460 |
+
# Q/K in the full space
|
| 1461 |
+
q = self.linear_q(hs) # [S, B, latent_q_dim]
|
| 1462 |
+
k = self.linear_k(hs) # [S, B, latent_k_dim]
|
| 1463 |
+
qk_packed0 = torch.cat([q, k], dim=-1) # [S, B, latent_q + latent_k]
|
| 1464 |
+
|
| 1465 |
+
# Pre-mean tensors in head form (for "qk_mean_{q,k}" calc)
|
| 1466 |
+
query_pre = qk_packed0[..., : self.latent_q_dim].view(
|
| 1467 |
+
*qk_packed0.shape[:2], self.num_q_heads, self.head_dim
|
| 1468 |
+
) # [S, B, qh, dh]
|
| 1469 |
+
|
| 1470 |
+
key_pre = qk_packed0[..., self.latent_q_dim :].view(
|
| 1471 |
+
*qk_packed0.shape[:2], self.num_kv_heads, self.head_dim
|
| 1472 |
+
) # [S, B, kh, dh]
|
| 1473 |
+
key_pre = (
|
| 1474 |
+
key_pre.unsqueeze(-2)
|
| 1475 |
+
.repeat(1, 1, 1, self.gqa_groups, 1)
|
| 1476 |
+
.view(*qk_packed0.shape[:2], self.num_q_heads, self.head_dim)
|
| 1477 |
+
) # [S, B, qh, dh]
|
| 1478 |
+
|
| 1479 |
+
# Means for residual mixing
|
| 1480 |
+
qk_mean_q = (query_pre + key_pre) / 2
|
| 1481 |
+
qk_mean_k = qk_mean_q.view(*qk_mean_q.shape[:2], self.num_kv_heads, self.gqa_groups, -1).mean(dim=-2)
|
| 1482 |
+
|
| 1483 |
+
if past_key_values is not None:
|
| 1484 |
+
if past_key_values.has_previous_state:
|
| 1485 |
+
# Generation
|
| 1486 |
+
qk_packed0 = qk_packed0.transpose(0, 1) # [B, 1, H]
|
| 1487 |
+
qk_packed0_cached = past_key_values.conv_states[self.layer_idx] # [B, H, 2]
|
| 1488 |
+
qk_packed0_cat = torch.cat([qk_packed0_cached, qk_packed0.transpose(1, 2)], dim=-1) # [B, H, 3]
|
| 1489 |
+
qk_packed3 = self.conv_qk(qk_packed0_cat).permute(2, 0, 1) # [S, B, E]
|
| 1490 |
+
qk_packed0_cache = past_key_values.update_conv_state(
|
| 1491 |
+
layer_idx=self.layer_idx, new_conv_state=qk_packed0
|
| 1492 |
+
) # [B, H, 2]
|
| 1493 |
+
|
| 1494 |
+
else:
|
| 1495 |
+
# Prefill
|
| 1496 |
+
qk_packed0_transposed = qk_packed0.permute(1, 2, 0) # [S, B, H] -> [B, H, S]
|
| 1497 |
+
conv_states = nn.functional.pad(
|
| 1498 |
+
qk_packed0_transposed,
|
| 1499 |
+
(
|
| 1500 |
+
self.cca_time0 - qk_packed0_transposed.shape[-1],
|
| 1501 |
+
0,
|
| 1502 |
+
),
|
| 1503 |
+
)
|
| 1504 |
+
qk_packed0_cache = past_key_values.update_conv_state(
|
| 1505 |
+
layer_idx=self.layer_idx, new_conv_state=conv_states
|
| 1506 |
+
)
|
| 1507 |
+
# Convs over sequence: [S, B, E] -> [B, E, S] -> pad -> conv ->
|
| 1508 |
+
# [S, B, E]
|
| 1509 |
+
qk_packed1 = qk_packed0.permute(1, 2, 0) # [B, E, S]
|
| 1510 |
+
qk_packed2 = F.pad(qk_packed1, (self.total_padding, 0))
|
| 1511 |
+
qk_packed3 = self.conv_qk(qk_packed2).permute(2, 0, 1) # [S, B, E]
|
| 1512 |
+
|
| 1513 |
+
else:
|
| 1514 |
+
# Convs over sequence: [S, B, E] -> [B, E, S] -> pad -> conv -> [S,
|
| 1515 |
+
# B, E]
|
| 1516 |
+
qk_packed1 = qk_packed0.permute(1, 2, 0) # [B, E, S]
|
| 1517 |
+
qk_packed2 = F.pad(qk_packed1, (self.total_padding, 0))
|
| 1518 |
+
qk_packed3 = self.conv_qk(qk_packed2).permute(2, 0, 1) # [S, B, E]
|
| 1519 |
+
|
| 1520 |
+
# Build queries/keys from conv output + means
|
| 1521 |
+
query = (
|
| 1522 |
+
qk_packed3[..., : self.latent_q_dim].view(*qk_packed3.shape[:2], self.num_q_heads, self.head_dim)
|
| 1523 |
+
+ qk_mean_q
|
| 1524 |
+
) # [S, B, qh, dh]
|
| 1525 |
+
|
| 1526 |
+
key = (
|
| 1527 |
+
qk_packed3[..., self.latent_q_dim :].view(*qk_packed3.shape[:2], self.num_kv_heads, self.head_dim)
|
| 1528 |
+
+ qk_mean_k
|
| 1529 |
+
) # [S, B, kh, dh]
|
| 1530 |
+
|
| 1531 |
+
# Values from the two time streams
|
| 1532 |
+
v1 = self.val_proj1(hs) # [S, B, latent_k_dim/2]
|
| 1533 |
+
if past_key_values is not None:
|
| 1534 |
+
if past_key_values.has_previous_state:
|
| 1535 |
+
# Generation
|
| 1536 |
+
# [B, H]
|
| 1537 |
+
hs_d = past_key_values.prev_hs[self.layer_idx].clone()
|
| 1538 |
+
hs_d = hs_d.unsqueeze(0) # [1, B, H]
|
| 1539 |
+
else:
|
| 1540 |
+
past_key_values.prev_hs[self.layer_idx] = torch.zeros(batch_size, self.hidden_size, device=hs.device, dtype=hs.dtype)
|
| 1541 |
+
past_key_values.prev_hs[self.layer_idx].copy_(hs[-1, :, :])
|
| 1542 |
+
|
| 1543 |
+
v2 = self.val_proj2(hs_d) # [S, B, latent_k_dim/2]
|
| 1544 |
+
value = (
|
| 1545 |
+
torch.cat([v1, v2], dim=-1).contiguous().view(*hs.shape[:2], self.num_kv_heads, self.head_dim)
|
| 1546 |
+
) # [S, B, kh, dh]
|
| 1547 |
+
|
| 1548 |
+
# L2-normalize per head, then scale
|
| 1549 |
+
query_norm = query.norm(p=2, dim=-1, keepdim=True)
|
| 1550 |
+
key_norm = key.norm(p=2, dim=-1, keepdim=True)
|
| 1551 |
+
|
| 1552 |
+
key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1)
|
| 1553 |
+
query = query * (self.sqrt_head_dim / query_norm)
|
| 1554 |
+
|
| 1555 |
+
# Flatten head axis, then return to HF layout [B, S, ...]
|
| 1556 |
+
query = query.view(*query.shape[:2], self.num_q_heads * self.head_dim).transpose(0, 1).contiguous()
|
| 1557 |
+
key = key.view(*key.shape[:2], self.num_kv_heads * self.head_dim).transpose(0, 1).contiguous()
|
| 1558 |
+
value = value.view(*value.shape[:2], self.num_kv_heads * self.head_dim).transpose(0, 1).contiguous()
|
| 1559 |
+
|
| 1560 |
+
query_states = query
|
| 1561 |
+
key_states = key
|
| 1562 |
+
value_states = value
|
| 1563 |
+
|
| 1564 |
+
query_states = query_states.view(batch_size, seq_length, self.num_q_heads, self.head_dim)
|
| 1565 |
+
key_states = key_states.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
| 1566 |
+
value_states = value_states.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
| 1567 |
+
|
| 1568 |
+
# RoPE.
|
| 1569 |
+
if self.config.rope_theta_local > 0.0:
|
| 1570 |
+
cos, sin = position_embeddings
|
| 1571 |
+
if self.config.rope_type_local == "rope":
|
| 1572 |
+
query_states = apply_rotary_emb(query_states, cos, sin)
|
| 1573 |
+
key_states = apply_rotary_emb(key_states, cos, sin)
|
| 1574 |
+
elif self.config.rope_type_local == "p-rope":
|
| 1575 |
+
query_states = apply_p_rotary_emb(query_states, cos, sin, p=0.5)
|
| 1576 |
+
key_states = apply_p_rotary_emb(key_states, cos, sin)
|
| 1577 |
+
else:
|
| 1578 |
+
raise ValueError(f"Unknow rope type : {self.config.rope_type_local}")
|
| 1579 |
+
|
| 1580 |
+
# KV-cache.
|
| 1581 |
+
if past_key_values is not None:
|
| 1582 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
| 1583 |
+
|
| 1584 |
+
# attention computation.
|
| 1585 |
+
wsize = min(self.window_size, self.config.slw_wsize) if self.config.slw_wsize > 0 else self.window_size
|
| 1586 |
+
|
| 1587 |
+
if ATTN_IMPL == "eager":
|
| 1588 |
+
attention_interface = lambda q, k, v, wsize, **kw: eager_attention_forward(q, k, v, window_size=(wsize, 0), **kw)
|
| 1589 |
+
elif ATTN_IMPL == "flex":
|
| 1590 |
+
if wsize != self.last_wsize:
|
| 1591 |
+
self.last_wsize = self.build_mask(wsize)
|
| 1592 |
+
attention_interface = lambda q, k, v, softmax_scale, **kw: flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=create_block_mask(self.attn_mask, B=None, H=None, Q_LEN=q.size(1), KV_LEN=k.size(1)), score_mod=self.score_mod, scale=softmax_scale, enable_gqa=self.num_attention_heads > self.num_key_value_heads).transpose(1, 2)
|
| 1593 |
+
elif ATTN_IMPL == "fa2":
|
| 1594 |
+
attention_interface = lambda q, k, v, wsize, **kw: flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)
|
| 1595 |
+
elif ATTN_IMPL == "fa3":
|
| 1596 |
+
attention_interface = lambda q, k, v, wsize, **kw: flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)[0]
|
| 1597 |
+
else:
|
| 1598 |
+
raise ValueError(f"Unknown ATTN_IMPL: {ATTN_IMPL}")
|
| 1599 |
+
|
| 1600 |
+
attn_output = attention_interface(
|
| 1601 |
+
query_states.bfloat16(),
|
| 1602 |
+
key_states.bfloat16(),
|
| 1603 |
+
value_states.bfloat16(),
|
| 1604 |
+
causal=True,
|
| 1605 |
+
wsize=wsize,
|
| 1606 |
+
softcap=self.config.softcap_local_attn,
|
| 1607 |
+
softmax_scale=None if not self.config.use_uscaling else 1/self.head_dim,
|
| 1608 |
+
)
|
| 1609 |
+
|
| 1610 |
+
return attn_output, None, None
|
| 1611 |
+
|
| 1612 |
class DragonNativeSparseAttention(nn.Module):
|
| 1613 |
"""
|
| 1614 |
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
|
|
| 2149 |
|
| 2150 |
projection_dim = self.head_qk_dim * self.num_attention_heads + self.head_qk_dim * self.num_key_value_heads + (self.head_v_dim * self.num_noise_heads//2)
|
| 2151 |
self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
|
| 2152 |
+
self.linear_qkv.norm_case_1 = True
|
| 2153 |
|
| 2154 |
if self.config.token_shift_attn:
|
| 2155 |
if self.config.scalar_proj_as_hidden_matrix:
|
|
|
|
| 2827 |
self.W_A_v = DragonLinear(config, self.hidden_size, self.num_noise_heads * self.rank, bias=False)
|
| 2828 |
self.W_B_k = DragonLinear(config, self.hidden_size, self.rank * self.head_qk_dim, bias=False)
|
| 2829 |
self.W_B_v = DragonLinear(config, self.hidden_size, self.rank * self.head_v_dim, bias=False)
|
| 2830 |
+
self.c_q.norm_case_1 = True
|
| 2831 |
+
# todo: norm others?
|
| 2832 |
|
| 2833 |
if self.config.token_shift_attn:
|
| 2834 |
if self.config.scalar_proj_as_hidden_matrix:
|
|
|
|
| 3617 |
self.num_attention_heads*self.dk + self.n_kv_heads*self.dk + self.n_kv_heads*self.dv,
|
| 3618 |
bias=False
|
| 3619 |
)
|
| 3620 |
+
self.linear_qkv.norm_case_1 = True
|
| 3621 |
self.linear_ba = DragonLinear(
|
| 3622 |
config, config.hidden_size,
|
| 3623 |
self.num_attention_heads + self.num_attention_heads, #+ self.num_attention_heads*self.dv, # b(H), a(H), g(H*dv)
|
| 3624 |
bias=False
|
| 3625 |
)
|
| 3626 |
|
| 3627 |
+
if config.legacy_gate:
|
| 3628 |
+
if config.gate_type == 'kimi':
|
| 3629 |
+
self.linear_g = nn.Sequential(
|
| 3630 |
+
DragonLinear(config, config.hidden_size, self.dv, bias=False),
|
| 3631 |
+
DragonLinear(config, self.dv, self.n_kv_heads*self.dv, bias=True),
|
| 3632 |
+
)
|
| 3633 |
+
self.output_norm = FusedRMSNormGated(hidden_size=self.dv, eps=config.norm_epsilon, activation='sigmoid')
|
| 3634 |
+
else:
|
| 3635 |
+
self.linear_g = DragonLinear(
|
| 3636 |
+
config, config.hidden_size,
|
| 3637 |
+
self.n_kv_heads * self.dv,
|
| 3638 |
+
bias=False
|
| 3639 |
+
)
|
| 3640 |
+
self.output_norm = FusedRMSNormGated(hidden_size=self.dv, eps=config.norm_epsilon)
|
| 3641 |
+
self.linear_g.norm_case_1 = True
|
| 3642 |
+
|
| 3643 |
dt_min = config.time_step_min
|
| 3644 |
dt_max = config.time_step_max
|
| 3645 |
dt_init_floor = config.time_step_floor
|
|
|
|
| 3654 |
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 3655 |
with torch.no_grad():
|
| 3656 |
self.dt_bias = nn.Parameter(inv_dt)
|
| 3657 |
+
self.dt_bias._no_weight_decay = True
|
| 3658 |
|
| 3659 |
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
| 3660 |
A = torch.empty(self.n_heads_local, dtype=torch.float32).uniform_(*A_init_range)
|
| 3661 |
A_log = torch.log(A) # Keep A_log in fp32
|
| 3662 |
self.A_log = nn.Parameter(A_log)
|
| 3663 |
+
self.A_log._no_weight_decay = True
|
| 3664 |
|
| 3665 |
if self.config.rope_gdn == "rope":
|
| 3666 |
self.rope_proj = DragonLinear(config, config.hidden_size, self.dk//4, bias=False)
|
|
|
|
| 3823 |
use_qk_l2norm_in_kernel=True
|
| 3824 |
) # (B L H dv)
|
| 3825 |
|
| 3826 |
+
if self.config.legacy_gate:
|
| 3827 |
+
g = self.linear_g(hidden_states) # (B, L, H*dv)
|
| 3828 |
+
g = rearrange(g, "b l (h d) -> b l h d", h=self.n_kv_heads)
|
| 3829 |
+
o = self.output_norm(o, g)
|
| 3830 |
+
|
| 3831 |
# update GDN cache
|
| 3832 |
if cache_params is not None:
|
| 3833 |
cache_params.ssm_caches[self.layer_idx] = ssm_cache
|
|
|
|
| 3861 |
self.q_proj = DragonLinear(config, config.hidden_size, self.key_dim, bias=False)
|
| 3862 |
self.k_proj = DragonLinear(config, config.hidden_size, self.key_dim, bias=False)
|
| 3863 |
self.v_proj = DragonLinear(config, config.hidden_size, self.value_dim, bias=False)
|
| 3864 |
+
self.q_proj.norm_case_1 = True
|
| 3865 |
+
self.k_proj.norm_case_1 = True
|
| 3866 |
+
self.v_proj.norm_case_1 = True
|
| 3867 |
|
| 3868 |
self.q_conv1d = ShortConvolution(
|
| 3869 |
hidden_size=self.key_dim,
|
|
|
|
| 3896 |
self.A_log = nn.Parameter(torch.log(torch.empty(self.num_q_heads, dtype=torch.float32).uniform_(1, 16)))
|
| 3897 |
self.dt_bias = nn.Parameter(torch.zeros(self.key_dim, dtype=torch.float32))
|
| 3898 |
|
| 3899 |
+
if config.legacy_gate:
|
| 3900 |
+
if config.gate_type == 'kimi':
|
| 3901 |
+
self.linear_g = nn.Sequential(
|
| 3902 |
+
DragonLinear(config, config.hidden_size, self.head_v_dim, bias=False),
|
| 3903 |
+
DragonLinear(config, self.head_v_dim, self.num_attention_heads*self.head_v_dim, bias=True),
|
| 3904 |
+
)
|
| 3905 |
+
self.output_norm = FusedRMSNormGated(hidden_size=self.head_v_dim, eps=config.norm_epsilon, activation='sigmoid')
|
| 3906 |
+
else:
|
| 3907 |
+
self.linear_g = DragonLinear(
|
| 3908 |
+
config, config.hidden_size,
|
| 3909 |
+
self.num_attention_heads * self.head_v_dim,
|
| 3910 |
+
bias=False
|
| 3911 |
+
)
|
| 3912 |
+
self.output_norm = FusedRMSNormGated(hidden_size=self.head_v_dim, eps=config.norm_epsilon)
|
| 3913 |
+
self.linear_g.norm_case_1 = True
|
| 3914 |
|
| 3915 |
@disable
|
| 3916 |
def _kda_gate_call(self, g, A_log, head_k_dim, g_bias):
|
|
|
|
| 3921 |
hidden_states: torch.Tensor,
|
| 3922 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 3923 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
| 3924 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 3925 |
**kwargs,
|
| 3926 |
):
|
| 3927 |
_, q_len, _ = hidden_states.shape
|
|
|
|
| 3938 |
conv_state_k = cache_params.k_conv_caches[self.layer_idx]
|
| 3939 |
conv_state_v = cache_params.v_conv_caches[self.layer_idx]
|
| 3940 |
|
| 3941 |
+
seq_idx = None
|
| 3942 |
+
if cu_seqlens is not None:
|
| 3943 |
+
seq_idx = prepare_sequence_ids(cu_seqlens).to(torch.int32).unsqueeze(0)
|
| 3944 |
q, conv_state_q = self.q_conv1d(
|
| 3945 |
x=self.q_proj(hidden_states),
|
| 3946 |
cache=conv_state_q,
|
| 3947 |
output_final_state=cache_params is not None,
|
| 3948 |
+
seq_idx=seq_idx,
|
| 3949 |
)
|
| 3950 |
k, conv_state_k = self.k_conv1d(
|
| 3951 |
x=self.k_proj(hidden_states),
|
| 3952 |
cache=conv_state_k,
|
| 3953 |
output_final_state=cache_params is not None,
|
| 3954 |
+
seq_idx=seq_idx,
|
| 3955 |
)
|
| 3956 |
v, conv_state_v = self.v_conv1d(
|
| 3957 |
x=self.v_proj(hidden_states),
|
| 3958 |
cache=conv_state_v,
|
| 3959 |
output_final_state=cache_params is not None,
|
| 3960 |
+
seq_idx=seq_idx,
|
| 3961 |
)
|
| 3962 |
|
| 3963 |
g = self.f_proj(hidden_states)
|
|
|
|
| 3983 |
initial_state=None,
|
| 3984 |
output_final_state=cache_params is not None,
|
| 3985 |
use_qk_l2norm_in_kernel=True,
|
| 3986 |
+
cu_seqlens=cu_seqlens,
|
| 3987 |
)
|
| 3988 |
elif mode == 'fused_recurrent':
|
| 3989 |
o, ssm_cache = fused_recurrent_kda(
|
|
|
|
| 4002 |
#o = o * F.silu(rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim))
|
| 4003 |
# TODO: other types of gates? as well as ZCG?
|
| 4004 |
|
| 4005 |
+
if self.config.legacy_gate:
|
| 4006 |
+
g = self.linear_g(hidden_states) # (B, L, H*dv)
|
| 4007 |
+
g = rearrange(g, "b l (h d) -> b l h d", h=self.num_attention_heads)
|
| 4008 |
+
o = self.output_norm(o, g)
|
| 4009 |
+
|
| 4010 |
if cache_params is not None:
|
| 4011 |
cache_params.ssm_caches[self.layer_idx] = ssm_cache
|
| 4012 |
cache_params.q_conv_caches[self.layer_idx] = conv_state_q
|
|
|
|
| 4056 |
if config.mamba3_rope:
|
| 4057 |
self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
|
| 4058 |
|
| 4059 |
+
# Order: [x, B, C, dt]
|
| 4060 |
+
d_in_proj = self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
|
| 4061 |
|
| 4062 |
if self.config.mamba3_is_A_dd:
|
| 4063 |
self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
|
|
|
|
| 4082 |
self.dt_bias._no_weight_decay = True
|
| 4083 |
|
| 4084 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
|
| 4085 |
+
self.in_proj.norm_case_1 = True
|
| 4086 |
|
| 4087 |
self.B_bias, self.C_bias = None, None
|
| 4088 |
if not config.mamba3_remove_BC_bias:
|
|
|
|
| 4112 |
self.D = nn.Parameter(torch.ones(self.nheads))
|
| 4113 |
self.D._no_weight_decay = True
|
| 4114 |
|
| 4115 |
+
if config.legacy_gate:
|
| 4116 |
+
self.linear_g = DragonLinear(
|
| 4117 |
+
config, config.hidden_size,
|
| 4118 |
+
self.d_inner,
|
| 4119 |
+
bias=False,
|
| 4120 |
+
)
|
| 4121 |
+
self.linear_g.norm_case_1 = True
|
| 4122 |
+
if config.mamba3_postgate_norm:
|
| 4123 |
+
self.output_norm = RMSNormGated(self.d_inner, eps=config.norm_epsilon, norm_before_gate=False)
|
| 4124 |
+
|
| 4125 |
+
def forward(
|
| 4126 |
self,
|
| 4127 |
hidden_states: torch.Tensor,
|
| 4128 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
| 4129 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 4130 |
**kwargs
|
| 4131 |
):
|
| 4132 |
+
cached_len = None
|
| 4133 |
+
if cache_params is not None:
|
| 4134 |
+
hidden_states_cached = cache_params.ssm_caches[self.layer_idx] # (B, L, D)
|
| 4135 |
+
if hidden_states_cached is not None:
|
| 4136 |
+
cached_len = hidden_states_cached.shape[1]
|
| 4137 |
+
hidden_states = torch.cat([hidden_states_cached, hidden_states], dim=1) # (B, L+1, D)
|
| 4138 |
+
cache_params.ssm_caches[self.layer_idx] = hidden_states
|
| 4139 |
+
|
| 4140 |
# Apply in_proj
|
| 4141 |
+
xBCdt = self.in_proj(hidden_states) # (B, l, D), l=1 when decoding
|
| 4142 |
+
xBC, dd_dt = torch.split(
|
| 4143 |
+
xBCdt,
|
| 4144 |
[
|
|
|
|
| 4145 |
self.d_inner + 2 * self.d_state * self.ngroups,
|
| 4146 |
self.nheads,
|
| 4147 |
],
|
|
|
|
| 4154 |
_A = -torch.exp(self.A_log).unsqueeze(0).unsqueeze(0)
|
| 4155 |
dt = F.softplus(dd_dt + self.dt_bias) # (B, L, N)
|
| 4156 |
|
| 4157 |
+
seq_idx = None
|
| 4158 |
+
if cu_seqlens is not None:
|
| 4159 |
+
seq_idx = prepare_sequence_ids(cu_seqlens).to(torch.int32).unsqueeze(0)
|
| 4160 |
+
|
| 4161 |
if not self.config.mamba3_remove_conv:
|
| 4162 |
xBC = causal_conv1d_fn(
|
| 4163 |
x=xBC.transpose(1, 2),
|
| 4164 |
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 4165 |
bias=self.conv1d.bias,
|
| 4166 |
activation=self.activation,
|
| 4167 |
+
seq_idx=seq_idx,
|
| 4168 |
).transpose(1, 2) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
| 4169 |
|
| 4170 |
x, B, C = torch.split(
|
|
|
|
| 4230 |
|
| 4231 |
x_scalar = (gamma_arr*_alpha_arr).to(torch.bfloat16)
|
| 4232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4233 |
out = mamba_chunk_scan_discretized_combined(
|
| 4234 |
x=x.bfloat16(),
|
| 4235 |
A=A,
|
|
|
|
| 4241 |
CB_sum=CB_sum,
|
| 4242 |
D=self.D,
|
| 4243 |
z=None,
|
| 4244 |
+
initial_states=None, # ssm_cache,
|
| 4245 |
+
return_final_states=False, # cache_params is not None,
|
| 4246 |
+
seq_idx=seq_idx,
|
| 4247 |
)
|
| 4248 |
+
y = out
|
| 4249 |
+
|
| 4250 |
+
if self.config.legacy_gate:
|
| 4251 |
+
if not self.config.mamba3_postgate_norm:
|
| 4252 |
+
g = self.linear_g(hidden_states) # (B, L, d_inner)
|
| 4253 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
| 4254 |
+
y = y * F.silu(g)
|
| 4255 |
+
y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads)
|
| 4256 |
+
else:
|
| 4257 |
+
g = self.linear_g(hidden_states) # (B, L, d_inner)
|
| 4258 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
| 4259 |
+
y = self.output_norm(y, g)
|
| 4260 |
+
y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads)
|
| 4261 |
|
| 4262 |
+
if cached_len and cached_len > 0:
|
| 4263 |
+
y = y[:, cached_len:, :] # keep only the new Ln steps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4264 |
|
| 4265 |
return y, None, None
|
| 4266 |
|
|
|
|
| 4281 |
# Order: [x, B, C, dt]
|
| 4282 |
d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
| 4283 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
|
| 4284 |
+
self.in_proj.norm_case_1 = True
|
| 4285 |
|
| 4286 |
if not self.config.mamba3_remove_conv:
|
| 4287 |
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
|
|
|
| 4319 |
self.D = nn.Parameter(torch.ones(self.nheads))
|
| 4320 |
self.D._no_weight_decay = True
|
| 4321 |
|
| 4322 |
+
if config.legacy_gate:
|
| 4323 |
+
self.linear_g = DragonLinear(
|
| 4324 |
+
config, config.hidden_size,
|
| 4325 |
+
self.d_inner,
|
| 4326 |
+
bias=False,
|
| 4327 |
+
)
|
| 4328 |
+
self.linear_g.norm_case_1 = True
|
| 4329 |
+
self.output_norm = RMSNormGated(self.d_inner, eps=config.norm_epsilon, norm_before_gate=False)
|
| 4330 |
+
|
| 4331 |
def forward(self, hidden_states, **kwargs):
|
| 4332 |
"""
|
| 4333 |
u: (B, L, D)
|
|
|
|
| 4374 |
initial_states=None,
|
| 4375 |
)
|
| 4376 |
|
| 4377 |
+
if self.config.legacy_gate:
|
| 4378 |
+
g = self.linear_g(hidden_states) # (B, L, d_inner)
|
| 4379 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
| 4380 |
+
y = self.output_norm(y, g)
|
| 4381 |
+
y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads)
|
| 4382 |
+
|
| 4383 |
return y, None, None
|
| 4384 |
|
| 4385 |
class DragonMamba3Mimo(nn.Module):
|
|
|
|
| 4394 |
"when creating this class."
|
| 4395 |
)
|
| 4396 |
|
| 4397 |
+
assert not self.config.gate_gdn, "gate must done inside the mimo mamba3 block."
|
| 4398 |
+
|
| 4399 |
self.d_model = config.hidden_size
|
| 4400 |
+
self.d_state = config.mamba_d_state
|
| 4401 |
self.conv_init = None
|
| 4402 |
self.expand = 2
|
| 4403 |
+
self.headdim = config.mamba_headdim
|
| 4404 |
self.ngroups = config.mamba_ngroups
|
| 4405 |
self.activation = "swish"
|
| 4406 |
self.bias = False
|
|
|
|
| 4415 |
self.dt_init_floor = 1e-4
|
| 4416 |
self.mimo_dim = config.mamba_mimo_dim
|
| 4417 |
self.mimo_proj_block_order = 1
|
|
|
|
| 4418 |
|
| 4419 |
self.d_inner = int(self.expand * self.d_model)
|
| 4420 |
assert self.d_inner % self.headdim == 0
|
| 4421 |
self.nheads = self.d_inner // self.headdim
|
| 4422 |
self.dr_out_dim = self.d_inner // self.mimo_proj_block_order
|
| 4423 |
|
|
|
|
| 4424 |
self.split_tensor_size = int(self.d_state * self.rope_fraction)
|
| 4425 |
if self.split_tensor_size % 2 != 0:
|
| 4426 |
self.split_tensor_size -= 1
|
|
|
|
| 4446 |
self.dt_bias._no_weight_decay = True
|
| 4447 |
|
| 4448 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
|
| 4449 |
+
self.in_proj.norm_case_1 = True
|
| 4450 |
|
| 4451 |
self.B_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
|
| 4452 |
self.C_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
|
|
|
|
| 4478 |
self.in_proj_mimo_z = nn.Parameter(in_proj_mimo_z_init_weights, requires_grad=True)
|
| 4479 |
self.out_proj_mimo = nn.Parameter(out_proj_mimo_init_weights, requires_grad=True)
|
| 4480 |
|
|
|
|
| 4481 |
# D "skip" parameter
|
| 4482 |
self.D = nn.Parameter(torch.ones(self.nheads))
|
| 4483 |
self.D._no_weight_decay = True
|
| 4484 |
|
| 4485 |
+
if config.legacy_gate:
|
| 4486 |
+
if config.mamba3_postgate_norm:
|
| 4487 |
+
self.output_norm = RMSNormGated(self.d_inner, eps=config.norm_epsilon, norm_before_gate=False)
|
| 4488 |
+
|
| 4489 |
def forward(self, hidden_states, **kwargs):
|
| 4490 |
# Apply in_proj
|
| 4491 |
zxBCdt = self.in_proj(hidden_states)
|
|
|
|
| 4578 |
_beta_arr = torch.roll(beta_arr, shifts=-1, dims=1)
|
| 4579 |
|
| 4580 |
x_scalar = (gamma_arr*_alpha_arr + _beta_arr).to(torch.bfloat16)
|
| 4581 |
+
|
| 4582 |
z = rearrange(z, "b l r (h p) -> b l r h p", p=self.headdim)
|
| 4583 |
|
| 4584 |
y = mamba_mimo_chunk_scan_discretized_fused_combined(
|
|
|
|
| 4591 |
gamma=gamma_arr,
|
| 4592 |
CB_sum=CB_sum,
|
| 4593 |
D=self.D,
|
| 4594 |
+
z=z if not (self.config.legacy_gate and self.config.mamba3_postgate_norm) else None,
|
| 4595 |
)
|
| 4596 |
|
| 4597 |
y = rearrange(y, "b l r h p -> b l r (h p)")
|
| 4598 |
+
|
| 4599 |
+
if self.config.legacy_gate and self.config.mamba3_postgate_norm:
|
| 4600 |
+
z = rearrange(z, "b l r h p -> b l r (h p)")
|
| 4601 |
+
y = self.output_norm(y, z)
|
| 4602 |
+
|
| 4603 |
#if seqlen_og is not None:
|
| 4604 |
# y = rearrange(y, "b l r d -> (b l) r d")
|
| 4605 |
|
|
|
|
| 4626 |
self.lambda1 = nn.Parameter(torch.zeros(self.link_size)) # sigmoid->0.5
|
| 4627 |
else :
|
| 4628 |
self.fc_1 = DragonLinear(config, config.hidden_size, intermediate_size, bias=False)
|
| 4629 |
+
self.fc_1.norm_case_1 = True
|
| 4630 |
self.fc_2 = DragonLinear(config, intermediate_size, config.hidden_size, bias=False)
|
| 4631 |
+
self.fc_2.norm_case_2 = True
|
| 4632 |
self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
|
| 4633 |
|
| 4634 |
def forward(self, hidden_states):
|
|
|
|
| 4657 |
self.intermediate_size = intermediate_size
|
| 4658 |
|
| 4659 |
self.fc_1 = DragonLinear(config, config.hidden_size, num_active_experts*self.intermediate_size, bias=False)
|
| 4660 |
+
self.fc_1.norm_case_1 = True
|
| 4661 |
self.fc_2 = DragonLinear(config, num_active_experts*self.intermediate_size, config.hidden_size, bias=False)
|
| 4662 |
+
self.fc_2.norm_case_2 = True
|
| 4663 |
self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
|
| 4664 |
|
| 4665 |
def forward(self, hidden_states, gates):
|
|
|
|
| 4737 |
head_dim = self.mixer.head_dim
|
| 4738 |
num_attention_heads = self.mixer.num_q_heads
|
| 4739 |
use_gate = config.gate_attn
|
| 4740 |
+
elif layer_type == 'C':
|
| 4741 |
+
self.mixer = DragonCompressedConvolutionalAttention2(config, layer_idx=layer_idx)
|
| 4742 |
+
head_dim = self.mixer.head_dim
|
| 4743 |
+
num_attention_heads = self.mixer.num_q_heads
|
| 4744 |
+
use_gate = config.gate_attn
|
| 4745 |
elif layer_type == 'n':
|
| 4746 |
self.mixer = DragonNativeSparseAttention(config, reuse_kv=False, layer_idx=layer_idx)
|
| 4747 |
head_dim = self.mixer.head_dim
|
|
|
|
| 4771 |
self.mixer = DragonMamba3(config, layer_idx=layer_idx)
|
| 4772 |
head_dim = self.mixer.headdim
|
| 4773 |
num_attention_heads = self.mixer.nheads
|
| 4774 |
+
use_gate = config.gate_gdn
|
| 4775 |
elif layer_type == '2':
|
| 4776 |
self.mixer = DragonMamba2(config, layer_idx=layer_idx)
|
| 4777 |
head_dim = self.mixer.headdim
|
|
|
|
| 4782 |
head_dim = self.mixer.headdim
|
| 4783 |
num_attention_heads = self.mixer.nheads
|
| 4784 |
use_gate = False # inside Mamba3Mimo
|
| 4785 |
+
elif layer_type == 'b':
|
| 4786 |
+
self.mixer = DragonMoBAttention(config, reuse_kv=False, layer_idx=layer_idx)
|
| 4787 |
+
head_dim = self.mixer.head_dim
|
| 4788 |
+
num_attention_heads = self.mixer.num_attention_heads
|
| 4789 |
+
use_gate = config.gate_attn
|
| 4790 |
else:
|
| 4791 |
raise ValueError(f"Unknown layer type: {layer_type}")
|
| 4792 |
|
|
|
|
| 4806 |
self.gate_proj.is_scalar_weight = True
|
| 4807 |
else:
|
| 4808 |
raise ValueError(f"Unknown gate_type: {self.config.gate_type}")
|
| 4809 |
+
self.gate_proj.norm_case_1 = True
|
| 4810 |
if self.config.zero_centered_gate:
|
| 4811 |
val = 1.
|
| 4812 |
if self.config.zero_centered_gate_type==3:
|
|
|
|
| 4827 |
self.use_gate = use_gate
|
| 4828 |
|
| 4829 |
self.mixer_proj = DragonLinear(config, head_dim*num_attention_heads, config.hidden_size, bias=False)
|
| 4830 |
+
self.mixer_proj.norm_case_2 = True
|
| 4831 |
if config.mixer_gn:
|
| 4832 |
self.mixer_group_norm = DragonHeadWiseRMSNorm(n_heads=num_attention_heads, d_head=head_dim, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
|
| 4833 |
|
|
|
|
| 4874 |
cu_seqlens=cu_seqlens,
|
| 4875 |
max_seqlen=max_seqlen,
|
| 4876 |
) # (B, L, E*D)
|
| 4877 |
+
if self.config.mixer_gn and not self.config.gate_before_norm:
|
| 4878 |
+
y_mixer = self.mixer_group_norm(y_mixer)
|
| 4879 |
if self.use_gate:
|
| 4880 |
if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
|
| 4881 |
g_proj = self.gate_proj(hidden_states).view(hidden_states.size(0), hidden_states.size(1), self.num_attention_heads, self.head_dim).to(y_mixer.dtype)
|
|
|
|
| 4890 |
y_mixer = y_mixer * (self.gate_act(g_proj) + self.gate_bias)
|
| 4891 |
elif self.config.zero_centered_gate_type == 3 or self.config.zero_centered_gate_type == 4:
|
| 4892 |
y_mixer = y_mixer * self.gate_act(g_proj + self.gate_bias)
|
| 4893 |
+
if self.config.mixer_gn and self.config.gate_before_norm:
|
| 4894 |
y_mixer = self.mixer_group_norm(y_mixer)
|
| 4895 |
y_mixer = y_mixer.view(y_mixer.size(0), y_mixer.size(1), -1)
|
| 4896 |
y_mixer = self.mixer_proj(y_mixer)
|
|
|
|
| 4904 |
|
| 4905 |
return hidden_states, last_key_states, last_value_states
|
| 4906 |
|
| 4907 |
+
class DragonGHyperConnection(nn.Module):
|
| 4908 |
+
def __init__(self, config: DragonConfig, m, n_in=3):
|
| 4909 |
+
super().__init__()
|
| 4910 |
+
self.config = config
|
| 4911 |
+
self.m, self.n_in = m, n_in
|
| 4912 |
+
dim = self.config.hidden_size
|
| 4913 |
+
self.factor = 1.0 / math.sqrt(dim // self.m)
|
| 4914 |
+
|
| 4915 |
+
# Initialize static beta: cyclic pattern
|
| 4916 |
+
static_beta_tensor = torch.zeros(self.m, n_in)
|
| 4917 |
+
for j in range(n_in):
|
| 4918 |
+
static_beta_tensor[j % self.m, j] = 1.0
|
| 4919 |
+
self.static_beta = nn.Parameter(static_beta_tensor.T.contiguous())
|
| 4920 |
+
|
| 4921 |
+
# Initialize static alpha: block matrix
|
| 4922 |
+
init_alpha = torch.cat([torch.eye(self.m), torch.eye(self.m), torch.zeros((self.m, self.n_in - self.m))], dim=1)
|
| 4923 |
+
if self.n_in > self.m:
|
| 4924 |
+
part2 = torch.cat([torch.zeros((self.n_in - self.m, self.m * 2)), torch.eye(self.n_in - self.m)], dim=1)
|
| 4925 |
+
init_alpha = torch.cat([init_alpha, part2], dim=0)
|
| 4926 |
+
self.static_alpha = nn.Parameter(init_alpha.contiguous())
|
| 4927 |
+
|
| 4928 |
+
# Dynamic parameters
|
| 4929 |
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros((dim // self.m, self.m + self.n_in)))
|
| 4930 |
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros((dim // self.m, self.m)))
|
| 4931 |
+
self.dynamic_alpha_fn.requires_weight_decay = True
|
| 4932 |
+
self.dynamic_beta_fn.requires_weight_decay = True
|
| 4933 |
+
if self.config.vwn_dynamic:
|
| 4934 |
+
self.dynamic_alpha_scale = nn.Parameter(torch.ones_like(self.static_alpha))
|
| 4935 |
+
self.dynamic_beta_scale = nn.Parameter(torch.ones_like(self.static_beta))
|
| 4936 |
+
if config.vwn_wd_alpha_beta:
|
| 4937 |
+
self.dynamic_alpha_scale.requires_weight_decay = True
|
| 4938 |
+
self.dynamic_beta_scale.requires_weight_decay = True
|
| 4939 |
+
else:
|
| 4940 |
+
self.register_buffer("dynamic_alpha_scale", torch.zeros_like(self.static_alpha), persistent=False)
|
| 4941 |
+
self.register_buffer("dynamic_beta_scale", torch.zeros_like(self.static_beta), persistent=False)
|
| 4942 |
+
|
| 4943 |
+
self.layer_norm = DragonNorm(config, dim//self.m)
|
| 4944 |
+
|
| 4945 |
+
def _base_width_connection(self, h, dynamic_fn, dynamic_scale, static_scale):
|
| 4946 |
+
h_shape = h.shape
|
| 4947 |
+
N, NMM = static_scale.shape
|
| 4948 |
+
M = (NMM - N) // 2
|
| 4949 |
+
h_reshape = h.reshape((h_shape[:-1].numel(),) + (N, h_shape[-1] // N))
|
| 4950 |
+
norm_h = self.layer_norm(h_reshape)
|
| 4951 |
+
alpha_beta = (F.tanh(norm_h @ dynamic_fn.T.to(dtype=norm_h.dtype) * self.factor) * dynamic_scale[None, ...] + static_scale[None, ...])
|
| 4952 |
+
alpha, beta = torch.split(alpha_beta, (M + N, M), dim=-1)
|
| 4953 |
+
mix_h = (h_reshape.transpose(1, 2) @ alpha.to(dtype=h_reshape.dtype)).transpose(1, 2)
|
| 4954 |
+
return mix_h.reshape(h_shape[:-1] + mix_h.shape[1:]), beta
|
| 4955 |
+
|
| 4956 |
+
def width_connection(self, h):
|
| 4957 |
+
dynamic_fn = torch.concat([self.dynamic_alpha_fn.T, self.dynamic_beta_fn.T], dim=0)
|
| 4958 |
+
dynamic_scale = torch.concat([self.dynamic_alpha_scale, self.dynamic_beta_scale], dim=-1).contiguous()
|
| 4959 |
+
static_scale = torch.concat([self.static_alpha, self.static_beta], dim=-1)
|
| 4960 |
+
return self._base_width_connection(h, dynamic_fn.to(dtype=h.dtype), dynamic_scale.to(dtype=h.dtype), static_scale.to(dtype=h.dtype))
|
| 4961 |
+
|
| 4962 |
+
def depth_connection(self, mix_h, h_o, beta, sqrt_one_minus_tau, sqrt_tau):
|
| 4963 |
+
h_o_shape = h_o.shape
|
| 4964 |
+
h_o = h_o.reshape(h_o_shape[:-1] + (self.m, h_o_shape[-1] // self.m))
|
| 4965 |
+
h_i = beta.view(h_o.shape[:2] + beta.shape[1:]).to(dtype=h_o.dtype) @ h_o
|
| 4966 |
+
h = sqrt_tau * h_i + sqrt_one_minus_tau * mix_h[..., self.m:, :]
|
| 4967 |
+
h_shape = h.shape
|
| 4968 |
+
return h.reshape(h_shape[:-2] + (h_shape[-2] * h_shape[-1],)).contiguous()
|
| 4969 |
+
|
| 4970 |
+
class DragonMonoVirtualBlock(GradientCheckpointingLayer):
|
| 4971 |
+
def __init__(self, config: DragonConfig, layer_idx: int, layer_type: str):
|
| 4972 |
+
super().__init__()
|
| 4973 |
+
self.config = config
|
| 4974 |
+
self.layer_idx = layer_idx
|
| 4975 |
+
|
| 4976 |
+
assert self.config.vwn
|
| 4977 |
+
|
| 4978 |
+
if layer_type == 'g':
|
| 4979 |
+
self.mixer = DragonGatedDeltaNet(config, layer_idx=layer_idx)
|
| 4980 |
+
head_dim = self.mixer.head_dim
|
| 4981 |
+
num_attention_heads = self.mixer.num_attention_heads
|
| 4982 |
+
use_gate = config.gate_gdn
|
| 4983 |
+
elif layer_type == 'f':
|
| 4984 |
+
self.mixer = DragonDifferentialAttention(config, layer_idx=layer_idx)
|
| 4985 |
+
head_dim = self.mixer.head_dim
|
| 4986 |
+
num_attention_heads = self.mixer.num_signal_heads
|
| 4987 |
+
use_gate = config.gate_attn
|
| 4988 |
+
elif layer_type == 's':
|
| 4989 |
+
self.mixer = DragonDeepSeekSparseAttention(config, reuse_kv=False, layer_idx=layer_idx)
|
| 4990 |
+
head_dim = self.mixer.head_dim
|
| 4991 |
+
num_attention_heads = self.mixer.num_attention_heads
|
| 4992 |
+
use_gate = config.gate_attn
|
| 4993 |
+
elif layer_type == 'm':
|
| 4994 |
+
self.mixer = DragonDynamicMaskAttention(config, reuse_kv=False, layer_idx=layer_idx)
|
| 4995 |
+
head_dim = self.mixer.head_dim
|
| 4996 |
+
num_attention_heads = self.mixer.num_attention_heads
|
| 4997 |
+
use_gate = config.gate_attn
|
| 4998 |
+
elif layer_type == 'w':
|
| 4999 |
+
self.mixer = DragonAttention(config, reuse_kv=False, layer_idx=layer_idx)
|
| 5000 |
+
head_dim = self.mixer.head_dim
|
| 5001 |
+
num_attention_heads = self.mixer.num_attention_heads
|
| 5002 |
+
use_gate = config.gate_attn
|
| 5003 |
+
elif layer_type == 'p':
|
| 5004 |
+
self.mixer = DragonSlidingWindowRecurrenceAttention(config)
|
| 5005 |
+
head_dim = self.mixer.head_dim
|
| 5006 |
+
num_attention_heads = self.mixer.num_attention_heads
|
| 5007 |
+
use_gate = config.gate_attn
|
| 5008 |
+
elif layer_type == 'c':
|
| 5009 |
+
self.mixer = DragonCompressedConvolutionalAttention(config, layer_idx=layer_idx)
|
| 5010 |
+
head_dim = self.mixer.head_dim
|
| 5011 |
+
num_attention_heads = self.mixer.num_q_heads
|
| 5012 |
+
use_gate = config.gate_attn
|
| 5013 |
+
elif layer_type == 'C':
|
| 5014 |
+
self.mixer = DragonCompressedConvolutionalAttention2(config, layer_idx=layer_idx)
|
| 5015 |
+
head_dim = self.mixer.head_dim
|
| 5016 |
+
num_attention_heads = self.mixer.num_q_heads
|
| 5017 |
+
use_gate = config.gate_attn
|
| 5018 |
+
elif layer_type == 'n':
|
| 5019 |
+
self.mixer = DragonNativeSparseAttention(config, reuse_kv=False, layer_idx=layer_idx)
|
| 5020 |
+
head_dim = self.mixer.head_dim
|
| 5021 |
+
num_attention_heads = self.mixer.num_attention_heads
|
| 5022 |
+
use_gate = config.gate_attn
|
| 5023 |
+
elif layer_type == 't':
|
| 5024 |
+
self.mixer = DragonTensorProductAttention(config, reuse_kv=False, layer_idx=layer_idx)
|
| 5025 |
+
head_dim = self.mixer.head_dim
|
| 5026 |
+
num_attention_heads = self.mixer.num_attention_heads
|
| 5027 |
+
use_gate = config.gate_attn
|
| 5028 |
+
elif layer_type == 'T':
|
| 5029 |
+
self.mixer = DragonDifferentialTensorProductAttention(config, layer_idx=layer_idx)
|
| 5030 |
+
head_dim = self.mixer.head_dim
|
| 5031 |
+
num_attention_heads = self.mixer.num_signal_heads
|
| 5032 |
+
use_gate = config.gate_attn
|
| 5033 |
+
elif layer_type == 'A':
|
| 5034 |
+
self.mixer = DragonDifferentialMultiLatentAttention(config, layer_idx=layer_idx)
|
| 5035 |
+
head_dim = self.mixer.head_dim
|
| 5036 |
+
num_attention_heads = self.mixer.num_signal_heads
|
| 5037 |
+
use_gate = config.gate_attn
|
| 5038 |
+
elif layer_type == 'k':
|
| 5039 |
+
self.mixer = DragonKimiDeltaAttention(config, layer_idx=layer_idx)
|
| 5040 |
+
head_dim = self.mixer.head_dim
|
| 5041 |
+
num_attention_heads = self.mixer.num_attention_heads
|
| 5042 |
+
use_gate = config.gate_gdn
|
| 5043 |
+
elif layer_type == '3':
|
| 5044 |
+
self.mixer = DragonMamba3(config, layer_idx=layer_idx)
|
| 5045 |
+
head_dim = self.mixer.headdim
|
| 5046 |
+
num_attention_heads = self.mixer.nheads
|
| 5047 |
+
use_gate = config.gate_gdn
|
| 5048 |
+
elif layer_type == '2':
|
| 5049 |
+
self.mixer = DragonMamba2(config, layer_idx=layer_idx)
|
| 5050 |
+
head_dim = self.mixer.headdim
|
| 5051 |
+
num_attention_heads = self.mixer.nheads
|
| 5052 |
+
use_gate = config.gate_gdn
|
| 5053 |
+
elif layer_type == 'M':
|
| 5054 |
+
self.mixer = DragonMamba3Mimo(config, layer_idx=layer_idx)
|
| 5055 |
+
head_dim = self.mixer.headdim
|
| 5056 |
+
num_attention_heads = self.mixer.nheads
|
| 5057 |
+
use_gate = False # inside Mamba3Mimo
|
| 5058 |
+
else:
|
| 5059 |
+
raise ValueError(f"Unknown layer type: {layer_type}")
|
| 5060 |
+
|
| 5061 |
+
if use_gate:
|
| 5062 |
+
if self.config.gate_type == "elementwise":
|
| 5063 |
+
self.gate_proj = DragonLinear(self.config, config.hidden_size, num_attention_heads*head_dim, bias=False)
|
| 5064 |
+
elif self.config.gate_type == "kimi":
|
| 5065 |
+
self.gate_proj = nn.Sequential(
|
| 5066 |
+
DragonLinear(config, config.hidden_size, head_dim, bias=False),
|
| 5067 |
+
DragonLinear(config, head_dim, num_attention_heads*head_dim, bias=True),
|
| 5068 |
+
)
|
| 5069 |
+
elif self.config.gate_type == "headwise":
|
| 5070 |
+
if self.config.scalar_proj_as_hidden_matrix:
|
| 5071 |
+
self.gate_proj = DragonLinear(self.config, config.hidden_size, num_attention_heads, bias=False)
|
| 5072 |
+
else:
|
| 5073 |
+
self.gate_proj = DragonLinear(self.config, config.hidden_size, num_attention_heads, bias=False, alpha_fwd=1., alpha_bwd=1.)
|
| 5074 |
+
self.gate_proj.is_scalar_weight = True
|
| 5075 |
+
else:
|
| 5076 |
+
raise ValueError(f"Unknown gate_type: {self.config.gate_type}")
|
| 5077 |
+
self.gate_proj.norm_case_1 = True
|
| 5078 |
+
if self.config.zero_centered_gate:
|
| 5079 |
+
val = 1.
|
| 5080 |
+
if self.config.zero_centered_gate_type==3:
|
| 5081 |
+
val = 1.28 # F.silu(E(g) + 1.28) = 1
|
| 5082 |
+
elif self.config.zero_centered_gate_type==4:
|
| 5083 |
+
val = 1.15 # E(silu(g + 1.15)) = 1
|
| 5084 |
+
self.register_buffer("gate_bias", torch.tensor(val), persistent=False)
|
| 5085 |
+
else:
|
| 5086 |
+
self.register_buffer("gate_bias", torch.tensor(0.), persistent=False)
|
| 5087 |
+
if self.config.gate_act == "silu":
|
| 5088 |
+
self.gate_act = F.silu
|
| 5089 |
+
elif self.config.gate_act == "sigmoid":
|
| 5090 |
+
self.gate_act = F.sigmoid
|
| 5091 |
+
else:
|
| 5092 |
+
raise ValueError(f"Unknown gate_act: {self.config.gate_act}")
|
| 5093 |
+
self.num_attention_heads = num_attention_heads
|
| 5094 |
+
self.head_dim = head_dim
|
| 5095 |
+
self.use_gate = use_gate
|
| 5096 |
+
|
| 5097 |
+
self.mixer_proj = DragonLinear(config, head_dim*num_attention_heads, config.hidden_size, bias=False)
|
| 5098 |
+
self.mixer_proj.norm_case_2 = True
|
| 5099 |
+
if config.mixer_gn:
|
| 5100 |
+
self.mixer_group_norm = DragonHeadWiseRMSNorm(n_heads=num_attention_heads, d_head=head_dim, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
|
| 5101 |
+
|
| 5102 |
+
self.input_norm = DragonNorm(config, config.hidden_size)
|
| 5103 |
+
self.postmixer_norm = DragonNorm(config, config.hidden_size)
|
| 5104 |
+
self.mixer_ghyper_connection = DragonGHyperConnection(config, m=config.vwn_m, n_in=config.vwn_n)
|
| 5105 |
+
self.mlp_ghyper_connection = DragonGHyperConnection(config, m=config.vwn_m, n_in=config.vwn_n)
|
| 5106 |
+
if not config.moe:
|
| 5107 |
+
if config.mlp_type == "simple":
|
| 5108 |
+
self.mlp = DragonMLP(config)
|
| 5109 |
+
elif config.mlp_type == "gated":
|
| 5110 |
+
self.mlp = GatedMlp(in_features=config.hidden_size, hidden_features=config.intermediate_size, out_features=config.hidden_size, activation=F.silu, bias1=False, bias2=False)
|
| 5111 |
+
else:
|
| 5112 |
+
self.mlp = DragonMoE(config)
|
| 5113 |
+
|
| 5114 |
+
if config.use_uscaling or not config.layer_norm_scaling:
|
| 5115 |
+
self.register_buffer("lns", torch.tensor(1.0), persistent=False)
|
| 5116 |
+
else:
|
| 5117 |
+
self.register_buffer("lns", torch.tensor(1. / math.sqrt(layer_idx + (2 if config.old_lns else 1))), persistent=False)
|
| 5118 |
+
self.register_buffer("sqrt_tau", torch.sqrt(torch.tensor(self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
|
| 5119 |
+
self.register_buffer("sqrt_one_minus_tau", torch.sqrt(torch.tensor(1.0 - self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
|
| 5120 |
+
|
| 5121 |
+
def forward(
|
| 5122 |
+
self,
|
| 5123 |
+
hidden_states: torch.Tensor,
|
| 5124 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 5125 |
+
cache_params: Optional[HybridDragonDynamicCache] = None,
|
| 5126 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 5127 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 5128 |
+
key_value_last_layer: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 5129 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 5130 |
+
max_seqlen: Optional[int] = None,
|
| 5131 |
+
**kwargs,
|
| 5132 |
+
):
|
| 5133 |
+
# hidden_states : (B, L, D'). D' = n/m D (expanded width)
|
| 5134 |
+
|
| 5135 |
+
# MIXER.
|
| 5136 |
+
mix_h, beta = self.mixer_ghyper_connection.width_connection(hidden_states)
|
| 5137 |
+
mix_h_shape = mix_h.shape
|
| 5138 |
+
h = mix_h[..., :self.config.vwn_m, :].reshape(mix_h_shape[:-2] + (self.config.vwn_m * mix_h_shape[-1],))
|
| 5139 |
+
# h is (B, L, D)
|
| 5140 |
+
h = self.lns * self.input_norm(h)
|
| 5141 |
+
y_mixer, last_key_states, last_value_states = self.mixer(
|
| 5142 |
+
hidden_states=h,
|
| 5143 |
+
position_embeddings=position_embeddings,
|
| 5144 |
+
position_ids=position_ids,
|
| 5145 |
+
cache_params=cache_params,
|
| 5146 |
+
key_value_last_layer=key_value_last_layer,
|
| 5147 |
+
cu_seqlens=cu_seqlens,
|
| 5148 |
+
max_seqlen=max_seqlen,
|
| 5149 |
+
) # (B, L, E*D)
|
| 5150 |
+
if self.config.mixer_gn and not self.config.gate_before_norm:
|
| 5151 |
+
y_mixer = self.mixer_group_norm(y_mixer)
|
| 5152 |
+
if self.use_gate:
|
| 5153 |
+
if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
|
| 5154 |
+
g_proj = self.gate_proj(h).view(h.size(0), h.size(1), self.num_attention_heads, self.head_dim).to(y_mixer.dtype)
|
| 5155 |
+
elif self.config.gate_type == "headwise":
|
| 5156 |
+
g_proj = self.gate_proj(h).unsqueeze(-1).to(y_mixer.dtype)
|
| 5157 |
+
else:
|
| 5158 |
+
raise ValueError(f"Unknown gate_type: {self.config.gate_type}")
|
| 5159 |
+
if self.config.zero_centered_gate_type == 1:
|
| 5160 |
+
y_mixer = y_mixer * self.gate_act(g_proj)
|
| 5161 |
+
y_mixer = y_mixer + self.gate_bias
|
| 5162 |
+
elif self.config.zero_centered_gate_type == 2:
|
| 5163 |
+
y_mixer = y_mixer * (self.gate_act(g_proj) + self.gate_bias)
|
| 5164 |
+
elif self.config.zero_centered_gate_type == 3 or self.config.zero_centered_gate_type == 4:
|
| 5165 |
+
y_mixer = y_mixer * self.gate_act(g_proj + self.gate_bias)
|
| 5166 |
+
if self.config.mixer_gn and self.config.gate_before_norm:
|
| 5167 |
+
y_mixer = self.mixer_group_norm(y_mixer)
|
| 5168 |
+
y_mixer = y_mixer.view(y_mixer.size(0), y_mixer.size(1), -1)
|
| 5169 |
+
y_mixer = self.mixer_proj(y_mixer) # (B, L, D)
|
| 5170 |
+
h = self.mixer_ghyper_connection.depth_connection(mix_h, y_mixer, beta, self.sqrt_one_minus_tau, self.sqrt_tau) # (B, L, D')
|
| 5171 |
+
|
| 5172 |
+
# MLP.
|
| 5173 |
+
mix_h, beta = self.mlp_ghyper_connection.width_connection(h)
|
| 5174 |
+
mix_h_shape = mix_h.shape
|
| 5175 |
+
h = mix_h[..., :self.config.vwn_m, :].reshape(mix_h_shape[:-2] + (self.config.vwn_m * mix_h_shape[-1],))
|
| 5176 |
+
# h is (B, L, D)
|
| 5177 |
+
h = self.lns * self.postmixer_norm(h)
|
| 5178 |
+
y_mlp = self.mlp(h) # (B, L, D)
|
| 5179 |
+
h = self.mlp_ghyper_connection.depth_connection(mix_h, y_mlp, beta, self.sqrt_one_minus_tau, self.sqrt_tau) # (B, L, D')
|
| 5180 |
+
|
| 5181 |
+
return h, 0, 0
|
| 5182 |
+
|
| 5183 |
class DragonBlock(GradientCheckpointingLayer):
|
| 5184 |
def __init__(self, config: DragonConfig, layer_idx: int, layer_type: str):
|
| 5185 |
super().__init__()
|
|
|
|
| 5265 |
"attentions": DragonBlock,
|
| 5266 |
}
|
| 5267 |
|
| 5268 |
+
def _init_weights(self, module):
|
| 5269 |
if isinstance(module, (DragonLinear, nn.Conv1d)):
|
| 5270 |
if module.bias is not None:
|
| 5271 |
nn.init.zeros_(module.bias)
|
| 5272 |
+
nn.init.normal_(module.weight, mean=0., std=self.config.initializer_range)
|
| 5273 |
elif isinstance(module, nn.Embedding):
|
| 5274 |
+
nn.init.normal_(module.weight, mean=0., std=self.config.initializer_range)
|
| 5275 |
|
| 5276 |
@dataclass
|
| 5277 |
class DragonOutput(ModelOutput):
|
|
|
|
| 5326 |
self.vocab_size = config.vocab_size
|
| 5327 |
|
| 5328 |
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 5329 |
+
if self.config.vwn:
|
| 5330 |
+
self.hidden_size_expanded = int(config.vwn_n/config.vwn_m * config.hidden_size)
|
| 5331 |
+
self.expand_embedding = DragonLinear(config, config.hidden_size, self.hidden_size_expanded, bias=False)
|
| 5332 |
+
|
| 5333 |
+
if not self.config.vwn:
|
| 5334 |
+
self.layers = nn.ModuleList([DragonBlock(config, layer_idx=i, layer_type=layer) if layer in ['l', 'r', 'd'] else DragonMonoBlock(config, layer_idx=i, layer_type=layer) for i, layer in enumerate(config.layers_config)])
|
| 5335 |
+
else:
|
| 5336 |
+
self.layers = nn.ModuleList([DragonBlock(config, layer_idx=i, layer_type=layer) if layer in ['l', 'r', 'd'] else DragonMonoVirtualBlock(config, layer_idx=i, layer_type=layer) for i, layer in enumerate(config.layers_config)])
|
| 5337 |
|
| 5338 |
if self.config.rope_type_global != '' or self.config.rope_type_local != '':
|
| 5339 |
self.rotary_emb = DragonRotaryEmbedding(config, head_dim=config.head_dim if config.head_dim else (config.expand_factor*config.hidden_size)//config.num_attention_heads, theta=config.rope_theta_local) # only for SWA
|
| 5340 |
else:
|
| 5341 |
self.rotary_emb = None
|
| 5342 |
|
| 5343 |
+
if self.config.vwn:
|
| 5344 |
+
if int(self.config.vwn_n/self.config.vwn_m) == 8:
|
| 5345 |
+
self.gn = torch.nn.GroupNorm(num_groups=self.hidden_size_expanded//config.hidden_size, num_channels=self.hidden_size_expanded, eps=config.norm_epsilon, affine=False) # todo : zcg ?
|
| 5346 |
+
self.reduce_h = DragonLinear(config, self.hidden_size_expanded, config.hidden_size, bias=False)
|
| 5347 |
+
|
| 5348 |
if self.config.final_norm:
|
| 5349 |
self.final_norm = DragonNorm(config, config.hidden_size)
|
| 5350 |
|
| 5351 |
self.gradient_checkpointing = False
|
| 5352 |
self.post_init()
|
| 5353 |
+
|
| 5354 |
def get_input_embeddings(self):
|
| 5355 |
return self.embedding
|
| 5356 |
|
|
|
|
| 5379 |
|
| 5380 |
if inputs_embeds is None:
|
| 5381 |
inputs_embeds = self.embedding(input_ids)
|
| 5382 |
+
if self.config.vwn:
|
| 5383 |
+
inputs_embeds = self.expand_embedding(inputs_embeds) # (B, L, D')
|
| 5384 |
|
| 5385 |
if self.config.patch_level_training:
|
| 5386 |
# (B, KL, D) => (B, L, D) OR (B, L, D) ==> (B, L//K, D)
|
|
|
|
| 5437 |
)
|
| 5438 |
shared_kv = (last_k, last_v)
|
| 5439 |
|
| 5440 |
+
if self.config.vwn:
|
| 5441 |
+
if int(self.config.vwn_n/self.config.vwn_m) == 8:
|
| 5442 |
+
B, L, D = hidden_states.shape
|
| 5443 |
+
hidden_states = self.gn(hidden_states.reshape(-1, D)).view(B, L, D)
|
| 5444 |
+
hidden_states = self.reduce_h(hidden_states) # back to (B, L, D)
|
| 5445 |
+
|
| 5446 |
if self.config.final_norm:
|
| 5447 |
hidden_states = self.final_norm(hidden_states)
|
| 5448 |
|
| 5449 |
if output_hidden_states:
|
| 5450 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 5451 |
|
| 5452 |
+
if past_key_values and not past_key_values.has_previous_state:
|
| 5453 |
+
past_key_values.has_previous_state = True
|
| 5454 |
+
|
| 5455 |
return DragonOutput(
|
| 5456 |
last_hidden_state=hidden_states,
|
| 5457 |
past_key_values=past_key_values if use_cache else None,
|
|
|
|
| 5465 |
self.config = config
|
| 5466 |
self.model = DragonModel(config)
|
| 5467 |
self.vocab_size = config.vocab_size
|
| 5468 |
+
bwd = 1/math.sqrt(config.hidden_size) if config.dataset_type == "hf" else 1/config.hidden_size
|
| 5469 |
+
if config.reduce_lm_head == 0:
|
| 5470 |
+
self.lm_head = DragonLinear(config, config.hidden_size, config.vocab_size, bias=False, alpha_fwd=1/config.hidden_size, alpha_bwd=bwd)
|
| 5471 |
+
else:
|
| 5472 |
+
self.lm_head = nn.Sequential(
|
| 5473 |
+
DragonLinear(config, config.hidden_size, config.reduce_lm_head, bias=False, alpha_fwd=1./math.sqrt(config.reduce_lm_head)),
|
| 5474 |
+
DragonLinear(config, config.reduce_lm_head, config.vocab_size, bias=False, alpha_fwd=1/config.hidden_size, alpha_bwd=bwd),
|
| 5475 |
+
)
|
| 5476 |
self.post_init()
|
| 5477 |
if config.tie_lm_head:
|
| 5478 |
self.lm_head.weight = self.model.embedding.weight
|
| 5479 |
|
| 5480 |
+
if config.init_gpt2:
|
| 5481 |
+
for pn, p in self.named_parameters():
|
| 5482 |
+
if pn.endswith('fc2.weight') or pn.endswith('mixer_proj.weight'):
|
| 5483 |
+
torch.nn.init.normal_(p, mean=0.0, std=config.initializer_range/math.sqrt(2 * len(config.layers_config)))
|
| 5484 |
+
|
| 5485 |
def forward(
|
| 5486 |
self,
|
| 5487 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 5527 |
labels = labels.to(hidden_states.device)
|
| 5528 |
|
| 5529 |
if linear_cross_entropy is None or not self.config.fused_loss_computation:
|
| 5530 |
+
if not self.config.reduce_lm_head:
|
| 5531 |
+
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)[:, slice_indices, :]).float()
|
| 5532 |
+
else:
|
| 5533 |
+
logits = self.lm_head(hidden_states.to(self.lm_head[0].weight.dtype)[:, slice_indices, :]).float()
|
| 5534 |
if not self.config.patch_level_training:
|
| 5535 |
shift_logits = logits[..., :-1, :].contiguous()
|
| 5536 |
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
| 5544 |
loss = loss + F.nll_loss(log_probs, shift_labels[:, i])
|
| 5545 |
loss = loss / self.config.patch_level_training_size
|
| 5546 |
else:
|
| 5547 |
+
assert not self.config.reduce_lm_head
|
| 5548 |
assert not self.config.patch_level_training, "Fused loss computation is not supported with patch-level training."
|
| 5549 |
loss = linear_cross_entropy(
|
| 5550 |
hidden_states[:, slice_indices, :].view(-1, hidden_states.size(-1)),
|
optimizers/Ademamix.py
CHANGED
|
@@ -46,7 +46,7 @@ class AdEMAMix(Optimizer):
|
|
| 46 |
"""
|
| 47 |
|
| 48 |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.999), alpha=8.0,
|
| 49 |
-
beta3_warmup=None, alpha_warmup=None,
|
| 50 |
weight_decay=0):
|
| 51 |
if not 0.0 <= lr:
|
| 52 |
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
@@ -62,6 +62,7 @@ class AdEMAMix(Optimizer):
|
|
| 62 |
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
| 63 |
if not 0.0 <= alpha:
|
| 64 |
raise ValueError("Invalid alpha value: {}".format(alpha))
|
|
|
|
| 65 |
defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup,
|
| 66 |
alpha_warmup=alpha_warmup, weight_decay=weight_decay)
|
| 67 |
super(AdEMAMix, self).__init__(params, defaults)
|
|
@@ -139,6 +140,8 @@ class AdEMAMix(Optimizer):
|
|
| 139 |
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 140 |
|
| 141 |
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
|
|
|
|
|
| 142 |
|
| 143 |
update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom
|
| 144 |
|
|
|
|
| 46 |
"""
|
| 47 |
|
| 48 |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.999), alpha=8.0,
|
| 49 |
+
beta3_warmup=None, alpha_warmup=None, eps=1e-8, normalize_alpha=False,
|
| 50 |
weight_decay=0):
|
| 51 |
if not 0.0 <= lr:
|
| 52 |
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
|
|
| 62 |
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
| 63 |
if not 0.0 <= alpha:
|
| 64 |
raise ValueError("Invalid alpha value: {}".format(alpha))
|
| 65 |
+
self.normalize_alpha = normalize_alpha
|
| 66 |
defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup,
|
| 67 |
alpha_warmup=alpha_warmup, weight_decay=weight_decay)
|
| 68 |
super(AdEMAMix, self).__init__(params, defaults)
|
|
|
|
| 140 |
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 141 |
|
| 142 |
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
| 143 |
+
if self.normalize_alpha:
|
| 144 |
+
denom = denom * (1.0 + alpha)
|
| 145 |
|
| 146 |
update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom
|
| 147 |
|
training_dragon.py
CHANGED
|
@@ -41,7 +41,8 @@ class NanoArgs:
|
|
| 41 |
rope_theta_local: float = 10000.0
|
| 42 |
rope_theta_global: float = 0.0
|
| 43 |
eps_rmsnorm: float = 1e-6
|
| 44 |
-
mlp_expand:
|
|
|
|
| 45 |
fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
|
| 46 |
use_uscaling: bool = False
|
| 47 |
uscaling_tau: float = 0.2
|
|
@@ -58,11 +59,19 @@ class NanoArgs:
|
|
| 58 |
seednorm_type: int = 1
|
| 59 |
seednorm_rank: int = 1
|
| 60 |
mixer_gn: bool = True
|
|
|
|
| 61 |
mlp_linking : bool = False
|
| 62 |
final_norm: bool = True
|
| 63 |
layer_norm_scaling: bool = False # not read when using muP
|
| 64 |
mlp_type: str = "simple" # simple, gated
|
| 65 |
tie_lm_head: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# MoE
|
| 68 |
moe: bool = False
|
|
@@ -117,6 +126,7 @@ class NanoArgs:
|
|
| 117 |
mamba3_remove_conv: bool = True
|
| 118 |
mamba3_is_A_dd: bool = True
|
| 119 |
mamba3_add_trapezoid: bool = True
|
|
|
|
| 120 |
|
| 121 |
# optim
|
| 122 |
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
|
@@ -129,6 +139,8 @@ class NanoArgs:
|
|
| 129 |
adam_beta1: float = 0.9
|
| 130 |
adam_beta2: float = 0.95
|
| 131 |
adam_eps: float = 1e-8
|
|
|
|
|
|
|
| 132 |
warmup_iters: int = 200
|
| 133 |
warmdown_iters: int = 3000
|
| 134 |
warmdown_type: str = "linear" # linear, cosine
|
|
@@ -142,6 +154,8 @@ class NanoArgs:
|
|
| 142 |
second_order_lr: float = 0.68
|
| 143 |
second_order_momentum: float = 0.37
|
| 144 |
second_order_interval: int = 25
|
|
|
|
|
|
|
| 145 |
|
| 146 |
# data
|
| 147 |
vocab_size: int = 50304
|
|
@@ -150,6 +164,7 @@ class NanoArgs:
|
|
| 150 |
intra_doc_masking: bool = False
|
| 151 |
input_bin: Optional[str] = None
|
| 152 |
input_val_bin: Optional[str] = None
|
|
|
|
| 153 |
|
| 154 |
# evaluation and logging
|
| 155 |
val_loss_every: int = 125
|
|
@@ -170,7 +185,34 @@ class NanoArgs:
|
|
| 170 |
# used during training
|
| 171 |
slw_window: int = 0
|
| 172 |
|
| 173 |
-
def _peek_data_shard(filename):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
with open(filename, "rb") as f:
|
| 175 |
header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
|
| 176 |
if header[0] != 20240520:
|
|
@@ -182,25 +224,22 @@ def _peek_data_shard(filename):
|
|
| 182 |
ntok = int(header[2])
|
| 183 |
return ntok
|
| 184 |
|
| 185 |
-
def
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
# memmap the token payload directly (uint16) after the 256*4B header
|
| 192 |
-
tokens = np.memmap(filename, dtype=np.uint16, mode="r", offset=256 * 4, shape=(ntok,))
|
| 193 |
-
assert tokens.size == ntok, "number of tokens read does not match header?"
|
| 194 |
-
return tokens
|
| 195 |
|
| 196 |
class DistributedDataLoader:
|
| 197 |
-
def __init__(self, filename_pattern, intra_doc_masking,B, T, process_rank, num_processes, bos_id):
|
| 198 |
self.process_rank = process_rank
|
| 199 |
self.num_processes = num_processes
|
| 200 |
self.intra_doc_masking = intra_doc_masking
|
| 201 |
self.bos_id = bos_id
|
| 202 |
self.B = B # micro batch size
|
| 203 |
self.T = T
|
|
|
|
| 204 |
|
| 205 |
# glob files that match the pattern
|
| 206 |
self.files = sorted(glob.glob(filename_pattern))
|
|
@@ -210,7 +249,7 @@ class DistributedDataLoader:
|
|
| 210 |
ntok_total = 0
|
| 211 |
self.shard_ntoks = []
|
| 212 |
for fname in self.files:
|
| 213 |
-
shard_ntok = _peek_data_shard(fname)
|
| 214 |
#print(f"shard {fname} has {shard_ntok} tokens")
|
| 215 |
assert shard_ntok >= num_processes * B * T + 1
|
| 216 |
self.shard_ntoks.append(shard_ntok)
|
|
@@ -223,12 +262,12 @@ class DistributedDataLoader:
|
|
| 223 |
def reset(self, shard=0):
|
| 224 |
self.current_shard = shard
|
| 225 |
self.current_position = self.process_rank * self.B * self.T
|
| 226 |
-
self.tokens = _load_data_shard(self.files[self.current_shard])
|
| 227 |
|
| 228 |
def advance(self): # advance to next data shard
|
| 229 |
self.current_shard = (self.current_shard + 1) % len(self.files)
|
| 230 |
self.current_position = self.process_rank * self.B * self.T
|
| 231 |
-
self.tokens = _load_data_shard(self.files[self.current_shard])
|
| 232 |
|
| 233 |
if self.process_rank == 0:
|
| 234 |
shard_tokens = self.shard_ntoks[self.current_shard]
|
|
@@ -282,30 +321,38 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
|
|
| 282 |
groups, seen = [], set()
|
| 283 |
id2name = {id(p): n for n, p in model.named_parameters()}
|
| 284 |
|
| 285 |
-
for mod in model.
|
| 286 |
if isinstance(mod, nn.Linear):
|
| 287 |
pname = id2name.get(id(mod.weight), "")
|
| 288 |
is_scalar = getattr(mod, "is_scalar_weight", False)
|
| 289 |
fan_in = mod.weight.shape[1]
|
| 290 |
-
scale = 1 / math.sqrt(fan_in)
|
| 291 |
if "lm_head" in pname:
|
|
|
|
| 292 |
lr_scaled = base_lr_head
|
| 293 |
wd_scaled = 0.0
|
|
|
|
| 294 |
elif is_scalar:
|
|
|
|
| 295 |
lr_scaled = base_lr_scalar
|
| 296 |
wd_scaled = 0.0
|
|
|
|
| 297 |
else:
|
|
|
|
| 298 |
lr_scaled = base_lr_hidden * scale
|
| 299 |
wd_scaled = wd / lr_scaled
|
|
|
|
| 300 |
|
| 301 |
groups.append({"params": [mod.weight], "lr": lr_scaled, "weight_decay": wd_scaled})
|
| 302 |
seen.add(mod.weight)
|
| 303 |
|
|
|
|
|
|
|
| 304 |
if mod.bias is not None:
|
|
|
|
| 305 |
groups.append({"params": [mod.bias], "lr": base_lr_scalar, "weight_decay": 0.0})
|
| 306 |
seen.add(mod.bias)
|
| 307 |
|
| 308 |
-
for p in model.
|
| 309 |
if p in seen:
|
| 310 |
continue
|
| 311 |
pname = id2name.get(id(p), "<unnamed>")
|
|
@@ -318,11 +365,15 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
|
|
| 318 |
lr_scaled = base_lr_scalar
|
| 319 |
|
| 320 |
wd_scaled = 0.
|
|
|
|
| 321 |
if getattr(p, "requires_weight_decay", False):
|
| 322 |
wd_scaled = wd / lr_scaled
|
|
|
|
| 323 |
|
| 324 |
groups.append({"params": [p], "lr": lr_scaled, "weight_decay": wd_scaled})
|
| 325 |
|
|
|
|
|
|
|
| 326 |
return groups
|
| 327 |
|
| 328 |
args = tyro.cli(NanoArgs)
|
|
@@ -341,6 +392,9 @@ if args.mlp_type == "gated":
|
|
| 341 |
print("problem: gated MLP with MoE is not supported, because we use FA backend")
|
| 342 |
exit(0)
|
| 343 |
|
|
|
|
|
|
|
|
|
|
| 344 |
# set up DDP (distributed data parallel).
|
| 345 |
assert torch.cuda.is_available()
|
| 346 |
dist.init_process_group(
|
|
@@ -434,13 +488,22 @@ tokenizer = transformers.AutoTokenizer.from_pretrained("/leonardo_work/BOOST_LCu
|
|
| 434 |
# load dataloaders.
|
| 435 |
#if args.patch_level_training:
|
| 436 |
# assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
|
| 437 |
-
train_loader = DistributedDataLoader(args.input_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id)
|
| 438 |
-
val_loader = DistributedDataLoader(args.input_val_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id)
|
| 439 |
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
|
| 440 |
print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
|
| 441 |
|
| 442 |
# load model.
|
| 443 |
config_hf = DragonConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
tie_lm_head=args.tie_lm_head,
|
| 445 |
mlp_type=args.mlp_type,
|
| 446 |
layer_norm_scaling=args.layer_norm_scaling,
|
|
@@ -452,6 +515,7 @@ config_hf = DragonConfig(
|
|
| 452 |
mamba3_remove_conv=args.mamba3_remove_conv,
|
| 453 |
mamba3_is_A_dd=args.mamba3_is_A_dd,
|
| 454 |
mamba3_add_trapezoid=args.mamba3_add_trapezoid,
|
|
|
|
| 455 |
moe=args.moe,
|
| 456 |
moe_num_routed_experts=args.moe_num_routed_experts,
|
| 457 |
moe_routed_scaling_factor=args.moe_routed_scaling_factor,
|
|
@@ -466,6 +530,7 @@ config_hf = DragonConfig(
|
|
| 466 |
shrink_qk_da=args.shrink_qk_da,
|
| 467 |
shrink_qk_gdn=args.shrink_qk_gdn,
|
| 468 |
mixer_gn=args.mixer_gn,
|
|
|
|
| 469 |
kda_allow_neg_eigval=args.kda_allow_neg_eigval,
|
| 470 |
kda_num_v_heads=args.kda_num_v_heads,
|
| 471 |
seednorm_wd=args.seednorm_wd,
|
|
@@ -508,7 +573,7 @@ config_hf = DragonConfig(
|
|
| 508 |
max_position_embeddings=args.sequence_length,
|
| 509 |
use_uscaling=args.use_uscaling,
|
| 510 |
hidden_size=args.d_model,
|
| 511 |
-
intermediate_size=args.d_model * args.mlp_expand,
|
| 512 |
expand_factor=args.expand_factor,
|
| 513 |
layers_config=args.layers_config,
|
| 514 |
num_attention_heads=args.n_heads,
|
|
@@ -535,18 +600,14 @@ else:
|
|
| 535 |
model = model.cuda()
|
| 536 |
print0(model)
|
| 537 |
|
| 538 |
-
"""# check here that the init std is as expected: # TODO TEMPORARY
|
| 539 |
with torch.no_grad():
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
lstd = model.model.layers[0].lin_attn.qkv_conv1d.weight.std().item()
|
| 549 |
-
print0(f"Model first layer conv QKV weight init std: {lstd:.6f} (expected {args.init_std})")"""
|
| 550 |
|
| 551 |
# count params. (total & active)
|
| 552 |
num_params = sum(p.numel() for p in model.parameters())
|
|
@@ -570,7 +631,7 @@ ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
|
|
| 570 |
|
| 571 |
if args.intra_doc_masking:
|
| 572 |
print0("!!! Using intra-document masking !!!")
|
| 573 |
-
print0("It is only compatible with GDN (conv+chunk), DA and GDTPA layers. For DA/GDTPA, kv shift is also compatible. All other config will not have intra-doc masking support!!")
|
| 574 |
|
| 575 |
# load optimizers & schedulers.
|
| 576 |
if args.use_uscaling:
|
|
@@ -587,18 +648,38 @@ if args.use_uscaling:
|
|
| 587 |
optimizer = torch.optim.AdamW(param_list, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
|
| 588 |
elif args.optim == "ademamix":
|
| 589 |
from .optimizers.Ademamix import AdEMAMix
|
| 590 |
-
beta3_warmup =
|
| 591 |
-
|
|
|
|
| 592 |
else:
|
| 593 |
raise ValueError(f"Unknown optimizer for unit scaling: {args.optim}")
|
| 594 |
else:
|
| 595 |
if args.optim == "adamw":
|
| 596 |
-
optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
elif args.optim == "ademamix":
|
| 598 |
from .optimizers.Ademamix import AdEMAMix
|
| 599 |
|
| 600 |
-
beta3_warmup =
|
| 601 |
-
|
|
|
|
| 602 |
else:
|
| 603 |
raise ValueError(f"Unknown Optimizer: {args.optim}")
|
| 604 |
if args.second_order_optim == "snoo":
|
|
@@ -624,7 +705,7 @@ def get_lr_wsd(num_iterations, warmup_iters, warmdown_iters, it):
|
|
| 624 |
if args.warmdown_type == "linear":
|
| 625 |
sched_func = partial(get_lr_wsd, args.total_iterations, args.warmup_iters, args.warmdown_iters)
|
| 626 |
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, sched_func) for opt in optimizers]
|
| 627 |
-
elif args.warmdown_type == "cosine":
|
| 628 |
sched = get_wsd_schedule(
|
| 629 |
optimizers[0],
|
| 630 |
num_warmup_steps=args.warmup_iters,
|
|
@@ -632,7 +713,7 @@ elif args.warmdown_type == "cosine":
|
|
| 632 |
num_training_steps=args.total_iterations,
|
| 633 |
min_lr_ratio=0.,
|
| 634 |
warmup_type='linear',
|
| 635 |
-
decay_type=
|
| 636 |
)
|
| 637 |
schedulers = [sched]
|
| 638 |
else:
|
|
@@ -721,8 +802,11 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
|
|
| 721 |
# save model & tokenizer to make evaluation easier.
|
| 722 |
tokenizer.save_pretrained(save_dir)
|
| 723 |
state_dict_bf16 = {k: v.detach().to(torch.bfloat16).cpu() for k, v in uncompiled_model.state_dict().items()}
|
|
|
|
|
|
|
| 724 |
uncompiled_model.config.torch_dtype = torch.bfloat16
|
| 725 |
uncompiled_model.save_pretrained(save_dir, safe_serialization=True, state_dict=state_dict_bf16)
|
|
|
|
| 726 |
# save training state.
|
| 727 |
train_state = dict(
|
| 728 |
iteration=iter_,
|
|
@@ -757,6 +841,18 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
|
|
| 757 |
(loss / accumulation_steps).backward()
|
| 758 |
else:
|
| 759 |
(loss / accumulation_steps).backward() # just sync on the last step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
# clip those gradients.
|
| 761 |
if args.grad_norm_clip is not None:
|
| 762 |
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.grad_norm_clip, foreach=True)
|
|
@@ -771,13 +867,26 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
|
|
| 771 |
# null those gradients.
|
| 772 |
model.zero_grad(set_to_none=True)
|
| 773 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
# ----------- LOGGING SECTION -----------
|
| 775 |
approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
|
| 776 |
avg_step_time = approx_training_time_ms / (iter_ + 1 - WARMUP_SKIP) if iter_ >= start_iter+WARMUP_SKIP else 0
|
| 777 |
extra = " ".join(f"{k}:{v}" for k, v in (to_log or {}).items())
|
| 778 |
print0(f"iteration:{iter_+1:0{len(str(start_iter+args.total_iterations))}d}/{args.total_iterations} train_loss:{train_loss.item():.4f} lr: {schedulers[0].get_last_lr()[0]:.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms {extra}")
|
| 779 |
if master_process:
|
| 780 |
-
wandb.log({'train_loss': train_loss.item(), 'step_avg_time': avg_step_time, **{f'lr_{i}': sched.get_last_lr()[0] for i, sched in enumerate(schedulers)}, 'grad_norm': grad_norm.item(), **to_log}, step=iter_)
|
| 781 |
|
| 782 |
print0(f"peak memory consumption during training: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
|
| 783 |
print0("Training complete.")
|
|
|
|
| 41 |
rope_theta_local: float = 10000.0
|
| 42 |
rope_theta_global: float = 0.0
|
| 43 |
eps_rmsnorm: float = 1e-6
|
| 44 |
+
mlp_expand: float = 4. # expand factor for MLP
|
| 45 |
+
intermediate_size: Optional[int] = None
|
| 46 |
fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
|
| 47 |
use_uscaling: bool = False
|
| 48 |
uscaling_tau: float = 0.2
|
|
|
|
| 59 |
seednorm_type: int = 1
|
| 60 |
seednorm_rank: int = 1
|
| 61 |
mixer_gn: bool = True
|
| 62 |
+
gate_before_norm: bool = True
|
| 63 |
mlp_linking : bool = False
|
| 64 |
final_norm: bool = True
|
| 65 |
layer_norm_scaling: bool = False # not read when using muP
|
| 66 |
mlp_type: str = "simple" # simple, gated
|
| 67 |
tie_lm_head: bool = False
|
| 68 |
+
legacy_gate: bool = False
|
| 69 |
+
vwn: bool = False
|
| 70 |
+
vwn_m: int = 2
|
| 71 |
+
vwn_n: int = 3
|
| 72 |
+
vwn_wd_alpha_beta: bool = False
|
| 73 |
+
vwn_dynamic: bool = True
|
| 74 |
+
reduce_lm_head: int = 0
|
| 75 |
|
| 76 |
# MoE
|
| 77 |
moe: bool = False
|
|
|
|
| 126 |
mamba3_remove_conv: bool = True
|
| 127 |
mamba3_is_A_dd: bool = True
|
| 128 |
mamba3_add_trapezoid: bool = True
|
| 129 |
+
mamba3_postgate_norm: bool = False # only works if legacy_gate is True!!
|
| 130 |
|
| 131 |
# optim
|
| 132 |
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
|
|
|
| 139 |
adam_beta1: float = 0.9
|
| 140 |
adam_beta2: float = 0.95
|
| 141 |
adam_eps: float = 1e-8
|
| 142 |
+
alpha_normalize: bool = False # whether to normalize update by (1+alpha) in AdEMAMix
|
| 143 |
+
alpha_ademamix: float = 8.0
|
| 144 |
warmup_iters: int = 200
|
| 145 |
warmdown_iters: int = 3000
|
| 146 |
warmdown_type: str = "linear" # linear, cosine
|
|
|
|
| 154 |
second_order_lr: float = 0.68
|
| 155 |
second_order_momentum: float = 0.37
|
| 156 |
second_order_interval: int = 25
|
| 157 |
+
init_gpt2: bool = False
|
| 158 |
+
wnorm: bool = False # as in nemotron-flash (2511.18890)
|
| 159 |
|
| 160 |
# data
|
| 161 |
vocab_size: int = 50304
|
|
|
|
| 164 |
intra_doc_masking: bool = False
|
| 165 |
input_bin: Optional[str] = None
|
| 166 |
input_val_bin: Optional[str] = None
|
| 167 |
+
dataset_type: str = "hf" # hf, mg
|
| 168 |
|
| 169 |
# evaluation and logging
|
| 170 |
val_loss_every: int = 125
|
|
|
|
| 185 |
# used during training
|
| 186 |
slw_window: int = 0
|
| 187 |
|
| 188 |
+
def _peek_data_shard(filename, dataset_type='hf'):
|
| 189 |
+
if dataset_type == 'hf':
|
| 190 |
+
return _peek_hf_shard(filename)
|
| 191 |
+
elif dataset_type == 'mg':
|
| 192 |
+
return _peek_mg_shard(filename)
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"unknown dataset type: {dataset_type}")
|
| 195 |
+
|
| 196 |
+
def _load_data_shard(filename, dataset_type='hf'):
|
| 197 |
+
if dataset_type == 'hf':
|
| 198 |
+
return _load_hf_shard(filename)
|
| 199 |
+
elif dataset_type == 'mg':
|
| 200 |
+
return _load_mg_shard(filename)
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"unknown dataset type: {dataset_type}")
|
| 203 |
+
|
| 204 |
+
def _load_hf_shard(filename):
|
| 205 |
+
with open(filename, "rb") as f:
|
| 206 |
+
header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
|
| 207 |
+
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
|
| 208 |
+
assert header[1] == 1, "unsupported version"
|
| 209 |
+
ntok = int(header[2])
|
| 210 |
+
# memmap the token payload directly (uint16) after the 256*4B header
|
| 211 |
+
tokens = np.memmap(filename, dtype=np.uint16, mode="r", offset=256 * 4, shape=(ntok,))
|
| 212 |
+
assert tokens.size == ntok, "number of tokens read does not match header?"
|
| 213 |
+
return tokens
|
| 214 |
+
|
| 215 |
+
def _peek_hf_shard(filename):
|
| 216 |
with open(filename, "rb") as f:
|
| 217 |
header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
|
| 218 |
if header[0] != 20240520:
|
|
|
|
| 224 |
ntok = int(header[2])
|
| 225 |
return ntok
|
| 226 |
|
| 227 |
+
def _peek_mg_shard(filename):
|
| 228 |
+
tokens = np.memmap(filename, dtype=np.uint16, mode="r")
|
| 229 |
+
return int(tokens.size)
|
| 230 |
+
|
| 231 |
+
def _load_mg_shard(filename):
|
| 232 |
+
return np.memmap(filename, dtype=np.uint16, mode="r")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
class DistributedDataLoader:
|
| 235 |
+
def __init__(self, filename_pattern, intra_doc_masking,B, T, process_rank, num_processes, bos_id, dataset_type='hf'):
|
| 236 |
self.process_rank = process_rank
|
| 237 |
self.num_processes = num_processes
|
| 238 |
self.intra_doc_masking = intra_doc_masking
|
| 239 |
self.bos_id = bos_id
|
| 240 |
self.B = B # micro batch size
|
| 241 |
self.T = T
|
| 242 |
+
self.dataset_type = dataset_type
|
| 243 |
|
| 244 |
# glob files that match the pattern
|
| 245 |
self.files = sorted(glob.glob(filename_pattern))
|
|
|
|
| 249 |
ntok_total = 0
|
| 250 |
self.shard_ntoks = []
|
| 251 |
for fname in self.files:
|
| 252 |
+
shard_ntok = _peek_data_shard(fname, dataset_type=self.dataset_type)
|
| 253 |
#print(f"shard {fname} has {shard_ntok} tokens")
|
| 254 |
assert shard_ntok >= num_processes * B * T + 1
|
| 255 |
self.shard_ntoks.append(shard_ntok)
|
|
|
|
| 262 |
def reset(self, shard=0):
|
| 263 |
self.current_shard = shard
|
| 264 |
self.current_position = self.process_rank * self.B * self.T
|
| 265 |
+
self.tokens = _load_data_shard(self.files[self.current_shard], dataset_type=self.dataset_type)
|
| 266 |
|
| 267 |
def advance(self): # advance to next data shard
|
| 268 |
self.current_shard = (self.current_shard + 1) % len(self.files)
|
| 269 |
self.current_position = self.process_rank * self.B * self.T
|
| 270 |
+
self.tokens = _load_data_shard(self.files[self.current_shard], dataset_type=self.dataset_type)
|
| 271 |
|
| 272 |
if self.process_rank == 0:
|
| 273 |
shard_tokens = self.shard_ntoks[self.current_shard]
|
|
|
|
| 321 |
groups, seen = [], set()
|
| 322 |
id2name = {id(p): n for n, p in model.named_parameters()}
|
| 323 |
|
| 324 |
+
for name, mod in model.named_modules():
|
| 325 |
if isinstance(mod, nn.Linear):
|
| 326 |
pname = id2name.get(id(mod.weight), "")
|
| 327 |
is_scalar = getattr(mod, "is_scalar_weight", False)
|
| 328 |
fan_in = mod.weight.shape[1]
|
|
|
|
| 329 |
if "lm_head" in pname:
|
| 330 |
+
scale = 1
|
| 331 |
lr_scaled = base_lr_head
|
| 332 |
wd_scaled = 0.0
|
| 333 |
+
wd_mult = 0.0
|
| 334 |
elif is_scalar:
|
| 335 |
+
scale = 1
|
| 336 |
lr_scaled = base_lr_scalar
|
| 337 |
wd_scaled = 0.0
|
| 338 |
+
wd_mult = 0.0
|
| 339 |
else:
|
| 340 |
+
scale = 1 / math.sqrt(fan_in)
|
| 341 |
lr_scaled = base_lr_hidden * scale
|
| 342 |
wd_scaled = wd / lr_scaled
|
| 343 |
+
wd_mult = 1/lr_scaled
|
| 344 |
|
| 345 |
groups.append({"params": [mod.weight], "lr": lr_scaled, "weight_decay": wd_scaled})
|
| 346 |
seen.add(mod.weight)
|
| 347 |
|
| 348 |
+
print(f"param {name}.weight | shape {mod.weight.shape} | scale {scale} | wd_mult={wd_mult:.3e}")
|
| 349 |
+
|
| 350 |
if mod.bias is not None:
|
| 351 |
+
assert False
|
| 352 |
groups.append({"params": [mod.bias], "lr": base_lr_scalar, "weight_decay": 0.0})
|
| 353 |
seen.add(mod.bias)
|
| 354 |
|
| 355 |
+
for name, p in model.named_parameters():
|
| 356 |
if p in seen:
|
| 357 |
continue
|
| 358 |
pname = id2name.get(id(p), "<unnamed>")
|
|
|
|
| 365 |
lr_scaled = base_lr_scalar
|
| 366 |
|
| 367 |
wd_scaled = 0.
|
| 368 |
+
wd_mult = 0.
|
| 369 |
if getattr(p, "requires_weight_decay", False):
|
| 370 |
wd_scaled = wd / lr_scaled
|
| 371 |
+
wd_mult = 1/lr_scaled
|
| 372 |
|
| 373 |
groups.append({"params": [p], "lr": lr_scaled, "weight_decay": wd_scaled})
|
| 374 |
|
| 375 |
+
print(f"param {name} | shape {p.shape} | scale {1.} | wd_mult={wd_mult:.3e}")
|
| 376 |
+
|
| 377 |
return groups
|
| 378 |
|
| 379 |
args = tyro.cli(NanoArgs)
|
|
|
|
| 392 |
print("problem: gated MLP with MoE is not supported, because we use FA backend")
|
| 393 |
exit(0)
|
| 394 |
|
| 395 |
+
if args.legacy_gate:
|
| 396 |
+
assert not args.gate_gdn, "legacy_gate is not compatible with gate_gdn."
|
| 397 |
+
|
| 398 |
# set up DDP (distributed data parallel).
|
| 399 |
assert torch.cuda.is_available()
|
| 400 |
dist.init_process_group(
|
|
|
|
| 488 |
# load dataloaders.
|
| 489 |
#if args.patch_level_training:
|
| 490 |
# assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
|
| 491 |
+
train_loader = DistributedDataLoader(args.input_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id, args.dataset_type)
|
| 492 |
+
val_loader = DistributedDataLoader(args.input_val_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id, args.dataset_type)
|
| 493 |
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
|
| 494 |
print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
|
| 495 |
|
| 496 |
# load model.
|
| 497 |
config_hf = DragonConfig(
|
| 498 |
+
reduce_lm_head=args.reduce_lm_head,
|
| 499 |
+
dataset_type=args.dataset_type,
|
| 500 |
+
vwn=args.vwn,
|
| 501 |
+
vwn_m=args.vwn_m,
|
| 502 |
+
vwn_n=args.vwn_n,
|
| 503 |
+
vwn_wd_alpha_beta=args.vwn_wd_alpha_beta,
|
| 504 |
+
vwn_dynamic=args.vwn_dynamic,
|
| 505 |
+
legacy_gate=args.legacy_gate,
|
| 506 |
+
init_gpt2=args.init_gpt2,
|
| 507 |
tie_lm_head=args.tie_lm_head,
|
| 508 |
mlp_type=args.mlp_type,
|
| 509 |
layer_norm_scaling=args.layer_norm_scaling,
|
|
|
|
| 515 |
mamba3_remove_conv=args.mamba3_remove_conv,
|
| 516 |
mamba3_is_A_dd=args.mamba3_is_A_dd,
|
| 517 |
mamba3_add_trapezoid=args.mamba3_add_trapezoid,
|
| 518 |
+
mamba3_postgate_norm=args.mamba3_postgate_norm,
|
| 519 |
moe=args.moe,
|
| 520 |
moe_num_routed_experts=args.moe_num_routed_experts,
|
| 521 |
moe_routed_scaling_factor=args.moe_routed_scaling_factor,
|
|
|
|
| 530 |
shrink_qk_da=args.shrink_qk_da,
|
| 531 |
shrink_qk_gdn=args.shrink_qk_gdn,
|
| 532 |
mixer_gn=args.mixer_gn,
|
| 533 |
+
gate_before_norm=args.gate_before_norm,
|
| 534 |
kda_allow_neg_eigval=args.kda_allow_neg_eigval,
|
| 535 |
kda_num_v_heads=args.kda_num_v_heads,
|
| 536 |
seednorm_wd=args.seednorm_wd,
|
|
|
|
| 573 |
max_position_embeddings=args.sequence_length,
|
| 574 |
use_uscaling=args.use_uscaling,
|
| 575 |
hidden_size=args.d_model,
|
| 576 |
+
intermediate_size=int(args.d_model * args.mlp_expand) if args.intermediate_size is None else args.intermediate_size,
|
| 577 |
expand_factor=args.expand_factor,
|
| 578 |
layers_config=args.layers_config,
|
| 579 |
num_attention_heads=args.n_heads,
|
|
|
|
| 600 |
model = model.cuda()
|
| 601 |
print0(model)
|
| 602 |
|
|
|
|
| 603 |
with torch.no_grad():
|
| 604 |
+
for name, p in model.named_parameters():
|
| 605 |
+
if p is None or p.numel() == 0:
|
| 606 |
+
continue
|
| 607 |
+
t = p.detach().float()
|
| 608 |
+
mean = t.mean().item()
|
| 609 |
+
std = t.std(unbiased=False).item()
|
| 610 |
+
print0(f"{name:60s} shape={tuple(p.shape)} mean={mean:+.4e} std={std:.4e}")
|
|
|
|
|
|
|
|
|
|
| 611 |
|
| 612 |
# count params. (total & active)
|
| 613 |
num_params = sum(p.numel() for p in model.parameters())
|
|
|
|
| 631 |
|
| 632 |
if args.intra_doc_masking:
|
| 633 |
print0("!!! Using intra-document masking !!!")
|
| 634 |
+
print0("It is only compatible with GDN (conv+chunk), KDA (conv+chunk), DA and GDTPA layers. For DA/GDTPA, kv shift is also compatible. All other config will not have intra-doc masking support!!")
|
| 635 |
|
| 636 |
# load optimizers & schedulers.
|
| 637 |
if args.use_uscaling:
|
|
|
|
| 648 |
optimizer = torch.optim.AdamW(param_list, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
|
| 649 |
elif args.optim == "ademamix":
|
| 650 |
from .optimizers.Ademamix import AdEMAMix
|
| 651 |
+
beta3_warmup = args.total_iterations
|
| 652 |
+
alpha_warmup = args.total_iterations
|
| 653 |
+
optimizer = AdEMAMix(param_list, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, normalize_alpha=args.alpha_normalize, alpha=args.alpha_ademamix, weight_decay=args.weight_decay)
|
| 654 |
else:
|
| 655 |
raise ValueError(f"Unknown optimizer for unit scaling: {args.optim}")
|
| 656 |
else:
|
| 657 |
if args.optim == "adamw":
|
| 658 |
+
#optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
|
| 659 |
+
decay_params = []
|
| 660 |
+
no_decay_params = []
|
| 661 |
+
for name, p in raw_model.named_parameters():
|
| 662 |
+
if not p.requires_grad:
|
| 663 |
+
continue
|
| 664 |
+
if getattr(p, "_no_weight_decay", False):
|
| 665 |
+
no_decay_params.append(p)
|
| 666 |
+
else:
|
| 667 |
+
decay_params.append(p)
|
| 668 |
+
optimizer = torch.optim.AdamW(
|
| 669 |
+
[
|
| 670 |
+
{"params": decay_params, "weight_decay": args.weight_decay},
|
| 671 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 672 |
+
],
|
| 673 |
+
lr=args.learning_rate,
|
| 674 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 675 |
+
eps=args.adam_eps,
|
| 676 |
+
)
|
| 677 |
elif args.optim == "ademamix":
|
| 678 |
from .optimizers.Ademamix import AdEMAMix
|
| 679 |
|
| 680 |
+
beta3_warmup = args.total_iterations
|
| 681 |
+
alpha_warmup = args.total_iterations
|
| 682 |
+
optimizer = AdEMAMix(raw_model.parameters(), lr=args.learning_rate, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, normalize_alpha=args.alpha_normalize, alpha=args.alpha_ademamix, weight_decay=args.weight_decay)
|
| 683 |
else:
|
| 684 |
raise ValueError(f"Unknown Optimizer: {args.optim}")
|
| 685 |
if args.second_order_optim == "snoo":
|
|
|
|
| 705 |
if args.warmdown_type == "linear":
|
| 706 |
sched_func = partial(get_lr_wsd, args.total_iterations, args.warmup_iters, args.warmdown_iters)
|
| 707 |
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, sched_func) for opt in optimizers]
|
| 708 |
+
elif args.warmdown_type == "cosine" or args.warmdown_type == "1-sqrt":
|
| 709 |
sched = get_wsd_schedule(
|
| 710 |
optimizers[0],
|
| 711 |
num_warmup_steps=args.warmup_iters,
|
|
|
|
| 713 |
num_training_steps=args.total_iterations,
|
| 714 |
min_lr_ratio=0.,
|
| 715 |
warmup_type='linear',
|
| 716 |
+
decay_type=args.warmdown_type,
|
| 717 |
)
|
| 718 |
schedulers = [sched]
|
| 719 |
else:
|
|
|
|
| 802 |
# save model & tokenizer to make evaluation easier.
|
| 803 |
tokenizer.save_pretrained(save_dir)
|
| 804 |
state_dict_bf16 = {k: v.detach().to(torch.bfloat16).cpu() for k, v in uncompiled_model.state_dict().items()}
|
| 805 |
+
idm_og = uncompiled_model.config.intra_doc_masking
|
| 806 |
+
uncompiled_model.config.intra_doc_masking = False
|
| 807 |
uncompiled_model.config.torch_dtype = torch.bfloat16
|
| 808 |
uncompiled_model.save_pretrained(save_dir, safe_serialization=True, state_dict=state_dict_bf16)
|
| 809 |
+
uncompiled_model.config.intra_doc_masking = idm_og
|
| 810 |
# save training state.
|
| 811 |
train_state = dict(
|
| 812 |
iteration=iter_,
|
|
|
|
| 841 |
(loss / accumulation_steps).backward()
|
| 842 |
else:
|
| 843 |
(loss / accumulation_steps).backward() # just sync on the last step
|
| 844 |
+
individual_grad_norms = {}
|
| 845 |
+
"""# Calculate individual param norms
|
| 846 |
+
# We use 'raw_model' to avoid 'module.' or '_orig_mod.' prefixes in wandb
|
| 847 |
+
individual_grad_norms = {}
|
| 848 |
+
# Only calculate on master process to save time, and maybe throttle frequency (e.g., every 10 steps)
|
| 849 |
+
# If you want it every step, remove the (iter_ % 10 == 0) check.
|
| 850 |
+
if master_process and (iter_ % 50 == 0):
|
| 851 |
+
for name, p in raw_model.named_parameters():
|
| 852 |
+
if p.grad is not None:
|
| 853 |
+
# Calculate L2 norm of the gradient
|
| 854 |
+
param_norm = p.grad.detach().data.norm(2).item()
|
| 855 |
+
individual_grad_norms[f"grad_norm/{name}"] = param_norm"""
|
| 856 |
# clip those gradients.
|
| 857 |
if args.grad_norm_clip is not None:
|
| 858 |
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.grad_norm_clip, foreach=True)
|
|
|
|
| 867 |
# null those gradients.
|
| 868 |
model.zero_grad(set_to_none=True)
|
| 869 |
|
| 870 |
+
# Wnorm
|
| 871 |
+
if args.wnorm:
|
| 872 |
+
with torch.no_grad():
|
| 873 |
+
for m in model.modules():
|
| 874 |
+
if getattr(m, "norm_case_1", False):
|
| 875 |
+
W = getattr(m, "weight", None)
|
| 876 |
+
denom = W.float().norm(p=2, dim=1, keepdim=True).clamp_min(1e-8).to(W.dtype)
|
| 877 |
+
W.div_(denom)
|
| 878 |
+
elif getattr(m, "norm_case_2", False):
|
| 879 |
+
W = getattr(m, "weight", None)
|
| 880 |
+
denom = W.float().norm(p=2, dim=0, keepdim=True).clamp_min(1e-8).to(W.dtype)
|
| 881 |
+
W.div_(denom)
|
| 882 |
+
|
| 883 |
# ----------- LOGGING SECTION -----------
|
| 884 |
approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
|
| 885 |
avg_step_time = approx_training_time_ms / (iter_ + 1 - WARMUP_SKIP) if iter_ >= start_iter+WARMUP_SKIP else 0
|
| 886 |
extra = " ".join(f"{k}:{v}" for k, v in (to_log or {}).items())
|
| 887 |
print0(f"iteration:{iter_+1:0{len(str(start_iter+args.total_iterations))}d}/{args.total_iterations} train_loss:{train_loss.item():.4f} lr: {schedulers[0].get_last_lr()[0]:.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms {extra}")
|
| 888 |
if master_process:
|
| 889 |
+
wandb.log({'train_loss': train_loss.item(), 'step_avg_time': avg_step_time, **{f'lr_{i}': sched.get_last_lr()[0] for i, sched in enumerate(schedulers)}, 'grad_norm': grad_norm.item(), **to_log, **individual_grad_norms}, step=iter_)
|
| 890 |
|
| 891 |
print0(f"peak memory consumption during training: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
|
| 892 |
print0("Training complete.")
|