alexandretl commited on
Commit
58b82e2
·
1 Parent(s): bc8288b

mamba3 flags | mamba3 default state size to 128, headdim to 64 | mamba2 | fix mamba3 mimo (JG) | (fake) moe | intra doc maskiiiing (with SS) | seednorm tests | coord checks

Browse files
configuration_dragon.py CHANGED
@@ -92,6 +92,21 @@ class DragonConfig(PretrainedConfig):
92
 
93
  def __init__(
94
  self,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  mla_kv_rank: int = 128,
96
  shrink_qk_da: int = 2,
97
  shrink_qk_gdn: int = 2,
@@ -119,6 +134,7 @@ class DragonConfig(PretrainedConfig):
119
  scalable_softmax: bool = True,
120
  resformer: bool = False,
121
  mamba_mimo_dim : int = 4,
 
122
  gate_type: str = "elementwise",
123
  gate_act: str = "silu",
124
  gate_attn: bool = False,
@@ -163,7 +179,7 @@ class DragonConfig(PretrainedConfig):
163
  rope_type_local="rope",
164
  rope_type_global="",
165
  rope_theta_local=163.,
166
- rope_theta_global=10000.,
167
  uscaling_tau=0.2,
168
  attention_dropout=0.,
169
  hidden_dropout=0.,
@@ -176,6 +192,21 @@ class DragonConfig(PretrainedConfig):
176
  mlp_linking=False,
177
  **kwargs,
178
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  self.mla_kv_rank = mla_kv_rank
180
  self.shrink_qk_da = shrink_qk_da
181
  self.shrink_qk_gdn = shrink_qk_gdn
@@ -228,6 +259,7 @@ class DragonConfig(PretrainedConfig):
228
  self.scalable_softmax = scalable_softmax
229
  self.resformer = resformer
230
  self.mamba_mimo_dim = mamba_mimo_dim
 
231
 
232
  self.vocab_size = vocab_size
233
  self.tie_word_embeddings = tie_word_embeddings
 
92
 
93
  def __init__(
94
  self,
95
+ mamba3_rope: bool = True,
96
+ mamba3_remove_BC_bias: bool = False,
97
+ mamba3_is_id_rms: bool = True,
98
+ mamba3_remove_conv: bool = True,
99
+ mamba3_is_A_dd: bool = True,
100
+ mamba3_add_trapezoid: bool = True,
101
+ moe: bool = False,
102
+ moe_num_routed_experts: int = 2,
103
+ moe_routed_scaling_factor: float = 2.5,
104
+ moe_routed_intermediate_size: int = 768,
105
+ moe_shared_intermediate_size: int = 768,
106
+ intra_doc_masking: bool = False,
107
+ seednorm_rank: int = 1,
108
+ seednorm_type: int = 1,
109
+ final_norm: bool = True,
110
  mla_kv_rank: int = 128,
111
  shrink_qk_da: int = 2,
112
  shrink_qk_gdn: int = 2,
 
134
  scalable_softmax: bool = True,
135
  resformer: bool = False,
136
  mamba_mimo_dim : int = 4,
137
+ mamba_ngroups : int = 1,
138
  gate_type: str = "elementwise",
139
  gate_act: str = "silu",
140
  gate_attn: bool = False,
 
179
  rope_type_local="rope",
180
  rope_type_global="",
181
  rope_theta_local=163.,
182
+ rope_theta_global=0.,
183
  uscaling_tau=0.2,
184
  attention_dropout=0.,
185
  hidden_dropout=0.,
 
192
  mlp_linking=False,
193
  **kwargs,
194
  ):
195
+ self.mamba3_rope = mamba3_rope
196
+ self.mamba3_remove_BC_bias = mamba3_remove_BC_bias
197
+ self.mamba3_is_id_rms = mamba3_is_id_rms
198
+ self.mamba3_remove_conv = mamba3_remove_conv
199
+ self.mamba3_is_A_dd = mamba3_is_A_dd
200
+ self.mamba3_add_trapezoid = mamba3_add_trapezoid
201
+ self.moe = moe
202
+ self.moe_num_routed_experts = moe_num_routed_experts
203
+ self.moe_routed_scaling_factor = moe_routed_scaling_factor
204
+ self.moe_routed_intermediate_size = moe_routed_intermediate_size
205
+ self.moe_shared_intermediate_size = moe_shared_intermediate_size
206
+ self.intra_doc_masking = intra_doc_masking
207
+ self.seednorm_rank = seednorm_rank
208
+ self.seednorm_type = seednorm_type
209
+ self.final_norm = final_norm
210
  self.mla_kv_rank = mla_kv_rank
211
  self.shrink_qk_da = shrink_qk_da
212
  self.shrink_qk_gdn = shrink_qk_gdn
 
259
  self.scalable_softmax = scalable_softmax
260
  self.resformer = resformer
261
  self.mamba_mimo_dim = mamba_mimo_dim
262
+ self.mamba_ngroups = mamba_ngroups
263
 
264
  self.vocab_size = vocab_size
265
  self.tie_word_embeddings = tie_word_embeddings
coordcheck_utils.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Microsoft Corporation.
2
+
3
+ """
4
+ Adapted from https://github.com/microsoft/mup
5
+ In short, it has been largely simplified.
6
+ """
7
+
8
+ import os
9
+ from copy import copy
10
+ from itertools import product
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from tqdm import tqdm
18
+ import matplotlib.pyplot as plt
19
+ import seaborn as sns
20
+
21
+ FDICT = {'l1': lambda x: torch.abs(x).mean(dtype=torch.float32)}
22
+
23
+ def convert_fdict(d):
24
+ '''convert a dict `d` with string values to function values.
25
+ Input:
26
+ d: a dict whose values are either strings or functions
27
+ Output:
28
+ a new dict, with the same keys as `d`, but the string values are
29
+ converted to functions using `FDICT`.
30
+ '''
31
+ return dict([
32
+ ((k, FDICT[v]) if isinstance(v, str) else (k, v))
33
+ for k, v in d.items()])
34
+
35
+ def _record_coords(records, width, modulename, t,
36
+ output_fdict=None, input_fdict=None, param_fdict=None):
37
+ '''Returns a forward hook that records coordinate statistics.
38
+
39
+ Returns a forward hook that records statistics regarding the output, input,
40
+ and/or parameters of a `nn.Module`. This hook is intended to run only once,
41
+ on the timestep specified by `t`.
42
+
43
+ On forward pass, the returned hook calculates statistics specified in
44
+ `output_fdict`, `input_fdict`, and `param_fdict`, such as the normalized l1
45
+ norm, of output, input, and/or parameters of the module. The statistics are
46
+ recorded along with the `width`, `modulename`, and `t` (the time step) as a
47
+ dict and inserted into `records` (which should be a list). More precisely,
48
+ for each output, input, and/or parameter, the inserted dict is of the form
49
+
50
+ {
51
+ 'width': width, 'module': modified_modulename, 't': t,
52
+ # keys are keys in fdict
53
+ 'l1': 0.241, 'l2': 0.420, 'mean': 0.0, ...
54
+ }
55
+
56
+ where `modified_modulename` is a string that combines the `modulename` with
57
+ an indicator of which output, input, or parameter tensor is the statistics
58
+ computed over.
59
+
60
+ The `*_fdict` inputs should be dictionaries with string keys and whose
61
+ values can either be functions or strings. The string values are converted
62
+ to functions via `convert_fdict`. The default values of `*_dict` inputs are
63
+ converted to `output_fdict = dict(l1=FDICT['l1'])`, `input_fdict = {}`,
64
+ `param_fdict = {}`, i.e., only the average coordinate size (`l1`) of the
65
+ output activations are recorded.
66
+
67
+ Inputs:
68
+ records:
69
+ list to append coordinate data to
70
+ width:
71
+ width of the model. This is used only for plotting coord check later
72
+ on, so it can be any notion of width.
73
+ modulename:
74
+ string name of the module. This is used only for plotting coord check.
75
+ t:
76
+ timestep of training. This is used only for plotting coord check.
77
+ output_fdict, input_fdict, param_fdict:
78
+ dicts with string keys and whose values can either be functions or
79
+ strings. The string values are converted to functions via
80
+ `convert_fdict`
81
+ Output:
82
+ a forward hook that records statistics regarding the output, input,
83
+ and/or parameters of a `nn.Module`, as discussed above.
84
+ '''
85
+ if output_fdict is None:
86
+ output_fdict = dict(l1=FDICT['l1'])
87
+ else:
88
+ output_fdict = convert_fdict(output_fdict)
89
+ if input_fdict is None:
90
+ input_fdict = {}
91
+ else:
92
+ input_fdict = convert_fdict(input_fdict)
93
+ if param_fdict is None:
94
+ param_fdict = {}
95
+ else:
96
+ param_fdict = convert_fdict(param_fdict)
97
+ def f(module, input, output):
98
+ def get_stat(d, x, fdict):
99
+ if isinstance(x, (tuple, list)):
100
+ for i, _x in enumerate(x):
101
+ _d = copy(d)
102
+ _d['module'] += f'[{i}]'
103
+ get_stat(_d, _x, fdict)
104
+ elif isinstance(x, dict):
105
+ for name, _x in x.items():
106
+ _d = copy(d)
107
+ _d['module'] += f'[{name}]'
108
+ get_stat(_d, _x, fdict)
109
+ elif isinstance(x, torch.Tensor):
110
+ _d = copy(d)
111
+ for fname, f in fdict.items():
112
+ _d[fname] = f(x).item()
113
+ records.append(_d)
114
+ elif x is None:
115
+ pass
116
+ else:
117
+ raise NotImplementedError(f'Unexpected output type: {type(x)}')
118
+ with torch.no_grad():
119
+ ret = {
120
+ 'width': width,
121
+ 'module': modulename,
122
+ 't': t
123
+ }
124
+
125
+ # output stats
126
+ if isinstance(output, (tuple, list)):
127
+ for i, out in enumerate(output):
128
+ _ret = copy(ret)
129
+ _ret['module'] += f':out[{i}]'
130
+ get_stat(_ret, out, output_fdict)
131
+ elif isinstance(output, dict):
132
+ for name, out in output.items():
133
+ _ret = copy(ret)
134
+ _ret['module'] += f':out[{name}]'
135
+ get_stat(_ret, out, output_fdict)
136
+ elif isinstance(output, torch.Tensor):
137
+ _ret = copy(ret)
138
+ for fname, f in output_fdict.items():
139
+ _ret[fname] = f(output).item()
140
+ records.append(_ret)
141
+ else:
142
+ raise NotImplementedError(f'Unexpected output type: {type(output)}')
143
+
144
+ # input stats
145
+ if input_fdict:
146
+ if isinstance(input, (tuple, list)):
147
+ for i, out in enumerate(input):
148
+ _ret = copy(ret)
149
+ _ret['module'] += f':in[{i}]'
150
+ get_stat(_ret, out, input_fdict)
151
+ elif isinstance(input, dict):
152
+ for name, out in input.items():
153
+ _ret = copy(ret)
154
+ _ret['module'] += f':in[{name}]'
155
+ get_stat(_ret, out, input_fdict)
156
+ elif isinstance(input, torch.Tensor):
157
+ _ret = copy(ret)
158
+ for fname, f in input_fdict.items():
159
+ _ret[fname] = f(input).item()
160
+ records.append(_ret)
161
+ else:
162
+ raise NotImplementedError(f'Unexpected output type: {type(input)}')
163
+
164
+ # param stats
165
+ if param_fdict:
166
+ for name, p in module.named_parameters():
167
+ _ret = copy(ret)
168
+ _ret['module'] += f':param[{name}]'
169
+ for fname, f in param_fdict.items():
170
+ _ret[fname] = f(p).item()
171
+ records.append(_ret)
172
+
173
+ return f
174
+
175
+ def _get_coord_data(models, dataloader, optcls, nsteps=5,
176
+ dict_in_out=False, flatten_input=False, flatten_output=False,
177
+ output_name='loss', lossfn='xent', filter_module_by_name=None,
178
+ fix_data=True, cuda=True, nseeds=1,
179
+ output_fdict=None, input_fdict=None, param_fdict=None,
180
+ show_progress=True, one_hot_target=False):
181
+ '''Inner method for `get_coord_data`.
182
+
183
+ Train the models in `models` with optimizer given by `optcls` and data from
184
+ `dataloader` for `nsteps` steps, and record coordinate statistics specified
185
+ by `output_fdict`, `input_fdict`, `param_fdict`. By default, only `l1` is
186
+ computed for output activations of each module.
187
+
188
+ Inputs:
189
+ models:
190
+ a dict of lazy models, where the keys are numbers indicating width.
191
+ Each entry of `models` is a function that instantiates a model given
192
+ nothing.
193
+ dataloader:
194
+ an iterator whose elements are either Huggingface style dicts, if
195
+ `dict_in_out` is True, or (input, label). If `fix_data` is True
196
+ (which is the default), then only the first element of `dataloader`
197
+ is used in a loop and the rest of `dataloder` is ignored.
198
+ optcls:
199
+ a function so that `optcls(model)` gives an optimizer used to train
200
+ the model.
201
+ nsteps:
202
+ number of steps to train the model
203
+ dict_in_out:
204
+ whether the data loader contains Huggingface-style dict input and
205
+ output. Default: False
206
+ flatten_input:
207
+ if not `dict_in_out`, reshape the input to be
208
+ `input.view(input.shape[0], -1)`. Typically used for testing MLPs.
209
+ flatten_output:
210
+ if not `dict_in_out`, reshape the label to be `label.view(-1,
211
+ input.shape[-1])`.
212
+ output_name:
213
+ if `dict_in_out`, this is the key for the loss value if the output
214
+ is a dict. If the output is not a dict, then we assume the first
215
+ element of the output is the loss.
216
+ lossfn:
217
+ loss function to use if not `dict_in_out`. Can be either a string from
218
+ [`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that
219
+ `lossfn(output, target)` returns the loss value. Examples of valid
220
+ `callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is
221
+ `torch.nn.functional`. Default: 'xent'
222
+ filter_module_by_name:
223
+ a function that returns a bool given module names (from
224
+ `model.named_modules()`), or None. If not None, then only modules
225
+ whose name yields True will be recorded.
226
+ cuda:
227
+ whether to use cuda or not. Default: True
228
+ nseeds:
229
+ number of times to repeat the training, each with different seeds.
230
+ output_fdict, input_fdict, param_fdict:
231
+ function dicts to be used in `_record_coords`. By default, only `l1`
232
+ is computed for output activations of each module.
233
+ show_progress:
234
+ show progress using tqdm. Default: True
235
+ one_hot_target:
236
+ convert target label into a one-hot vector. This typically is only
237
+ used for `'mse'` or `'l1'` losses in classification tasks.
238
+ Default: False
239
+ Output:
240
+ a pandas DataFrame containing recorded results. The column names are
241
+ `'width', 'module', 't'` as well as names of statistics recorded, such
242
+ as `'l1'` (see `FDICT` for other premade statistics that can be
243
+ collected).
244
+
245
+ Breaking Changes:
246
+ In v1.0.0, when `lossfn=='mse'`, the target is automatically converted
247
+ to a one hot vector before loss computation. Starting in v1.1.0, this
248
+ behavior is turned off, and the user needs to explicitly turn on this
249
+ behavior by setting `one_hot_target=True`.
250
+
251
+ '''
252
+ df = []
253
+ if fix_data:
254
+ batch = next(iter(dataloader))
255
+ dataloader = [batch] * nsteps
256
+ if show_progress:
257
+ pbar = tqdm(total=nseeds * len(models))
258
+
259
+ for i in range(nseeds):
260
+ torch.manual_seed(i)
261
+ for width, model in models.items():
262
+ model = model()
263
+ model = model.train()
264
+ if cuda:
265
+ model = model.cuda()
266
+ optimizer = optcls(model)
267
+ for batch_idx, batch in enumerate(dataloader, 1):
268
+ remove_hooks = []
269
+ # add hooks
270
+ for name, module in model.named_modules():
271
+ if filter_module_by_name and not filter_module_by_name(name):
272
+ continue
273
+ remove_hooks.append(module.register_forward_hook(
274
+ _record_coords(df, width, name, batch_idx,
275
+ output_fdict=output_fdict,
276
+ input_fdict=input_fdict,
277
+ param_fdict=param_fdict)))
278
+ if dict_in_out:
279
+ (data, target) = batch
280
+ loss = model(input_ids=data, labels=target).loss
281
+ else:
282
+ assert False, "Not implemented for non-dict input/output."
283
+ optimizer.zero_grad()
284
+ loss.backward()
285
+ optimizer.step()
286
+
287
+ # remove hooks
288
+ for handle in remove_hooks:
289
+ handle.remove()
290
+
291
+ if batch_idx == nsteps: break
292
+ if show_progress:
293
+ pbar.update(1)
294
+ if show_progress:
295
+ pbar.close()
296
+ return pd.DataFrame(df)
297
+
298
+
299
+ def get_coord_data(models, dataloader, optcls, nsteps, **kwargs):
300
+ '''Get coord data for coord check.
301
+
302
+ Train the models in `models` with data from `dataloader` and optimizer
303
+ specified by `optimizer` and `lr` for `nsteps` steps, and record coordinate
304
+ statistics specified by `output_fdict`, `input_fdict`, `param_fdict`. By
305
+ default, only `l1` is computed for output activations of each module.
306
+
307
+ This function wraps around `_get_coord_data`, with the main difference being
308
+ user can specify common optimizers via a more convenient interface.
309
+
310
+ Inputs:
311
+ models:
312
+ a dict of lazy models, where the keys are numbers indicating width.
313
+ Each entry of `models` is a function that instantiates a model given
314
+ nothing.
315
+ dataloader:
316
+ an iterator whose elements are either Huggingface style dicts, if
317
+ `dict_in_out` is True, or (input, label). If `fix_data` is True
318
+ (which is the default), then only the first element of `dataloader`
319
+ is used in a loop and the rest of `dataloder` is ignored.
320
+ optimizer:
321
+ a string in `['sgd', 'adam', 'adamw']`, with default being `'sgd'`.
322
+ lr:
323
+ learning rate. By default is 0.1 for `'sgd'` and 1e-3 for others.
324
+ mup:
325
+ If True, then use the optimizer from `mup.optim`; otherwise, use the
326
+ one from `torch.optim`.
327
+ filter_trainable_by_name:
328
+ a function that returns a bool given module names (from
329
+ `model.named_modules()`), or None. If not None, then only modules
330
+ whose name yields True will be trained.
331
+ nsteps:
332
+ number of steps to train the model
333
+ dict_in_out:
334
+ whether the data loader contains Huggingface-style dict input and
335
+ output. Default: False
336
+ flatten_input:
337
+ if not `dict_in_out`, reshape the input to be
338
+ `input.view(input.shape[0], -1)`. Typically used for testing MLPs.
339
+ flatten_output:
340
+ if not `dict_in_out`, reshape the label to be `label.view(-1,
341
+ input.shape[-1])`.
342
+ output_name:
343
+ if `dict_in_out`, this is the key for the loss value if the output
344
+ is a dict. If the output is not a dict, then we assume the first
345
+ element of the output is the loss.
346
+ lossfn:
347
+ loss function to use if not `dict_in_out`. Can be either a string from
348
+ [`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that
349
+ `lossfn(output, target)` returns the loss value. Examples of valid
350
+ `callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is
351
+ `torch.nn.functional`. Default: 'xent'
352
+ filter_module_by_name:
353
+ a function that returns a bool given module names (from
354
+ `model.named_modules()`), or None. If not None, then only modules
355
+ whose name yields True will be recorded.
356
+ cuda:
357
+ whether to use cuda or not. Default: True
358
+ nseeds:
359
+ number of times to repeat the training, each with different seeds.
360
+ output_fdict, input_fdict, param_fdict:
361
+ function dicts to be used in `_record_coords`. By default, only `l1`
362
+ is computed for output activations of each module.
363
+ show_progress:
364
+ show progress using tqdm. Default: True
365
+ one_hot_target:
366
+ convert target label into a one-hot vector. This typically is only
367
+ used for `'mse'` or `'l1'` losses in classification tasks.
368
+ Default: False
369
+ Output:
370
+ a pandas DataFrame containing recorded results. The column names are
371
+ `'width', 'module', 't'` as well as names of statistics recorded, such
372
+ as `'l1'` (see `FDICT` for other premade statistics that can be
373
+ collected).
374
+
375
+ Breaking Changes:
376
+ In v1.0.0, when `lossfn=='mse'`, the target is automatically converted
377
+ to a one hot vector before loss computation. Starting in v1.1.0, this
378
+ behavior is turned off, and the user needs to explicitly turn on this
379
+ behavior by setting `one_hot_target=True`.
380
+ '''
381
+
382
+ data = _get_coord_data(models, dataloader, optcls, nsteps, dict_in_out=True, **kwargs)
383
+ return data
384
+
385
+
386
+ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
387
+ legend='full', name_contains=None, name_not_contains=None, module_list=None,
388
+ loglog=True, logbase=2, face_color=None, subplot_width=5,
389
+ subplot_height=4):
390
+ '''Plot coord check data `df` obtained from `get_coord_data`.
391
+
392
+ Input:
393
+ df:
394
+ a pandas DataFrame obtained from `get_coord_data`
395
+ y:
396
+ the column of `df` to plot on the y-axis. Default: `'l1'`
397
+ save_to:
398
+ path to save the resulting figure, or None. Default: None.
399
+ suptitle:
400
+ The title of the entire figure.
401
+ x:
402
+ the column of `df` to plot on the x-axis. Default: `'width'`
403
+ hue:
404
+ the column of `df` to represent as color. Default: `'module'`
405
+ legend:
406
+ 'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`.
407
+ name_contains, name_not_contains:
408
+ only plot modules whose name contains `name_contains` and does not contain `name_not_contains`
409
+ module_list:
410
+ only plot modules that are given in the list, overrides `name_contains` and `name_not_contains`
411
+ loglog:
412
+ whether to use loglog scale. Default: True
413
+ logbase:
414
+ the log base, if using loglog scale. Default: 2
415
+ face_color:
416
+ background color of the plot. Default: None (which means white)
417
+ subplot_width, subplot_height:
418
+ The width and height for each timestep's subplot. More precisely,
419
+ the figure size will be
420
+ `(subplot_width*number_of_time_steps, subplot_height)`.
421
+ Default: 5, 4
422
+
423
+ Output:
424
+ the `matplotlib` figure object
425
+ '''
426
+ ### preprocessing
427
+ df = copy(df)
428
+ df = df[df.module != ''] # nn.Sequential has name '', which duplicates the output layer
429
+ if module_list is not None:
430
+ df = df[df['module'].isin(module_list)]
431
+ else:
432
+ if name_contains is not None:
433
+ df = df[df['module'].str.contains(name_contains)]
434
+ if name_not_contains is not None:
435
+ df = df[~(df['module'].str.contains(name_not_contains))]
436
+ try:
437
+ df['module'] = pd.to_numeric(df['module']) # for nn.Sequential, module names are numerical
438
+ except ValueError:
439
+ pass
440
+
441
+ ts = df.t.unique()
442
+
443
+ sns.set()
444
+
445
+ def tight_layout(plt):
446
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95])
447
+
448
+ ### plot
449
+ fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height))
450
+ hue_order = sorted(set(df['module']))
451
+ if face_color is not None:
452
+ fig.patch.set_facecolor(face_color)
453
+ ymin, ymax = min(df[y]), max(df[y])
454
+ for t in ts:
455
+ t = int(t)
456
+ plt.subplot(1, len(ts), t)
457
+ sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=None) # to show legend, set legend if t == 1 else None
458
+ plt.title(f't={t}')
459
+ if t != 1:
460
+ plt.ylabel('')
461
+ if loglog:
462
+ plt.loglog(base=logbase)
463
+ ax = plt.gca()
464
+ ax.set_ylim([ymin, ymax])
465
+ if suptitle:
466
+ plt.suptitle(suptitle)
467
+ tight_layout(plt)
468
+ if save_to is not None:
469
+ plt.savefig(save_to)
470
+ print(f'coord check plot saved to {save_to}')
471
+
472
+ return fig
coordchecking_dragon.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import tyro
3
+ from pathlib import Path
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset, DataLoader
9
+
10
+ from .configuration_dragon import DragonConfig
11
+ from .modeling_dragon import DragonForCausalLM
12
+ from .coordcheck_utils import get_coord_data, plot_coord_data
13
+
14
+ # TRITON_HOME="/p/project1/jureap140/temp" python make_coord_check.py
15
+
16
+ @dataclass
17
+ class Args:
18
+ save_dir: Path
19
+ mup: bool = False
20
+ learning_rate: float = 1e-2
21
+ layers_config: str = "gggTgggTgggTggg"
22
+ args = tyro.cli(Args)
23
+
24
+ batch_size = 8
25
+ batch_len = 1024
26
+ max_value = 100
27
+
28
+ widths = [128, 512, 1024, 2048]
29
+ n_heads = [4, 8, 16, 32]
30
+ d_head = 64
31
+
32
+ class RandomDataset(Dataset):
33
+ def __len__(self):
34
+ return 9999999
35
+
36
+ def __getitem__(self, _):
37
+ data = torch.randint(low=0, high=max_value, size=(batch_size, batch_len))
38
+ return data.cuda(), data.cuda()
39
+
40
+ def lazy_model(width):
41
+ config_hf = DragonConfig(
42
+ layers_config=args.layers_config,
43
+ hidden_size=width,
44
+ intermediate_size=4*width,
45
+ tpa_rank=4,
46
+ token_shift_attn=True,
47
+ head_dim=d_head,
48
+ shrink_qk_da=1,
49
+ num_attention_heads=n_heads[widths.index(width)],
50
+ num_signal_heads_diff=n_heads[widths.index(width)]-n_heads[widths.index(width)]//4,
51
+ num_key_value_heads=n_heads[widths.index(width)],
52
+ head_dim_gdn=d_head,
53
+ shrink_qk_gdn=2,
54
+ num_attention_heads_gdn=n_heads[widths.index(width)],
55
+ zero_centered_gate=True,
56
+ zero_centered_gate_type=4,
57
+ mamba_mimo_dim=4,
58
+ mamba_ngroups=1,
59
+ gate_attn=True,
60
+ zero_centered_gamma=True,
61
+ vocab_size=max_value,
62
+ max_position_embeddings=1024,
63
+ use_uscaling=True,
64
+ uscaling_tau=0.2,
65
+ initializer_range=1.,
66
+ use_cache=False,
67
+ )
68
+
69
+ if args.mup:
70
+ config_hf.use_uscaling = True
71
+ config_hf.initializer_range = 1.0
72
+ else:
73
+ config_hf.use_uscaling = False
74
+ config_hf.initializer_range = 0.006
75
+
76
+ return lambda: DragonForCausalLM(config_hf).to("cuda")
77
+
78
+ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_lr_head, wd):
79
+ groups, seen = [], set()
80
+ id2name = {id(p): n for n, p in model.named_parameters()}
81
+
82
+ for mod in model.modules():
83
+ if isinstance(mod, nn.Linear):
84
+ pname = id2name.get(id(mod.weight), "")
85
+ is_scalar = getattr(mod, "is_scalar_weight", False)
86
+ fan_in = mod.weight.shape[1]
87
+ scale = 1 / math.sqrt(fan_in)
88
+ if "lm_head" in pname:
89
+ lr_scaled = base_lr_head
90
+ wd_scaled = 0.0
91
+ elif is_scalar:
92
+ lr_scaled = base_lr_scalar
93
+ wd_scaled = 0.0
94
+ else:
95
+ lr_scaled = base_lr_hidden * scale
96
+ wd_scaled = wd / lr_scaled
97
+
98
+ groups.append({"params": [mod.weight], "lr": lr_scaled, "weight_decay": wd_scaled})
99
+ seen.add(mod.weight)
100
+
101
+ if mod.bias is not None:
102
+ groups.append({"params": [mod.bias], "lr": base_lr_scalar, "weight_decay": 0.0})
103
+ seen.add(mod.bias)
104
+
105
+ for p in model.parameters():
106
+ if p in seen:
107
+ continue
108
+ pname = id2name.get(id(p), "<unnamed>")
109
+
110
+ if "embedding" in pname:
111
+ #fan_out = p.shape[1] # nn.Embedding is transposed
112
+ #lr_scaled = base_lr / math.sqrt(fan_out) # u-muP
113
+ lr_scaled = base_lr_embed
114
+ else:
115
+ lr_scaled = base_lr_scalar
116
+
117
+ wd_scaled = 0.
118
+ if getattr(p, "requires_weight_decay", False):
119
+ wd_scaled = wd / lr_scaled
120
+
121
+ groups.append({"params": [p], "lr": lr_scaled, "weight_decay": wd_scaled})
122
+
123
+ return groups
124
+
125
+ models = {width: lazy_model(width) for width in widths}
126
+
127
+ dataset = RandomDataset()
128
+ loader = DataLoader(dataset, batch_size=None, shuffle=True)
129
+ iter_ = iter(loader)
130
+
131
+ def get_optim(model):
132
+ if args.mup:
133
+ param_list = param_groups_mup(
134
+ model,
135
+ base_lr_hidden=args.learning_rate,
136
+ base_lr_scalar=2**-6,
137
+ base_lr_embed=2**-4,
138
+ base_lr_head=2**-6,
139
+ wd=0.,
140
+ )
141
+ optimizer = torch.optim.AdamW(param_list, betas=(0.9, 0.95), eps=1e-8)
142
+ else:
143
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0., betas=(0.9, 0.95), eps=1e-8)
144
+ return optimizer
145
+ optcls = lambda model: get_optim(model)
146
+
147
+ df = get_coord_data(models, iter_, optcls, nsteps=10)
148
+
149
+ if args.mup:
150
+ name = f"mup_{args.learning_rate}_{args.layers_config}.png"
151
+ else:
152
+ name = f"sp_{args.learning_rate}_{args.layers_config}.png"
153
+
154
+ plot_coord_data(df, legend="full", save_to=args.save_dir / name)
inspecting_dragon.py CHANGED
@@ -19,9 +19,13 @@ class NanoArgs:
19
  # arch - general
