Commit ·
58b82e2
1
Parent(s): bc8288b
mamba3 flags | mamba3 default state size to 128, headdim to 64 | mamba2 | fix mamba3 mimo (JG) | (fake) moe | intra doc maskiiiing (with SS) | seednorm tests | coord checks
Browse files- configuration_dragon.py +33 -1
- coordcheck_utils.py +472 -0
- coordchecking_dragon.py +154 -0
- inspecting_dragon.py +55 -13
- modeling_dragon.py +425 -74
- training_dragon.py +86 -15
configuration_dragon.py
CHANGED
|
@@ -92,6 +92,21 @@ class DragonConfig(PretrainedConfig):
|
|
| 92 |
|
| 93 |
def __init__(
|
| 94 |
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
mla_kv_rank: int = 128,
|
| 96 |
shrink_qk_da: int = 2,
|
| 97 |
shrink_qk_gdn: int = 2,
|
|
@@ -119,6 +134,7 @@ class DragonConfig(PretrainedConfig):
|
|
| 119 |
scalable_softmax: bool = True,
|
| 120 |
resformer: bool = False,
|
| 121 |
mamba_mimo_dim : int = 4,
|
|
|
|
| 122 |
gate_type: str = "elementwise",
|
| 123 |
gate_act: str = "silu",
|
| 124 |
gate_attn: bool = False,
|
|
@@ -163,7 +179,7 @@ class DragonConfig(PretrainedConfig):
|
|
| 163 |
rope_type_local="rope",
|
| 164 |
rope_type_global="",
|
| 165 |
rope_theta_local=163.,
|
| 166 |
-
rope_theta_global=
|
| 167 |
uscaling_tau=0.2,
|
| 168 |
attention_dropout=0.,
|
| 169 |
hidden_dropout=0.,
|
|
@@ -176,6 +192,21 @@ class DragonConfig(PretrainedConfig):
|
|
| 176 |
mlp_linking=False,
|
| 177 |
**kwargs,
|
| 178 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
self.mla_kv_rank = mla_kv_rank
|
| 180 |
self.shrink_qk_da = shrink_qk_da
|
| 181 |
self.shrink_qk_gdn = shrink_qk_gdn
|
|
@@ -228,6 +259,7 @@ class DragonConfig(PretrainedConfig):
|
|
| 228 |
self.scalable_softmax = scalable_softmax
|
| 229 |
self.resformer = resformer
|
| 230 |
self.mamba_mimo_dim = mamba_mimo_dim
|
|
|
|
| 231 |
|
| 232 |
self.vocab_size = vocab_size
|
| 233 |
self.tie_word_embeddings = tie_word_embeddings
|
|
|
|
| 92 |
|
| 93 |
def __init__(
|
| 94 |
self,
|
| 95 |
+
mamba3_rope: bool = True,
|
| 96 |
+
mamba3_remove_BC_bias: bool = False,
|
| 97 |
+
mamba3_is_id_rms: bool = True,
|
| 98 |
+
mamba3_remove_conv: bool = True,
|
| 99 |
+
mamba3_is_A_dd: bool = True,
|
| 100 |
+
mamba3_add_trapezoid: bool = True,
|
| 101 |
+
moe: bool = False,
|
| 102 |
+
moe_num_routed_experts: int = 2,
|
| 103 |
+
moe_routed_scaling_factor: float = 2.5,
|
| 104 |
+
moe_routed_intermediate_size: int = 768,
|
| 105 |
+
moe_shared_intermediate_size: int = 768,
|
| 106 |
+
intra_doc_masking: bool = False,
|
| 107 |
+
seednorm_rank: int = 1,
|
| 108 |
+
seednorm_type: int = 1,
|
| 109 |
+
final_norm: bool = True,
|
| 110 |
mla_kv_rank: int = 128,
|
| 111 |
shrink_qk_da: int = 2,
|
| 112 |
shrink_qk_gdn: int = 2,
|
|
|
|
| 134 |
scalable_softmax: bool = True,
|
| 135 |
resformer: bool = False,
|
| 136 |
mamba_mimo_dim : int = 4,
|
| 137 |
+
mamba_ngroups : int = 1,
|
| 138 |
gate_type: str = "elementwise",
|
| 139 |
gate_act: str = "silu",
|
| 140 |
gate_attn: bool = False,
|
|
|
|
| 179 |
rope_type_local="rope",
|
| 180 |
rope_type_global="",
|
| 181 |
rope_theta_local=163.,
|
| 182 |
+
rope_theta_global=0.,
|
| 183 |
uscaling_tau=0.2,
|
| 184 |
attention_dropout=0.,
|
| 185 |
hidden_dropout=0.,
|
|
|
|
| 192 |
mlp_linking=False,
|
| 193 |
**kwargs,
|
| 194 |
):
|
| 195 |
+
self.mamba3_rope = mamba3_rope
|
| 196 |
+
self.mamba3_remove_BC_bias = mamba3_remove_BC_bias
|
| 197 |
+
self.mamba3_is_id_rms = mamba3_is_id_rms
|
| 198 |
+
self.mamba3_remove_conv = mamba3_remove_conv
|
| 199 |
+
self.mamba3_is_A_dd = mamba3_is_A_dd
|
| 200 |
+
self.mamba3_add_trapezoid = mamba3_add_trapezoid
|
| 201 |
+
self.moe = moe
|
| 202 |
+
self.moe_num_routed_experts = moe_num_routed_experts
|
| 203 |
+
self.moe_routed_scaling_factor = moe_routed_scaling_factor
|
| 204 |
+
self.moe_routed_intermediate_size = moe_routed_intermediate_size
|
| 205 |
+
self.moe_shared_intermediate_size = moe_shared_intermediate_size
|
| 206 |
+
self.intra_doc_masking = intra_doc_masking
|
| 207 |
+
self.seednorm_rank = seednorm_rank
|
| 208 |
+
self.seednorm_type = seednorm_type
|
| 209 |
+
self.final_norm = final_norm
|
| 210 |
self.mla_kv_rank = mla_kv_rank
|
| 211 |
self.shrink_qk_da = shrink_qk_da
|
| 212 |
self.shrink_qk_gdn = shrink_qk_gdn
|
|
|
|
| 259 |
self.scalable_softmax = scalable_softmax
|
| 260 |
self.resformer = resformer
|
| 261 |
self.mamba_mimo_dim = mamba_mimo_dim
|
| 262 |
+
self.mamba_ngroups = mamba_ngroups
|
| 263 |
|
| 264 |
self.vocab_size = vocab_size
|
| 265 |
self.tie_word_embeddings = tie_word_embeddings
|
coordcheck_utils.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Microsoft Corporation.
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Adapted from https://github.com/microsoft/mup
|
| 5 |
+
In short, it has been largely simplified.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from copy import copy
|
| 10 |
+
from itertools import product
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import seaborn as sns
|
| 20 |
+
|
| 21 |
+
FDICT = {'l1': lambda x: torch.abs(x).mean(dtype=torch.float32)}
|
| 22 |
+
|
| 23 |
+
def convert_fdict(d):
|
| 24 |
+
'''convert a dict `d` with string values to function values.
|
| 25 |
+
Input:
|
| 26 |
+
d: a dict whose values are either strings or functions
|
| 27 |
+
Output:
|
| 28 |
+
a new dict, with the same keys as `d`, but the string values are
|
| 29 |
+
converted to functions using `FDICT`.
|
| 30 |
+
'''
|
| 31 |
+
return dict([
|
| 32 |
+
((k, FDICT[v]) if isinstance(v, str) else (k, v))
|
| 33 |
+
for k, v in d.items()])
|
| 34 |
+
|
| 35 |
+
def _record_coords(records, width, modulename, t,
|
| 36 |
+
output_fdict=None, input_fdict=None, param_fdict=None):
|
| 37 |
+
'''Returns a forward hook that records coordinate statistics.
|
| 38 |
+
|
| 39 |
+
Returns a forward hook that records statistics regarding the output, input,
|
| 40 |
+
and/or parameters of a `nn.Module`. This hook is intended to run only once,
|
| 41 |
+
on the timestep specified by `t`.
|
| 42 |
+
|
| 43 |
+
On forward pass, the returned hook calculates statistics specified in
|
| 44 |
+
`output_fdict`, `input_fdict`, and `param_fdict`, such as the normalized l1
|
| 45 |
+
norm, of output, input, and/or parameters of the module. The statistics are
|
| 46 |
+
recorded along with the `width`, `modulename`, and `t` (the time step) as a
|
| 47 |
+
dict and inserted into `records` (which should be a list). More precisely,
|
| 48 |
+
for each output, input, and/or parameter, the inserted dict is of the form
|
| 49 |
+
|
| 50 |
+
{
|
| 51 |
+
'width': width, 'module': modified_modulename, 't': t,
|
| 52 |
+
# keys are keys in fdict
|
| 53 |
+
'l1': 0.241, 'l2': 0.420, 'mean': 0.0, ...
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
where `modified_modulename` is a string that combines the `modulename` with
|
| 57 |
+
an indicator of which output, input, or parameter tensor is the statistics
|
| 58 |
+
computed over.
|
| 59 |
+
|
| 60 |
+
The `*_fdict` inputs should be dictionaries with string keys and whose
|
| 61 |
+
values can either be functions or strings. The string values are converted
|
| 62 |
+
to functions via `convert_fdict`. The default values of `*_dict` inputs are
|
| 63 |
+
converted to `output_fdict = dict(l1=FDICT['l1'])`, `input_fdict = {}`,
|
| 64 |
+
`param_fdict = {}`, i.e., only the average coordinate size (`l1`) of the
|
| 65 |
+
output activations are recorded.
|
| 66 |
+
|
| 67 |
+
Inputs:
|
| 68 |
+
records:
|
| 69 |
+
list to append coordinate data to
|
| 70 |
+
width:
|
| 71 |
+
width of the model. This is used only for plotting coord check later
|
| 72 |
+
on, so it can be any notion of width.
|
| 73 |
+
modulename:
|
| 74 |
+
string name of the module. This is used only for plotting coord check.
|
| 75 |
+
t:
|
| 76 |
+
timestep of training. This is used only for plotting coord check.
|
| 77 |
+
output_fdict, input_fdict, param_fdict:
|
| 78 |
+
dicts with string keys and whose values can either be functions or
|
| 79 |
+
strings. The string values are converted to functions via
|
| 80 |
+
`convert_fdict`
|
| 81 |
+
Output:
|
| 82 |
+
a forward hook that records statistics regarding the output, input,
|
| 83 |
+
and/or parameters of a `nn.Module`, as discussed above.
|
| 84 |
+
'''
|
| 85 |
+
if output_fdict is None:
|
| 86 |
+
output_fdict = dict(l1=FDICT['l1'])
|
| 87 |
+
else:
|
| 88 |
+
output_fdict = convert_fdict(output_fdict)
|
| 89 |
+
if input_fdict is None:
|
| 90 |
+
input_fdict = {}
|
| 91 |
+
else:
|
| 92 |
+
input_fdict = convert_fdict(input_fdict)
|
| 93 |
+
if param_fdict is None:
|
| 94 |
+
param_fdict = {}
|
| 95 |
+
else:
|
| 96 |
+
param_fdict = convert_fdict(param_fdict)
|
| 97 |
+
def f(module, input, output):
|
| 98 |
+
def get_stat(d, x, fdict):
|
| 99 |
+
if isinstance(x, (tuple, list)):
|
| 100 |
+
for i, _x in enumerate(x):
|
| 101 |
+
_d = copy(d)
|
| 102 |
+
_d['module'] += f'[{i}]'
|
| 103 |
+
get_stat(_d, _x, fdict)
|
| 104 |
+
elif isinstance(x, dict):
|
| 105 |
+
for name, _x in x.items():
|
| 106 |
+
_d = copy(d)
|
| 107 |
+
_d['module'] += f'[{name}]'
|
| 108 |
+
get_stat(_d, _x, fdict)
|
| 109 |
+
elif isinstance(x, torch.Tensor):
|
| 110 |
+
_d = copy(d)
|
| 111 |
+
for fname, f in fdict.items():
|
| 112 |
+
_d[fname] = f(x).item()
|
| 113 |
+
records.append(_d)
|
| 114 |
+
elif x is None:
|
| 115 |
+
pass
|
| 116 |
+
else:
|
| 117 |
+
raise NotImplementedError(f'Unexpected output type: {type(x)}')
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
ret = {
|
| 120 |
+
'width': width,
|
| 121 |
+
'module': modulename,
|
| 122 |
+
't': t
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
# output stats
|
| 126 |
+
if isinstance(output, (tuple, list)):
|
| 127 |
+
for i, out in enumerate(output):
|
| 128 |
+
_ret = copy(ret)
|
| 129 |
+
_ret['module'] += f':out[{i}]'
|
| 130 |
+
get_stat(_ret, out, output_fdict)
|
| 131 |
+
elif isinstance(output, dict):
|
| 132 |
+
for name, out in output.items():
|
| 133 |
+
_ret = copy(ret)
|
| 134 |
+
_ret['module'] += f':out[{name}]'
|
| 135 |
+
get_stat(_ret, out, output_fdict)
|
| 136 |
+
elif isinstance(output, torch.Tensor):
|
| 137 |
+
_ret = copy(ret)
|
| 138 |
+
for fname, f in output_fdict.items():
|
| 139 |
+
_ret[fname] = f(output).item()
|
| 140 |
+
records.append(_ret)
|
| 141 |
+
else:
|
| 142 |
+
raise NotImplementedError(f'Unexpected output type: {type(output)}')
|
| 143 |
+
|
| 144 |
+
# input stats
|
| 145 |
+
if input_fdict:
|
| 146 |
+
if isinstance(input, (tuple, list)):
|
| 147 |
+
for i, out in enumerate(input):
|
| 148 |
+
_ret = copy(ret)
|
| 149 |
+
_ret['module'] += f':in[{i}]'
|
| 150 |
+
get_stat(_ret, out, input_fdict)
|
| 151 |
+
elif isinstance(input, dict):
|
| 152 |
+
for name, out in input.items():
|
| 153 |
+
_ret = copy(ret)
|
| 154 |
+
_ret['module'] += f':in[{name}]'
|
| 155 |
+
get_stat(_ret, out, input_fdict)
|
| 156 |
+
elif isinstance(input, torch.Tensor):
|
| 157 |
+
_ret = copy(ret)
|
| 158 |
+
for fname, f in input_fdict.items():
|
| 159 |
+
_ret[fname] = f(input).item()
|
| 160 |
+
records.append(_ret)
|
| 161 |
+
else:
|
| 162 |
+
raise NotImplementedError(f'Unexpected output type: {type(input)}')
|
| 163 |
+
|
| 164 |
+
# param stats
|
| 165 |
+
if param_fdict:
|
| 166 |
+
for name, p in module.named_parameters():
|
| 167 |
+
_ret = copy(ret)
|
| 168 |
+
_ret['module'] += f':param[{name}]'
|
| 169 |
+
for fname, f in param_fdict.items():
|
| 170 |
+
_ret[fname] = f(p).item()
|
| 171 |
+
records.append(_ret)
|
| 172 |
+
|
| 173 |
+
return f
|
| 174 |
+
|
| 175 |
+
def _get_coord_data(models, dataloader, optcls, nsteps=5,
|
| 176 |
+
dict_in_out=False, flatten_input=False, flatten_output=False,
|
| 177 |
+
output_name='loss', lossfn='xent', filter_module_by_name=None,
|
| 178 |
+
fix_data=True, cuda=True, nseeds=1,
|
| 179 |
+
output_fdict=None, input_fdict=None, param_fdict=None,
|
| 180 |
+
show_progress=True, one_hot_target=False):
|
| 181 |
+
'''Inner method for `get_coord_data`.
|
| 182 |
+
|
| 183 |
+
Train the models in `models` with optimizer given by `optcls` and data from
|
| 184 |
+
`dataloader` for `nsteps` steps, and record coordinate statistics specified
|
| 185 |
+
by `output_fdict`, `input_fdict`, `param_fdict`. By default, only `l1` is
|
| 186 |
+
computed for output activations of each module.
|
| 187 |
+
|
| 188 |
+
Inputs:
|
| 189 |
+
models:
|
| 190 |
+
a dict of lazy models, where the keys are numbers indicating width.
|
| 191 |
+
Each entry of `models` is a function that instantiates a model given
|
| 192 |
+
nothing.
|
| 193 |
+
dataloader:
|
| 194 |
+
an iterator whose elements are either Huggingface style dicts, if
|
| 195 |
+
`dict_in_out` is True, or (input, label). If `fix_data` is True
|
| 196 |
+
(which is the default), then only the first element of `dataloader`
|
| 197 |
+
is used in a loop and the rest of `dataloder` is ignored.
|
| 198 |
+
optcls:
|
| 199 |
+
a function so that `optcls(model)` gives an optimizer used to train
|
| 200 |
+
the model.
|
| 201 |
+
nsteps:
|
| 202 |
+
number of steps to train the model
|
| 203 |
+
dict_in_out:
|
| 204 |
+
whether the data loader contains Huggingface-style dict input and
|
| 205 |
+
output. Default: False
|
| 206 |
+
flatten_input:
|
| 207 |
+
if not `dict_in_out`, reshape the input to be
|
| 208 |
+
`input.view(input.shape[0], -1)`. Typically used for testing MLPs.
|
| 209 |
+
flatten_output:
|
| 210 |
+
if not `dict_in_out`, reshape the label to be `label.view(-1,
|
| 211 |
+
input.shape[-1])`.
|
| 212 |
+
output_name:
|
| 213 |
+
if `dict_in_out`, this is the key for the loss value if the output
|
| 214 |
+
is a dict. If the output is not a dict, then we assume the first
|
| 215 |
+
element of the output is the loss.
|
| 216 |
+
lossfn:
|
| 217 |
+
loss function to use if not `dict_in_out`. Can be either a string from
|
| 218 |
+
[`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that
|
| 219 |
+
`lossfn(output, target)` returns the loss value. Examples of valid
|
| 220 |
+
`callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is
|
| 221 |
+
`torch.nn.functional`. Default: 'xent'
|
| 222 |
+
filter_module_by_name:
|
| 223 |
+
a function that returns a bool given module names (from
|
| 224 |
+
`model.named_modules()`), or None. If not None, then only modules
|
| 225 |
+
whose name yields True will be recorded.
|
| 226 |
+
cuda:
|
| 227 |
+
whether to use cuda or not. Default: True
|
| 228 |
+
nseeds:
|
| 229 |
+
number of times to repeat the training, each with different seeds.
|
| 230 |
+
output_fdict, input_fdict, param_fdict:
|
| 231 |
+
function dicts to be used in `_record_coords`. By default, only `l1`
|
| 232 |
+
is computed for output activations of each module.
|
| 233 |
+
show_progress:
|
| 234 |
+
show progress using tqdm. Default: True
|
| 235 |
+
one_hot_target:
|
| 236 |
+
convert target label into a one-hot vector. This typically is only
|
| 237 |
+
used for `'mse'` or `'l1'` losses in classification tasks.
|
| 238 |
+
Default: False
|
| 239 |
+
Output:
|
| 240 |
+
a pandas DataFrame containing recorded results. The column names are
|
| 241 |
+
`'width', 'module', 't'` as well as names of statistics recorded, such
|
| 242 |
+
as `'l1'` (see `FDICT` for other premade statistics that can be
|
| 243 |
+
collected).
|
| 244 |
+
|
| 245 |
+
Breaking Changes:
|
| 246 |
+
In v1.0.0, when `lossfn=='mse'`, the target is automatically converted
|
| 247 |
+
to a one hot vector before loss computation. Starting in v1.1.0, this
|
| 248 |
+
behavior is turned off, and the user needs to explicitly turn on this
|
| 249 |
+
behavior by setting `one_hot_target=True`.
|
| 250 |
+
|
| 251 |
+
'''
|
| 252 |
+
df = []
|
| 253 |
+
if fix_data:
|
| 254 |
+
batch = next(iter(dataloader))
|
| 255 |
+
dataloader = [batch] * nsteps
|
| 256 |
+
if show_progress:
|
| 257 |
+
pbar = tqdm(total=nseeds * len(models))
|
| 258 |
+
|
| 259 |
+
for i in range(nseeds):
|
| 260 |
+
torch.manual_seed(i)
|
| 261 |
+
for width, model in models.items():
|
| 262 |
+
model = model()
|
| 263 |
+
model = model.train()
|
| 264 |
+
if cuda:
|
| 265 |
+
model = model.cuda()
|
| 266 |
+
optimizer = optcls(model)
|
| 267 |
+
for batch_idx, batch in enumerate(dataloader, 1):
|
| 268 |
+
remove_hooks = []
|
| 269 |
+
# add hooks
|
| 270 |
+
for name, module in model.named_modules():
|
| 271 |
+
if filter_module_by_name and not filter_module_by_name(name):
|
| 272 |
+
continue
|
| 273 |
+
remove_hooks.append(module.register_forward_hook(
|
| 274 |
+
_record_coords(df, width, name, batch_idx,
|
| 275 |
+
output_fdict=output_fdict,
|
| 276 |
+
input_fdict=input_fdict,
|
| 277 |
+
param_fdict=param_fdict)))
|
| 278 |
+
if dict_in_out:
|
| 279 |
+
(data, target) = batch
|
| 280 |
+
loss = model(input_ids=data, labels=target).loss
|
| 281 |
+
else:
|
| 282 |
+
assert False, "Not implemented for non-dict input/output."
|
| 283 |
+
optimizer.zero_grad()
|
| 284 |
+
loss.backward()
|
| 285 |
+
optimizer.step()
|
| 286 |
+
|
| 287 |
+
# remove hooks
|
| 288 |
+
for handle in remove_hooks:
|
| 289 |
+
handle.remove()
|
| 290 |
+
|
| 291 |
+
if batch_idx == nsteps: break
|
| 292 |
+
if show_progress:
|
| 293 |
+
pbar.update(1)
|
| 294 |
+
if show_progress:
|
| 295 |
+
pbar.close()
|
| 296 |
+
return pd.DataFrame(df)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def get_coord_data(models, dataloader, optcls, nsteps, **kwargs):
|
| 300 |
+
'''Get coord data for coord check.
|
| 301 |
+
|
| 302 |
+
Train the models in `models` with data from `dataloader` and optimizer
|
| 303 |
+
specified by `optimizer` and `lr` for `nsteps` steps, and record coordinate
|
| 304 |
+
statistics specified by `output_fdict`, `input_fdict`, `param_fdict`. By
|
| 305 |
+
default, only `l1` is computed for output activations of each module.
|
| 306 |
+
|
| 307 |
+
This function wraps around `_get_coord_data`, with the main difference being
|
| 308 |
+
user can specify common optimizers via a more convenient interface.
|
| 309 |
+
|
| 310 |
+
Inputs:
|
| 311 |
+
models:
|
| 312 |
+
a dict of lazy models, where the keys are numbers indicating width.
|
| 313 |
+
Each entry of `models` is a function that instantiates a model given
|
| 314 |
+
nothing.
|
| 315 |
+
dataloader:
|
| 316 |
+
an iterator whose elements are either Huggingface style dicts, if
|
| 317 |
+
`dict_in_out` is True, or (input, label). If `fix_data` is True
|
| 318 |
+
(which is the default), then only the first element of `dataloader`
|
| 319 |
+
is used in a loop and the rest of `dataloder` is ignored.
|
| 320 |
+
optimizer:
|
| 321 |
+
a string in `['sgd', 'adam', 'adamw']`, with default being `'sgd'`.
|
| 322 |
+
lr:
|
| 323 |
+
learning rate. By default is 0.1 for `'sgd'` and 1e-3 for others.
|
| 324 |
+
mup:
|
| 325 |
+
If True, then use the optimizer from `mup.optim`; otherwise, use the
|
| 326 |
+
one from `torch.optim`.
|
| 327 |
+
filter_trainable_by_name:
|
| 328 |
+
a function that returns a bool given module names (from
|
| 329 |
+
`model.named_modules()`), or None. If not None, then only modules
|
| 330 |
+
whose name yields True will be trained.
|
| 331 |
+
nsteps:
|
| 332 |
+
number of steps to train the model
|
| 333 |
+
dict_in_out:
|
| 334 |
+
whether the data loader contains Huggingface-style dict input and
|
| 335 |
+
output. Default: False
|
| 336 |
+
flatten_input:
|
| 337 |
+
if not `dict_in_out`, reshape the input to be
|
| 338 |
+
`input.view(input.shape[0], -1)`. Typically used for testing MLPs.
|
| 339 |
+
flatten_output:
|
| 340 |
+
if not `dict_in_out`, reshape the label to be `label.view(-1,
|
| 341 |
+
input.shape[-1])`.
|
| 342 |
+
output_name:
|
| 343 |
+
if `dict_in_out`, this is the key for the loss value if the output
|
| 344 |
+
is a dict. If the output is not a dict, then we assume the first
|
| 345 |
+
element of the output is the loss.
|
| 346 |
+
lossfn:
|
| 347 |
+
loss function to use if not `dict_in_out`. Can be either a string from
|
| 348 |
+
[`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that
|
| 349 |
+
`lossfn(output, target)` returns the loss value. Examples of valid
|
| 350 |
+
`callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is
|
| 351 |
+
`torch.nn.functional`. Default: 'xent'
|
| 352 |
+
filter_module_by_name:
|
| 353 |
+
a function that returns a bool given module names (from
|
| 354 |
+
`model.named_modules()`), or None. If not None, then only modules
|
| 355 |
+
whose name yields True will be recorded.
|
| 356 |
+
cuda:
|
| 357 |
+
whether to use cuda or not. Default: True
|
| 358 |
+
nseeds:
|
| 359 |
+
number of times to repeat the training, each with different seeds.
|
| 360 |
+
output_fdict, input_fdict, param_fdict:
|
| 361 |
+
function dicts to be used in `_record_coords`. By default, only `l1`
|
| 362 |
+
is computed for output activations of each module.
|
| 363 |
+
show_progress:
|
| 364 |
+
show progress using tqdm. Default: True
|
| 365 |
+
one_hot_target:
|
| 366 |
+
convert target label into a one-hot vector. This typically is only
|
| 367 |
+
used for `'mse'` or `'l1'` losses in classification tasks.
|
| 368 |
+
Default: False
|
| 369 |
+
Output:
|
| 370 |
+
a pandas DataFrame containing recorded results. The column names are
|
| 371 |
+
`'width', 'module', 't'` as well as names of statistics recorded, such
|
| 372 |
+
as `'l1'` (see `FDICT` for other premade statistics that can be
|
| 373 |
+
collected).
|
| 374 |
+
|
| 375 |
+
Breaking Changes:
|
| 376 |
+
In v1.0.0, when `lossfn=='mse'`, the target is automatically converted
|
| 377 |
+
to a one hot vector before loss computation. Starting in v1.1.0, this
|
| 378 |
+
behavior is turned off, and the user needs to explicitly turn on this
|
| 379 |
+
behavior by setting `one_hot_target=True`.
|
| 380 |
+
'''
|
| 381 |
+
|
| 382 |
+
data = _get_coord_data(models, dataloader, optcls, nsteps, dict_in_out=True, **kwargs)
|
| 383 |
+
return data
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
|
| 387 |
+
legend='full', name_contains=None, name_not_contains=None, module_list=None,
|
| 388 |
+
loglog=True, logbase=2, face_color=None, subplot_width=5,
|
| 389 |
+
subplot_height=4):
|
| 390 |
+
'''Plot coord check data `df` obtained from `get_coord_data`.
|
| 391 |
+
|
| 392 |
+
Input:
|
| 393 |
+
df:
|
| 394 |
+
a pandas DataFrame obtained from `get_coord_data`
|
| 395 |
+
y:
|
| 396 |
+
the column of `df` to plot on the y-axis. Default: `'l1'`
|
| 397 |
+
save_to:
|
| 398 |
+
path to save the resulting figure, or None. Default: None.
|
| 399 |
+
suptitle:
|
| 400 |
+
The title of the entire figure.
|
| 401 |
+
x:
|
| 402 |
+
the column of `df` to plot on the x-axis. Default: `'width'`
|
| 403 |
+
hue:
|
| 404 |
+
the column of `df` to represent as color. Default: `'module'`
|
| 405 |
+
legend:
|
| 406 |
+
'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`.
|
| 407 |
+
name_contains, name_not_contains:
|
| 408 |
+
only plot modules whose name contains `name_contains` and does not contain `name_not_contains`
|
| 409 |
+
module_list:
|
| 410 |
+
only plot modules that are given in the list, overrides `name_contains` and `name_not_contains`
|
| 411 |
+
loglog:
|
| 412 |
+
whether to use loglog scale. Default: True
|
| 413 |
+
logbase:
|
| 414 |
+
the log base, if using loglog scale. Default: 2
|
| 415 |
+
face_color:
|
| 416 |
+
background color of the plot. Default: None (which means white)
|
| 417 |
+
subplot_width, subplot_height:
|
| 418 |
+
The width and height for each timestep's subplot. More precisely,
|
| 419 |
+
the figure size will be
|
| 420 |
+
`(subplot_width*number_of_time_steps, subplot_height)`.
|
| 421 |
+
Default: 5, 4
|
| 422 |
+
|
| 423 |
+
Output:
|
| 424 |
+
the `matplotlib` figure object
|
| 425 |
+
'''
|
| 426 |
+
### preprocessing
|
| 427 |
+
df = copy(df)
|
| 428 |
+
df = df[df.module != ''] # nn.Sequential has name '', which duplicates the output layer
|
| 429 |
+
if module_list is not None:
|
| 430 |
+
df = df[df['module'].isin(module_list)]
|
| 431 |
+
else:
|
| 432 |
+
if name_contains is not None:
|
| 433 |
+
df = df[df['module'].str.contains(name_contains)]
|
| 434 |
+
if name_not_contains is not None:
|
| 435 |
+
df = df[~(df['module'].str.contains(name_not_contains))]
|
| 436 |
+
try:
|
| 437 |
+
df['module'] = pd.to_numeric(df['module']) # for nn.Sequential, module names are numerical
|
| 438 |
+
except ValueError:
|
| 439 |
+
pass
|
| 440 |
+
|
| 441 |
+
ts = df.t.unique()
|
| 442 |
+
|
| 443 |
+
sns.set()
|
| 444 |
+
|
| 445 |
+
def tight_layout(plt):
|
| 446 |
+
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
| 447 |
+
|
| 448 |
+
### plot
|
| 449 |
+
fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height))
|
| 450 |
+
hue_order = sorted(set(df['module']))
|
| 451 |
+
if face_color is not None:
|
| 452 |
+
fig.patch.set_facecolor(face_color)
|
| 453 |
+
ymin, ymax = min(df[y]), max(df[y])
|
| 454 |
+
for t in ts:
|
| 455 |
+
t = int(t)
|
| 456 |
+
plt.subplot(1, len(ts), t)
|
| 457 |
+
sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=None) # to show legend, set legend if t == 1 else None
|
| 458 |
+
plt.title(f't={t}')
|
| 459 |
+
if t != 1:
|
| 460 |
+
plt.ylabel('')
|
| 461 |
+
if loglog:
|
| 462 |
+
plt.loglog(base=logbase)
|
| 463 |
+
ax = plt.gca()
|
| 464 |
+
ax.set_ylim([ymin, ymax])
|
| 465 |
+
if suptitle:
|
| 466 |
+
plt.suptitle(suptitle)
|
| 467 |
+
tight_layout(plt)
|
| 468 |
+
if save_to is not None:
|
| 469 |
+
plt.savefig(save_to)
|
| 470 |
+
print(f'coord check plot saved to {save_to}')
|
| 471 |
+
|
| 472 |
+
return fig
|
coordchecking_dragon.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import tyro
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
|
| 10 |
+
from .configuration_dragon import DragonConfig
|
| 11 |
+
from .modeling_dragon import DragonForCausalLM
|
| 12 |
+
from .coordcheck_utils import get_coord_data, plot_coord_data
|
| 13 |
+
|
| 14 |
+
# TRITON_HOME="/p/project1/jureap140/temp" python make_coord_check.py
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class Args:
|
| 18 |
+
save_dir: Path
|
| 19 |
+
mup: bool = False
|
| 20 |
+
learning_rate: float = 1e-2
|
| 21 |
+
layers_config: str = "gggTgggTgggTggg"
|
| 22 |
+
args = tyro.cli(Args)
|
| 23 |
+
|
| 24 |
+
batch_size = 8
|
| 25 |
+
batch_len = 1024
|
| 26 |
+
max_value = 100
|
| 27 |
+
|
| 28 |
+
widths = [128, 512, 1024, 2048]
|
| 29 |
+
n_heads = [4, 8, 16, 32]
|
| 30 |
+
d_head = 64
|
| 31 |
+
|
| 32 |
+
class RandomDataset(Dataset):
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return 9999999
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, _):
|
| 37 |
+
data = torch.randint(low=0, high=max_value, size=(batch_size, batch_len))
|
| 38 |
+
return data.cuda(), data.cuda()
|
| 39 |
+
|
| 40 |
+
def lazy_model(width):
|
| 41 |
+
config_hf = DragonConfig(
|
| 42 |
+
layers_config=args.layers_config,
|
| 43 |
+
hidden_size=width,
|
| 44 |
+
intermediate_size=4*width,
|
| 45 |
+
tpa_rank=4,
|
| 46 |
+
token_shift_attn=True,
|
| 47 |
+
head_dim=d_head,
|
| 48 |
+
shrink_qk_da=1,
|
| 49 |
+
num_attention_heads=n_heads[widths.index(width)],
|
| 50 |
+
num_signal_heads_diff=n_heads[widths.index(width)]-n_heads[widths.index(width)]//4,
|
| 51 |
+
num_key_value_heads=n_heads[widths.index(width)],
|
| 52 |
+
head_dim_gdn=d_head,
|
| 53 |
+
shrink_qk_gdn=2,
|
| 54 |
+
num_attention_heads_gdn=n_heads[widths.index(width)],
|
| 55 |
+
zero_centered_gate=True,
|
| 56 |
+
zero_centered_gate_type=4,
|
| 57 |
+
mamba_mimo_dim=4,
|
| 58 |
+
mamba_ngroups=1,
|
| 59 |
+
gate_attn=True,
|
| 60 |
+
zero_centered_gamma=True,
|
| 61 |
+
vocab_size=max_value,
|
| 62 |
+
max_position_embeddings=1024,
|
| 63 |
+
use_uscaling=True,
|
| 64 |
+
uscaling_tau=0.2,
|
| 65 |
+
initializer_range=1.,
|
| 66 |
+
use_cache=False,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if args.mup:
|
| 70 |
+
config_hf.use_uscaling = True
|
| 71 |
+
config_hf.initializer_range = 1.0
|
| 72 |
+
else:
|
| 73 |
+
config_hf.use_uscaling = False
|
| 74 |
+
config_hf.initializer_range = 0.006
|
| 75 |
+
|
| 76 |
+
return lambda: DragonForCausalLM(config_hf).to("cuda")
|
| 77 |
+
|
| 78 |
+
def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_lr_head, wd):
|
| 79 |
+
groups, seen = [], set()
|
| 80 |
+
id2name = {id(p): n for n, p in model.named_parameters()}
|
| 81 |
+
|
| 82 |
+
for mod in model.modules():
|
| 83 |
+
if isinstance(mod, nn.Linear):
|
| 84 |
+
pname = id2name.get(id(mod.weight), "")
|
| 85 |
+
is_scalar = getattr(mod, "is_scalar_weight", False)
|
| 86 |
+
fan_in = mod.weight.shape[1]
|
| 87 |
+
scale = 1 / math.sqrt(fan_in)
|
| 88 |
+
if "lm_head" in pname:
|
| 89 |
+
lr_scaled = base_lr_head
|
| 90 |
+
wd_scaled = 0.0
|
| 91 |
+
elif is_scalar:
|
| 92 |
+
lr_scaled = base_lr_scalar
|
| 93 |
+
wd_scaled = 0.0
|
| 94 |
+
else:
|
| 95 |
+
lr_scaled = base_lr_hidden * scale
|
| 96 |
+
wd_scaled = wd / lr_scaled
|
| 97 |
+
|
| 98 |
+
groups.append({"params": [mod.weight], "lr": lr_scaled, "weight_decay": wd_scaled})
|
| 99 |
+
seen.add(mod.weight)
|
| 100 |
+
|
| 101 |
+
if mod.bias is not None:
|
| 102 |
+
groups.append({"params": [mod.bias], "lr": base_lr_scalar, "weight_decay": 0.0})
|
| 103 |
+
seen.add(mod.bias)
|
| 104 |
+
|
| 105 |
+
for p in model.parameters():
|
| 106 |
+
if p in seen:
|
| 107 |
+
continue
|
| 108 |
+
pname = id2name.get(id(p), "<unnamed>")
|
| 109 |
+
|
| 110 |
+
if "embedding" in pname:
|
| 111 |
+
#fan_out = p.shape[1] # nn.Embedding is transposed
|
| 112 |
+
#lr_scaled = base_lr / math.sqrt(fan_out) # u-muP
|
| 113 |
+
lr_scaled = base_lr_embed
|
| 114 |
+
else:
|
| 115 |
+
lr_scaled = base_lr_scalar
|
| 116 |
+
|
| 117 |
+
wd_scaled = 0.
|
| 118 |
+
if getattr(p, "requires_weight_decay", False):
|
| 119 |
+
wd_scaled = wd / lr_scaled
|
| 120 |
+
|
| 121 |
+
groups.append({"params": [p], "lr": lr_scaled, "weight_decay": wd_scaled})
|
| 122 |
+
|
| 123 |
+
return groups
|
| 124 |
+
|
| 125 |
+
models = {width: lazy_model(width) for width in widths}
|
| 126 |
+
|
| 127 |
+
dataset = RandomDataset()
|
| 128 |
+
loader = DataLoader(dataset, batch_size=None, shuffle=True)
|
| 129 |
+
iter_ = iter(loader)
|
| 130 |
+
|
| 131 |
+
def get_optim(model):
|
| 132 |
+
if args.mup:
|
| 133 |
+
param_list = param_groups_mup(
|
| 134 |
+
model,
|
| 135 |
+
base_lr_hidden=args.learning_rate,
|
| 136 |
+
base_lr_scalar=2**-6,
|
| 137 |
+
base_lr_embed=2**-4,
|
| 138 |
+
base_lr_head=2**-6,
|
| 139 |
+
wd=0.,
|
| 140 |
+
)
|
| 141 |
+
optimizer = torch.optim.AdamW(param_list, betas=(0.9, 0.95), eps=1e-8)
|
| 142 |
+
else:
|
| 143 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0., betas=(0.9, 0.95), eps=1e-8)
|
| 144 |
+
return optimizer
|
| 145 |
+
optcls = lambda model: get_optim(model)
|
| 146 |
+
|
| 147 |
+
df = get_coord_data(models, iter_, optcls, nsteps=10)
|
| 148 |
+
|
| 149 |
+
if args.mup:
|
| 150 |
+
name = f"mup_{args.learning_rate}_{args.layers_config}.png"
|
| 151 |
+
else:
|
| 152 |
+
name = f"sp_{args.learning_rate}_{args.layers_config}.png"
|
| 153 |
+
|
| 154 |
+
plot_coord_data(df, legend="full", save_to=args.save_dir / name)
|
inspecting_dragon.py
CHANGED
|
@@ -19,9 +19,13 @@ class NanoArgs:
|
|
| 19 |
# arch - general
|
| 20 |
d_model : int = 768
|
| 21 |
n_heads : int = 6 # head dim 128 suggested by @Grad62304977
|
|
|
|
| 22 |
layers_config : str = 4*"lrdlr"
|
| 23 |
-
expand_factor : int =
|
|
|
|
|
|
|
| 24 |
rope_theta_local: float = 10000.0
|
|
|
|
| 25 |
eps_rmsnorm: float = 1e-6
|
| 26 |
mlp_expand: int = 4 # expand factor for MLP
|
| 27 |
fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
|
|
@@ -32,9 +36,14 @@ class NanoArgs:
|
|
| 32 |
zero_centered_gate_type: int = 1 # 1, 2, 3, 4
|
| 33 |
gate_attn: bool = False
|
| 34 |
gate_gdn: bool = True
|
| 35 |
-
gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head)
|
| 36 |
gate_act: str = "silu" # silu, sigmoid
|
| 37 |
scalar_proj_as_hidden_matrix: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# attention related
|
| 40 |
n_kv_heads : int = 0
|
|
@@ -46,26 +55,38 @@ class NanoArgs:
|
|
| 46 |
softcap_global_attn: float = 0.0
|
| 47 |
qk_norm: bool = True
|
| 48 |
scalable_softmax: bool = True
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
num_attention_heads_indexer: int = 8
|
| 51 |
head_dim_indexer: int = 32
|
| 52 |
dsa_q_lora_rank: int = 128
|
| 53 |
dsa_topk: int = 512
|
| 54 |
-
cca_head_dim: int = 128
|
| 55 |
cca_seq_kernel_size: int = 4
|
| 56 |
-
nsa_head_dim: int = 128
|
| 57 |
nsa_topk: int = 16
|
| 58 |
nsa_block_size: int = 64
|
| 59 |
nsa_window_size: int = 512
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# GDN related
|
| 62 |
rope_gdn: Optional[str] = None # None, rope, (srope)
|
|
|
|
| 63 |
n_heads_gdn: int = 0
|
| 64 |
n_kv_heads_gdn: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# optim
|
| 67 |
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
| 68 |
-
second_order_optim : Optional[str] = None #
|
| 69 |
batch_size: int = 8*64 # batch size, in sequences, across all devices
|
| 70 |
device_batch_size: int = 64 # batch size, in sequences, per device
|
| 71 |
total_iterations: int = 1000 # number of iterations to run
|
|
@@ -83,14 +104,13 @@ class NanoArgs:
|
|
| 83 |
init_std: float = 0.006
|
| 84 |
patch_level_training: bool = False
|
| 85 |
patch_level_training_size: int = 4
|
| 86 |
-
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# data
|
| 89 |
vocab_size: int = 50304
|
| 90 |
sequence_length: int = 1024
|
| 91 |
-
use_patch_level_training: bool = False
|
| 92 |
-
patch_size: int = 4
|
| 93 |
-
patch_training_fraction: float = 0.67
|
| 94 |
input_bin: Optional[str] = None
|
| 95 |
input_val_bin: Optional[str] = None
|
| 96 |
|
|
@@ -116,21 +136,39 @@ args = tyro.cli(NanoArgs)
|
|
| 116 |
|
| 117 |
# load model.
|
| 118 |
config_hf = DragonConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
| 121 |
patch_level_training=args.patch_level_training,
|
| 122 |
patch_level_training_size=args.patch_level_training_size,
|
| 123 |
-
nsa_head_dim=args.nsa_head_dim,
|
| 124 |
nsa_topk=args.nsa_topk,
|
| 125 |
nsa_block_size=args.nsa_block_size,
|
| 126 |
nsa_window_size=args.nsa_window_size,
|
| 127 |
-
cca_head_dim=args.cca_head_dim,
|
| 128 |
cca_seq_kernel_size=args.cca_seq_kernel_size,
|
|
|
|
|
|
|
| 129 |
num_attention_heads_gdn=args.n_heads_gdn,
|
| 130 |
num_key_value_heads_gdn=args.n_kv_heads_gdn,
|
| 131 |
zero_centered_gate=args.zero_centered_gate,
|
| 132 |
zero_centered_gate_type=args.zero_centered_gate_type,
|
| 133 |
scalable_softmax=args.scalable_softmax,
|
|
|
|
|
|
|
|
|
|
| 134 |
gate_type=args.gate_type,
|
| 135 |
gate_act=args.gate_act,
|
| 136 |
gate_attn=args.gate_attn,
|
|
@@ -157,8 +195,12 @@ config_hf = DragonConfig(
|
|
| 157 |
norm_epsilon=args.eps_rmsnorm,
|
| 158 |
use_cache=False,
|
| 159 |
sliding_window_size=args.swa_window_size,
|
|
|
|
|
|
|
|
|
|
| 160 |
rope_theta_local=args.rope_theta_local,
|
| 161 |
uscaling_tau=args.uscaling_tau,
|
|
|
|
| 162 |
)
|
| 163 |
|
| 164 |
model = DragonForCausalLM(config_hf)
|
|
|
|
| 19 |
# arch - general
|
| 20 |
d_model : int = 768
|
| 21 |
n_heads : int = 6 # head dim 128 suggested by @Grad62304977
|
| 22 |
+
head_dim: Optional[int] = None
|
| 23 |
layers_config : str = 4*"lrdlr"
|
| 24 |
+
expand_factor : int = 2 # expand factor for Mamba/Dragon
|
| 25 |
+
rope_type_local: str = "" #p-rope
|
| 26 |
+
rope_type_global: str = "" #p-rope
|
| 27 |
rope_theta_local: float = 10000.0
|
| 28 |
+
rope_theta_global: float = 0.0
|
| 29 |
eps_rmsnorm: float = 1e-6
|
| 30 |
mlp_expand: int = 4 # expand factor for MLP
|
| 31 |
fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
|
|
|
|
| 36 |
zero_centered_gate_type: int = 1 # 1, 2, 3, 4
|
| 37 |
gate_attn: bool = False
|
| 38 |
gate_gdn: bool = True
|
| 39 |
+
gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head), kimi (lora)
|
| 40 |
gate_act: str = "silu" # silu, sigmoid
|
| 41 |
scalar_proj_as_hidden_matrix: bool = True
|
| 42 |
+
normalization_type: str = "rmsnorm" # rmsnorm, seednorm
|
| 43 |
+
seednorm_wd: bool = True
|
| 44 |
+
mixer_gn: bool = True
|
| 45 |
+
mlp_linking : bool = False
|
| 46 |
+
final_norm: bool = True
|
| 47 |
|
| 48 |
# attention related
|
| 49 |
n_kv_heads : int = 0
|
|
|
|
| 55 |
softcap_global_attn: float = 0.0
|
| 56 |
qk_norm: bool = True
|
| 57 |
scalable_softmax: bool = True
|
| 58 |
+
resformer : bool = False # Works only on f layers (DiffAttention)
|
| 59 |
+
token_shift_attn: bool = False
|
| 60 |
+
token_shift_gdn: bool = False
|
| 61 |
+
token_conv1d_attn: bool = False
|
| 62 |
+
token_conv1d_gdn: bool = True
|
| 63 |
num_attention_heads_indexer: int = 8
|
| 64 |
head_dim_indexer: int = 32
|
| 65 |
dsa_q_lora_rank: int = 128
|
| 66 |
dsa_topk: int = 512
|
|
|
|
| 67 |
cca_seq_kernel_size: int = 4
|
|
|
|
| 68 |
nsa_topk: int = 16
|
| 69 |
nsa_block_size: int = 64
|
| 70 |
nsa_window_size: int = 512
|
| 71 |
+
num_signal_heads_diff: Optional[int] = None
|
| 72 |
+
tpa_rank: int = 2
|
| 73 |
+
shrink_qk_da: int = 2
|
| 74 |
+
mla_kv_rank: int = 128
|
| 75 |
|
| 76 |
# GDN related
|
| 77 |
rope_gdn: Optional[str] = None # None, rope, (srope)
|
| 78 |
+
head_dim_gdn: Optional[int] = None
|
| 79 |
n_heads_gdn: int = 0
|
| 80 |
n_kv_heads_gdn: int = 0
|
| 81 |
+
shrink_qk_gdn: int = 2
|
| 82 |
+
kda_allow_neg_eigval: bool = False
|
| 83 |
+
kda_num_v_heads: Optional[int] = None
|
| 84 |
+
mamba_mimo_dim: Optional[int] = 2
|
| 85 |
+
mamba_ngroups: Optional[int] = 1
|
| 86 |
|
| 87 |
# optim
|
| 88 |
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
| 89 |
+
second_order_optim : Optional[str] = None # snoo
|
| 90 |
batch_size: int = 8*64 # batch size, in sequences, across all devices
|
| 91 |
device_batch_size: int = 64 # batch size, in sequences, per device
|
| 92 |
total_iterations: int = 1000 # number of iterations to run
|
|
|
|
| 104 |
init_std: float = 0.006
|
| 105 |
patch_level_training: bool = False
|
| 106 |
patch_level_training_size: int = 4
|
| 107 |
+
second_order_lr: float = 0.68
|
| 108 |
+
second_order_momentum: float = 0.37
|
| 109 |
+
second_order_interval: int = 25
|
| 110 |
|
| 111 |
# data
|
| 112 |
vocab_size: int = 50304
|
| 113 |
sequence_length: int = 1024
|
|
|
|
|
|
|
|
|
|
| 114 |
input_bin: Optional[str] = None
|
| 115 |
input_val_bin: Optional[str] = None
|
| 116 |
|
|
|
|
| 136 |
|
| 137 |
# load model.
|
| 138 |
config_hf = DragonConfig(
|
| 139 |
+
final_norm=args.final_norm,
|
| 140 |
+
mla_kv_rank=args.mla_kv_rank,
|
| 141 |
+
rope_gdn=args.rope_gdn,
|
| 142 |
+
shrink_qk_da=args.shrink_qk_da,
|
| 143 |
+
shrink_qk_gdn=args.shrink_qk_gdn,
|
| 144 |
+
mixer_gn=args.mixer_gn,
|
| 145 |
+
kda_allow_neg_eigval=args.kda_allow_neg_eigval,
|
| 146 |
+
kda_num_v_heads=args.kda_num_v_heads,
|
| 147 |
+
seednorm_wd=args.seednorm_wd,
|
| 148 |
+
normalization_type=args.normalization_type,
|
| 149 |
+
tpa_rank=args.tpa_rank,
|
| 150 |
+
num_signal_heads_diff=args.num_signal_heads_diff,
|
| 151 |
scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
|
| 152 |
+
token_shift_attn=args.token_shift_attn,
|
| 153 |
+
token_shift_gdn=args.token_shift_gdn,
|
| 154 |
+
token_conv1d_attn=args.token_conv1d_attn,
|
| 155 |
+
token_conv1d_gdn=args.token_conv1d_gdn,
|
| 156 |
patch_level_training=args.patch_level_training,
|
| 157 |
patch_level_training_size=args.patch_level_training_size,
|
|
|
|
| 158 |
nsa_topk=args.nsa_topk,
|
| 159 |
nsa_block_size=args.nsa_block_size,
|
| 160 |
nsa_window_size=args.nsa_window_size,
|
|
|
|
| 161 |
cca_seq_kernel_size=args.cca_seq_kernel_size,
|
| 162 |
+
head_dim=args.head_dim,
|
| 163 |
+
head_dim_gdn=args.head_dim_gdn,
|
| 164 |
num_attention_heads_gdn=args.n_heads_gdn,
|
| 165 |
num_key_value_heads_gdn=args.n_kv_heads_gdn,
|
| 166 |
zero_centered_gate=args.zero_centered_gate,
|
| 167 |
zero_centered_gate_type=args.zero_centered_gate_type,
|
| 168 |
scalable_softmax=args.scalable_softmax,
|
| 169 |
+
mamba_mimo_dim=args.mamba_mimo_dim,
|
| 170 |
+
mamba_ngroups=args.mamba_ngroups,
|
| 171 |
+
resformer=args.resformer,
|
| 172 |
gate_type=args.gate_type,
|
| 173 |
gate_act=args.gate_act,
|
| 174 |
gate_attn=args.gate_attn,
|
|
|
|
| 195 |
norm_epsilon=args.eps_rmsnorm,
|
| 196 |
use_cache=False,
|
| 197 |
sliding_window_size=args.swa_window_size,
|
| 198 |
+
rope_type_global=args.rope_type_global,
|
| 199 |
+
rope_type_local=args.rope_type_local,
|
| 200 |
+
rope_theta_global=args.rope_theta_global,
|
| 201 |
rope_theta_local=args.rope_theta_local,
|
| 202 |
uscaling_tau=args.uscaling_tau,
|
| 203 |
+
mlp_linking=args.mlp_linking
|
| 204 |
)
|
| 205 |
|
| 206 |
model = DragonForCausalLM(config_hf)
|
modeling_dragon.py
CHANGED
|
@@ -19,11 +19,20 @@ from transformers.utils import ModelOutput, logging
|
|
| 19 |
|
| 20 |
from fla.ops.nsa.parallel import parallel_nsa
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
try:
|
| 23 |
from dragon_mamba3_ops.siso_variant.ssd_combined_fused import mamba_chunk_scan_discretized_combined
|
|
|
|
| 24 |
from dragon_mamba3_ops.angle_cumsum import angle_dt
|
| 25 |
from dragon_mamba3_ops.rotary_mamba import rotary_qk
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
mamba_chunk_scan_discretized_combined, angle_dt, rotary_qk = None, None, None
|
| 28 |
|
| 29 |
try:
|
|
@@ -39,8 +48,9 @@ try:
|
|
| 39 |
from fla.ops.kda import chunk_kda, fused_recurrent_kda
|
| 40 |
from fla.ops.kda.gate import fused_kda_gate
|
| 41 |
from fla.modules import FusedRMSNormGated, ShortConvolution
|
|
|
|
| 42 |
except ImportError:
|
| 43 |
-
chunk_kda, fused_recurrent_kda, fused_kda_gate = None, None, None
|
| 44 |
|
| 45 |
from torch.compiler import disable
|
| 46 |
|
|
@@ -56,13 +66,14 @@ ATTN_IMPL = "eager"
|
|
| 56 |
try:
|
| 57 |
import flash_attn_interface # FA3
|
| 58 |
flash_attn_func = flash_attn_interface.flash_attn_func
|
|
|
|
| 59 |
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
| 60 |
if not _flash_supports_window_size:
|
| 61 |
raise ImportError("flash_attn_func does not support window_size parameter. Please update to more recent flash_attn version")
|
| 62 |
ATTN_IMPL = "fa3"
|
| 63 |
except ImportError:
|
| 64 |
try:
|
| 65 |
-
from flash_attn import flash_attn_func # FA2
|
| 66 |
ATTN_IMPL = "fa2"
|
| 67 |
except ImportError:
|
| 68 |
try:
|
|
@@ -123,7 +134,16 @@ class DragonNorm(nn.Module):
|
|
| 123 |
if config.normalization_type == "rmsnorm":
|
| 124 |
self.norm = DragonRMSNorm(hidden_size, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
|
| 125 |
elif config.normalization_type == "seednorm":
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
else:
|
| 128 |
raise ValueError(f"Unknown normalization_type: {config.normalization_type}")
|
| 129 |
|
|
@@ -159,6 +179,54 @@ class DragonSeeDNorm(nn.Module):
|
|
| 159 |
dynamic_scale = rescale.unsqueeze(-1) * self.alpha # (B, L, D)
|
| 160 |
return (dynamic_scale + self.gamma) * self.rms(hidden_states)
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
class DragonLayerNorm(nn.Module):
|
| 163 |
def __init__(self, hidden_size, eps=1e-6): # TODO: ZCG ?
|
| 164 |
super().__init__()
|
|
@@ -1696,6 +1764,8 @@ class DragonDifferentialAttention(nn.Module):
|
|
| 1696 |
hidden_states: torch.Tensor,
|
| 1697 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1698 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
|
|
|
|
|
|
| 1699 |
**kwargs,
|
| 1700 |
):
|
| 1701 |
_, q_len, _ = hidden_states.shape
|
|
@@ -1747,6 +1817,17 @@ class DragonDifferentialAttention(nn.Module):
|
|
| 1747 |
k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 1748 |
v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 1749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1750 |
key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
|
| 1751 |
value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
|
| 1752 |
|
|
@@ -1859,18 +1940,28 @@ class DragonDifferentialAttention(nn.Module):
|
|
| 1859 |
elif DIFF_ATTN_IMPL == "fa2":
|
| 1860 |
def diff_attention_interface(q, k, v, wsize, **kw):
|
| 1861 |
if self.head_qk_dim == self.head_v_dim:
|
| 1862 |
-
|
|
|
|
|
|
|
|
|
|
| 1863 |
D = v.size(3)
|
| 1864 |
v1 = v[:, :, :, :D//2]
|
| 1865 |
v2 = v[:, :, :, D//2:]
|
| 1866 |
-
|
| 1867 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1868 |
o = torch.cat([o1, o2], dim=-1)
|
| 1869 |
return o
|
| 1870 |
elif DIFF_ATTN_IMPL == "fa3":
|
| 1871 |
def diff_attention_interface(q, k, v, wsize, **kw):
|
| 1872 |
if self.head_qk_dim == self.head_v_dim:
|
| 1873 |
-
|
|
|
|
|
|
|
|
|
|
| 1874 |
D = v.size(3)
|
| 1875 |
v1 = v[:, :, :, :D//2]
|
| 1876 |
v2 = v[:, :, :, D//2:]
|
|
@@ -2350,6 +2441,8 @@ class DragonDifferentialTensorProductAttention(nn.Module):
|
|
| 2350 |
hidden_states: torch.Tensor,
|
| 2351 |
position_ids: Optional[torch.LongTensor] = None,
|
| 2352 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
|
|
|
|
|
|
| 2353 |
**kwargs,
|
| 2354 |
):
|
| 2355 |
b, q_len, _ = hidden_states.shape
|
|
@@ -2398,6 +2491,17 @@ class DragonDifferentialTensorProductAttention(nn.Module):
|
|
| 2398 |
k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 2399 |
v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 2400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2401 |
key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
|
| 2402 |
value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
|
| 2403 |
|
|
@@ -2510,7 +2614,10 @@ class DragonDifferentialTensorProductAttention(nn.Module):
|
|
| 2510 |
elif DIFF_ATTN_IMPL == "fa2":
|
| 2511 |
def diff_attention_interface(q, k, v, wsize, **kw):
|
| 2512 |
if self.head_qk_dim == self.head_v_dim:
|
| 2513 |
-
|
|
|
|
|
|
|
|
|
|
| 2514 |
D = v.size(3)
|
| 2515 |
v1 = v[:, :, :, :D//2]
|
| 2516 |
v2 = v[:, :, :, D//2:]
|
|
@@ -2521,7 +2628,10 @@ class DragonDifferentialTensorProductAttention(nn.Module):
|
|
| 2521 |
elif DIFF_ATTN_IMPL == "fa3":
|
| 2522 |
def diff_attention_interface(q, k, v, wsize, **kw):
|
| 2523 |
if self.head_qk_dim == self.head_v_dim:
|
| 2524 |
-
|
|
|
|
|
|
|
|
|
|
| 2525 |
D = v.size(3)
|
| 2526 |
v1 = v[:, :, :, :D//2]
|
| 2527 |
v2 = v[:, :, :, D//2:]
|
|
@@ -3102,6 +3212,7 @@ class DragonGatedDeltaNet(nn.Module):
|
|
| 3102 |
hidden_states: torch.Tensor,
|
| 3103 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 3104 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
|
|
|
| 3105 |
**kwargs,
|
| 3106 |
):
|
| 3107 |
_, q_len, _ = hidden_states.shape
|
|
@@ -3164,12 +3275,15 @@ class DragonGatedDeltaNet(nn.Module):
|
|
| 3164 |
conv_cache = F.pad(mixed_qkv, (self.conv_size - mixed_qkv.shape[-1], 0))
|
| 3165 |
cache_params.conv_caches[self.layer_idx] = conv_cache
|
| 3166 |
if self.causal_conv1d_fn is not None:
|
|
|
|
|
|
|
|
|
|
| 3167 |
mixed_qkv = self.causal_conv1d_fn(
|
| 3168 |
x=mixed_qkv,
|
| 3169 |
weight=self.qkv_conv1d.weight.squeeze(1),
|
| 3170 |
bias=self.qkv_conv1d.bias,
|
| 3171 |
activation='silu',
|
| 3172 |
-
seq_idx=
|
| 3173 |
)
|
| 3174 |
else:
|
| 3175 |
mixed_qkv = F.silu(self.qkv_conv1d(mixed_qkv)[:, :, :q_len])
|
|
@@ -3216,7 +3330,8 @@ class DragonGatedDeltaNet(nn.Module):
|
|
| 3216 |
scale=None if not self.config.use_uscaling else 1/self.dk,
|
| 3217 |
initial_state=None,
|
| 3218 |
output_final_state=cache_params is not None,
|
| 3219 |
-
use_qk_l2norm_in_kernel=True
|
|
|
|
| 3220 |
) # (B L H dv)
|
| 3221 |
else:
|
| 3222 |
o, ssm_cache = self.recurrent_gated_delta_rule(
|
|
@@ -3404,19 +3519,16 @@ class DragonMamba3(nn.Module):
|
|
| 3404 |
)
|
| 3405 |
|
| 3406 |
self.d_model = config.hidden_size
|
| 3407 |
-
self.d_state =
|
| 3408 |
self.conv_init = None
|
| 3409 |
self.expand = 2
|
| 3410 |
-
self.headdim =
|
| 3411 |
-
self.ngroups =
|
| 3412 |
self.activation = "swish"
|
| 3413 |
self.bias = False
|
| 3414 |
-
self.conv_bias = True
|
| 3415 |
self.chunk_size = 128
|
| 3416 |
self.A_floor = 1e-4
|
| 3417 |
self.rope_fraction = 0.5
|
| 3418 |
-
self.remove_conv = True
|
| 3419 |
-
self.add_conv_activation = False
|
| 3420 |
self.dt_min = 0.001
|
| 3421 |
self.dt_max = 0.1
|
| 3422 |
self.dt_init_floor = 1e-4
|
|
@@ -3432,13 +3544,24 @@ class DragonMamba3(nn.Module):
|
|
| 3432 |
if self.split_tensor_size == 0:
|
| 3433 |
return
|
| 3434 |
|
| 3435 |
-
|
|
|
|
| 3436 |
|
| 3437 |
# Order: [x, B, C, dt]
|
| 3438 |
d_in_proj = self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
|
| 3439 |
|
| 3440 |
-
self.
|
| 3441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3442 |
|
| 3443 |
_dt = torch.exp(
|
| 3444 |
torch.rand(self.nheads) * (math.log(self.dt_max) - math.log(self.dt_min))
|
|
@@ -3447,21 +3570,25 @@ class DragonMamba3(nn.Module):
|
|
| 3447 |
_dt = torch.clamp(_dt, min=self.dt_init_floor)
|
| 3448 |
_dt_bias = _dt + torch.log(-torch.expm1(-_dt))
|
| 3449 |
self.dt_bias = nn.Parameter(_dt_bias, requires_grad=True)
|
|
|
|
| 3450 |
|
| 3451 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
|
| 3452 |
|
| 3453 |
-
self.B_bias
|
| 3454 |
-
|
|
|
|
|
|
|
| 3455 |
|
| 3456 |
-
|
| 3457 |
-
|
|
|
|
| 3458 |
|
| 3459 |
-
if not
|
| 3460 |
conv_dim = self.d_inner + 2 * self.d_state * self.ngroups
|
| 3461 |
self.conv1d = nn.Conv1d(
|
| 3462 |
in_channels=conv_dim,
|
| 3463 |
out_channels=conv_dim,
|
| 3464 |
-
bias=
|
| 3465 |
kernel_size=4,
|
| 3466 |
groups=conv_dim,
|
| 3467 |
)
|
|
@@ -3473,8 +3600,14 @@ class DragonMamba3(nn.Module):
|
|
| 3473 |
|
| 3474 |
# D "skip" parameter
|
| 3475 |
self.D = nn.Parameter(torch.ones(self.nheads))
|
|
|
|
| 3476 |
|
| 3477 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3478 |
# Apply in_proj
|
| 3479 |
xBCdt = self.in_proj(hidden_states)
|
| 3480 |
xBC, dd_dt = torch.split(
|
|
@@ -3485,16 +3618,19 @@ class DragonMamba3(nn.Module):
|
|
| 3485 |
],
|
| 3486 |
dim=-1)
|
| 3487 |
|
| 3488 |
-
|
| 3489 |
-
|
|
|
|
|
|
|
|
|
|
| 3490 |
dt = F.softplus(dd_dt + self.dt_bias) # (B, L, N)
|
| 3491 |
|
| 3492 |
-
if not self.
|
| 3493 |
xBC = causal_conv1d_fn(
|
| 3494 |
x=xBC.transpose(1, 2),
|
| 3495 |
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 3496 |
bias=self.conv1d.bias,
|
| 3497 |
-
activation=self.activation
|
| 3498 |
).transpose(1, 2) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
| 3499 |
|
| 3500 |
x, B, C = torch.split(
|
|
@@ -3507,37 +3643,64 @@ class DragonMamba3(nn.Module):
|
|
| 3507 |
B = rearrange(B, "b l (g n) -> b l g n", g=self.ngroups)
|
| 3508 |
C = rearrange(C, "b l (g n) -> b l g n", g=self.ngroups)
|
| 3509 |
|
| 3510 |
-
|
| 3511 |
-
|
|
|
|
| 3512 |
|
| 3513 |
if self.ngroups != self.nheads:
|
| 3514 |
B = B.expand(-1, -1, self.nheads, -1) # (B, L, N, S)
|
| 3515 |
C = C.expand(-1, -1, self.nheads, -1) # (B, L, N, S)
|
| 3516 |
|
| 3517 |
-
|
| 3518 |
-
|
| 3519 |
-
|
|
|
|
| 3520 |
|
| 3521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3522 |
|
| 3523 |
x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim)
|
| 3524 |
|
| 3525 |
A = _A * dt
|
| 3526 |
gating_factor = dt # B, L, N
|
| 3527 |
|
| 3528 |
-
|
|
|
|
| 3529 |
|
| 3530 |
-
|
| 3531 |
-
|
| 3532 |
-
|
| 3533 |
|
| 3534 |
-
|
| 3535 |
-
|
| 3536 |
-
|
| 3537 |
|
| 3538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3539 |
|
| 3540 |
-
|
| 3541 |
x=x.bfloat16(),
|
| 3542 |
A=A,
|
| 3543 |
B=B.bfloat16(),
|
|
@@ -3547,11 +3710,117 @@ class DragonMamba3(nn.Module):
|
|
| 3547 |
gamma=gamma_arr,
|
| 3548 |
CB_sum=CB_sum,
|
| 3549 |
D=self.D,
|
| 3550 |
-
z=None
|
|
|
|
|
|
|
| 3551 |
)
|
| 3552 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3553 |
return y, None, None
|
| 3554 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3555 |
|
| 3556 |
class DragonMamba3Mimo(nn.Module):
|
| 3557 |
def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
|
|
@@ -3570,7 +3839,7 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3570 |
self.conv_init = None
|
| 3571 |
self.expand = 2
|
| 3572 |
self.headdim = 128
|
| 3573 |
-
self.ngroups =
|
| 3574 |
self.activation = "swish"
|
| 3575 |
self.bias = False
|
| 3576 |
self.conv_bias = True
|
|
@@ -3604,7 +3873,7 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3604 |
# Order: [z, x, B, C, dt]
|
| 3605 |
d_in_proj = 2 * self.d_inner + 2 * self.d_state * self.ngroups * self.mimo_dim + self.nheads
|
| 3606 |
|
| 3607 |
-
self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False
|
| 3608 |
self.trapezoid_proj = DragonLinear(config, self.d_model, self.nheads, bias=False)
|
| 3609 |
|
| 3610 |
_dt = torch.exp(
|
|
@@ -3618,9 +3887,9 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3618 |
|
| 3619 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
|
| 3620 |
|
| 3621 |
-
self.B_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
|
| 3622 |
-
self.C_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
|
| 3623 |
-
|
| 3624 |
self.B_norm = DragonNorm(config, self.d_state)
|
| 3625 |
self.C_norm = DragonNorm(config, self.d_state)
|
| 3626 |
|
|
@@ -3655,9 +3924,9 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3655 |
|
| 3656 |
def forward(self, hidden_states, **kwargs):
|
| 3657 |
# Apply in_proj
|
| 3658 |
-
|
| 3659 |
z, xBC, dd_dt = torch.split(
|
| 3660 |
-
|
| 3661 |
[
|
| 3662 |
self.d_inner,
|
| 3663 |
self.d_inner + 2 * self.d_state * self.ngroups * self.mimo_dim,
|
|
@@ -3719,14 +3988,15 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3719 |
C = self.C_norm(C)
|
| 3720 |
|
| 3721 |
if self.ngroups != self.nheads:
|
| 3722 |
-
B = B.expand(-1, -1, self.nheads, -1) # (B, L, R, N, S)
|
| 3723 |
-
C = C.expand(-1, -1, self.nheads, -1) # (B, L, R, N, S)
|
|
|
|
| 3724 |
|
| 3725 |
angle = self.rope_proj(hidden_states) # (B, L, S)
|
| 3726 |
angle = angle.unsqueeze(-2).expand(-1, -1, self.nheads, -1) # (B, L, G, S)
|
| 3727 |
angle = angle_dt(angle, dt)
|
| 3728 |
|
| 3729 |
-
C, B, CB_sum =
|
| 3730 |
|
| 3731 |
x = rearrange(x, "b l r (h p) -> b l r h p", p=self.headdim)
|
| 3732 |
|
|
@@ -3747,7 +4017,7 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3747 |
|
| 3748 |
z = rearrange(z, "b l r (h p) -> b l r h p", p=self.headdim)
|
| 3749 |
|
| 3750 |
-
y =
|
| 3751 |
x=x.bfloat16(),
|
| 3752 |
A=A.bfloat16(),
|
| 3753 |
B=B.bfloat16(),
|
|
@@ -3761,31 +4031,33 @@ class DragonMamba3Mimo(nn.Module):
|
|
| 3761 |
)
|
| 3762 |
|
| 3763 |
y = rearrange(y, "b l r h p -> b l r (h p)")
|
| 3764 |
-
if seqlen_og is not None:
|
| 3765 |
-
|
| 3766 |
|
| 3767 |
# Perform MIMO down projection (mimo_rank*d_inner -> d_inner)
|
| 3768 |
y = rearrange(y, "b l r d -> b l (r d)")
|
| 3769 |
y = rearrange(y, "b l (g d) -> b l g d", g=self.mimo_dim*self.mimo_proj_block_order)
|
| 3770 |
y = torch.einsum("blgd,drg->bldr", y, self.out_proj_mimo)
|
| 3771 |
y = rearrange(y, "b l d r -> b l (d r)")
|
|
|
|
| 3772 |
|
| 3773 |
return y, None, None
|
| 3774 |
|
| 3775 |
class DragonMLP(nn.Module):
|
| 3776 |
-
def __init__(self, config: DragonConfig):
|
| 3777 |
super().__init__()
|
| 3778 |
self.config = config
|
|
|
|
| 3779 |
#print("previous MLP : ", PREVIOUS_MLP)
|
| 3780 |
self.link_size = 16
|
| 3781 |
self.mlp_linking = config.mlp_linking and PREVIOUS_MLP is not None
|
| 3782 |
if self.mlp_linking:
|
| 3783 |
self.previous_mlp = PREVIOUS_MLP
|
| 3784 |
-
self.fc_1 = DragonLinear(config, config.hidden_size,
|
| 3785 |
self.lambda1 = nn.Parameter(torch.zeros(self.link_size)) # sigmoid->0.5
|
| 3786 |
else :
|
| 3787 |
-
self.fc_1 = DragonLinear(config, config.hidden_size,
|
| 3788 |
-
self.fc_2 = DragonLinear(config,
|
| 3789 |
self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
|
| 3790 |
|
| 3791 |
def forward(self, hidden_states):
|
|
@@ -3803,7 +4075,51 @@ class DragonMLP(nn.Module):
|
|
| 3803 |
return hidden_states
|
| 3804 |
|
| 3805 |
def get_mlp_link(self):
|
| 3806 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3807 |
|
| 3808 |
PREVIOUS_MLP = None
|
| 3809 |
class DragonMonoBlock(GradientCheckpointingLayer):
|
|
@@ -3878,6 +4194,16 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 3878 |
head_dim = self.mixer.headdim
|
| 3879 |
num_attention_heads = self.mixer.nheads
|
| 3880 |
use_gate = config.gate_gdn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3881 |
else:
|
| 3882 |
raise ValueError(f"Unknown layer type: {layer_type}")
|
| 3883 |
|
|
@@ -3922,7 +4248,10 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 3922 |
|
| 3923 |
self.input_norm = DragonNorm(config, config.hidden_size)
|
| 3924 |
self.postmixer_norm = DragonNorm(config, config.hidden_size)
|
| 3925 |
-
|
|
|
|
|
|
|
|
|
|
| 3926 |
global PREVIOUS_MLP
|
| 3927 |
PREVIOUS_MLP = self.mlp
|
| 3928 |
|
|
@@ -3938,6 +4267,8 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 3938 |
cache_position: Optional[torch.LongTensor] = None,
|
| 3939 |
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 3940 |
key_value_last_layer: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
|
|
|
|
| 3941 |
**kwargs,
|
| 3942 |
):
|
| 3943 |
# MIXER.
|
|
@@ -3949,6 +4280,8 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 3949 |
position_ids=position_ids,
|
| 3950 |
cache_params=cache_params,
|
| 3951 |
key_value_last_layer=key_value_last_layer,
|
|
|
|
|
|
|
| 3952 |
) # (B, L, E*D)
|
| 3953 |
if self.use_gate:
|
| 3954 |
if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
|
|
@@ -4126,8 +4459,13 @@ class DragonModel(DragonPreTrainedModel):
|
|
| 4126 |
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 4127 |
self.layers = nn.ModuleList([DragonBlock(config, layer_idx=i, layer_type=layer) if layer in ['l', 'r', 'd'] else DragonMonoBlock(config, layer_idx=i, layer_type=layer) for i, layer in enumerate(config.layers_config)])
|
| 4128 |
|
| 4129 |
-
self.
|
| 4130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4131 |
|
| 4132 |
self.gradient_checkpointing = False
|
| 4133 |
self.post_init()
|
|
@@ -4148,6 +4486,8 @@ class DragonModel(DragonPreTrainedModel):
|
|
| 4148 |
cache_position: Optional[torch.LongTensor] = None,
|
| 4149 |
output_hidden_states: Optional[bool] = None,
|
| 4150 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
|
|
| 4151 |
**kwargs
|
| 4152 |
) -> DragonOutput:
|
| 4153 |
B, L = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
|
|
@@ -4191,7 +4531,10 @@ class DragonModel(DragonPreTrainedModel):
|
|
| 4191 |
|
| 4192 |
all_hidden_states = () if output_hidden_states else None
|
| 4193 |
|
| 4194 |
-
|
|
|
|
|
|
|
|
|
|
| 4195 |
|
| 4196 |
shared_kv = (None, None)
|
| 4197 |
for block in self.layers:
|
|
@@ -4205,11 +4548,14 @@ class DragonModel(DragonPreTrainedModel):
|
|
| 4205 |
cache_position=cache_position,
|
| 4206 |
position_embeddings=position_embeddings,
|
| 4207 |
key_value_last_layer=shared_kv,
|
|
|
|
|
|
|
| 4208 |
**kwargs,
|
| 4209 |
)
|
| 4210 |
shared_kv = (last_k, last_v)
|
| 4211 |
|
| 4212 |
-
|
|
|
|
| 4213 |
|
| 4214 |
if output_hidden_states:
|
| 4215 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
@@ -4242,6 +4588,9 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
|
|
| 4242 |
cache_position: Optional[torch.Tensor] = None,
|
| 4243 |
output_hidden_states: Optional[bool] = None,
|
| 4244 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
|
|
|
| 4245 |
token_type_ids=None,
|
| 4246 |
**kwargs,
|
| 4247 |
) -> DragonCausalLMOutput:
|
|
@@ -4256,6 +4605,8 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
|
|
| 4256 |
cache_position=cache_position,
|
| 4257 |
inputs_embeds=inputs_embeds,
|
| 4258 |
output_hidden_states=output_hidden_states,
|
|
|
|
|
|
|
| 4259 |
**kwargs,
|
| 4260 |
)
|
| 4261 |
|
|
@@ -4299,9 +4650,9 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
|
|
| 4299 |
|
| 4300 |
return DragonCausalLMOutput(
|
| 4301 |
loss=loss,
|
| 4302 |
-
logits=logits,
|
| 4303 |
-
past_key_values=outputs.past_key_values,
|
| 4304 |
-
hidden_states=outputs.hidden_states,
|
| 4305 |
)
|
| 4306 |
DragonForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
| 4307 |
|
|
|
|
| 19 |
|
| 20 |
from fla.ops.nsa.parallel import parallel_nsa
|
| 21 |
|
| 22 |
+
try:
|
| 23 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 24 |
+
except ImportError:
|
| 25 |
+
mamba_chunk_scan_combined = None
|
| 26 |
+
|
| 27 |
try:
|
| 28 |
from dragon_mamba3_ops.siso_variant.ssd_combined_fused import mamba_chunk_scan_discretized_combined
|
| 29 |
+
from dragon_mamba3_ops.mimo_variant.ssd_mimo import mamba_chunk_scan_discretized_fused_combined as mamba_mimo_chunk_scan_discretized_fused_combined
|
| 30 |
from dragon_mamba3_ops.angle_cumsum import angle_dt
|
| 31 |
from dragon_mamba3_ops.rotary_mamba import rotary_qk
|
| 32 |
+
from dragon_mamba3_ops.rotary_mamba_mimo import rotary_qk as mimo_rotary_qk
|
| 33 |
+
except ImportError as exc:
|
| 34 |
+
print("Warning: No Mamba-3 found !")
|
| 35 |
+
print(exc)
|
| 36 |
mamba_chunk_scan_discretized_combined, angle_dt, rotary_qk = None, None, None
|
| 37 |
|
| 38 |
try:
|
|
|
|
| 48 |
from fla.ops.kda import chunk_kda, fused_recurrent_kda
|
| 49 |
from fla.ops.kda.gate import fused_kda_gate
|
| 50 |
from fla.modules import FusedRMSNormGated, ShortConvolution
|
| 51 |
+
from fla.ops.utils import prepare_sequence_ids
|
| 52 |
except ImportError:
|
| 53 |
+
chunk_kda, fused_recurrent_kda, fused_kda_gate, prepare_sequence_ids = None, None, None, None
|
| 54 |
|
| 55 |
from torch.compiler import disable
|
| 56 |
|
|
|
|
| 66 |
try:
|
| 67 |
import flash_attn_interface # FA3
|
| 68 |
flash_attn_func = flash_attn_interface.flash_attn_func
|
| 69 |
+
flash_attn_varlen_func = flash_attn_interface.flash_attn_varlen_func
|
| 70 |
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
| 71 |
if not _flash_supports_window_size:
|
| 72 |
raise ImportError("flash_attn_func does not support window_size parameter. Please update to more recent flash_attn version")
|
| 73 |
ATTN_IMPL = "fa3"
|
| 74 |
except ImportError:
|
| 75 |
try:
|
| 76 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func # FA2
|
| 77 |
ATTN_IMPL = "fa2"
|
| 78 |
except ImportError:
|
| 79 |
try:
|
|
|
|
| 134 |
if config.normalization_type == "rmsnorm":
|
| 135 |
self.norm = DragonRMSNorm(hidden_size, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
|
| 136 |
elif config.normalization_type == "seednorm":
|
| 137 |
+
if config.seednorm_type == 1:
|
| 138 |
+
self.norm = DragonSeeDNorm(config, hidden_size, eps=config.norm_epsilon)
|
| 139 |
+
elif config.seednorm_type == 2:
|
| 140 |
+
self.norm = DragonSeeDNormType2(config, hidden_size, eps=config.norm_epsilon)
|
| 141 |
+
elif config.seednorm_type == 3:
|
| 142 |
+
self.norm = DragonSeeDNormType3(config, hidden_size, eps=config.norm_epsilon)
|
| 143 |
+
elif config.seednorm_type == 4:
|
| 144 |
+
self.norm = DragonSeeDNormType4(config, hidden_size, eps=config.norm_epsilon)
|
| 145 |
+
else:
|
| 146 |
+
raise ValueError(f"Unknown seednorm_type: {config.seednorm_type}")
|
| 147 |
else:
|
| 148 |
raise ValueError(f"Unknown normalization_type: {config.normalization_type}")
|
| 149 |
|
|
|
|
| 179 |
dynamic_scale = rescale.unsqueeze(-1) * self.alpha # (B, L, D)
|
| 180 |
return (dynamic_scale + self.gamma) * self.rms(hidden_states)
|
| 181 |
|
| 182 |
+
class DragonSeeDNormType2(nn.Module):
|
| 183 |
+
def __init__(self, config: DragonConfig, hidden_size, eps=1e-6):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.hidden_size = hidden_size
|
| 186 |
+
|
| 187 |
+
self.beta = DragonLinear(config, hidden_size, 1, bias=False)
|
| 188 |
+
self.alpha = nn.Parameter(torch.ones(hidden_size) * 1.)
|
| 189 |
+
if config.seednorm_wd:
|
| 190 |
+
self.alpha.requires_weight_decay = True
|
| 191 |
+
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
| 192 |
+
self.rms = nn.RMSNorm(hidden_size, eps=eps, elementwise_affine=False)
|
| 193 |
+
|
| 194 |
+
def forward(self, hidden_states):
|
| 195 |
+
rescale = F.tanh(self.beta(hidden_states)) # (B, L, 1)
|
| 196 |
+
dynamic_scale = rescale * self.alpha # (B, L, D)
|
| 197 |
+
return (dynamic_scale + self.gamma) * self.rms(hidden_states)
|
| 198 |
+
|
| 199 |
+
class DragonSeeDNormType3(nn.Module):
|
| 200 |
+
def __init__(self, config: DragonConfig, hidden_size, eps=1e-6):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.hidden_size = hidden_size
|
| 203 |
+
|
| 204 |
+
self.beta = nn.Sequential(
|
| 205 |
+
DragonLinear(config, hidden_size, config.seednorm_rank, bias=False),
|
| 206 |
+
DragonLinear(config, config.seednorm_rank, hidden_size, bias=False),
|
| 207 |
+
)
|
| 208 |
+
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
| 209 |
+
self.rms = nn.RMSNorm(hidden_size, eps=eps, elementwise_affine=False)
|
| 210 |
+
|
| 211 |
+
def forward(self, hidden_states):
|
| 212 |
+
dynamic_rescale = F.tanh(self.beta(hidden_states)) # (B, L, D)
|
| 213 |
+
return (dynamic_rescale + self.gamma) * self.rms(hidden_states)
|
| 214 |
+
|
| 215 |
+
class DragonSeeDNormType4(nn.Module):
|
| 216 |
+
def __init__(self, config: DragonConfig, hidden_size, eps=1e-6):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.hidden_size = hidden_size
|
| 219 |
+
|
| 220 |
+
self.beta = nn.Sequential(
|
| 221 |
+
DragonLinear(config, hidden_size, config.seednorm_rank, bias=False),
|
| 222 |
+
DragonLinear(config, config.seednorm_rank, hidden_size, bias=False),
|
| 223 |
+
)
|
| 224 |
+
self.rms = nn.RMSNorm(hidden_size, eps=eps, elementwise_affine=False)
|
| 225 |
+
|
| 226 |
+
def forward(self, hidden_states):
|
| 227 |
+
dynamic_rescale = F.silu(self.beta(hidden_states) + 1.15) # (B, L, D)
|
| 228 |
+
return dynamic_rescale * self.rms(hidden_states)
|
| 229 |
+
|
| 230 |
class DragonLayerNorm(nn.Module):
|
| 231 |
def __init__(self, hidden_size, eps=1e-6): # TODO: ZCG ?
|
| 232 |
super().__init__()
|
|
|
|
| 1764 |
hidden_states: torch.Tensor,
|
| 1765 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1766 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
| 1767 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 1768 |
+
max_seqlen: Optional[int] = None,
|
| 1769 |
**kwargs,
|
| 1770 |
):
|
| 1771 |
_, q_len, _ = hidden_states.shape
|
|
|
|
| 1817 |
k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 1818 |
v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 1819 |
|
| 1820 |
+
if position_ids is not None:
|
| 1821 |
+
# first token of each doc has pos==0
|
| 1822 |
+
doc_start = (position_ids == 0) # (B, L) bool
|
| 1823 |
+
m = doc_start.unsqueeze(-1).unsqueeze(-1) # (B, L, 1, 1) bool
|
| 1824 |
+
|
| 1825 |
+
# zero the previous contribution at boundaries
|
| 1826 |
+
k_prev = k_prev.masked_fill(m, 0)
|
| 1827 |
+
v_prev = v_prev.masked_fill(m, 0)
|
| 1828 |
+
alpha_k = alpha_k.masked_fill(m, 0)
|
| 1829 |
+
alpha_v = alpha_v.masked_fill(m, 0)
|
| 1830 |
+
|
| 1831 |
key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
|
| 1832 |
value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
|
| 1833 |
|
|
|
|
| 1940 |
elif DIFF_ATTN_IMPL == "fa2":
|
| 1941 |
def diff_attention_interface(q, k, v, wsize, **kw):
|
| 1942 |
if self.head_qk_dim == self.head_v_dim:
|
| 1943 |
+
if not self.config.intra_doc_masking:
|
| 1944 |
+
return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)
|
| 1945 |
+
else:
|
| 1946 |
+
return flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw).unsqueeze(0)
|
| 1947 |
D = v.size(3)
|
| 1948 |
v1 = v[:, :, :, :D//2]
|
| 1949 |
v2 = v[:, :, :, D//2:]
|
| 1950 |
+
if not self.config.intra_doc_masking:
|
| 1951 |
+
o1 = flash_attn_func(q, k, v1, window_size=(wsize, 0), **kw)
|
| 1952 |
+
o2 = flash_attn_func(q, k, v2, window_size=(wsize, 0), **kw)
|
| 1953 |
+
else:
|
| 1954 |
+
o1 = flash_attn_varlen_func(q[0], k[0], v1[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw).unsqueeze(0)
|
| 1955 |
+
o2 = flash_attn_varlen_func(q[0], k[0], v2[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw).unsqueeze(0)
|
| 1956 |
o = torch.cat([o1, o2], dim=-1)
|
| 1957 |
return o
|
| 1958 |
elif DIFF_ATTN_IMPL == "fa3":
|
| 1959 |
def diff_attention_interface(q, k, v, wsize, **kw):
|
| 1960 |
if self.head_qk_dim == self.head_v_dim:
|
| 1961 |
+
if not self.config.intra_doc_masking:
|
| 1962 |
+
return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)[0]
|
| 1963 |
+
else:
|
| 1964 |
+
return flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw)[0].unsqueeze(0)
|
| 1965 |
D = v.size(3)
|
| 1966 |
v1 = v[:, :, :, :D//2]
|
| 1967 |
v2 = v[:, :, :, D//2:]
|
|
|
|
| 2441 |
hidden_states: torch.Tensor,
|
| 2442 |
position_ids: Optional[torch.LongTensor] = None,
|
| 2443 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
| 2444 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 2445 |
+
max_seqlen: Optional[int] = None,
|
| 2446 |
**kwargs,
|
| 2447 |
):
|
| 2448 |
b, q_len, _ = hidden_states.shape
|
|
|
|
| 2491 |
k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 2492 |
v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
|
| 2493 |
|
| 2494 |
+
if position_ids is not None:
|
| 2495 |
+
# first token of each doc has pos==0
|
| 2496 |
+
doc_start = (position_ids == 0) # (B, L) bool
|
| 2497 |
+
m = doc_start.unsqueeze(-1).unsqueeze(-1) # (B, L, 1, 1) bool
|
| 2498 |
+
|
| 2499 |
+
# zero the previous contribution at boundaries
|
| 2500 |
+
k_prev = k_prev.masked_fill(m, 0)
|
| 2501 |
+
v_prev = v_prev.masked_fill(m, 0)
|
| 2502 |
+
alpha_k = alpha_k.masked_fill(m, 0)
|
| 2503 |
+
alpha_v = alpha_v.masked_fill(m, 0)
|
| 2504 |
+
|
| 2505 |
key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
|
| 2506 |
value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
|
| 2507 |
|
|
|
|
| 2614 |
elif DIFF_ATTN_IMPL == "fa2":
|
| 2615 |
def diff_attention_interface(q, k, v, wsize, **kw):
|
| 2616 |
if self.head_qk_dim == self.head_v_dim:
|
| 2617 |
+
if not self.config.intra_doc_masking:
|
| 2618 |
+
return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)
|
| 2619 |
+
else:
|
| 2620 |
+
return flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw).unsqueeze(0)
|
| 2621 |
D = v.size(3)
|
| 2622 |
v1 = v[:, :, :, :D//2]
|
| 2623 |
v2 = v[:, :, :, D//2:]
|
|
|
|
| 2628 |
elif DIFF_ATTN_IMPL == "fa3":
|
| 2629 |
def diff_attention_interface(q, k, v, wsize, **kw):
|
| 2630 |
if self.head_qk_dim == self.head_v_dim:
|
| 2631 |
+
if not self.config.intra_doc_masking:
|
| 2632 |
+
return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)[0]
|
| 2633 |
+
else:
|
| 2634 |
+
return flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw)[0].unsqueeze(0)
|
| 2635 |
D = v.size(3)
|
| 2636 |
v1 = v[:, :, :, :D//2]
|
| 2637 |
v2 = v[:, :, :, D//2:]
|
|
|
|
| 3212 |
hidden_states: torch.Tensor,
|
| 3213 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 3214 |
cache_params: Optional[HybridDragonDynamicCache] = None,
|
| 3215 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 3216 |
**kwargs,
|
| 3217 |
):
|
| 3218 |
_, q_len, _ = hidden_states.shape
|
|
|
|
| 3275 |
conv_cache = F.pad(mixed_qkv, (self.conv_size - mixed_qkv.shape[-1], 0))
|
| 3276 |
cache_params.conv_caches[self.layer_idx] = conv_cache
|
| 3277 |
if self.causal_conv1d_fn is not None:
|
| 3278 |
+
seq_idx = None
|
| 3279 |
+
if cu_seqlens is not None:
|
| 3280 |
+
seq_idx = prepare_sequence_ids(cu_seqlens).to(torch.int32).unsqueeze(0)
|
| 3281 |
mixed_qkv = self.causal_conv1d_fn(
|
| 3282 |
x=mixed_qkv,
|
| 3283 |
weight=self.qkv_conv1d.weight.squeeze(1),
|
| 3284 |
bias=self.qkv_conv1d.bias,
|
| 3285 |
activation='silu',
|
| 3286 |
+
seq_idx=seq_idx,
|
| 3287 |
)
|
| 3288 |
else:
|
| 3289 |
mixed_qkv = F.silu(self.qkv_conv1d(mixed_qkv)[:, :, :q_len])
|
|
|
|
| 3330 |
scale=None if not self.config.use_uscaling else 1/self.dk,
|
| 3331 |
initial_state=None,
|
| 3332 |
output_final_state=cache_params is not None,
|
| 3333 |
+
use_qk_l2norm_in_kernel=True,
|
| 3334 |
+
cu_seqlens=cu_seqlens,
|
| 3335 |
) # (B L H dv)
|
| 3336 |
else:
|
| 3337 |
o, ssm_cache = self.recurrent_gated_delta_rule(
|
|
|
|
| 3519 |
)
|
| 3520 |
|
| 3521 |
self.d_model = config.hidden_size
|
| 3522 |
+
self.d_state = 128
|
| 3523 |
self.conv_init = None
|
| 3524 |
self.expand = 2
|
| 3525 |
+
self.headdim = 64
|
| 3526 |
+
self.ngroups = config.mamba_ngroups
|
| 3527 |
self.activation = "swish"
|
| 3528 |
self.bias = False
|
|
|
|
| 3529 |
self.chunk_size = 128
|
| 3530 |
self.A_floor = 1e-4
|
| 3531 |
self.rope_fraction = 0.5
|
|
|
|
|
|
|
| 3532 |
self.dt_min = 0.001
|
| 3533 |
self.dt_max = 0.1
|
| 3534 |
self.dt_init_floor = 1e-4
|
|
|
|
| 3544 |
if self.split_tensor_size == 0:
|
| 3545 |
return
|
| 3546 |
|
| 3547 |
+
if config.mamba3_rope:
|
| 3548 |
+
self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
|
| 3549 |
|
| 3550 |
# Order: [x, B, C, dt]
|
| 3551 |
d_in_proj = self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
|
| 3552 |
|
| 3553 |
+
if self.config.mamba3_is_A_dd:
|
| 3554 |
+
self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
|
| 3555 |
+
else:
|
| 3556 |
+
A_init_range = (1, 16)
|
| 3557 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
| 3558 |
+
A = torch.empty(self.nheads, dtype=torch.float32).uniform_(*A_init_range)
|
| 3559 |
+
A_log = torch.log(A).to(dtype=torch.float32)
|
| 3560 |
+
self.A_log = nn.Parameter(A_log)
|
| 3561 |
+
self.A_log._no_weight_decay = True
|
| 3562 |
+
|
| 3563 |
+
if config.mamba3_add_trapezoid:
|
| 3564 |
+
self.trapezoid_proj = DragonLinear(config, self.d_model, self.nheads, bias=False)
|
| 3565 |
|
| 3566 |
_dt = torch.exp(
|
| 3567 |
torch.rand(self.nheads) * (math.log(self.dt_max) - math.log(self.dt_min))
|
|
|
|
| 3570 |
_dt = torch.clamp(_dt, min=self.dt_init_floor)
|
| 3571 |
_dt_bias = _dt + torch.log(-torch.expm1(-_dt))
|
| 3572 |
self.dt_bias = nn.Parameter(_dt_bias, requires_grad=True)
|
| 3573 |
+
self.dt_bias._no_weight_decay = True
|
| 3574 |
|
| 3575 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
|
| 3576 |
|
| 3577 |
+
self.B_bias, self.C_bias = None, None
|
| 3578 |
+
if not config.mamba3_remove_BC_bias:
|
| 3579 |
+
self.B_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
|
| 3580 |
+
self.C_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
|
| 3581 |
|
| 3582 |
+
if config.mamba3_is_id_rms:
|
| 3583 |
+
self.B_norm = DragonNorm(config, self.d_state)
|
| 3584 |
+
self.C_norm = DragonNorm(config, self.d_state)
|
| 3585 |
|
| 3586 |
+
if not config.mamba3_remove_conv:
|
| 3587 |
conv_dim = self.d_inner + 2 * self.d_state * self.ngroups
|
| 3588 |
self.conv1d = nn.Conv1d(
|
| 3589 |
in_channels=conv_dim,
|
| 3590 |
out_channels=conv_dim,
|
| 3591 |
+
bias=False,
|
| 3592 |
kernel_size=4,
|
| 3593 |
groups=conv_dim,
|
| 3594 |
)
|
|
|
|
| 3600 |
|
| 3601 |
# D "skip" parameter
|
| 3602 |
self.D = nn.Parameter(torch.ones(self.nheads))
|
| 3603 |
+
self.D._no_weight_decay = True
|
| 3604 |
|
| 3605 |
+
def forward(
|
| 3606 |
+
self,
|
| 3607 |
+
hidden_states: torch.Tensor,
|
| 3608 |
+
cache_params: Optional[HybridDragonDynamicCache] = None,
|
| 3609 |
+
**kwargs
|
| 3610 |
+
):
|
| 3611 |
# Apply in_proj
|
| 3612 |
xBCdt = self.in_proj(hidden_states)
|
| 3613 |
xBC, dd_dt = torch.split(
|
|
|
|
| 3618 |
],
|
| 3619 |
dim=-1)
|
| 3620 |
|
| 3621 |
+
if self.config.mamba3_is_A_dd:
|
| 3622 |
+
_A = -F.softplus((self.A_proj(hidden_states.to(torch.float32))).to(torch.float32)) # (B, L, N)
|
| 3623 |
+
_A = torch.clamp(_A, max=-self.A_floor)
|
| 3624 |
+
else:
|
| 3625 |
+
_A = -torch.exp(self.A_log).unsqueeze(0).unsqueeze(0)
|
| 3626 |
dt = F.softplus(dd_dt + self.dt_bias) # (B, L, N)
|
| 3627 |
|
| 3628 |
+
if not self.config.mamba3_remove_conv:
|
| 3629 |
xBC = causal_conv1d_fn(
|
| 3630 |
x=xBC.transpose(1, 2),
|
| 3631 |
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 3632 |
bias=self.conv1d.bias,
|
| 3633 |
+
activation=self.activation,
|
| 3634 |
).transpose(1, 2) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
| 3635 |
|
| 3636 |
x, B, C = torch.split(
|
|
|
|
| 3643 |
B = rearrange(B, "b l (g n) -> b l g n", g=self.ngroups)
|
| 3644 |
C = rearrange(C, "b l (g n) -> b l g n", g=self.ngroups)
|
| 3645 |
|
| 3646 |
+
if self.config.mamba3_is_id_rms:
|
| 3647 |
+
B = self.B_norm(B)
|
| 3648 |
+
C = self.C_norm(C)
|
| 3649 |
|
| 3650 |
if self.ngroups != self.nheads:
|
| 3651 |
B = B.expand(-1, -1, self.nheads, -1) # (B, L, N, S)
|
| 3652 |
C = C.expand(-1, -1, self.nheads, -1) # (B, L, N, S)
|
| 3653 |
|
| 3654 |
+
if self.config.mamba3_rope:
|
| 3655 |
+
angle = self.rope_proj(hidden_states) # (B, L, S)
|
| 3656 |
+
angle = angle.unsqueeze(-2).expand(-1, -1, self.nheads, -1) # (B, L, G, S)
|
| 3657 |
+
angle = angle_dt(angle, dt)
|
| 3658 |
|
| 3659 |
+
C, B, CB_sum = rotary_qk(q=C, k=B, angle=angle, bias_q=self.C_bias, bias_k=self.B_bias, conjugate=False, inplace=False)
|
| 3660 |
+
else:
|
| 3661 |
+
if not self.config.mamba3_remove_BC_bias:
|
| 3662 |
+
og_dtpe = B.dtype
|
| 3663 |
+
B = (B + self.B_bias).to(og_dtpe)
|
| 3664 |
+
C = (C + self.C_bias).to(og_dtpe)
|
| 3665 |
+
|
| 3666 |
+
CB_sum = torch.sum(
|
| 3667 |
+
B.to(torch.float32)*C.to(torch.float32),
|
| 3668 |
+
dim=-1,
|
| 3669 |
+
keepdim=False
|
| 3670 |
+
)
|
| 3671 |
|
| 3672 |
x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim)
|
| 3673 |
|
| 3674 |
A = _A * dt
|
| 3675 |
gating_factor = dt # B, L, N
|
| 3676 |
|
| 3677 |
+
if self.config.mamba3_add_trapezoid:
|
| 3678 |
+
trap = F.sigmoid(self.trapezoid_proj(hidden_states)) # (B, L, N)
|
| 3679 |
|
| 3680 |
+
alpha_arr = torch.exp(A)
|
| 3681 |
+
beta_arr = (1-trap)*gating_factor*alpha_arr
|
| 3682 |
+
gamma_arr = trap*gating_factor
|
| 3683 |
|
| 3684 |
+
# roll alpha and beta to the left by 1
|
| 3685 |
+
_alpha_arr = torch.roll(alpha_arr, shifts=-1, dims=1)
|
| 3686 |
+
_beta_arr = torch.roll(beta_arr, shifts=-1, dims=1)
|
| 3687 |
|
| 3688 |
+
x_scalar = (gamma_arr*_alpha_arr + _beta_arr).to(torch.bfloat16)
|
| 3689 |
+
else:
|
| 3690 |
+
alpha_arr = torch.exp(A)
|
| 3691 |
+
beta_arr = torch.zeros_like(alpha_arr)
|
| 3692 |
+
gamma_arr = gating_factor
|
| 3693 |
+
|
| 3694 |
+
# roll alpha to the left by 1
|
| 3695 |
+
_alpha_arr = torch.roll(alpha_arr, shifts=-1, dims=1)
|
| 3696 |
+
|
| 3697 |
+
x_scalar = (gamma_arr*_alpha_arr).to(torch.bfloat16)
|
| 3698 |
+
|
| 3699 |
+
ssm_cache = None
|
| 3700 |
+
if cache_params is not None:
|
| 3701 |
+
ssm_cache = cache_params.ssm_caches[self.layer_idx]
|
| 3702 |
|
| 3703 |
+
out = mamba_chunk_scan_discretized_combined(
|
| 3704 |
x=x.bfloat16(),
|
| 3705 |
A=A,
|
| 3706 |
B=B.bfloat16(),
|
|
|
|
| 3710 |
gamma=gamma_arr,
|
| 3711 |
CB_sum=CB_sum,
|
| 3712 |
D=self.D,
|
| 3713 |
+
z=None,
|
| 3714 |
+
initial_states=ssm_cache,
|
| 3715 |
+
return_final_states=cache_params is not None,
|
| 3716 |
)
|
| 3717 |
|
| 3718 |
+
if cache_params is not None:
|
| 3719 |
+
y, ssm_cache = out
|
| 3720 |
+
cache_params.ssm_caches[self.layer_idx] = ssm_cache
|
| 3721 |
+
else:
|
| 3722 |
+
y = out
|
| 3723 |
+
|
| 3724 |
return y, None, None
|
| 3725 |
|
| 3726 |
+
class DragonMamba2(nn.Module):
|
| 3727 |
+
def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
|
| 3728 |
+
super().__init__()
|
| 3729 |
+
self.d_model = config.hidden_size
|
| 3730 |
+
self.d_state = 128
|
| 3731 |
+
self.expand = 2
|
| 3732 |
+
self.d_inner = self.expand * self.d_model
|
| 3733 |
+
self.headdim = 64
|
| 3734 |
+
self.ngroups = config.mamba_ngroups
|
| 3735 |
+
assert self.d_inner % self.headdim == 0
|
| 3736 |
+
self.nheads = self.d_inner // self.headdim
|
| 3737 |
+
self.layer_idx = layer_idx
|
| 3738 |
+
|
| 3739 |
+
# Order: [x, B, C, dt]
|
| 3740 |
+
d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
| 3741 |
+
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
|
| 3742 |
+
|
| 3743 |
+
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
| 3744 |
+
self.conv1d = nn.Conv1d(
|
| 3745 |
+
in_channels=conv_dim,
|
| 3746 |
+
out_channels=conv_dim,
|
| 3747 |
+
bias=False,
|
| 3748 |
+
kernel_size=4,
|
| 3749 |
+
groups=conv_dim,
|
| 3750 |
+
padding=4-1,
|
| 3751 |
+
)
|
| 3752 |
+
self.act = nn.SiLU()
|
| 3753 |
+
|
| 3754 |
+
# Initialize log dt bias
|
| 3755 |
+
dt_min=0.001
|
| 3756 |
+
dt_max=0.1
|
| 3757 |
+
dt_init_floor=1e-4
|
| 3758 |
+
dt_limit=(0.0, float("inf"))
|
| 3759 |
+
dt = torch.exp(torch.rand(self.nheads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
|
| 3760 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
| 3761 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 3762 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 3763 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
| 3764 |
+
self.dt_bias._no_weight_decay = True
|
| 3765 |
+
|
| 3766 |
+
# A parameter
|
| 3767 |
+
A_init_range=(1, 16)
|
| 3768 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
| 3769 |
+
A = torch.empty(self.nheads, dtype=torch.float32).uniform_(*A_init_range)
|
| 3770 |
+
A_log = torch.log(A)
|
| 3771 |
+
self.A_log = nn.Parameter(A_log)
|
| 3772 |
+
self.A_log._no_weight_decay = True
|
| 3773 |
+
|
| 3774 |
+
# D "skip" parameter
|
| 3775 |
+
self.D = nn.Parameter(torch.ones(self.nheads))
|
| 3776 |
+
self.D._no_weight_decay = True
|
| 3777 |
+
|
| 3778 |
+
def forward(self, hidden_states, **kwargs):
|
| 3779 |
+
"""
|
| 3780 |
+
u: (B, L, D)
|
| 3781 |
+
Returns: same shape as u
|
| 3782 |
+
"""
|
| 3783 |
+
_, seqlen, _ = hidden_states.shape
|
| 3784 |
+
|
| 3785 |
+
zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj)
|
| 3786 |
+
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
| 3787 |
+
|
| 3788 |
+
xBC, dt = torch.split(
|
| 3789 |
+
zxbcdt, [self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
|
| 3790 |
+
)
|
| 3791 |
+
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
| 3792 |
+
|
| 3793 |
+
# 1D Convolution
|
| 3794 |
+
if causal_conv1d_fn is None:
|
| 3795 |
+
xBC = self.act(
|
| 3796 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
| 3797 |
+
) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
| 3798 |
+
xBC = xBC[:, :seqlen, :]
|
| 3799 |
+
else:
|
| 3800 |
+
xBC = causal_conv1d_fn(
|
| 3801 |
+
x=xBC.transpose(1, 2),
|
| 3802 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 3803 |
+
bias=self.conv1d.bias,
|
| 3804 |
+
activation="swish",
|
| 3805 |
+
).transpose(1, 2)
|
| 3806 |
+
|
| 3807 |
+
# Split into 3 main branches: X, B, C
|
| 3808 |
+
# These correspond to V, K, Q respectively in the SSM/attention duality
|
| 3809 |
+
x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
|
| 3810 |
+
y = mamba_chunk_scan_combined(
|
| 3811 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
| 3812 |
+
dt,
|
| 3813 |
+
A,
|
| 3814 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
| 3815 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
| 3816 |
+
chunk_size=256,
|
| 3817 |
+
D=self.D,
|
| 3818 |
+
z=None,
|
| 3819 |
+
seq_idx=None,
|
| 3820 |
+
initial_states=None,
|
| 3821 |
+
)
|
| 3822 |
+
|
| 3823 |
+
return y, None, None
|
| 3824 |
|
| 3825 |
class DragonMamba3Mimo(nn.Module):
|
| 3826 |
def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
|
|
|
|
| 3839 |
self.conv_init = None
|
| 3840 |
self.expand = 2
|
| 3841 |
self.headdim = 128
|
| 3842 |
+
self.ngroups = config.mamba_ngroups
|
| 3843 |
self.activation = "swish"
|
| 3844 |
self.bias = False
|
| 3845 |
self.conv_bias = True
|
|
|
|
| 3873 |
# Order: [z, x, B, C, dt]
|
| 3874 |
d_in_proj = 2 * self.d_inner + 2 * self.d_state * self.ngroups * self.mimo_dim + self.nheads
|
| 3875 |
|
| 3876 |
+
self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
|
| 3877 |
self.trapezoid_proj = DragonLinear(config, self.d_model, self.nheads, bias=False)
|
| 3878 |
|
| 3879 |
_dt = torch.exp(
|
|
|
|
| 3887 |
|
| 3888 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
|
| 3889 |
|
| 3890 |
+
self.B_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
|
| 3891 |
+
self.C_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
|
| 3892 |
+
|
| 3893 |
self.B_norm = DragonNorm(config, self.d_state)
|
| 3894 |
self.C_norm = DragonNorm(config, self.d_state)
|
| 3895 |
|
|
|
|
| 3924 |
|
| 3925 |
def forward(self, hidden_states, **kwargs):
|
| 3926 |
# Apply in_proj
|
| 3927 |
+
zxBCdt = self.in_proj(hidden_states)
|
| 3928 |
z, xBC, dd_dt = torch.split(
|
| 3929 |
+
zxBCdt,
|
| 3930 |
[
|
| 3931 |
self.d_inner,
|
| 3932 |
self.d_inner + 2 * self.d_state * self.ngroups * self.mimo_dim,
|
|
|
|
| 3988 |
C = self.C_norm(C)
|
| 3989 |
|
| 3990 |
if self.ngroups != self.nheads:
|
| 3991 |
+
B = B.expand(-1, -1, -1, self.nheads, -1) # (B, L, R, N, S)
|
| 3992 |
+
C = C.expand(-1, -1, -1, self.nheads, -1) # (B, L, R, N, S)
|
| 3993 |
+
|
| 3994 |
|
| 3995 |
angle = self.rope_proj(hidden_states) # (B, L, S)
|
| 3996 |
angle = angle.unsqueeze(-2).expand(-1, -1, self.nheads, -1) # (B, L, G, S)
|
| 3997 |
angle = angle_dt(angle, dt)
|
| 3998 |
|
| 3999 |
+
C, B, CB_sum = mimo_rotary_qk(q=C, k=B, angle=angle, bias_q=self.C_bias, bias_k=self.B_bias, conjugate=False, inplace=False)
|
| 4000 |
|
| 4001 |
x = rearrange(x, "b l r (h p) -> b l r h p", p=self.headdim)
|
| 4002 |
|
|
|
|
| 4017 |
|
| 4018 |
z = rearrange(z, "b l r (h p) -> b l r h p", p=self.headdim)
|
| 4019 |
|
| 4020 |
+
y = mamba_mimo_chunk_scan_discretized_fused_combined(
|
| 4021 |
x=x.bfloat16(),
|
| 4022 |
A=A.bfloat16(),
|
| 4023 |
B=B.bfloat16(),
|
|
|
|
| 4031 |
)
|
| 4032 |
|
| 4033 |
y = rearrange(y, "b l r h p -> b l r (h p)")
|
| 4034 |
+
#if seqlen_og is not None:
|
| 4035 |
+
# y = rearrange(y, "b l r d -> (b l) r d")
|
| 4036 |
|
| 4037 |
# Perform MIMO down projection (mimo_rank*d_inner -> d_inner)
|
| 4038 |
y = rearrange(y, "b l r d -> b l (r d)")
|
| 4039 |
y = rearrange(y, "b l (g d) -> b l g d", g=self.mimo_dim*self.mimo_proj_block_order)
|
| 4040 |
y = torch.einsum("blgd,drg->bldr", y, self.out_proj_mimo)
|
| 4041 |
y = rearrange(y, "b l d r -> b l (d r)")
|
| 4042 |
+
y = rearrange(y, "b l (h d) -> b l h d", d=self.headdim)
|
| 4043 |
|
| 4044 |
return y, None, None
|
| 4045 |
|
| 4046 |
class DragonMLP(nn.Module):
|
| 4047 |
+
def __init__(self, config: DragonConfig, intermediate_size: Optional[int] = None):
|
| 4048 |
super().__init__()
|
| 4049 |
self.config = config
|
| 4050 |
+
intermediate_size = intermediate_size or config.intermediate_size
|
| 4051 |
#print("previous MLP : ", PREVIOUS_MLP)
|
| 4052 |
self.link_size = 16
|
| 4053 |
self.mlp_linking = config.mlp_linking and PREVIOUS_MLP is not None
|
| 4054 |
if self.mlp_linking:
|
| 4055 |
self.previous_mlp = PREVIOUS_MLP
|
| 4056 |
+
self.fc_1 = DragonLinear(config, config.hidden_size, intermediate_size, bias=False)
|
| 4057 |
self.lambda1 = nn.Parameter(torch.zeros(self.link_size)) # sigmoid->0.5
|
| 4058 |
else :
|
| 4059 |
+
self.fc_1 = DragonLinear(config, config.hidden_size, intermediate_size, bias=False)
|
| 4060 |
+
self.fc_2 = DragonLinear(config, intermediate_size, config.hidden_size, bias=False)
|
| 4061 |
self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
|
| 4062 |
|
| 4063 |
def forward(self, hidden_states):
|
|
|
|
| 4075 |
return hidden_states
|
| 4076 |
|
| 4077 |
def get_mlp_link(self):
|
| 4078 |
+
mlp_link = self.mlp_link
|
| 4079 |
+
self.mlp_link = None
|
| 4080 |
+
return mlp_link
|
| 4081 |
+
|
| 4082 |
+
class DragonGatedMLP(nn.Module):
|
| 4083 |
+
def __init__(self, config: DragonConfig, intermediate_size: Optional[int] = None, num_active_experts: int = 1):
|
| 4084 |
+
super().__init__()
|
| 4085 |
+
self.config = config
|
| 4086 |
+
self.intermediate_size = intermediate_size
|
| 4087 |
+
|
| 4088 |
+
self.fc_1 = DragonLinear(config, config.hidden_size, num_active_experts*self.intermediate_size, bias=False)
|
| 4089 |
+
self.fc_2 = DragonLinear(config, num_active_experts*self.intermediate_size, config.hidden_size, bias=False)
|
| 4090 |
+
self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
|
| 4091 |
+
|
| 4092 |
+
def forward(self, hidden_states, gates):
|
| 4093 |
+
B, L, D = hidden_states.size()
|
| 4094 |
+
hidden_states = self.fc_1(hidden_states) # (B, L, E*D)
|
| 4095 |
+
hidden_states = self._2_sqrt_5 * F.relu(hidden_states).square().view(B, L, -1, self.intermediate_size) # (B, L, E, D)
|
| 4096 |
+
hidden_states = hidden_states * gates.unsqueeze(-1) # (B, L, E, D)
|
| 4097 |
+
hidden_states = self.fc_2(hidden_states.view(B, L, -1)) # (B, L, D)
|
| 4098 |
+
return hidden_states
|
| 4099 |
+
|
| 4100 |
+
class DragonMoE(nn.Module):
|
| 4101 |
+
def __init__(self, config: DragonConfig):
|
| 4102 |
+
super().__init__()
|
| 4103 |
+
self.config = config
|
| 4104 |
+
self.num_experts = config.moe_num_routed_experts
|
| 4105 |
+
self.routed_scaling_factor = config.moe_routed_scaling_factor
|
| 4106 |
+
|
| 4107 |
+
self.router = DragonLinear(config, config.hidden_size, self.num_experts, bias=False, dtype=torch.float32)
|
| 4108 |
+
self.experts = DragonGatedMLP(config, config.moe_routed_intermediate_size, self.num_experts)
|
| 4109 |
+
if config.moe_shared_intermediate_size > 0:
|
| 4110 |
+
self.shared_expert = DragonMLP(config, config.moe_shared_intermediate_size)
|
| 4111 |
+
|
| 4112 |
+
def forward(self, hidden_states):
|
| 4113 |
+
# compute gating score.
|
| 4114 |
+
weights = F.sigmoid(self.router(hidden_states.to(torch.float32))) # (B, L, experts)
|
| 4115 |
+
weights = weights / weights.sum(dim=-1, keepdim=True) # (B, L, experts)
|
| 4116 |
+
weights = (weights * self.routed_scaling_factor).to(hidden_states.dtype)
|
| 4117 |
+
# forward through (routed) experts.
|
| 4118 |
+
y = self.experts(hidden_states, weights) # (B, L, E, D)
|
| 4119 |
+
# forward through shared expert.
|
| 4120 |
+
if self.config.moe_shared_intermediate_size > 0:
|
| 4121 |
+
y = y + self.shared_expert(hidden_states)
|
| 4122 |
+
return y
|
| 4123 |
|
| 4124 |
PREVIOUS_MLP = None
|
| 4125 |
class DragonMonoBlock(GradientCheckpointingLayer):
|
|
|
|
| 4194 |
head_dim = self.mixer.headdim
|
| 4195 |
num_attention_heads = self.mixer.nheads
|
| 4196 |
use_gate = config.gate_gdn
|
| 4197 |
+
elif layer_type == '2':
|
| 4198 |
+
self.mixer = DragonMamba2(config, layer_idx=layer_idx)
|
| 4199 |
+
head_dim = self.mixer.headdim
|
| 4200 |
+
num_attention_heads = self.mixer.nheads
|
| 4201 |
+
use_gate = config.gate_gdn
|
| 4202 |
+
elif layer_type == 'M':
|
| 4203 |
+
self.mixer = DragonMamba3Mimo(config, layer_idx=layer_idx)
|
| 4204 |
+
head_dim = self.mixer.headdim
|
| 4205 |
+
num_attention_heads = self.mixer.nheads
|
| 4206 |
+
use_gate = False # inside Mamba3Mimo
|
| 4207 |
else:
|
| 4208 |
raise ValueError(f"Unknown layer type: {layer_type}")
|
| 4209 |
|
|
|
|
| 4248 |
|
| 4249 |
self.input_norm = DragonNorm(config, config.hidden_size)
|
| 4250 |
self.postmixer_norm = DragonNorm(config, config.hidden_size)
|
| 4251 |
+
if not config.moe:
|
| 4252 |
+
self.mlp = DragonMLP(config)
|
| 4253 |
+
else:
|
| 4254 |
+
self.mlp = DragonMoE(config)
|
| 4255 |
global PREVIOUS_MLP
|
| 4256 |
PREVIOUS_MLP = self.mlp
|
| 4257 |
|
|
|
|
| 4267 |
cache_position: Optional[torch.LongTensor] = None,
|
| 4268 |
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 4269 |
key_value_last_layer: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 4270 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 4271 |
+
max_seqlen: Optional[int] = None,
|
| 4272 |
**kwargs,
|
| 4273 |
):
|
| 4274 |
# MIXER.
|
|
|
|
| 4280 |
position_ids=position_ids,
|
| 4281 |
cache_params=cache_params,
|
| 4282 |
key_value_last_layer=key_value_last_layer,
|
| 4283 |
+
cu_seqlens=cu_seqlens,
|
| 4284 |
+
max_seqlen=max_seqlen,
|
| 4285 |
) # (B, L, E*D)
|
| 4286 |
if self.use_gate:
|
| 4287 |
if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
|
|
|
|
| 4459 |
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 4460 |
self.layers = nn.ModuleList([DragonBlock(config, layer_idx=i, layer_type=layer) if layer in ['l', 'r', 'd'] else DragonMonoBlock(config, layer_idx=i, layer_type=layer) for i, layer in enumerate(config.layers_config)])
|
| 4461 |
|
| 4462 |
+
if self.config.rope_type_global != '' or self.config.rope_type_local != '':
|
| 4463 |
+
self.rotary_emb = DragonRotaryEmbedding(config, head_dim=config.head_dim if config.head_dim else (config.expand_factor*config.hidden_size)//config.num_attention_heads, theta=config.rope_theta_local) # only for SWA
|
| 4464 |
+
else:
|
| 4465 |
+
self.rotary_emb = None
|
| 4466 |
+
|
| 4467 |
+
if self.config.final_norm:
|
| 4468 |
+
self.final_norm = DragonNorm(config, config.hidden_size)
|
| 4469 |
|
| 4470 |
self.gradient_checkpointing = False
|
| 4471 |
self.post_init()
|
|
|
|
| 4486 |
cache_position: Optional[torch.LongTensor] = None,
|
| 4487 |
output_hidden_states: Optional[bool] = None,
|
| 4488 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 4489 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 4490 |
+
max_seqlen: Optional[int] = None,
|
| 4491 |
**kwargs
|
| 4492 |
) -> DragonOutput:
|
| 4493 |
B, L = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
|
|
|
|
| 4531 |
|
| 4532 |
all_hidden_states = () if output_hidden_states else None
|
| 4533 |
|
| 4534 |
+
if self.rotary_emb is not None:
|
| 4535 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 4536 |
+
else:
|
| 4537 |
+
position_embeddings = None
|
| 4538 |
|
| 4539 |
shared_kv = (None, None)
|
| 4540 |
for block in self.layers:
|
|
|
|
| 4548 |
cache_position=cache_position,
|
| 4549 |
position_embeddings=position_embeddings,
|
| 4550 |
key_value_last_layer=shared_kv,
|
| 4551 |
+
cu_seqlens=cu_seqlens,
|
| 4552 |
+
max_seqlen=max_seqlen,
|
| 4553 |
**kwargs,
|
| 4554 |
)
|
| 4555 |
shared_kv = (last_k, last_v)
|
| 4556 |
|
| 4557 |
+
if self.config.final_norm:
|
| 4558 |
+
hidden_states = self.final_norm(hidden_states)
|
| 4559 |
|
| 4560 |
if output_hidden_states:
|
| 4561 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
| 4588 |
cache_position: Optional[torch.Tensor] = None,
|
| 4589 |
output_hidden_states: Optional[bool] = None,
|
| 4590 |
attention_mask: Optional[torch.Tensor] = None,
|
| 4591 |
+
just_loss: Optional[bool] = False,
|
| 4592 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 4593 |
+
max_seqlen: Optional[int] = None,
|
| 4594 |
token_type_ids=None,
|
| 4595 |
**kwargs,
|
| 4596 |
) -> DragonCausalLMOutput:
|
|
|
|
| 4605 |
cache_position=cache_position,
|
| 4606 |
inputs_embeds=inputs_embeds,
|
| 4607 |
output_hidden_states=output_hidden_states,
|
| 4608 |
+
cu_seqlens=cu_seqlens,
|
| 4609 |
+
max_seqlen=max_seqlen,
|
| 4610 |
**kwargs,
|
| 4611 |
)
|
| 4612 |
|
|
|
|
| 4650 |
|
| 4651 |
return DragonCausalLMOutput(
|
| 4652 |
loss=loss,
|
| 4653 |
+
logits=logits if not just_loss else None,
|
| 4654 |
+
past_key_values=outputs.past_key_values if not just_loss else None,
|
| 4655 |
+
hidden_states=outputs.hidden_states if not just_loss else None,
|
| 4656 |
)
|
| 4657 |
DragonForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
| 4658 |
|
training_dragon.py
CHANGED
|
@@ -35,8 +35,8 @@ class NanoArgs:
|
|
| 35 |
head_dim: Optional[int] = None
|
| 36 |
layers_config : str = 4*"lrdlr"
|
| 37 |
expand_factor : int = 2 # expand factor for Mamba/Dragon
|
| 38 |
-
rope_type_local: str = "
|
| 39 |
-
rope_type_global: str = "
|
| 40 |
rope_theta_local: float = 10000.0
|
| 41 |
rope_theta_global: float = 0.0
|
| 42 |
eps_rmsnorm: float = 1e-6
|
|
@@ -54,8 +54,18 @@ class NanoArgs:
|
|
| 54 |
scalar_proj_as_hidden_matrix: bool = True
|
| 55 |
normalization_type: str = "rmsnorm" # rmsnorm, seednorm
|
| 56 |
seednorm_wd: bool = True
|
|
|
|
|
|
|
| 57 |
mixer_gn: bool = True
|
| 58 |
mlp_linking : bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# attention related
|
| 61 |
n_kv_heads : int = 0
|
|
@@ -93,6 +103,14 @@ class NanoArgs:
|
|
| 93 |
shrink_qk_gdn: int = 2
|
| 94 |
kda_allow_neg_eigval: bool = False
|
| 95 |
kda_num_v_heads: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
# optim
|
| 98 |
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
|
@@ -120,7 +138,9 @@ class NanoArgs:
|
|
| 120 |
|
| 121 |
# data
|
| 122 |
vocab_size: int = 50304
|
|
|
|
| 123 |
sequence_length: int = 1024
|
|
|
|
| 124 |
input_bin: Optional[str] = None
|
| 125 |
input_val_bin: Optional[str] = None
|
| 126 |
|
|
@@ -138,6 +158,7 @@ class NanoArgs:
|
|
| 138 |
load_optim: bool = True
|
| 139 |
load_sched: bool = True
|
| 140 |
compile: bool = True
|
|
|
|
| 141 |
|
| 142 |
# used during training
|
| 143 |
slw_window: int = 0
|
|
@@ -166,9 +187,11 @@ def _load_data_shard(filename):
|
|
| 166 |
return tokens
|
| 167 |
|
| 168 |
class DistributedDataLoader:
|
| 169 |
-
def __init__(self, filename_pattern, B, T, process_rank, num_processes):
|
| 170 |
self.process_rank = process_rank
|
| 171 |
self.num_processes = num_processes
|
|
|
|
|
|
|
| 172 |
self.B = B # micro batch size
|
| 173 |
self.T = T
|
| 174 |
|
|
@@ -221,12 +244,32 @@ class DistributedDataLoader:
|
|
| 221 |
x = torch.from_numpy(buf.reshape(B, T)) # inputs
|
| 222 |
y = torch.from_numpy(buf.reshape(B, T)) # targets
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
# advance current position and load next shard if necessary
|
| 225 |
self.current_position += B * T * self.num_processes
|
| 226 |
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
|
| 227 |
self.advance()
|
| 228 |
|
| 229 |
-
return x.cuda(), y.cuda()
|
| 230 |
|
| 231 |
def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_lr_head, wd):
|
| 232 |
groups, seen = [], set()
|
|
@@ -277,6 +320,11 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
|
|
| 277 |
|
| 278 |
args = tyro.cli(NanoArgs)
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
# set up DDP (distributed data parallel).
|
| 281 |
assert torch.cuda.is_available()
|
| 282 |
dist.init_process_group(
|
|
@@ -293,6 +341,8 @@ torch.cuda.set_device(device)
|
|
| 293 |
print(f"using device: {device}")
|
| 294 |
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
|
| 295 |
torch._dynamo.config.optimize_ddp=False
|
|
|
|
|
|
|
| 296 |
|
| 297 |
# setup logging.
|
| 298 |
resume_dir = None
|
|
@@ -363,16 +413,33 @@ if args.patch_level_training:
|
|
| 363 |
assert args.batch_size % (B * ddp_world_size) == 0
|
| 364 |
accumulation_steps = args.batch_size // (B * ddp_world_size)
|
| 365 |
|
|
|
|
|
|
|
| 366 |
# load dataloaders.
|
| 367 |
#if args.patch_level_training:
|
| 368 |
# assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
|
| 369 |
-
train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
|
| 370 |
-
val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
|
| 371 |
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
|
| 372 |
print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
|
| 373 |
|
| 374 |
# load model.
|
| 375 |
config_hf = DragonConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
mla_kv_rank=args.mla_kv_rank,
|
| 377 |
rope_gdn=args.rope_gdn,
|
| 378 |
shrink_qk_da=args.shrink_qk_da,
|
|
@@ -402,6 +469,8 @@ config_hf = DragonConfig(
|
|
| 402 |
zero_centered_gate=args.zero_centered_gate,
|
| 403 |
zero_centered_gate_type=args.zero_centered_gate_type,
|
| 404 |
scalable_softmax=args.scalable_softmax,
|
|
|
|
|
|
|
| 405 |
resformer=args.resformer,
|
| 406 |
gate_type=args.gate_type,
|
| 407 |
gate_act=args.gate_act,
|
|
@@ -461,7 +530,7 @@ with torch.no_grad():
|
|
| 461 |
# count params. (total & active)
|
| 462 |
num_params = sum(p.numel() for p in model.parameters())
|
| 463 |
"""model.eval()
|
| 464 |
-
x, y = train_loader.next_batch()
|
| 465 |
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 466 |
model(input_ids=x[[0], [0]].unsqueeze(0)).logits.sum().backward()
|
| 467 |
num_active = sum(p.grad.count_nonzero() for p in model.parameters() if p.grad is not None)
|
|
@@ -472,12 +541,16 @@ print0(f"number of total parameters: {num_params}")
|
|
| 472 |
|
| 473 |
# DDP & compile.
|
| 474 |
uncompiled_model = model
|
| 475 |
-
model = torch.compile(model, dynamic=
|
| 476 |
model.train()
|
| 477 |
model = DDP(model, device_ids=[ddp_local_rank], find_unused_parameters=args.resformer)
|
| 478 |
raw_model = model.module
|
| 479 |
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
|
| 480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
# load optimizers & schedulers.
|
| 482 |
if args.use_uscaling:
|
| 483 |
#assert args.optim == "adamw", "uscaling is only supported with AdamW optimizer currently"
|
|
@@ -553,9 +626,7 @@ WARMUP_SKIP = 10
|
|
| 553 |
|
| 554 |
# begin training.
|
| 555 |
train_loader.reset()
|
| 556 |
-
|
| 557 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained("/leonardo_work/BOOST_LCustodi/script/training/temp/hf_models/gpt2", use_fast=True)
|
| 558 |
-
x, y = train_loader.next_batch()
|
| 559 |
|
| 560 |
for iter_ in range(start_iter, start_iter+args.total_iterations+1):
|
| 561 |
last_iter = (iter_ == start_iter+args.total_iterations)
|
|
@@ -588,9 +659,9 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
|
|
| 588 |
val_loss = torch.zeros((), device=device, dtype=torch.float32)
|
| 589 |
for _ in range(args.val_iterations):
|
| 590 |
for _ in range(accumulation_steps):
|
| 591 |
-
inputs, targets = val_loader.next_batch()
|
| 592 |
with ctx:
|
| 593 |
-
val_loss += model(input_ids=inputs, labels=targets).loss.detach()
|
| 594 |
val_loss /= args.val_iterations * accumulation_steps
|
| 595 |
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
|
| 596 |
val_loss = val_loss.item()
|
|
@@ -641,10 +712,10 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
|
|
| 641 |
for i in range(1, accumulation_steps+1):
|
| 642 |
# forward pass.
|
| 643 |
with ctx:
|
| 644 |
-
loss = model(input_ids=x, labels=y).loss
|
| 645 |
train_loss = loss.detach()
|
| 646 |
# prepare next batch.
|
| 647 |
-
x, y = train_loader.next_batch()
|
| 648 |
# backward pass.
|
| 649 |
if i < accumulation_steps:
|
| 650 |
with model.no_sync():
|
|
|
|
| 35 |
head_dim: Optional[int] = None
|
| 36 |
layers_config : str = 4*"lrdlr"
|
| 37 |
expand_factor : int = 2 # expand factor for Mamba/Dragon
|
| 38 |
+
rope_type_local: str = "" #p-rope
|
| 39 |
+
rope_type_global: str = "" #p-rope
|
| 40 |
rope_theta_local: float = 10000.0
|
| 41 |
rope_theta_global: float = 0.0
|
| 42 |
eps_rmsnorm: float = 1e-6
|
|
|
|
| 54 |
scalar_proj_as_hidden_matrix: bool = True
|
| 55 |
normalization_type: str = "rmsnorm" # rmsnorm, seednorm
|
| 56 |
seednorm_wd: bool = True
|
| 57 |
+
seednorm_type: int = 1
|
| 58 |
+
seednorm_rank: int = 1
|
| 59 |
mixer_gn: bool = True
|
| 60 |
mlp_linking : bool = False
|
| 61 |
+
final_norm: bool = True
|
| 62 |
+
|
| 63 |
+
# MoE
|
| 64 |
+
moe: bool = False
|
| 65 |
+
moe_num_routed_experts: int = 2
|
| 66 |
+
moe_routed_scaling_factor: float = 2.5
|
| 67 |
+
moe_routed_intermediate_size: int = 768
|
| 68 |
+
moe_shared_intermediate_size: int = 768
|
| 69 |
|
| 70 |
# attention related
|
| 71 |
n_kv_heads : int = 0
|
|
|
|
| 103 |
shrink_qk_gdn: int = 2
|
| 104 |
kda_allow_neg_eigval: bool = False
|
| 105 |
kda_num_v_heads: Optional[int] = None
|
| 106 |
+
mamba_mimo_dim: Optional[int] = 2
|
| 107 |
+
mamba_ngroups: Optional[int] = 1
|
| 108 |
+
mamba3_rope: bool = True
|
| 109 |
+
mamba3_remove_BC_bias: bool = False
|
| 110 |
+
mamba3_is_id_rms: bool = True
|
| 111 |
+
mamba3_remove_conv: bool = True
|
| 112 |
+
mamba3_is_A_dd: bool = True
|
| 113 |
+
mamba3_add_trapezoid: bool = True
|
| 114 |
|
| 115 |
# optim
|
| 116 |
optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
|
|
|
|
| 138 |
|
| 139 |
# data
|
| 140 |
vocab_size: int = 50304
|
| 141 |
+
bos_id: int = 50256
|
| 142 |
sequence_length: int = 1024
|
| 143 |
+
intra_doc_masking: bool = False
|
| 144 |
input_bin: Optional[str] = None
|
| 145 |
input_val_bin: Optional[str] = None
|
| 146 |
|
|
|
|
| 158 |
load_optim: bool = True
|
| 159 |
load_sched: bool = True
|
| 160 |
compile: bool = True
|
| 161 |
+
compile_dynamic: bool = False
|
| 162 |
|
| 163 |
# used during training
|
| 164 |
slw_window: int = 0
|
|
|
|
| 187 |
return tokens
|
| 188 |
|
| 189 |
class DistributedDataLoader:
|
| 190 |
+
def __init__(self, filename_pattern, intra_doc_masking,B, T, process_rank, num_processes, bos_id):
|
| 191 |
self.process_rank = process_rank
|
| 192 |
self.num_processes = num_processes
|
| 193 |
+
self.intra_doc_masking = intra_doc_masking
|
| 194 |
+
self.bos_id = bos_id
|
| 195 |
self.B = B # micro batch size
|
| 196 |
self.T = T
|
| 197 |
|
|
|
|
| 244 |
x = torch.from_numpy(buf.reshape(B, T)) # inputs
|
| 245 |
y = torch.from_numpy(buf.reshape(B, T)) # targets
|
| 246 |
|
| 247 |
+
# compute cumulative document positions for intra-document masking
|
| 248 |
+
cu = None
|
| 249 |
+
maxlen = None
|
| 250 |
+
position_ids = None
|
| 251 |
+
if self.intra_doc_masking:
|
| 252 |
+
assert self.B == 1
|
| 253 |
+
starts = (x == self.bos_id).nonzero(as_tuple=True)[1].to(torch.long)
|
| 254 |
+
if starts.numel() == 0 or starts[0] != 0:
|
| 255 |
+
starts = torch.cat([torch.zeros(1, dtype=torch.long), starts])
|
| 256 |
+
ends = torch.cat([starts[1:], torch.tensor([x.numel()])])
|
| 257 |
+
seqlens = (ends - starts).to(torch.int32)
|
| 258 |
+
# cu_seqlens, max_seqlen.
|
| 259 |
+
cu = torch.cat([torch.zeros(1, dtype=torch.int32), seqlens.cumsum(0)]).cuda().to(torch.int32)
|
| 260 |
+
maxlen = int(seqlens.max())
|
| 261 |
+
# position_ids.
|
| 262 |
+
lengths = seqlens.to(torch.long)
|
| 263 |
+
starts_per_token = torch.repeat_interleave(starts.to(torch.long), lengths)
|
| 264 |
+
idx = torch.arange(T, device=x.device, dtype=torch.long)
|
| 265 |
+
position_ids = (idx - starts_per_token).unsqueeze(0)
|
| 266 |
+
|
| 267 |
# advance current position and load next shard if necessary
|
| 268 |
self.current_position += B * T * self.num_processes
|
| 269 |
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
|
| 270 |
self.advance()
|
| 271 |
|
| 272 |
+
return x.cuda(), y.cuda(), cu, maxlen, position_ids
|
| 273 |
|
| 274 |
def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_lr_head, wd):
|
| 275 |
groups, seen = [], set()
|
|
|
|
| 320 |
|
| 321 |
args = tyro.cli(NanoArgs)
|
| 322 |
|
| 323 |
+
if args.intra_doc_masking:
|
| 324 |
+
if args.device_batch_size != 1:
|
| 325 |
+
args.device_batch_size = 1
|
| 326 |
+
print("!!! Forcing device_batch_size to 1 for intra-document masking !!!")
|
| 327 |
+
|
| 328 |
# set up DDP (distributed data parallel).
|
| 329 |
assert torch.cuda.is_available()
|
| 330 |
dist.init_process_group(
|
|
|
|
| 341 |
print(f"using device: {device}")
|
| 342 |
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
|
| 343 |
torch._dynamo.config.optimize_ddp=False
|
| 344 |
+
if args.compile_dynamic:
|
| 345 |
+
torch._dynamo.config.allow_unspec_int_on_nn_module=True
|
| 346 |
|
| 347 |
# setup logging.
|
| 348 |
resume_dir = None
|
|
|
|
| 413 |
assert args.batch_size % (B * ddp_world_size) == 0
|
| 414 |
accumulation_steps = args.batch_size // (B * ddp_world_size)
|
| 415 |
|
| 416 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("/leonardo_work/BOOST_LCustodi/script/training/temp/hf_models/gpt2", use_fast=True)
|
| 417 |
+
|
| 418 |
# load dataloaders.
|
| 419 |
#if args.patch_level_training:
|
| 420 |
# assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
|
| 421 |
+
train_loader = DistributedDataLoader(args.input_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id)
|
| 422 |
+
val_loader = DistributedDataLoader(args.input_val_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id)
|
| 423 |
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
|
| 424 |
print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
|
| 425 |
|
| 426 |
# load model.
|
| 427 |
config_hf = DragonConfig(
|
| 428 |
+
mamba3_rope=args.mamba3_rope,
|
| 429 |
+
mamba3_remove_BC_bias=args.mamba3_remove_BC_bias,
|
| 430 |
+
mamba3_is_id_rms=args.mamba3_is_id_rms,
|
| 431 |
+
mamba3_remove_conv=args.mamba3_remove_conv,
|
| 432 |
+
mamba3_is_A_dd=args.mamba3_is_A_dd,
|
| 433 |
+
mamba3_add_trapezoid=args.mamba3_add_trapezoid,
|
| 434 |
+
moe=args.moe,
|
| 435 |
+
moe_num_routed_experts=args.moe_num_routed_experts,
|
| 436 |
+
moe_routed_scaling_factor=args.moe_routed_scaling_factor,
|
| 437 |
+
moe_routed_intermediate_size=args.moe_routed_intermediate_size,
|
| 438 |
+
moe_shared_intermediate_size=args.moe_shared_intermediate_size,
|
| 439 |
+
intra_doc_masking=args.intra_doc_masking,
|
| 440 |
+
seednorm_rank=args.seednorm_rank,
|
| 441 |
+
seednorm_type=args.seednorm_type,
|
| 442 |
+
final_norm=args.final_norm,
|
| 443 |
mla_kv_rank=args.mla_kv_rank,
|
| 444 |
rope_gdn=args.rope_gdn,
|
| 445 |
shrink_qk_da=args.shrink_qk_da,
|
|
|
|
| 469 |
zero_centered_gate=args.zero_centered_gate,
|
| 470 |
zero_centered_gate_type=args.zero_centered_gate_type,
|
| 471 |
scalable_softmax=args.scalable_softmax,
|
| 472 |
+
mamba_mimo_dim=args.mamba_mimo_dim,
|
| 473 |
+
mamba_ngroups=args.mamba_ngroups,
|
| 474 |
resformer=args.resformer,
|
| 475 |
gate_type=args.gate_type,
|
| 476 |
gate_act=args.gate_act,
|
|
|
|
| 530 |
# count params. (total & active)
|
| 531 |
num_params = sum(p.numel() for p in model.parameters())
|
| 532 |
"""model.eval()
|
| 533 |
+
x, y, _, _, _ = train_loader.next_batch()
|
| 534 |
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 535 |
model(input_ids=x[[0], [0]].unsqueeze(0)).logits.sum().backward()
|
| 536 |
num_active = sum(p.grad.count_nonzero() for p in model.parameters() if p.grad is not None)
|
|
|
|
| 541 |
|
| 542 |
# DDP & compile.
|
| 543 |
uncompiled_model = model
|
| 544 |
+
model = torch.compile(model, dynamic=args.compile_dynamic) if args.compile else model
|
| 545 |
model.train()
|
| 546 |
model = DDP(model, device_ids=[ddp_local_rank], find_unused_parameters=args.resformer)
|
| 547 |
raw_model = model.module
|
| 548 |
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
|
| 549 |
|
| 550 |
+
if args.intra_doc_masking:
|
| 551 |
+
print0("!!! Using intra-document masking !!!")
|
| 552 |
+
print0("It is only compatible with GDN (conv+chunk), DA and GDTPA layers. For DA/GDTPA, kv shift is also compatible. All other config will not have intra-doc masking support!!")
|
| 553 |
+
|
| 554 |
# load optimizers & schedulers.
|
| 555 |
if args.use_uscaling:
|
| 556 |
#assert args.optim == "adamw", "uscaling is only supported with AdamW optimizer currently"
|
|
|
|
| 626 |
|
| 627 |
# begin training.
|
| 628 |
train_loader.reset()
|
| 629 |
+
x, y, cu, maxlen, position_ids = train_loader.next_batch()
|
|
|
|
|
|
|
| 630 |
|
| 631 |
for iter_ in range(start_iter, start_iter+args.total_iterations+1):
|
| 632 |
last_iter = (iter_ == start_iter+args.total_iterations)
|
|
|
|
| 659 |
val_loss = torch.zeros((), device=device, dtype=torch.float32)
|
| 660 |
for _ in range(args.val_iterations):
|
| 661 |
for _ in range(accumulation_steps):
|
| 662 |
+
inputs, targets, cu, maxlen, position_ids = val_loader.next_batch()
|
| 663 |
with ctx:
|
| 664 |
+
val_loss += model(input_ids=inputs, labels=targets, just_loss=True, cu_seqlens=cu, max_seqlen=maxlen, position_ids=position_ids).loss.detach()
|
| 665 |
val_loss /= args.val_iterations * accumulation_steps
|
| 666 |
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
|
| 667 |
val_loss = val_loss.item()
|
|
|
|
| 712 |
for i in range(1, accumulation_steps+1):
|
| 713 |
# forward pass.
|
| 714 |
with ctx:
|
| 715 |
+
loss = model(input_ids=x, labels=y, just_loss=True, cu_seqlens=cu, max_seqlen=maxlen, position_ids=position_ids).loss
|
| 716 |
train_loss = loss.detach()
|
| 717 |
# prepare next batch.
|
| 718 |
+
x, y, cu, maxlen, position_ids = train_loader.next_batch()
|
| 719 |
# backward pass.
|
| 720 |
if i < accumulation_steps:
|
| 721 |
with model.no_sync():
|