Fredtt3 commited on
Commit
bbd9c60
·
verified ·
1 Parent(s): 3b5f344

Create configs_llada.py

Browse files
Files changed (1) hide show
  1. configs_llada.py +454 -0
configs_llada.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaDA configuration
3
+ """
4
+ from transformers import AutoConfig
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+ from enum import Enum
8
+ from os import PathLike
9
+ from typing import Union
10
+ from dataclasses import dataclass
11
+ from typing import (
12
+ Optional,
13
+ Union,
14
+ )
15
+
16
+
17
+ __all__ = [
18
+ "ActivationType",
19
+ "ActivationCheckpointingStrategy",
20
+ "BlockType",
21
+ "LayerNormType",
22
+ "InitFnType",
23
+ "ModelConfig",
24
+ ]
25
+
26
+ PathOrStr = Union[str, PathLike]
27
+
28
+
29
+ class StrEnum(str, Enum):
30
+ """
31
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
32
+ We include this here for compatibility with older version of Python.
33
+ """
34
+
35
+ def __str__(self) -> str:
36
+ return self.value
37
+
38
+ def __repr__(self) -> str:
39
+ return f"'{str(self)}'"
40
+
41
+
42
+ class LayerNormType(StrEnum):
43
+ default = "default"
44
+ """
45
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
46
+ """
47
+
48
+ low_precision = "low_precision"
49
+ """
50
+ A low-precision version of the default LayerNorm.
51
+ """
52
+
53
+ rms = "rms"
54
+ """
55
+ An RMSNorm implementation. When using ``torch.compile`` this is
56
+ probably the fastest implementation.
57
+ """
58
+
59
+ gemma_rms = "gemma_rms"
60
+ """
61
+ An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
62
+ probably the fastest implementation.
63
+ """
64
+
65
+ amd_compatible = "amd_compatible"
66
+ """
67
+ LayerNorm implemented manually to work around an issue with ROCm.
68
+ """
69
+
70
+
71
+ class ActivationType(StrEnum):
72
+ gelu = "gelu"
73
+ relu = "relu"
74
+ silu = "silu"
75
+ swiglu = "swiglu"
76
+
77
+
78
+ class BlockType(StrEnum):
79
+ sequential = "sequential"
80
+ parallel = "parallel"
81
+
82
+ llama = "llama"
83
+ """
84
+ A block similar to the sequential block with slightly different
85
+ implementations of operations like attention to imitate the behavior of Llama.
86
+ """
87
+
88
+
89
+ class InitFnType(StrEnum):
90
+ mitchell = "mitchell"
91
+ """
92
+ The strategy suggested to us by Mitchell Wortsman from UW.
93
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
94
+ on the size of the weights as well as the depth of the layer.
95
+ """
96
+
97
+ normal = "normal"
98
+ """
99
+ All weights are initialized from the same normal distribution.
100
+ """
101
+
102
+ kaiming_normal = "kaiming_normal"
103
+ """
104
+ All weights are initialized with the Kaiming method from a normal distribution.
105
+ Note this currently won't work with FSDP.
106
+ """
107
+
108
+ fan_in = "fan_in"
109
+ """
110
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
111
+ is the input dimensionality of the kernel.
112
+ """
113
+
114
+ full_megatron = "full_megatron"
115
+ """
116
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
117
+ """
118
+
119
+
120
+ @dataclass
121
+ class ModelConfig():
122
+ """
123
+ LLaDA (model) configuration.
124
+ """
125
+
126
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
127
+
128
+ d_model: int = 768
129
+ """
130
+ The hidden size of the model.
131
+ """
132
+
133
+ n_heads: int = 12
134
+ """
135
+ The number of self-attention heads.
136
+ """
137
+
138
+ n_kv_heads: Optional[int] = None
139
+ """
140
+ The number of heads to use for keys and values. Defaults to `n_heads`.
141
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
142
+ Set this to 1 for multi-query attention.
143
+ Set it to some in-between value for Llama2-style grouped query attention.
144
+ """
145
+
146
+ n_layers: int = 12
147
+ """
148
+ The number of layers/blocks.
149
+ """
150
+
151
+ mlp_ratio: int = 4
152
+ """
153
+ The ratio of the inner MLP dimensionality to ``d_model``.
154
+ This is only used when ``mlp_hidden_size`` is not set.
155
+ """
156
+
157
+ mlp_hidden_size: Optional[int] = None
158
+ """
159
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
160
+ """
161
+
162
+ activation_type: ActivationType = ActivationType.swiglu
163
+ """
164
+ The activation function to use within the MLP layers.
165
+ """
166
+
167
+ block_type: BlockType = BlockType.sequential
168
+ """
169
+ The transformer block implementation.
170
+ """
171
+
172
+ block_group_size: int = 1
173
+ """
174
+ The number of blocks to group together into a single parent block.
175
+ This has no affect on the number of parameters in the model and is only used to wrap groups
176
+ of blocks together with a single FSDP wrapper during training.
177
+ """
178
+
179
+ alibi: bool = False
180
+ """
181
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
182
+ """
183
+
184
+ alibi_bias_max: float = 8.0
185
+ """
186
+ Maximum absolute value of ALiBi bias.
187
+ """
188
+
189
+ rope: bool = False
190
+ """
191
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
192
+ """
193
+
194
+ rope_full_precision: bool = True
195
+ """
196
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
197
+ apply RoPE at the precision of the input.
198
+ """
199
+
200
+ flash_attention: bool = False
201
+ """
202
+ If ``True``, use ``FlashAttention``.
203
+ """
204
+
205
+ attention_dropout: float = 0.1
206
+ """
207
+ The dropout probability within the attention modules.
208
+ """
209
+
210
+ multi_query_attention: Optional[bool] = None
211
+ """
212
+ Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
213
+ and is more efficient during inference.
214
+ """
215
+
216
+ attention_layer_norm: bool = False
217
+ """
218
+ Apply layer norm to the keys and queries within the attention mechanism.
219
+ This can help stabilize training.
220
+ """
221
+
222
+ residual_dropout: float = 0.1
223
+ """
224
+ The dropout probability for the MLP and attention output within each block.
225
+ """
226
+
227
+ embedding_dropout: float = 0.1
228
+ """
229
+ The dropout probability for embeddings.
230
+ """
231
+
232
+ input_emb_norm: bool = False
233
+ """
234
+ An input hidden_states norm implementation by gemmma.
235
+ """
236
+
237
+ layer_norm_type: LayerNormType = LayerNormType.default
238
+ """
239
+ The layernorm implementation to use.
240
+ """
241
+
242
+ layer_norm_with_affine: bool = True
243
+ """
244
+ Whether to include bias and weight parameters for the layer norms.
245
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
246
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
247
+ to ``False``.
248
+ """
249
+
250
+ rms_norm_eps: float = 1e-05
251
+ """
252
+ The rms layernorm eps param.
253
+ """
254
+
255
+ attention_layer_norm_with_affine: bool = True
256
+ """
257
+ Toggle affine transform for the QK norms.
258
+ """
259
+
260
+ max_sequence_length: int = 1024
261
+ """
262
+ The maximum input sequence length supported by the model.
263
+ """
264
+
265
+ rope_theta: float = 10000.0
266
+ """
267
+ The rope base param.
268
+ """
269
+
270
+ include_qkv_bias: Optional[bool] = False
271
+ """
272
+ Whether or not to include bias parameters in qkv linear layers.
273
+ """
274
+
275
+ include_bias: bool = False
276
+ """
277
+ Whether or not to include bias parameters in linear layers.
278
+ In PaLM, they got rid of all bias terms because they found that large
279
+ models tend to have near 0 bias terms anyway.
280
+ """
281
+
282
+ bias_for_layer_norm: Optional[bool] = None
283
+ """
284
+ Whether or not to include bias parameters in layer norm.
285
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
286
+ layer norm.
287
+ When this is None (the default), it inherits the setting from include_bias.
288
+ """
289
+
290
+ scale_logits: bool = False
291
+ """
292
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
293
+ """
294
+
295
+ vocab_size: int = 50257
296
+ """
297
+ Vocabulary size of the model.
298
+ """
299
+
300
+ embedding_size: Optional[int] = 50304
301
+ """
302
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
303
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
304
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
305
+ substantially.
306
+ """
307
+
308
+ weight_tying: bool = True
309
+ """
310
+ Whether to tie output linear weights to the input embedding.
311
+ """
312
+
313
+ eos_token_id: int = 50256
314
+ """
315
+ The ID of the end-of-sentence special token.
316
+ """
317
+
318
+ pad_token_id: int = 50256
319
+ """
320
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
321
+ """
322
+
323
+ mask_token_id: Optional[int] = 50256
324
+ """
325
+ The ID of the token to use for mask token. Defaults to the ID of the EOS token.
326
+ """
327
+
328
+ init_device: Optional[str] = None
329
+ """
330
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
331
+ """
332
+
333
+ init_fn: InitFnType = InitFnType.normal
334
+ """
335
+ The weight initialization strategy.
336
+ """
337
+
338
+ init_std: float = 0.02
339
+ """
340
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
341
+ as "normal".
342
+ """
343
+
344
+ init_cutoff_factor: Optional[float] = None
345
+ """
346
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
347
+ as "normal". Setting this to None means values are not cutoff.
348
+ """
349
+
350
+ precision: Optional[str] = None
351
+ """
352
+ Precision used to train/evaluate with. You shouldn't set this directly.
353
+ See :data:`TrainConfig.precision` instead.
354
+ """
355
+
356
+ @property
357
+ def effective_n_kv_heads(self) -> int:
358
+ if self.n_kv_heads is None:
359
+ if self.multi_query_attention is True:
360
+ return 1
361
+ else:
362
+ return self.n_heads
363
+ else:
364
+ if self.multi_query_attention is None:
365
+ return self.n_kv_heads
366
+ if self.multi_query_attention:
367
+ n_kv_heads_should_be = 1
368
+ else:
369
+ n_kv_heads_should_be = self.n_heads
370
+ if self.n_kv_heads == n_kv_heads_should_be:
371
+ return n_kv_heads_should_be
372
+ else:
373
+ raise Exception(
374
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
375
+ )
376
+
377
+ class ActivationCheckpointingStrategy(StrEnum):
378
+ whole_layer = "whole_layer"
379
+ """
380
+ Checkpoint every transformer layer.
381
+ """
382
+
383
+ one_in_two = "one_in_two"
384
+ """
385
+ Checkpoint one in two transformer layers.
386
+ """
387
+
388
+ one_in_three = "one_in_three"
389
+ """
390
+ Checkpoint one in three transformer layers.
391
+ """
392
+
393
+ one_in_four = "one_in_four"
394
+ """
395
+ Checkpoint one in four transformer layers.
396
+ """
397
+
398
+ two_in_three = "two_in_three"
399
+ """
400
+ Checkpoint two out of every three transformer layers.
401
+ """
402
+
403
+ three_in_four = "three_in_four"
404
+ """
405
+ Checkpoint three out of four of every transformer layers.
406
+ """
407
+
408
+ four_in_five = "four_in_five"
409
+ """
410
+ Checkpoint four out of five of every transformer layers.
411
+ """
412
+
413
+ nine_in_ten = "nine_in_ten"
414
+ """
415
+ Checkpoint nine out of ten of every transformer layers.
416
+ """
417
+
418
+ fine_grained = "fine_grained"
419
+ """
420
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
421
+ """
422
+
423
+
424
+ class LLaDAConfig(PretrainedConfig):
425
+ model_type = "llada"
426
+ keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
427
+
428
+ def __init__(self, use_cache: bool = False, **kwargs):
429
+ model_config = ModelConfig()
430
+ all_kwargs = model_config.__dict__
431
+ all_kwargs.update(kwargs)
432
+ all_kwargs.update({"use_cache": use_cache})
433
+ all_kwargs.update(
434
+ {
435
+ "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
436
+ }
437
+ )
438
+ super().__init__(**all_kwargs)
439
+
440
+ @property
441
+ def num_attention_heads(self):
442
+ return self.n_heads
443
+
444
+ @property
445
+ def num_hidden_layers(self):
446
+ return self.n_layers
447
+
448
+ @property
449
+ def hidden_size(self):
450
+ return self.d_model
451
+
452
+
453
+ # Register the config class so that it is available for transformer pipelines, auto-loading etc.
454
+ AutoConfig.register("llada", LLaDAConfig)