20
  d_model : int = 768
21
  n_heads : int = 6 # head dim 128 suggested by @Grad62304977
 
22
  layers_config : str = 4*"lrdlr"
23
- expand_factor : int = 1 # expand factor for Mamba/Dragon
 
 
24
  rope_theta_local: float = 10000.0
 
25
  eps_rmsnorm: float = 1e-6
26
  mlp_expand: int = 4 # expand factor for MLP
27
  fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
@@ -32,9 +36,14 @@ class NanoArgs:
32
  zero_centered_gate_type: int = 1 # 1, 2, 3, 4
33
  gate_attn: bool = False
34
  gate_gdn: bool = True
35
- gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head)
36
  gate_act: str = "silu" # silu, sigmoid
37
  scalar_proj_as_hidden_matrix: bool = True
 
 
 
 
 
38
 
39
  # attention related
40
  n_kv_heads : int = 0
@@ -46,26 +55,38 @@ class NanoArgs:
46
  softcap_global_attn: float = 0.0
47
  qk_norm: bool = True
48
  scalable_softmax: bool = True
49
- token_shift: bool = False
 
 
 
 
50
  num_attention_heads_indexer: int = 8
51
  head_dim_indexer: int = 32
52
  dsa_q_lora_rank: int = 128
53
  dsa_topk: int = 512
54
- cca_head_dim: int = 128
55
  cca_seq_kernel_size: int = 4
56
- nsa_head_dim: int = 128
57
  nsa_topk: int = 16
58
  nsa_block_size: int = 64
59
  nsa_window_size: int = 512
 
 
 
 
60
 
61
  # GDN related
62
  rope_gdn: Optional[str] = None # None, rope, (srope)
 
63
  n_heads_gdn: int = 0
64
  n_kv_heads_gdn: int = 0
 
 
 
 
 
65
 
66
  # optim
67
  optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
68
- second_order_optim : Optional[str] = None #Snoo
69
  batch_size: int = 8*64 # batch size, in sequences, across all devices
70
  device_batch_size: int = 64 # batch size, in sequences, per device
71
  total_iterations: int = 1000 # number of iterations to run
@@ -83,14 +104,13 @@ class NanoArgs:
83
  init_std: float = 0.006
84
  patch_level_training: bool = False
85
  patch_level_training_size: int = 4
86
- patch_level_training_mode: str = "reduced" # reduced = ask L tokens, treat L//K. full = ask K*L tokens, treat L.
 
 
87
 
88
  # data
89
  vocab_size: int = 50304
90
  sequence_length: int = 1024
91
- use_patch_level_training: bool = False
92
- patch_size: int = 4
93
- patch_training_fraction: float = 0.67
94
  input_bin: Optional[str] = None
95
  input_val_bin: Optional[str] = None
96
 
@@ -116,21 +136,39 @@ args = tyro.cli(NanoArgs)
116
 
117
  # load model.
118
  config_hf = DragonConfig(
 
 
 
 
 
 
 
 
 
 
 
 
119
  scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
120
- token_shift=args.token_shift,
 
 
 
121
  patch_level_training=args.patch_level_training,
122
  patch_level_training_size=args.patch_level_training_size,
123
- nsa_head_dim=args.nsa_head_dim,
124
  nsa_topk=args.nsa_topk,
125
  nsa_block_size=args.nsa_block_size,
126
  nsa_window_size=args.nsa_window_size,
127
- cca_head_dim=args.cca_head_dim,
128
  cca_seq_kernel_size=args.cca_seq_kernel_size,
 
 
129
  num_attention_heads_gdn=args.n_heads_gdn,
130
  num_key_value_heads_gdn=args.n_kv_heads_gdn,
131
  zero_centered_gate=args.zero_centered_gate,
132
  zero_centered_gate_type=args.zero_centered_gate_type,
133
  scalable_softmax=args.scalable_softmax,
 
 
 
134
  gate_type=args.gate_type,
135
  gate_act=args.gate_act,
136
  gate_attn=args.gate_attn,
@@ -157,8 +195,12 @@ config_hf = DragonConfig(
157
  norm_epsilon=args.eps_rmsnorm,
158
  use_cache=False,
159
  sliding_window_size=args.swa_window_size,
 
 
 
160
  rope_theta_local=args.rope_theta_local,
161
  uscaling_tau=args.uscaling_tau,
 
162
  )
163
 
164
  model = DragonForCausalLM(config_hf)
 
19
  # arch - general
20
  d_model : int = 768
21
  n_heads : int = 6 # head dim 128 suggested by @Grad62304977
22
+ head_dim: Optional[int] = None
23
  layers_config : str = 4*"lrdlr"
24
+ expand_factor : int = 2 # expand factor for Mamba/Dragon
25
+ rope_type_local: str = "" #p-rope
26
+ rope_type_global: str = "" #p-rope
27
  rope_theta_local: float = 10000.0
28
+ rope_theta_global: float = 0.0
29
  eps_rmsnorm: float = 1e-6
30
  mlp_expand: int = 4 # expand factor for MLP
31
  fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
 
36
  zero_centered_gate_type: int = 1 # 1, 2, 3, 4
37
  gate_attn: bool = False
38
  gate_gdn: bool = True
39
+ gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head), kimi (lora)
40
  gate_act: str = "silu" # silu, sigmoid
41
  scalar_proj_as_hidden_matrix: bool = True
42
+ normalization_type: str = "rmsnorm" # rmsnorm, seednorm
43
+ seednorm_wd: bool = True
44
+ mixer_gn: bool = True
45
+ mlp_linking : bool = False
46
+ final_norm: bool = True
47
 
48
  # attention related
49
  n_kv_heads : int = 0
 
55
  softcap_global_attn: float = 0.0
56
  qk_norm: bool = True
57
  scalable_softmax: bool = True
58
+ resformer : bool = False # Works only on f layers (DiffAttention)
59
+ token_shift_attn: bool = False
60
+ token_shift_gdn: bool = False
61
+ token_conv1d_attn: bool = False
62
+ token_conv1d_gdn: bool = True
63
  num_attention_heads_indexer: int = 8
64
  head_dim_indexer: int = 32
65
  dsa_q_lora_rank: int = 128
66
  dsa_topk: int = 512
 
67
  cca_seq_kernel_size: int = 4
 
68
  nsa_topk: int = 16
69
  nsa_block_size: int = 64
70
  nsa_window_size: int = 512
71
+ num_signal_heads_diff: Optional[int] = None
72
+ tpa_rank: int = 2
73
+ shrink_qk_da: int = 2
74
+ mla_kv_rank: int = 128
75
 
76
  # GDN related
77
  rope_gdn: Optional[str] = None # None, rope, (srope)
78
+ head_dim_gdn: Optional[int] = None
79
  n_heads_gdn: int = 0
80
  n_kv_heads_gdn: int = 0
81
+ shrink_qk_gdn: int = 2
82
+ kda_allow_neg_eigval: bool = False
83
+ kda_num_v_heads: Optional[int] = None
84
+ mamba_mimo_dim: Optional[int] = 2
85
+ mamba_ngroups: Optional[int] = 1
86
 
87
  # optim
88
  optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
89
+ second_order_optim : Optional[str] = None # snoo
90
  batch_size: int = 8*64 # batch size, in sequences, across all devices
91
  device_batch_size: int = 64 # batch size, in sequences, per device
92
  total_iterations: int = 1000 # number of iterations to run
 
104
  init_std: float = 0.006
105
  patch_level_training: bool = False
106
  patch_level_training_size: int = 4
107
+ second_order_lr: float = 0.68
108
+ second_order_momentum: float = 0.37
109
+ second_order_interval: int = 25
110
 
111
  # data
112
  vocab_size: int = 50304
113
  sequence_length: int = 1024
 
 
 
114
  input_bin: Optional[str] = None
115
  input_val_bin: Optional[str] = None
116
 
 
136
 
137
  # load model.
138
  config_hf = DragonConfig(
139
+ final_norm=args.final_norm,
140
+ mla_kv_rank=args.mla_kv_rank,
141
+ rope_gdn=args.rope_gdn,
142
+ shrink_qk_da=args.shrink_qk_da,
143
+ shrink_qk_gdn=args.shrink_qk_gdn,
144
+ mixer_gn=args.mixer_gn,
145
+ kda_allow_neg_eigval=args.kda_allow_neg_eigval,
146
+ kda_num_v_heads=args.kda_num_v_heads,
147
+ seednorm_wd=args.seednorm_wd,
148
+ normalization_type=args.normalization_type,
149
+ tpa_rank=args.tpa_rank,
150
+ num_signal_heads_diff=args.num_signal_heads_diff,
151
  scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
152
+ token_shift_attn=args.token_shift_attn,
153
+ token_shift_gdn=args.token_shift_gdn,
154
+ token_conv1d_attn=args.token_conv1d_attn,
155
+ token_conv1d_gdn=args.token_conv1d_gdn,
156
  patch_level_training=args.patch_level_training,
157
  patch_level_training_size=args.patch_level_training_size,
 
158
  nsa_topk=args.nsa_topk,
159
  nsa_block_size=args.nsa_block_size,
160
  nsa_window_size=args.nsa_window_size,
 
161
  cca_seq_kernel_size=args.cca_seq_kernel_size,
162
+ head_dim=args.head_dim,
163
+ head_dim_gdn=args.head_dim_gdn,
164
  num_attention_heads_gdn=args.n_heads_gdn,
165
  num_key_value_heads_gdn=args.n_kv_heads_gdn,
166
  zero_centered_gate=args.zero_centered_gate,
167
  zero_centered_gate_type=args.zero_centered_gate_type,
168
  scalable_softmax=args.scalable_softmax,
169
+ mamba_mimo_dim=args.mamba_mimo_dim,
170
+ mamba_ngroups=args.mamba_ngroups,
171
+ resformer=args.resformer,
172
  gate_type=args.gate_type,
173
  gate_act=args.gate_act,
174
  gate_attn=args.gate_attn,
 
195
  norm_epsilon=args.eps_rmsnorm,
196
  use_cache=False,
197
  sliding_window_size=args.swa_window_size,
198
+ rope_type_global=args.rope_type_global,
199
+ rope_type_local=args.rope_type_local,
200
+ rope_theta_global=args.rope_theta_global,
201
  rope_theta_local=args.rope_theta_local,
202
  uscaling_tau=args.uscaling_tau,
203
+ mlp_linking=args.mlp_linking
204
  )
205
 
206
  model = DragonForCausalLM(config_hf)
modeling_dragon.py CHANGED
@@ -19,11 +19,20 @@ from transformers.utils import ModelOutput, logging
19
 
20
  from fla.ops.nsa.parallel import parallel_nsa
21
 
 
 
 
 
 
22
  try:
23
  from dragon_mamba3_ops.siso_variant.ssd_combined_fused import mamba_chunk_scan_discretized_combined
 
24
  from dragon_mamba3_ops.angle_cumsum import angle_dt
25
  from dragon_mamba3_ops.rotary_mamba import rotary_qk
26
- except ImportError:
 
 
 
27
  mamba_chunk_scan_discretized_combined, angle_dt, rotary_qk = None, None, None
28
 
29
  try:
@@ -39,8 +48,9 @@ try:
39
  from fla.ops.kda import chunk_kda, fused_recurrent_kda
40
  from fla.ops.kda.gate import fused_kda_gate
41
  from fla.modules import FusedRMSNormGated, ShortConvolution
 
42
  except ImportError:
43
- chunk_kda, fused_recurrent_kda, fused_kda_gate = None, None, None
44
 
45
  from torch.compiler import disable
46
 
@@ -56,13 +66,14 @@ ATTN_IMPL = "eager"
56
  try:
57
  import flash_attn_interface # FA3
58
  flash_attn_func = flash_attn_interface.flash_attn_func
 
59
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
60
  if not _flash_supports_window_size:
61
  raise ImportError("flash_attn_func does not support window_size parameter. Please update to more recent flash_attn version")
62
  ATTN_IMPL = "fa3"
63
  except ImportError:
64
  try:
65
- from flash_attn import flash_attn_func # FA2
66
  ATTN_IMPL = "fa2"
67
  except ImportError:
68
  try:
@@ -123,7 +134,16 @@ class DragonNorm(nn.Module):
123
  if config.normalization_type == "rmsnorm":
124
  self.norm = DragonRMSNorm(hidden_size, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
125
  elif config.normalization_type == "seednorm":
126
- self.norm = DragonSeeDNorm(config, hidden_size, eps=config.norm_epsilon)
 
 
 
 
 
 
 
 
 
127
  else:
128
  raise ValueError(f"Unknown normalization_type: {config.normalization_type}")
129
 
@@ -159,6 +179,54 @@ class DragonSeeDNorm(nn.Module):
159
  dynamic_scale = rescale.unsqueeze(-1) * self.alpha # (B, L, D)
160
  return (dynamic_scale + self.gamma) * self.rms(hidden_states)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  class DragonLayerNorm(nn.Module):
163
  def __init__(self, hidden_size, eps=1e-6): # TODO: ZCG ?
164
  super().__init__()
@@ -1696,6 +1764,8 @@ class DragonDifferentialAttention(nn.Module):
1696
  hidden_states: torch.Tensor,
1697
  position_ids: Optional[torch.LongTensor] = None,
1698
  cache_params: Optional[HybridDragonDynamicCache] = None,
 
 
1699
  **kwargs,
1700
  ):
1701
  _, q_len, _ = hidden_states.shape
@@ -1747,6 +1817,17 @@ class DragonDifferentialAttention(nn.Module):
1747
  k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
1748
  v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
1749
 
 
 
 
 
 
 
 
 
 
 
 
1750
  key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
1751
  value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
1752
 
@@ -1859,18 +1940,28 @@ class DragonDifferentialAttention(nn.Module):
1859
  elif DIFF_ATTN_IMPL == "fa2":
1860
  def diff_attention_interface(q, k, v, wsize, **kw):
1861
  if self.head_qk_dim == self.head_v_dim:
1862
- return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)
 
 
 
1863
  D = v.size(3)
1864
  v1 = v[:, :, :, :D//2]
1865
  v2 = v[:, :, :, D//2:]
1866
- o1 = flash_attn_func(q, k, v1, window_size=(wsize, 0), **kw)
1867
- o2 = flash_attn_func(q, k, v2, window_size=(wsize, 0), **kw)
 
 
 
 
1868
  o = torch.cat([o1, o2], dim=-1)
1869
  return o
1870
  elif DIFF_ATTN_IMPL == "fa3":
1871
  def diff_attention_interface(q, k, v, wsize, **kw):
1872
  if self.head_qk_dim == self.head_v_dim:
1873
- return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)[0]
 
 
 
1874
  D = v.size(3)
1875
  v1 = v[:, :, :, :D//2]
1876
  v2 = v[:, :, :, D//2:]
@@ -2350,6 +2441,8 @@ class DragonDifferentialTensorProductAttention(nn.Module):
2350
  hidden_states: torch.Tensor,
2351
  position_ids: Optional[torch.LongTensor] = None,
2352
  cache_params: Optional[HybridDragonDynamicCache] = None,
 
 
2353
  **kwargs,
2354
  ):
2355
  b, q_len, _ = hidden_states.shape
@@ -2398,6 +2491,17 @@ class DragonDifferentialTensorProductAttention(nn.Module):
2398
  k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
2399
  v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
2400
 
 
 
 
 
 
 
 
 
 
 
 
2401
  key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
2402
  value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
2403
 
@@ -2510,7 +2614,10 @@ class DragonDifferentialTensorProductAttention(nn.Module):
2510
  elif DIFF_ATTN_IMPL == "fa2":
2511
  def diff_attention_interface(q, k, v, wsize, **kw):
2512
  if self.head_qk_dim == self.head_v_dim:
2513
- return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)
 
 
 
2514
  D = v.size(3)
2515
  v1 = v[:, :, :, :D//2]
2516
  v2 = v[:, :, :, D//2:]
@@ -2521,7 +2628,10 @@ class DragonDifferentialTensorProductAttention(nn.Module):
2521
  elif DIFF_ATTN_IMPL == "fa3":
2522
  def diff_attention_interface(q, k, v, wsize, **kw):
2523
  if self.head_qk_dim == self.head_v_dim:
2524
- return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)[0]
 
 
 
2525
  D = v.size(3)
2526
  v1 = v[:, :, :, :D//2]
2527
  v2 = v[:, :, :, D//2:]
@@ -3102,6 +3212,7 @@ class DragonGatedDeltaNet(nn.Module):
3102
  hidden_states: torch.Tensor,
3103
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
3104
  cache_params: Optional[HybridDragonDynamicCache] = None,
 
3105
  **kwargs,
3106
  ):
3107
  _, q_len, _ = hidden_states.shape
@@ -3164,12 +3275,15 @@ class DragonGatedDeltaNet(nn.Module):
3164
  conv_cache = F.pad(mixed_qkv, (self.conv_size - mixed_qkv.shape[-1], 0))
3165
  cache_params.conv_caches[self.layer_idx] = conv_cache
3166
  if self.causal_conv1d_fn is not None:
 
 
 
3167
  mixed_qkv = self.causal_conv1d_fn(
3168
  x=mixed_qkv,
3169
  weight=self.qkv_conv1d.weight.squeeze(1),
3170
  bias=self.qkv_conv1d.bias,
3171
  activation='silu',
3172
- seq_idx=None,
3173
  )
3174
  else:
3175
  mixed_qkv = F.silu(self.qkv_conv1d(mixed_qkv)[:, :, :q_len])
@@ -3216,7 +3330,8 @@ class DragonGatedDeltaNet(nn.Module):
3216
  scale=None if not self.config.use_uscaling else 1/self.dk,
3217
  initial_state=None,
3218
  output_final_state=cache_params is not None,
3219
- use_qk_l2norm_in_kernel=True
 
3220
  ) # (B L H dv)
3221
  else:
3222
  o, ssm_cache = self.recurrent_gated_delta_rule(
@@ -3404,19 +3519,16 @@ class DragonMamba3(nn.Module):
3404
  )
3405
 
3406
  self.d_model = config.hidden_size
3407
- self.d_state = 64
3408
  self.conv_init = None
3409
  self.expand = 2
3410
- self.headdim = 128
3411
- self.ngroups = 20
3412
  self.activation = "swish"
3413
  self.bias = False
3414
- self.conv_bias = True
3415
  self.chunk_size = 128
3416
  self.A_floor = 1e-4
3417
  self.rope_fraction = 0.5
3418
- self.remove_conv = True
3419
- self.add_conv_activation = False
3420
  self.dt_min = 0.001
3421
  self.dt_max = 0.1
3422
  self.dt_init_floor = 1e-4
@@ -3432,13 +3544,24 @@ class DragonMamba3(nn.Module):
3432
  if self.split_tensor_size == 0:
3433
  return
3434
 
3435
- self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
 
3436
 
3437
  # Order: [x, B, C, dt]
3438
  d_in_proj = self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
3439
 
3440
- self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
3441
- self.trapezoid_proj = DragonLinear(config, self.d_model, self.nheads, bias=False)
 
 
 
 
 
 
 
 
 
 
3442
 
3443
  _dt = torch.exp(
3444
  torch.rand(self.nheads) * (math.log(self.dt_max) - math.log(self.dt_min))
@@ -3447,21 +3570,25 @@ class DragonMamba3(nn.Module):
3447
  _dt = torch.clamp(_dt, min=self.dt_init_floor)
3448
  _dt_bias = _dt + torch.log(-torch.expm1(-_dt))
3449
  self.dt_bias = nn.Parameter(_dt_bias, requires_grad=True)
 
3450
 
3451
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
3452
 
3453
- self.B_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
3454
- self.C_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
 
 
3455
 
3456
- self.B_norm = DragonNorm(config, self.d_state)
3457
- self.C_norm = DragonNorm(config, self.d_state)
 
3458
 
3459
- if not self.remove_conv:
3460
  conv_dim = self.d_inner + 2 * self.d_state * self.ngroups
3461
  self.conv1d = nn.Conv1d(
3462
  in_channels=conv_dim,
3463
  out_channels=conv_dim,
3464
- bias=self.conv_bias,
3465
  kernel_size=4,
3466
  groups=conv_dim,
3467
  )
@@ -3473,8 +3600,14 @@ class DragonMamba3(nn.Module):
3473
 
3474
  # D "skip" parameter
3475
  self.D = nn.Parameter(torch.ones(self.nheads))
 
3476
 
3477
- def forward(self, hidden_states, **kwargs):
 
 
 
 
 
3478
  # Apply in_proj
3479
  xBCdt = self.in_proj(hidden_states)
3480
  xBC, dd_dt = torch.split(
@@ -3485,16 +3618,19 @@ class DragonMamba3(nn.Module):
3485
  ],
3486
  dim=-1)
3487
 
3488
- _A = -F.softplus((self.A_proj(hidden_states.to(torch.float32))).to(torch.float32)) # (B, L, N)
3489
- _A = torch.clamp(_A, max=-self.A_floor)
 
 
 
3490
  dt = F.softplus(dd_dt + self.dt_bias) # (B, L, N)
3491
 
3492
- if not self.remove_conv:
3493
  xBC = causal_conv1d_fn(
3494
  x=xBC.transpose(1, 2),
3495
  weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
3496
  bias=self.conv1d.bias,
3497
- activation=self.activation if self.add_conv_activation else None,
3498
  ).transpose(1, 2) # (B, L, self.d_inner + 2 * ngroups * d_state)
3499
 
3500
  x, B, C = torch.split(
@@ -3507,37 +3643,64 @@ class DragonMamba3(nn.Module):
3507
  B = rearrange(B, "b l (g n) -> b l g n", g=self.ngroups)
3508
  C = rearrange(C, "b l (g n) -> b l g n", g=self.ngroups)
3509
 
3510
- B = self.B_norm(B)
3511
- C = self.C_norm(C)
 
3512
 
3513
  if self.ngroups != self.nheads:
3514
  B = B.expand(-1, -1, self.nheads, -1) # (B, L, N, S)
3515
  C = C.expand(-1, -1, self.nheads, -1) # (B, L, N, S)
3516
 
3517
- angle = self.rope_proj(hidden_states) # (B, L, S)
3518
- angle = angle.unsqueeze(-2).expand(-1, -1, self.nheads, -1) # (B, L, G, S)
3519
- angle = angle_dt(angle, dt)
 
3520
 
3521
- C, B, CB_sum = rotary_qk(q=C, k=B, angle=angle, bias_q=self.C_bias, bias_k=self.B_bias, conjugate=False, inplace=False)
 
 
 
 
 
 
 
 
 
 
 
3522
 
3523
  x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim)
3524
 
3525
  A = _A * dt
3526
  gating_factor = dt # B, L, N
3527
 
3528
- trap = F.sigmoid(self.trapezoid_proj(hidden_states)) # (B, L, N)
 
3529
 
3530
- alpha_arr = torch.exp(A)
3531
- beta_arr = (1-trap)*gating_factor*alpha_arr
3532
- gamma_arr = trap*gating_factor
3533
 
3534
- # roll alpha and beta to the left by 1
3535
- _alpha_arr = torch.roll(alpha_arr, shifts=-1, dims=1)
3536
- _beta_arr = torch.roll(beta_arr, shifts=-1, dims=1)
3537
 
3538
- x_scalar = (gamma_arr*_alpha_arr + _beta_arr).to(torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
3539
 
3540
- y = mamba_chunk_scan_discretized_combined(
3541
  x=x.bfloat16(),
3542
  A=A,
3543
  B=B.bfloat16(),
@@ -3547,11 +3710,117 @@ class DragonMamba3(nn.Module):
3547
  gamma=gamma_arr,
3548
  CB_sum=CB_sum,
3549
  D=self.D,
3550
- z=None
 
 
3551
  )
3552
 
 
 
 
 
 
 
3553
  return y, None, None
3554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3555
 
3556
  class DragonMamba3Mimo(nn.Module):
3557
  def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
@@ -3570,7 +3839,7 @@ class DragonMamba3Mimo(nn.Module):
3570
  self.conv_init = None
3571
  self.expand = 2
3572
  self.headdim = 128
3573
- self.ngroups = 20
3574
  self.activation = "swish"
3575
  self.bias = False
3576
  self.conv_bias = True
@@ -3604,7 +3873,7 @@ class DragonMamba3Mimo(nn.Module):
3604
  # Order: [z, x, B, C, dt]
3605
  d_in_proj = 2 * self.d_inner + 2 * self.d_state * self.ngroups * self.mimo_dim + self.nheads
3606
 
3607
- self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False) # dtype=float32
3608
  self.trapezoid_proj = DragonLinear(config, self.d_model, self.nheads, bias=False)
3609
 
3610
  _dt = torch.exp(
@@ -3618,9 +3887,9 @@ class DragonMamba3Mimo(nn.Module):
3618
 
3619
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
3620
 
3621
- self.B_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
3622
- self.C_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
3623
-
3624
  self.B_norm = DragonNorm(config, self.d_state)
3625
  self.C_norm = DragonNorm(config, self.d_state)
3626
 
@@ -3655,9 +3924,9 @@ class DragonMamba3Mimo(nn.Module):
3655
 
3656
  def forward(self, hidden_states, **kwargs):
3657
  # Apply in_proj
3658
- xBCdt = self.in_proj(hidden_states)
3659
  z, xBC, dd_dt = torch.split(
3660
- xBCdt,
3661
  [
3662
  self.d_inner,
3663
  self.d_inner + 2 * self.d_state * self.ngroups * self.mimo_dim,
@@ -3719,14 +3988,15 @@ class DragonMamba3Mimo(nn.Module):
3719
  C = self.C_norm(C)
3720
 
3721
  if self.ngroups != self.nheads:
3722
- B = B.expand(-1, -1, self.nheads, -1) # (B, L, R, N, S)
3723
- C = C.expand(-1, -1, self.nheads, -1) # (B, L, R, N, S)
 
3724
 
3725
  angle = self.rope_proj(hidden_states) # (B, L, S)
3726
  angle = angle.unsqueeze(-2).expand(-1, -1, self.nheads, -1) # (B, L, G, S)
3727
  angle = angle_dt(angle, dt)
3728
 
3729
- C, B, CB_sum = rotary_qk(q=C, k=B, angle=angle, bias_q=self.C_bias, bias_k=self.B_bias, conjugate=False, inplace=False)
3730
 
3731
  x = rearrange(x, "b l r (h p) -> b l r h p", p=self.headdim)
3732
 
@@ -3747,7 +4017,7 @@ class DragonMamba3Mimo(nn.Module):
3747
 
3748
  z = rearrange(z, "b l r (h p) -> b l r h p", p=self.headdim)
3749
 
3750
- y = mamba_chunk_scan_discretized_combined(
3751
  x=x.bfloat16(),
3752
  A=A.bfloat16(),
3753
  B=B.bfloat16(),
@@ -3761,31 +4031,33 @@ class DragonMamba3Mimo(nn.Module):
3761
  )
3762
 
3763
  y = rearrange(y, "b l r h p -> b l r (h p)")
3764
- if seqlen_og is not None:
3765
- y = rearrange(y, "b l r d -> (b l) r d")
3766
 
3767
  # Perform MIMO down projection (mimo_rank*d_inner -> d_inner)
3768
  y = rearrange(y, "b l r d -> b l (r d)")
3769
  y = rearrange(y, "b l (g d) -> b l g d", g=self.mimo_dim*self.mimo_proj_block_order)
3770
  y = torch.einsum("blgd,drg->bldr", y, self.out_proj_mimo)
3771
  y = rearrange(y, "b l d r -> b l (d r)")
 
3772
 
3773
  return y, None, None
3774
 
3775
  class DragonMLP(nn.Module):
3776
- def __init__(self, config: DragonConfig):
3777
  super().__init__()
3778
  self.config = config
 
3779
  #print("previous MLP : ", PREVIOUS_MLP)
3780
  self.link_size = 16
3781
  self.mlp_linking = config.mlp_linking and PREVIOUS_MLP is not None
3782
  if self.mlp_linking:
3783
  self.previous_mlp = PREVIOUS_MLP
3784
- self.fc_1 = DragonLinear(config, config.hidden_size, config.intermediate_size, bias=False)
3785
  self.lambda1 = nn.Parameter(torch.zeros(self.link_size)) # sigmoid->0.5
3786
  else :
3787
- self.fc_1 = DragonLinear(config, config.hidden_size, config.intermediate_size, bias=False)
3788
- self.fc_2 = DragonLinear(config, config.intermediate_size, config.hidden_size, bias=False)
3789
  self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
3790
 
3791
  def forward(self, hidden_states):
@@ -3803,7 +4075,51 @@ class DragonMLP(nn.Module):
3803
  return hidden_states
3804
 
3805
  def get_mlp_link(self):
3806
- return self.mlp_link
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3807
 
3808
  PREVIOUS_MLP = None
3809
  class DragonMonoBlock(GradientCheckpointingLayer):
@@ -3878,6 +4194,16 @@ class DragonMonoBlock(GradientCheckpointingLayer):
3878
  head_dim = self.mixer.headdim
3879
  num_attention_heads = self.mixer.nheads
3880
  use_gate = config.gate_gdn
 
 
 
 
 
 
 
 
 
 
3881
  else:
3882
  raise ValueError(f"Unknown layer type: {layer_type}")
3883
 
@@ -3922,7 +4248,10 @@ class DragonMonoBlock(GradientCheckpointingLayer):
3922
 
3923
  self.input_norm = DragonNorm(config, config.hidden_size)
3924
  self.postmixer_norm = DragonNorm(config, config.hidden_size)
3925
- self.mlp = DragonMLP(config)
 
 
 
3926
  global PREVIOUS_MLP
3927
  PREVIOUS_MLP = self.mlp
3928
 
@@ -3938,6 +4267,8 @@ class DragonMonoBlock(GradientCheckpointingLayer):
3938
  cache_position: Optional[torch.LongTensor] = None,
3939
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
3940
  key_value_last_layer: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 
 
3941
  **kwargs,
3942
  ):
3943
  # MIXER.
@@ -3949,6 +4280,8 @@ class DragonMonoBlock(GradientCheckpointingLayer):
3949
  position_ids=position_ids,
3950
  cache_params=cache_params,
3951
  key_value_last_layer=key_value_last_layer,
 
 
3952
  ) # (B, L, E*D)
3953
  if self.use_gate:
3954
  if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
@@ -4126,8 +4459,13 @@ class DragonModel(DragonPreTrainedModel):
4126
  self.embedding = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
4127
  self.layers = nn.ModuleList([DragonBlock(config, layer_idx=i, layer_type=layer) if layer in ['l', 'r', 'd'] else DragonMonoBlock(config, layer_idx=i, layer_type=layer) for i, layer in enumerate(config.layers_config)])
4128
 
4129
- self.rotary_emb = DragonRotaryEmbedding(config, head_dim=config.head_dim if config.head_dim else (config.expand_factor*config.hidden_size)//config.num_attention_heads, theta=config.rope_theta_local) # only for SWA
4130
- self.final_norm = DragonNorm(config, config.hidden_size)
 
 
 
 
 
4131
 
4132
  self.gradient_checkpointing = False
4133
  self.post_init()
@@ -4148,6 +4486,8 @@ class DragonModel(DragonPreTrainedModel):
4148
  cache_position: Optional[torch.LongTensor] = None,
4149
  output_hidden_states: Optional[bool] = None,
4150
  inputs_embeds: Optional[torch.FloatTensor] = None,
 
 
4151
  **kwargs
4152
  ) -> DragonOutput:
4153
  B, L = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
@@ -4191,7 +4531,10 @@ class DragonModel(DragonPreTrainedModel):
4191
 
4192
  all_hidden_states = () if output_hidden_states else None
4193
 
4194
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
 
 
4195
 
4196
  shared_kv = (None, None)
4197
  for block in self.layers:
@@ -4205,11 +4548,14 @@ class DragonModel(DragonPreTrainedModel):
4205
  cache_position=cache_position,
4206
  position_embeddings=position_embeddings,
4207
  key_value_last_layer=shared_kv,
 
 
4208
  **kwargs,
4209
  )
4210
  shared_kv = (last_k, last_v)
4211
 
4212
- hidden_states = self.final_norm(hidden_states)
 
4213
 
4214
  if output_hidden_states:
4215
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -4242,6 +4588,9 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
4242
  cache_position: Optional[torch.Tensor] = None,
4243
  output_hidden_states: Optional[bool] = None,
4244
  attention_mask: Optional[torch.Tensor] = None,
 
 
 
4245
  token_type_ids=None,
4246
  **kwargs,
4247
  ) -> DragonCausalLMOutput:
@@ -4256,6 +4605,8 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
4256
  cache_position=cache_position,
4257
  inputs_embeds=inputs_embeds,
4258
  output_hidden_states=output_hidden_states,
 
 
4259
  **kwargs,
4260
  )
4261
 
@@ -4299,9 +4650,9 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
4299
 
4300
  return DragonCausalLMOutput(
4301
  loss=loss,
4302
- logits=logits,
4303
- past_key_values=outputs.past_key_values,
4304
- hidden_states=outputs.hidden_states,
4305
  )
4306
  DragonForCausalLM.register_for_auto_class("AutoModelForCausalLM")
4307
 
 
19
 
20
  from fla.ops.nsa.parallel import parallel_nsa
21
 
22
+ try:
23
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
24
+ except ImportError:
25
+ mamba_chunk_scan_combined = None
26
+
27
  try:
28
  from dragon_mamba3_ops.siso_variant.ssd_combined_fused import mamba_chunk_scan_discretized_combined
29
+ from dragon_mamba3_ops.mimo_variant.ssd_mimo import mamba_chunk_scan_discretized_fused_combined as mamba_mimo_chunk_scan_discretized_fused_combined
30
  from dragon_mamba3_ops.angle_cumsum import angle_dt
31
  from dragon_mamba3_ops.rotary_mamba import rotary_qk
32
+ from dragon_mamba3_ops.rotary_mamba_mimo import rotary_qk as mimo_rotary_qk
33
+ except ImportError as exc:
34
+ print("Warning: No Mamba-3 found !")
35
+ print(exc)
36
  mamba_chunk_scan_discretized_combined, angle_dt, rotary_qk = None, None, None
37
 
38
  try:
 
48
  from fla.ops.kda import chunk_kda, fused_recurrent_kda
49
  from fla.ops.kda.gate import fused_kda_gate
50
  from fla.modules import FusedRMSNormGated, ShortConvolution
51
+ from fla.ops.utils import prepare_sequence_ids
52
  except ImportError:
53
+ chunk_kda, fused_recurrent_kda, fused_kda_gate, prepare_sequence_ids = None, None, None, None
54
 
55
  from torch.compiler import disable
56
 
 
66
  try:
67
  import flash_attn_interface # FA3
68
  flash_attn_func = flash_attn_interface.flash_attn_func
69
+ flash_attn_varlen_func = flash_attn_interface.flash_attn_varlen_func
70
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
71
  if not _flash_supports_window_size:
72
  raise ImportError("flash_attn_func does not support window_size parameter. Please update to more recent flash_attn version")
73
  ATTN_IMPL = "fa3"
74
  except ImportError:
75
  try:
76
+ from flash_attn import flash_attn_func, flash_attn_varlen_func # FA2
77
  ATTN_IMPL = "fa2"
78
  except ImportError:
79
  try:
 
134
  if config.normalization_type == "rmsnorm":
135
  self.norm = DragonRMSNorm(hidden_size, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
136
  elif config.normalization_type == "seednorm":
137
+ if config.seednorm_type == 1:
138
+ self.norm = DragonSeeDNorm(config, hidden_size, eps=config.norm_epsilon)
139
+ elif config.seednorm_type == 2:
140
+ self.norm = DragonSeeDNormType2(config, hidden_size, eps=config.norm_epsilon)
141
+ elif config.seednorm_type == 3:
142
+ self.norm = DragonSeeDNormType3(config, hidden_size, eps=config.norm_epsilon)
143
+ elif config.seednorm_type == 4:
144
+ self.norm = DragonSeeDNormType4(config, hidden_size, eps=config.norm_epsilon)
145
+ else:
146
+ raise ValueError(f"Unknown seednorm_type: {config.seednorm_type}")
147
  else:
148
  raise ValueError(f"Unknown normalization_type: {config.normalization_type}")
149
 
 
179
  dynamic_scale = rescale.unsqueeze(-1) * self.alpha # (B, L, D)
180
  return (dynamic_scale + self.gamma) * self.rms(hidden_states)
181
 
182
+ class DragonSeeDNormType2(nn.Module):
183
+ def __init__(self, config: DragonConfig, hidden_size, eps=1e-6):
184
+ super().__init__()
185
+ self.hidden_size = hidden_size
186
+
187
+ self.beta = DragonLinear(config, hidden_size, 1, bias=False)
188
+ self.alpha = nn.Parameter(torch.ones(hidden_size) * 1.)
189
+ if config.seednorm_wd:
190
+ self.alpha.requires_weight_decay = True
191
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
192
+ self.rms = nn.RMSNorm(hidden_size, eps=eps, elementwise_affine=False)
193
+
194
+ def forward(self, hidden_states):
195
+ rescale = F.tanh(self.beta(hidden_states)) # (B, L, 1)
196
+ dynamic_scale = rescale * self.alpha # (B, L, D)
197
+ return (dynamic_scale + self.gamma) * self.rms(hidden_states)
198
+
199
+ class DragonSeeDNormType3(nn.Module):
200
+ def __init__(self, config: DragonConfig, hidden_size, eps=1e-6):
201
+ super().__init__()
202
+ self.hidden_size = hidden_size
203
+
204
+ self.beta = nn.Sequential(
205
+ DragonLinear(config, hidden_size, config.seednorm_rank, bias=False),
206
+ DragonLinear(config, config.seednorm_rank, hidden_size, bias=False),
207
+ )
208
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
209
+ self.rms = nn.RMSNorm(hidden_size, eps=eps, elementwise_affine=False)
210
+
211
+ def forward(self, hidden_states):
212
+ dynamic_rescale = F.tanh(self.beta(hidden_states)) # (B, L, D)
213
+ return (dynamic_rescale + self.gamma) * self.rms(hidden_states)
214
+
215
+ class DragonSeeDNormType4(nn.Module):
216
+ def __init__(self, config: DragonConfig, hidden_size, eps=1e-6):
217
+ super().__init__()
218
+ self.hidden_size = hidden_size
219
+
220
+ self.beta = nn.Sequential(
221
+ DragonLinear(config, hidden_size, config.seednorm_rank, bias=False),
222
+ DragonLinear(config, config.seednorm_rank, hidden_size, bias=False),
223
+ )
224
+ self.rms = nn.RMSNorm(hidden_size, eps=eps, elementwise_affine=False)
225
+
226
+ def forward(self, hidden_states):
227
+ dynamic_rescale = F.silu(self.beta(hidden_states) + 1.15) # (B, L, D)
228
+ return dynamic_rescale * self.rms(hidden_states)
229
+
230
  class DragonLayerNorm(nn.Module):
231
  def __init__(self, hidden_size, eps=1e-6): # TODO: ZCG ?
232
  super().__init__()
 
1764
  hidden_states: torch.Tensor,
1765
  position_ids: Optional[torch.LongTensor] = None,
1766
  cache_params: Optional[HybridDragonDynamicCache] = None,
1767
+ cu_seqlens: Optional[torch.Tensor] = None,
1768
+ max_seqlen: Optional[int] = None,
1769
  **kwargs,
1770
  ):
1771
  _, q_len, _ = hidden_states.shape
 
1817
  k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
1818
  v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
1819
 
1820
+ if position_ids is not None:
1821
+ # first token of each doc has pos==0
1822
+ doc_start = (position_ids == 0) # (B, L) bool
1823
+ m = doc_start.unsqueeze(-1).unsqueeze(-1) # (B, L, 1, 1) bool
1824
+
1825
+ # zero the previous contribution at boundaries
1826
+ k_prev = k_prev.masked_fill(m, 0)
1827
+ v_prev = v_prev.masked_fill(m, 0)
1828
+ alpha_k = alpha_k.masked_fill(m, 0)
1829
+ alpha_v = alpha_v.masked_fill(m, 0)
1830
+
1831
  key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
1832
  value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
1833
 
 
1940
  elif DIFF_ATTN_IMPL == "fa2":
1941
  def diff_attention_interface(q, k, v, wsize, **kw):
1942
  if self.head_qk_dim == self.head_v_dim:
1943
+ if not self.config.intra_doc_masking:
1944
+ return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)
1945
+ else:
1946
+ return flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw).unsqueeze(0)
1947
  D = v.size(3)
1948
  v1 = v[:, :, :, :D//2]
1949
  v2 = v[:, :, :, D//2:]
1950
+ if not self.config.intra_doc_masking:
1951
+ o1 = flash_attn_func(q, k, v1, window_size=(wsize, 0), **kw)
1952
+ o2 = flash_attn_func(q, k, v2, window_size=(wsize, 0), **kw)
1953
+ else:
1954
+ o1 = flash_attn_varlen_func(q[0], k[0], v1[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw).unsqueeze(0)
1955
+ o2 = flash_attn_varlen_func(q[0], k[0], v2[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw).unsqueeze(0)
1956
  o = torch.cat([o1, o2], dim=-1)
1957
  return o
1958
  elif DIFF_ATTN_IMPL == "fa3":
1959
  def diff_attention_interface(q, k, v, wsize, **kw):
1960
  if self.head_qk_dim == self.head_v_dim:
1961
+ if not self.config.intra_doc_masking:
1962
+ return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)[0]
1963
+ else:
1964
+ return flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw)[0].unsqueeze(0)
1965
  D = v.size(3)
1966
  v1 = v[:, :, :, :D//2]
1967
  v2 = v[:, :, :, D//2:]
 
2441
  hidden_states: torch.Tensor,
2442
  position_ids: Optional[torch.LongTensor] = None,
2443
  cache_params: Optional[HybridDragonDynamicCache] = None,
2444
+ cu_seqlens: Optional[torch.Tensor] = None,
2445
+ max_seqlen: Optional[int] = None,
2446
  **kwargs,
2447
  ):
2448
  b, q_len, _ = hidden_states.shape
 
2491
  k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
2492
  v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
2493
 
2494
+ if position_ids is not None:
2495
+ # first token of each doc has pos==0
2496
+ doc_start = (position_ids == 0) # (B, L) bool
2497
+ m = doc_start.unsqueeze(-1).unsqueeze(-1) # (B, L, 1, 1) bool
2498
+
2499
+ # zero the previous contribution at boundaries
2500
+ k_prev = k_prev.masked_fill(m, 0)
2501
+ v_prev = v_prev.masked_fill(m, 0)
2502
+ alpha_k = alpha_k.masked_fill(m, 0)
2503
+ alpha_v = alpha_v.masked_fill(m, 0)
2504
+
2505
  key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
2506
  value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
2507
 
 
2614
  elif DIFF_ATTN_IMPL == "fa2":
2615
  def diff_attention_interface(q, k, v, wsize, **kw):
2616
  if self.head_qk_dim == self.head_v_dim:
2617
+ if not self.config.intra_doc_masking:
2618
+ return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)
2619
+ else:
2620
+ return flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw).unsqueeze(0)
2621
  D = v.size(3)
2622
  v1 = v[:, :, :, :D//2]
2623
  v2 = v[:, :, :, D//2:]
 
2628
  elif DIFF_ATTN_IMPL == "fa3":
2629
  def diff_attention_interface(q, k, v, wsize, **kw):
2630
  if self.head_qk_dim == self.head_v_dim:
2631
+ if not self.config.intra_doc_masking:
2632
+ return flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)[0]
2633
+ else:
2634
+ return flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, window_size=(wsize, 0), **kw)[0].unsqueeze(0)
2635
  D = v.size(3)
2636
  v1 = v[:, :, :, :D//2]
2637
  v2 = v[:, :, :, D//2:]
 
3212
  hidden_states: torch.Tensor,
3213
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
3214
  cache_params: Optional[HybridDragonDynamicCache] = None,
3215
+ cu_seqlens: Optional[torch.Tensor] = None,
3216
  **kwargs,
3217
  ):
3218
  _, q_len, _ = hidden_states.shape
 
3275
  conv_cache = F.pad(mixed_qkv, (self.conv_size - mixed_qkv.shape[-1], 0))
3276
  cache_params.conv_caches[self.layer_idx] = conv_cache
3277
  if self.causal_conv1d_fn is not None:
3278
+ seq_idx = None
3279
+ if cu_seqlens is not None:
3280
+ seq_idx = prepare_sequence_ids(cu_seqlens).to(torch.int32).unsqueeze(0)
3281
  mixed_qkv = self.causal_conv1d_fn(
3282
  x=mixed_qkv,
3283
  weight=self.qkv_conv1d.weight.squeeze(1),
3284
  bias=self.qkv_conv1d.bias,
3285
  activation='silu',
3286
+ seq_idx=seq_idx,
3287
  )
3288
  else:
3289
  mixed_qkv = F.silu(self.qkv_conv1d(mixed_qkv)[:, :, :q_len])
 
3330
  scale=None if not self.config.use_uscaling else 1/self.dk,
3331
  initial_state=None,
3332
  output_final_state=cache_params is not None,
3333
+ use_qk_l2norm_in_kernel=True,
3334
+ cu_seqlens=cu_seqlens,
3335
  ) # (B L H dv)
3336
  else:
3337
  o, ssm_cache = self.recurrent_gated_delta_rule(
 
3519
  )
3520
 
3521
  self.d_model = config.hidden_size
3522
+ self.d_state = 128
3523
  self.conv_init = None
3524
  self.expand = 2
3525
+ self.headdim = 64
3526
+ self.ngroups = config.mamba_ngroups
3527
  self.activation = "swish"
3528
  self.bias = False
 
3529
  self.chunk_size = 128
3530
  self.A_floor = 1e-4
3531
  self.rope_fraction = 0.5
 
 
3532
  self.dt_min = 0.001
3533
  self.dt_max = 0.1
3534
  self.dt_init_floor = 1e-4
 
3544
  if self.split_tensor_size == 0:
3545
  return
3546
 
3547
+ if config.mamba3_rope:
3548
+ self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
3549
 
3550
  # Order: [x, B, C, dt]
3551
  d_in_proj = self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
3552
 
3553
+ if self.config.mamba3_is_A_dd:
3554
+ self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
3555
+ else:
3556
+ A_init_range = (1, 16)
3557
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
3558
+ A = torch.empty(self.nheads, dtype=torch.float32).uniform_(*A_init_range)
3559
+ A_log = torch.log(A).to(dtype=torch.float32)
3560
+ self.A_log = nn.Parameter(A_log)
3561
+ self.A_log._no_weight_decay = True
3562
+
3563
+ if config.mamba3_add_trapezoid:
3564
+ self.trapezoid_proj = DragonLinear(config, self.d_model, self.nheads, bias=False)
3565
 
3566
  _dt = torch.exp(
3567
  torch.rand(self.nheads) * (math.log(self.dt_max) - math.log(self.dt_min))
 
3570
  _dt = torch.clamp(_dt, min=self.dt_init_floor)
3571
  _dt_bias = _dt + torch.log(-torch.expm1(-_dt))
3572
  self.dt_bias = nn.Parameter(_dt_bias, requires_grad=True)
3573
+ self.dt_bias._no_weight_decay = True
3574
 
3575
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
3576
 
3577
+ self.B_bias, self.C_bias = None, None
3578
+ if not config.mamba3_remove_BC_bias:
3579
+ self.B_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
3580
+ self.C_bias = nn.Parameter(torch.ones((self.nheads, self.d_state)), requires_grad=True)
3581
 
3582
+ if config.mamba3_is_id_rms:
3583
+ self.B_norm = DragonNorm(config, self.d_state)
3584
+ self.C_norm = DragonNorm(config, self.d_state)
3585
 
3586
+ if not config.mamba3_remove_conv:
3587
  conv_dim = self.d_inner + 2 * self.d_state * self.ngroups
3588
  self.conv1d = nn.Conv1d(
3589
  in_channels=conv_dim,
3590
  out_channels=conv_dim,
3591
+ bias=False,
3592
  kernel_size=4,
3593
  groups=conv_dim,
3594
  )
 
3600
 
3601
  # D "skip" parameter
3602
  self.D = nn.Parameter(torch.ones(self.nheads))
3603
+ self.D._no_weight_decay = True
3604
 
3605
+ def forward(
3606
+ self,
3607
+ hidden_states: torch.Tensor,
3608
+ cache_params: Optional[HybridDragonDynamicCache] = None,
3609
+ **kwargs
3610
+ ):
3611
  # Apply in_proj
3612
  xBCdt = self.in_proj(hidden_states)
3613
  xBC, dd_dt = torch.split(
 
3618
  ],
3619
  dim=-1)
3620
 
3621
+ if self.config.mamba3_is_A_dd:
3622
+ _A = -F.softplus((self.A_proj(hidden_states.to(torch.float32))).to(torch.float32)) # (B, L, N)
3623
+ _A = torch.clamp(_A, max=-self.A_floor)
3624
+ else:
3625
+ _A = -torch.exp(self.A_log).unsqueeze(0).unsqueeze(0)
3626
  dt = F.softplus(dd_dt + self.dt_bias) # (B, L, N)
3627
 
3628
+ if not self.config.mamba3_remove_conv:
3629
  xBC = causal_conv1d_fn(
3630
  x=xBC.transpose(1, 2),
3631
  weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
3632
  bias=self.conv1d.bias,
3633
+ activation=self.activation,
3634
  ).transpose(1, 2) # (B, L, self.d_inner + 2 * ngroups * d_state)
3635
 
3636
  x, B, C = torch.split(
 
3643
  B = rearrange(B, "b l (g n) -> b l g n", g=self.ngroups)
3644
  C = rearrange(C, "b l (g n) -> b l g n", g=self.ngroups)
3645
 
3646
+ if self.config.mamba3_is_id_rms:
3647
+ B = self.B_norm(B)
3648
+ C = self.C_norm(C)
3649
 
3650
  if self.ngroups != self.nheads:
3651
  B = B.expand(-1, -1, self.nheads, -1) # (B, L, N, S)
3652
  C = C.expand(-1, -1, self.nheads, -1) # (B, L, N, S)
3653
 
3654
+ if self.config.mamba3_rope:
3655
+ angle = self.rope_proj(hidden_states) # (B, L, S)
3656
+ angle = angle.unsqueeze(-2).expand(-1, -1, self.nheads, -1) # (B, L, G, S)
3657
+ angle = angle_dt(angle, dt)
3658
 
3659
+ C, B, CB_sum = rotary_qk(q=C, k=B, angle=angle, bias_q=self.C_bias, bias_k=self.B_bias, conjugate=False, inplace=False)
3660
+ else:
3661
+ if not self.config.mamba3_remove_BC_bias:
3662
+ og_dtpe = B.dtype
3663
+ B = (B + self.B_bias).to(og_dtpe)
3664
+ C = (C + self.C_bias).to(og_dtpe)
3665
+
3666
+ CB_sum = torch.sum(
3667
+ B.to(torch.float32)*C.to(torch.float32),
3668
+ dim=-1,
3669
+ keepdim=False
3670
+ )
3671
 
3672
  x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim)
3673
 
3674
  A = _A * dt
3675
  gating_factor = dt # B, L, N
3676
 
3677
+ if self.config.mamba3_add_trapezoid:
3678
+ trap = F.sigmoid(self.trapezoid_proj(hidden_states)) # (B, L, N)
3679
 
3680
+ alpha_arr = torch.exp(A)
3681
+ beta_arr = (1-trap)*gating_factor*alpha_arr
3682
+ gamma_arr = trap*gating_factor
3683
 
3684
+ # roll alpha and beta to the left by 1
3685
+ _alpha_arr = torch.roll(alpha_arr, shifts=-1, dims=1)
3686
+ _beta_arr = torch.roll(beta_arr, shifts=-1, dims=1)
3687
 
3688
+ x_scalar = (gamma_arr*_alpha_arr + _beta_arr).to(torch.bfloat16)
3689
+ else:
3690
+ alpha_arr = torch.exp(A)
3691
+ beta_arr = torch.zeros_like(alpha_arr)
3692
+ gamma_arr = gating_factor
3693
+
3694
+ # roll alpha to the left by 1
3695
+ _alpha_arr = torch.roll(alpha_arr, shifts=-1, dims=1)
3696
+
3697
+ x_scalar = (gamma_arr*_alpha_arr).to(torch.bfloat16)
3698
+
3699
+ ssm_cache = None
3700
+ if cache_params is not None:
3701
+ ssm_cache = cache_params.ssm_caches[self.layer_idx]
3702
 
3703
+ out = mamba_chunk_scan_discretized_combined(
3704
  x=x.bfloat16(),
3705
  A=A,
3706
  B=B.bfloat16(),
 
3710
  gamma=gamma_arr,
3711
  CB_sum=CB_sum,
3712
  D=self.D,
3713
+ z=None,
3714
+ initial_states=ssm_cache,
3715
+ return_final_states=cache_params is not None,
3716
  )
3717
 
3718
+ if cache_params is not None:
3719
+ y, ssm_cache = out
3720
+ cache_params.ssm_caches[self.layer_idx] = ssm_cache
3721
+ else:
3722
+ y = out
3723
+
3724
  return y, None, None
3725
 
3726
+ class DragonMamba2(nn.Module):
3727
+ def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
3728
+ super().__init__()
3729
+ self.d_model = config.hidden_size
3730
+ self.d_state = 128
3731
+ self.expand = 2
3732
+ self.d_inner = self.expand * self.d_model
3733
+ self.headdim = 64
3734
+ self.ngroups = config.mamba_ngroups
3735
+ assert self.d_inner % self.headdim == 0
3736
+ self.nheads = self.d_inner // self.headdim
3737
+ self.layer_idx = layer_idx
3738
+
3739
+ # Order: [x, B, C, dt]
3740
+ d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
3741
+ self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
3742
+
3743
+ conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
3744
+ self.conv1d = nn.Conv1d(
3745
+ in_channels=conv_dim,
3746
+ out_channels=conv_dim,
3747
+ bias=False,
3748
+ kernel_size=4,
3749
+ groups=conv_dim,
3750
+ padding=4-1,
3751
+ )
3752
+ self.act = nn.SiLU()
3753
+
3754
+ # Initialize log dt bias
3755
+ dt_min=0.001
3756
+ dt_max=0.1
3757
+ dt_init_floor=1e-4
3758
+ dt_limit=(0.0, float("inf"))
3759
+ dt = torch.exp(torch.rand(self.nheads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
3760
+ dt = torch.clamp(dt, min=dt_init_floor)
3761
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
3762
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
3763
+ self.dt_bias = nn.Parameter(inv_dt)
3764
+ self.dt_bias._no_weight_decay = True
3765
+
3766
+ # A parameter
3767
+ A_init_range=(1, 16)
3768
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
3769
+ A = torch.empty(self.nheads, dtype=torch.float32).uniform_(*A_init_range)
3770
+ A_log = torch.log(A)
3771
+ self.A_log = nn.Parameter(A_log)
3772
+ self.A_log._no_weight_decay = True
3773
+
3774
+ # D "skip" parameter
3775
+ self.D = nn.Parameter(torch.ones(self.nheads))
3776
+ self.D._no_weight_decay = True
3777
+
3778
+ def forward(self, hidden_states, **kwargs):
3779
+ """
3780
+ u: (B, L, D)
3781
+ Returns: same shape as u
3782
+ """
3783
+ _, seqlen, _ = hidden_states.shape
3784
+
3785
+ zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj)
3786
+ A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
3787
+
3788
+ xBC, dt = torch.split(
3789
+ zxbcdt, [self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
3790
+ )
3791
+ dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
3792
+
3793
+ # 1D Convolution
3794
+ if causal_conv1d_fn is None:
3795
+ xBC = self.act(
3796
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
3797
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
3798
+ xBC = xBC[:, :seqlen, :]
3799
+ else:
3800
+ xBC = causal_conv1d_fn(
3801
+ x=xBC.transpose(1, 2),
3802
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
3803
+ bias=self.conv1d.bias,
3804
+ activation="swish",
3805
+ ).transpose(1, 2)
3806
+
3807
+ # Split into 3 main branches: X, B, C
3808
+ # These correspond to V, K, Q respectively in the SSM/attention duality
3809
+ x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
3810
+ y = mamba_chunk_scan_combined(
3811
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
3812
+ dt,
3813
+ A,
3814
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
3815
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
3816
+ chunk_size=256,
3817
+ D=self.D,
3818
+ z=None,
3819
+ seq_idx=None,
3820
+ initial_states=None,
3821
+ )
3822
+
3823
+ return y, None, None
3824
 
3825
  class DragonMamba3Mimo(nn.Module):
3826
  def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
 
3839
  self.conv_init = None
3840
  self.expand = 2
3841
  self.headdim = 128
3842
+ self.ngroups = config.mamba_ngroups
3843
  self.activation = "swish"
3844
  self.bias = False
3845
  self.conv_bias = True
 
3873
  # Order: [z, x, B, C, dt]
3874
  d_in_proj = 2 * self.d_inner + 2 * self.d_state * self.ngroups * self.mimo_dim + self.nheads
3875
 
3876
+ self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
3877
  self.trapezoid_proj = DragonLinear(config, self.d_model, self.nheads, bias=False)
3878
 
3879
  _dt = torch.exp(
 
3887
 
3888
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
3889
 
3890
+ self.B_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
3891
+ self.C_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
3892
+
3893
  self.B_norm = DragonNorm(config, self.d_state)
3894
  self.C_norm = DragonNorm(config, self.d_state)
3895
 
 
3924
 
3925
  def forward(self, hidden_states, **kwargs):
3926
  # Apply in_proj
3927
+ zxBCdt = self.in_proj(hidden_states)
3928
  z, xBC, dd_dt = torch.split(
3929
+ zxBCdt,
3930
  [
3931
  self.d_inner,
3932
  self.d_inner + 2 * self.d_state * self.ngroups * self.mimo_dim,
 
3988
  C = self.C_norm(C)
3989
 
3990
  if self.ngroups != self.nheads:
3991
+ B = B.expand(-1, -1, -1, self.nheads, -1) # (B, L, R, N, S)
3992
+ C = C.expand(-1, -1, -1, self.nheads, -1) # (B, L, R, N, S)
3993
+
3994
 
3995
  angle = self.rope_proj(hidden_states) # (B, L, S)
3996
  angle = angle.unsqueeze(-2).expand(-1, -1, self.nheads, -1) # (B, L, G, S)
3997
  angle = angle_dt(angle, dt)
3998
 
3999
+ C, B, CB_sum = mimo_rotary_qk(q=C, k=B, angle=angle, bias_q=self.C_bias, bias_k=self.B_bias, conjugate=False, inplace=False)
4000
 
4001
  x = rearrange(x, "b l r (h p) -> b l r h p", p=self.headdim)
4002
 
 
4017
 
4018
  z = rearrange(z, "b l r (h p) -> b l r h p", p=self.headdim)
4019
 
4020
+ y = mamba_mimo_chunk_scan_discretized_fused_combined(
4021
  x=x.bfloat16(),
4022
  A=A.bfloat16(),
4023
  B=B.bfloat16(),
 
4031
  )
4032
 
4033
  y = rearrange(y, "b l r h p -> b l r (h p)")
4034
+ #if seqlen_og is not None:
4035
+ # y = rearrange(y, "b l r d -> (b l) r d")
4036
 
4037
  # Perform MIMO down projection (mimo_rank*d_inner -> d_inner)
4038
  y = rearrange(y, "b l r d -> b l (r d)")
4039
  y = rearrange(y, "b l (g d) -> b l g d", g=self.mimo_dim*self.mimo_proj_block_order)
4040
  y = torch.einsum("blgd,drg->bldr", y, self.out_proj_mimo)
4041
  y = rearrange(y, "b l d r -> b l (d r)")
4042
+ y = rearrange(y, "b l (h d) -> b l h d", d=self.headdim)
4043
 
4044
  return y, None, None
4045
 
4046
  class DragonMLP(nn.Module):
4047
+ def __init__(self, config: DragonConfig, intermediate_size: Optional[int] = None):
4048
  super().__init__()
4049
  self.config = config
4050
+ intermediate_size = intermediate_size or config.intermediate_size
4051
  #print("previous MLP : ", PREVIOUS_MLP)
4052
  self.link_size = 16
4053
  self.mlp_linking = config.mlp_linking and PREVIOUS_MLP is not None
4054
  if self.mlp_linking:
4055
  self.previous_mlp = PREVIOUS_MLP
4056
+ self.fc_1 = DragonLinear(config, config.hidden_size, intermediate_size, bias=False)
4057
  self.lambda1 = nn.Parameter(torch.zeros(self.link_size)) # sigmoid->0.5
4058
  else :
4059
+ self.fc_1 = DragonLinear(config, config.hidden_size, intermediate_size, bias=False)
4060
+ self.fc_2 = DragonLinear(config, intermediate_size, config.hidden_size, bias=False)
4061
  self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
4062
 
4063
  def forward(self, hidden_states):
 
4075
  return hidden_states
4076
 
4077
  def get_mlp_link(self):
4078
+ mlp_link = self.mlp_link
4079
+ self.mlp_link = None
4080
+ return mlp_link
4081
+
4082
+ class DragonGatedMLP(nn.Module):
4083
+ def __init__(self, config: DragonConfig, intermediate_size: Optional[int] = None, num_active_experts: int = 1):
4084
+ super().__init__()
4085
+ self.config = config
4086
+ self.intermediate_size = intermediate_size
4087
+
4088
+ self.fc_1 = DragonLinear(config, config.hidden_size, num_active_experts*self.intermediate_size, bias=False)
4089
+ self.fc_2 = DragonLinear(config, num_active_experts*self.intermediate_size, config.hidden_size, bias=False)
4090
+ self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
4091
+
4092
+ def forward(self, hidden_states, gates):
4093
+ B, L, D = hidden_states.size()
4094
+ hidden_states = self.fc_1(hidden_states) # (B, L, E*D)
4095
+ hidden_states = self._2_sqrt_5 * F.relu(hidden_states).square().view(B, L, -1, self.intermediate_size) # (B, L, E, D)
4096
+ hidden_states = hidden_states * gates.unsqueeze(-1) # (B, L, E, D)
4097
+ hidden_states = self.fc_2(hidden_states.view(B, L, -1)) # (B, L, D)
4098
+ return hidden_states
4099
+
4100
+ class DragonMoE(nn.Module):
4101
+ def __init__(self, config: DragonConfig):
4102
+ super().__init__()
4103
+ self.config = config
4104
+ self.num_experts = config.moe_num_routed_experts
4105
+ self.routed_scaling_factor = config.moe_routed_scaling_factor
4106
+
4107
+ self.router = DragonLinear(config, config.hidden_size, self.num_experts, bias=False, dtype=torch.float32)
4108
+ self.experts = DragonGatedMLP(config, config.moe_routed_intermediate_size, self.num_experts)
4109
+ if config.moe_shared_intermediate_size > 0:
4110
+ self.shared_expert = DragonMLP(config, config.moe_shared_intermediate_size)
4111
+
4112
+ def forward(self, hidden_states):
4113
+ # compute gating score.
4114
+ weights = F.sigmoid(self.router(hidden_states.to(torch.float32))) # (B, L, experts)
4115
+ weights = weights / weights.sum(dim=-1, keepdim=True) # (B, L, experts)
4116
+ weights = (weights * self.routed_scaling_factor).to(hidden_states.dtype)
4117
+ # forward through (routed) experts.
4118
+ y = self.experts(hidden_states, weights) # (B, L, E, D)
4119
+ # forward through shared expert.
4120
+ if self.config.moe_shared_intermediate_size > 0:
4121
+ y = y + self.shared_expert(hidden_states)
4122
+ return y
4123
 
4124
  PREVIOUS_MLP = None
4125
  class DragonMonoBlock(GradientCheckpointingLayer):
 
4194
  head_dim = self.mixer.headdim
4195
  num_attention_heads = self.mixer.nheads
4196
  use_gate = config.gate_gdn
4197
+ elif layer_type == '2':
4198
+ self.mixer = DragonMamba2(config, layer_idx=layer_idx)
4199
+ head_dim = self.mixer.headdim
4200
+ num_attention_heads = self.mixer.nheads
4201
+ use_gate = config.gate_gdn
4202
+ elif layer_type == 'M':
4203
+ self.mixer = DragonMamba3Mimo(config, layer_idx=layer_idx)
4204
+ head_dim = self.mixer.headdim
4205
+ num_attention_heads = self.mixer.nheads
4206
+ use_gate = False # inside Mamba3Mimo
4207
  else:
4208
  raise ValueError(f"Unknown layer type: {layer_type}")
4209
 
 
4248
 
4249
  self.input_norm = DragonNorm(config, config.hidden_size)
4250
  self.postmixer_norm = DragonNorm(config, config.hidden_size)
4251
+ if not config.moe:
4252
+ self.mlp = DragonMLP(config)
4253
+ else:
4254
+ self.mlp = DragonMoE(config)
4255
  global PREVIOUS_MLP
4256
  PREVIOUS_MLP = self.mlp
4257
 
 
4267
  cache_position: Optional[torch.LongTensor] = None,
4268
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
4269
  key_value_last_layer: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
4270
+ cu_seqlens: Optional[torch.Tensor] = None,
4271
+ max_seqlen: Optional[int] = None,
4272
  **kwargs,
4273
  ):
4274
  # MIXER.
 
4280
  position_ids=position_ids,
4281
  cache_params=cache_params,
4282
  key_value_last_layer=key_value_last_layer,
4283
+ cu_seqlens=cu_seqlens,
4284
+ max_seqlen=max_seqlen,
4285
  ) # (B, L, E*D)
4286
  if self.use_gate:
4287
  if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
 
4459
  self.embedding = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
4460
  self.layers = nn.ModuleList([DragonBlock(config, layer_idx=i, layer_type=layer) if layer in ['l', 'r', 'd'] else DragonMonoBlock(config, layer_idx=i, layer_type=layer) for i, layer in enumerate(config.layers_config)])
4461
 
4462
+ if self.config.rope_type_global != '' or self.config.rope_type_local != '':
4463
+ self.rotary_emb = DragonRotaryEmbedding(config, head_dim=config.head_dim if config.head_dim else (config.expand_factor*config.hidden_size)//config.num_attention_heads, theta=config.rope_theta_local) # only for SWA
4464
+ else:
4465
+ self.rotary_emb = None
4466
+
4467
+ if self.config.final_norm:
4468
+ self.final_norm = DragonNorm(config, config.hidden_size)
4469
 
4470
  self.gradient_checkpointing = False
4471
  self.post_init()
 
4486
  cache_position: Optional[torch.LongTensor] = None,
4487
  output_hidden_states: Optional[bool] = None,
4488
  inputs_embeds: Optional[torch.FloatTensor] = None,
4489
+ cu_seqlens: Optional[torch.Tensor] = None,
4490
+ max_seqlen: Optional[int] = None,
4491
  **kwargs
4492
  ) -> DragonOutput:
4493
  B, L = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
 
4531
 
4532
  all_hidden_states = () if output_hidden_states else None
4533
 
4534
+ if self.rotary_emb is not None:
4535
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
4536
+ else:
4537
+ position_embeddings = None
4538
 
4539
  shared_kv = (None, None)
4540
  for block in self.layers:
 
4548
  cache_position=cache_position,
4549
  position_embeddings=position_embeddings,
4550
  key_value_last_layer=shared_kv,
4551
+ cu_seqlens=cu_seqlens,
4552
+ max_seqlen=max_seqlen,
4553
  **kwargs,
4554
  )
4555
  shared_kv = (last_k, last_v)
4556
 
4557
+ if self.config.final_norm:
4558
+ hidden_states = self.final_norm(hidden_states)
4559
 
4560
  if output_hidden_states:
4561
  all_hidden_states = all_hidden_states + (hidden_states,)
 
4588
  cache_position: Optional[torch.Tensor] = None,
4589
  output_hidden_states: Optional[bool] = None,
4590
  attention_mask: Optional[torch.Tensor] = None,
4591
+ just_loss: Optional[bool] = False,
4592
+ cu_seqlens: Optional[torch.Tensor] = None,
4593
+ max_seqlen: Optional[int] = None,
4594
  token_type_ids=None,
4595
  **kwargs,
4596
  ) -> DragonCausalLMOutput:
 
4605
  cache_position=cache_position,
4606
  inputs_embeds=inputs_embeds,
4607
  output_hidden_states=output_hidden_states,
4608
+ cu_seqlens=cu_seqlens,
4609
+ max_seqlen=max_seqlen,
4610
  **kwargs,
4611
  )
4612
 
 
4650
 
4651
  return DragonCausalLMOutput(
4652
  loss=loss,
4653
+ logits=logits if not just_loss else None,
4654
+ past_key_values=outputs.past_key_values if not just_loss else None,
4655
+ hidden_states=outputs.hidden_states if not just_loss else None,
4656
  )
4657
  DragonForCausalLM.register_for_auto_class("AutoModelForCausalLM")
4658
 
training_dragon.py CHANGED
@@ -35,8 +35,8 @@ class NanoArgs:
35
  head_dim: Optional[int] = None
36
  layers_config : str = 4*"lrdlr"
37
  expand_factor : int = 2 # expand factor for Mamba/Dragon
38
- rope_type_local: str = "rope" #p-rope
39
- rope_type_global: str = "rope" #p-rope
40
  rope_theta_local: float = 10000.0
41
  rope_theta_global: float = 0.0
42
  eps_rmsnorm: float = 1e-6
@@ -54,8 +54,18 @@ class NanoArgs:
54
  scalar_proj_as_hidden_matrix: bool = True
55
  normalization_type: str = "rmsnorm" # rmsnorm, seednorm
56
  seednorm_wd: bool = True
 
 
57
  mixer_gn: bool = True
58
  mlp_linking : bool = False
 
 
 
 
 
 
 
 
59
 
60
  # attention related
61
  n_kv_heads : int = 0
@@ -93,6 +103,14 @@ class NanoArgs:
93
  shrink_qk_gdn: int = 2
94
  kda_allow_neg_eigval: bool = False
95
  kda_num_v_heads: Optional[int] = None
 
 
 
 
 
 
 
 
96
 
97
  # optim
98
  optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
@@ -120,7 +138,9 @@ class NanoArgs:
120
 
121
  # data
122
  vocab_size: int = 50304
 
123
  sequence_length: int = 1024
 
124
  input_bin: Optional[str] = None
125
  input_val_bin: Optional[str] = None
126
 
@@ -138,6 +158,7 @@ class NanoArgs:
138
  load_optim: bool = True
139
  load_sched: bool = True
140
  compile: bool = True
 
141
 
142
  # used during training
143
  slw_window: int = 0
@@ -166,9 +187,11 @@ def _load_data_shard(filename):
166
  return tokens
167
 
168
  class DistributedDataLoader:
169
- def __init__(self, filename_pattern, B, T, process_rank, num_processes):
170
  self.process_rank = process_rank
171
  self.num_processes = num_processes
 
 
172
  self.B = B # micro batch size
173
  self.T = T
174
 
@@ -221,12 +244,32 @@ class DistributedDataLoader:
221
  x = torch.from_numpy(buf.reshape(B, T)) # inputs
222
  y = torch.from_numpy(buf.reshape(B, T)) # targets
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  # advance current position and load next shard if necessary
225
  self.current_position += B * T * self.num_processes
226
  if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
227
  self.advance()
228
 
229
- return x.cuda(), y.cuda()
230
 
231
  def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_lr_head, wd):
232
  groups, seen = [], set()
@@ -277,6 +320,11 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
277
 
278
  args = tyro.cli(NanoArgs)
279
 
 
 
 
 
 
280
  # set up DDP (distributed data parallel).
281
  assert torch.cuda.is_available()
282
  dist.init_process_group(
@@ -293,6 +341,8 @@ torch.cuda.set_device(device)
293
  print(f"using device: {device}")
294
  master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
295
  torch._dynamo.config.optimize_ddp=False
 
 
296
 
297
  # setup logging.
298
  resume_dir = None
@@ -363,16 +413,33 @@ if args.patch_level_training:
363
  assert args.batch_size % (B * ddp_world_size) == 0
364
  accumulation_steps = args.batch_size // (B * ddp_world_size)
365
 
 
 
366
  # load dataloaders.
367
  #if args.patch_level_training:
368
  # assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
369
- train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
370
- val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
371
  print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
372
  print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
373
 
374
  # load model.
375
  config_hf = DragonConfig(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  mla_kv_rank=args.mla_kv_rank,
377
  rope_gdn=args.rope_gdn,
378
  shrink_qk_da=args.shrink_qk_da,
@@ -402,6 +469,8 @@ config_hf = DragonConfig(
402
  zero_centered_gate=args.zero_centered_gate,
403
  zero_centered_gate_type=args.zero_centered_gate_type,
404
  scalable_softmax=args.scalable_softmax,
 
 
405
  resformer=args.resformer,
406
  gate_type=args.gate_type,
407
  gate_act=args.gate_act,
@@ -461,7 +530,7 @@ with torch.no_grad():
461
  # count params. (total & active)
462
  num_params = sum(p.numel() for p in model.parameters())
463
  """model.eval()
464
- x, y = train_loader.next_batch()
465
  with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
466
  model(input_ids=x[[0], [0]].unsqueeze(0)).logits.sum().backward()
467
  num_active = sum(p.grad.count_nonzero() for p in model.parameters() if p.grad is not None)
@@ -472,12 +541,16 @@ print0(f"number of total parameters: {num_params}")
472
 
473
  # DDP & compile.
474
  uncompiled_model = model
475
- model = torch.compile(model, dynamic=True) if args.compile else model
476
  model.train()
477
  model = DDP(model, device_ids=[ddp_local_rank], find_unused_parameters=args.resformer)
478
  raw_model = model.module
479
  ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
480
 
 
 
 
 
481
  # load optimizers & schedulers.
482
  if args.use_uscaling:
483
  #assert args.optim == "adamw", "uscaling is only supported with AdamW optimizer currently"
@@ -553,9 +626,7 @@ WARMUP_SKIP = 10
553
 
554
  # begin training.
555
  train_loader.reset()
556
- #tokenizer = transformers.AutoTokenizer.from_pretrained("openai-community/gpt2", use_fast=True) # for saving
557
- tokenizer = transformers.AutoTokenizer.from_pretrained("/leonardo_work/BOOST_LCustodi/script/training/temp/hf_models/gpt2", use_fast=True)
558
- x, y = train_loader.next_batch()
559
 
560
  for iter_ in range(start_iter, start_iter+args.total_iterations+1):
561
  last_iter = (iter_ == start_iter+args.total_iterations)
@@ -588,9 +659,9 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
588
  val_loss = torch.zeros((), device=device, dtype=torch.float32)
589
  for _ in range(args.val_iterations):
590
  for _ in range(accumulation_steps):
591
- inputs, targets = val_loader.next_batch()
592
  with ctx:
593
- val_loss += model(input_ids=inputs, labels=targets).loss.detach()
594
  val_loss /= args.val_iterations * accumulation_steps
595
  dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
596
  val_loss = val_loss.item()
@@ -641,10 +712,10 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
641
  for i in range(1, accumulation_steps+1):
642
  # forward pass.
643
  with ctx:
644
- loss = model(input_ids=x, labels=y).loss
645
  train_loss = loss.detach()
646
  # prepare next batch.
647
- x, y = train_loader.next_batch()
648
  # backward pass.
649
  if i < accumulation_steps:
650
  with model.no_sync():
 
35
  head_dim: Optional[int] = None
36
  layers_config : str = 4*"lrdlr"
37
  expand_factor : int = 2 # expand factor for Mamba/Dragon
38
+ rope_type_local: str = "" #p-rope
39
+ rope_type_global: str = "" #p-rope
40
  rope_theta_local: float = 10000.0
41
  rope_theta_global: float = 0.0
42
  eps_rmsnorm: float = 1e-6
 
54
  scalar_proj_as_hidden_matrix: bool = True
55
  normalization_type: str = "rmsnorm" # rmsnorm, seednorm
56
  seednorm_wd: bool = True
57
+ seednorm_type: int = 1
58
+ seednorm_rank: int = 1
59
  mixer_gn: bool = True
60
  mlp_linking : bool = False
61
+ final_norm: bool = True
62
+
63
+ # MoE
64
+ moe: bool = False
65
+ moe_num_routed_experts: int = 2
66
+ moe_routed_scaling_factor: float = 2.5
67
+ moe_routed_intermediate_size: int = 768
68
+ moe_shared_intermediate_size: int = 768
69
 
70
  # attention related
71
  n_kv_heads : int = 0
 
103
  shrink_qk_gdn: int = 2
104
  kda_allow_neg_eigval: bool = False
105
  kda_num_v_heads: Optional[int] = None
106
+ mamba_mimo_dim: Optional[int] = 2
107
+ mamba_ngroups: Optional[int] = 1
108
+ mamba3_rope: bool = True
109
+ mamba3_remove_BC_bias: bool = False
110
+ mamba3_is_id_rms: bool = True
111
+ mamba3_remove_conv: bool = True
112
+ mamba3_is_A_dd: bool = True
113
+ mamba3_add_trapezoid: bool = True
114
 
115
  # optim
116
  optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
 
138
 
139
  # data
140
  vocab_size: int = 50304
141
+ bos_id: int = 50256
142
  sequence_length: int = 1024
143
+ intra_doc_masking: bool = False
144
  input_bin: Optional[str] = None
145
  input_val_bin: Optional[str] = None
146
 
 
158
  load_optim: bool = True
159
  load_sched: bool = True
160
  compile: bool = True
161
+ compile_dynamic: bool = False
162
 
163
  # used during training
164
  slw_window: int = 0
 
187
  return tokens
188
 
189
  class DistributedDataLoader:
190
+ def __init__(self, filename_pattern, intra_doc_masking,B, T, process_rank, num_processes, bos_id):
191
  self.process_rank = process_rank
192
  self.num_processes = num_processes
193
+ self.intra_doc_masking = intra_doc_masking
194
+ self.bos_id = bos_id
195
  self.B = B # micro batch size
196
  self.T = T
197
 
 
244
  x = torch.from_numpy(buf.reshape(B, T)) # inputs
245
  y = torch.from_numpy(buf.reshape(B, T)) # targets
246
 
247
+ # compute cumulative document positions for intra-document masking
248
+ cu = None
249
+ maxlen = None
250
+ position_ids = None
251
+ if self.intra_doc_masking:
252
+ assert self.B == 1
253
+ starts = (x == self.bos_id).nonzero(as_tuple=True)[1].to(torch.long)
254
+ if starts.numel() == 0 or starts[0] != 0:
255
+ starts = torch.cat([torch.zeros(1, dtype=torch.long), starts])
256
+ ends = torch.cat([starts[1:], torch.tensor([x.numel()])])
257
+ seqlens = (ends - starts).to(torch.int32)
258
+ # cu_seqlens, max_seqlen.
259
+ cu = torch.cat([torch.zeros(1, dtype=torch.int32), seqlens.cumsum(0)]).cuda().to(torch.int32)
260
+ maxlen = int(seqlens.max())
261
+ # position_ids.
262
+ lengths = seqlens.to(torch.long)
263
+ starts_per_token = torch.repeat_interleave(starts.to(torch.long), lengths)
264
+ idx = torch.arange(T, device=x.device, dtype=torch.long)
265
+ position_ids = (idx - starts_per_token).unsqueeze(0)
266
+
267
  # advance current position and load next shard if necessary
268
  self.current_position += B * T * self.num_processes
269
  if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
270
  self.advance()
271
 
272
+ return x.cuda(), y.cuda(), cu, maxlen, position_ids
273
 
274
  def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_lr_head, wd):
275
  groups, seen = [], set()
 
320
 
321
  args = tyro.cli(NanoArgs)
322
 
323
+ if args.intra_doc_masking:
324
+ if args.device_batch_size != 1:
325
+ args.device_batch_size = 1
326
+ print("!!! Forcing device_batch_size to 1 for intra-document masking !!!")
327
+
328
  # set up DDP (distributed data parallel).
329
  assert torch.cuda.is_available()
330
  dist.init_process_group(
 
341
  print(f"using device: {device}")
342
  master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
343
  torch._dynamo.config.optimize_ddp=False
344
+ if args.compile_dynamic:
345
+ torch._dynamo.config.allow_unspec_int_on_nn_module=True
346
 
347
  # setup logging.
348
  resume_dir = None
 
413
  assert args.batch_size % (B * ddp_world_size) == 0
414
  accumulation_steps = args.batch_size // (B * ddp_world_size)
415
 
416
+ tokenizer = transformers.AutoTokenizer.from_pretrained("/leonardo_work/BOOST_LCustodi/script/training/temp/hf_models/gpt2", use_fast=True)
417
+
418
  # load dataloaders.
419
  #if args.patch_level_training:
420
  # assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
421
+ train_loader = DistributedDataLoader(args.input_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id)
422
+ val_loader = DistributedDataLoader(args.input_val_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id)
423
  print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
424
  print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
425
 
426
  # load model.
427
  config_hf = DragonConfig(
428
+ mamba3_rope=args.mamba3_rope,
429
+ mamba3_remove_BC_bias=args.mamba3_remove_BC_bias,
430
+ mamba3_is_id_rms=args.mamba3_is_id_rms,
431
+ mamba3_remove_conv=args.mamba3_remove_conv,
432
+ mamba3_is_A_dd=args.mamba3_is_A_dd,
433
+ mamba3_add_trapezoid=args.mamba3_add_trapezoid,
434
+ moe=args.moe,
435
+ moe_num_routed_experts=args.moe_num_routed_experts,
436
+ moe_routed_scaling_factor=args.moe_routed_scaling_factor,
437
+ moe_routed_intermediate_size=args.moe_routed_intermediate_size,
438
+ moe_shared_intermediate_size=args.moe_shared_intermediate_size,
439
+ intra_doc_masking=args.intra_doc_masking,
440
+ seednorm_rank=args.seednorm_rank,
441
+ seednorm_type=args.seednorm_type,
442
+ final_norm=args.final_norm,
443
  mla_kv_rank=args.mla_kv_rank,
444
  rope_gdn=args.rope_gdn,
445
  shrink_qk_da=args.shrink_qk_da,
 
469
  zero_centered_gate=args.zero_centered_gate,
470
  zero_centered_gate_type=args.zero_centered_gate_type,
471
  scalable_softmax=args.scalable_softmax,
472
+ mamba_mimo_dim=args.mamba_mimo_dim,
473
+ mamba_ngroups=args.mamba_ngroups,
474
  resformer=args.resformer,
475
  gate_type=args.gate_type,
476
  gate_act=args.gate_act,
 
530
  # count params. (total & active)
531
  num_params = sum(p.numel() for p in model.parameters())
532
  """model.eval()
533
+ x, y, _, _, _ = train_loader.next_batch()
534
  with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
535
  model(input_ids=x[[0], [0]].unsqueeze(0)).logits.sum().backward()
536
  num_active = sum(p.grad.count_nonzero() for p in model.parameters() if p.grad is not None)
 
541
 
542
  # DDP & compile.
543
  uncompiled_model = model
544
+ model = torch.compile(model, dynamic=args.compile_dynamic) if args.compile else model
545
  model.train()
546
  model = DDP(model, device_ids=[ddp_local_rank], find_unused_parameters=args.resformer)
547
  raw_model = model.module
548
  ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
549
 
550
+ if args.intra_doc_masking:
551
+ print0("!!! Using intra-document masking !!!")
552
+ print0("It is only compatible with GDN (conv+chunk), DA and GDTPA layers. For DA/GDTPA, kv shift is also compatible. All other config will not have intra-doc masking support!!")
553
+
554
  # load optimizers & schedulers.
555
  if args.use_uscaling:
556
  #assert args.optim == "adamw", "uscaling is only supported with AdamW optimizer currently"
 
626
 
627
  # begin training.
628
  train_loader.reset()
629
+ x, y, cu, maxlen, position_ids = train_loader.next_batch()
 
 
630
 
631
  for iter_ in range(start_iter, start_iter+args.total_iterations+1):
632
  last_iter = (iter_ == start_iter+args.total_iterations)
 
659
  val_loss = torch.zeros((), device=device, dtype=torch.float32)
660
  for _ in range(args.val_iterations):
661
  for _ in range(accumulation_steps):
662
+ inputs, targets, cu, maxlen, position_ids = val_loader.next_batch()
663
  with ctx:
664
+ val_loss += model(input_ids=inputs, labels=targets, just_loss=True, cu_seqlens=cu, max_seqlen=maxlen, position_ids=position_ids).loss.detach()
665
  val_loss /= args.val_iterations * accumulation_steps
666
  dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
667
  val_loss = val_loss.item()
 
712
  for i in range(1, accumulation_steps+1):
713
  # forward pass.
714
  with ctx:
715
+ loss = model(input_ids=x, labels=y, just_loss=True, cu_seqlens=cu, max_seqlen=maxlen, position_ids=position_ids).loss
716
  train_loss = loss.detach()
717
  # prepare next batch.
718
+ x, y, cu, maxlen, position_ids = train_loader.next_batch()
719
  # backward pass.
720
  if i < accumulation_steps:
721
  with model.no_sync():