QinFFF commited on
Commit
543fd2e
·
verified ·
1 Parent(s): 115faf4

Remove configuration_llada.py

Browse files
Files changed (1) hide show
  1. configuration_llada.py +0 -463
configuration_llada.py DELETED
@@ -1,463 +0,0 @@
1
- """
2
- LLaDA configuration
3
- """
4
- from transformers import AutoConfig, PretrainedConfig
5
-
6
- from enum import Enum
7
- from os import PathLike
8
- from typing import Union
9
- from dataclasses import asdict, dataclass, field
10
- from glob import glob
11
- from pathlib import Path
12
- from typing import (
13
- Any,
14
- Dict,
15
- Iterable,
16
- List,
17
- Optional,
18
- Tuple,
19
- Type,
20
- TypeVar,
21
- Union,
22
- cast,
23
- )
24
-
25
-
26
- __all__ = [
27
- "ActivationType",
28
- "ActivationCheckpointingStrategy",
29
- "BlockType",
30
- "LayerNormType",
31
- "InitFnType",
32
- "ModelConfig",
33
- ]
34
-
35
- PathOrStr = Union[str, PathLike]
36
-
37
-
38
- class StrEnum(str, Enum):
39
- """
40
- This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
41
- We include this here for compatibility with older version of Python.
42
- """
43
-
44
- def __str__(self) -> str:
45
- return self.value
46
-
47
- def __repr__(self) -> str:
48
- return f"'{str(self)}'"
49
-
50
-
51
- class LayerNormType(StrEnum):
52
- default = "default"
53
- """
54
- The default LayerNorm implementation, equivalent to PyTorch's built-in version.
55
- """
56
-
57
- low_precision = "low_precision"
58
- """
59
- A low-precision version of the default LayerNorm.
60
- """
61
-
62
- rms = "rms"
63
- """
64
- An RMSNorm implementation. When using ``torch.compile`` this is
65
- probably the fastest implementation.
66
- """
67
-
68
- gemma_rms = "gemma_rms"
69
- """
70
- An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
71
- probably the fastest implementation.
72
- """
73
-
74
- amd_compatible = "amd_compatible"
75
- """
76
- LayerNorm implemented manually to work around an issue with ROCm.
77
- """
78
-
79
-
80
- class ActivationType(StrEnum):
81
- gelu = "gelu"
82
- relu = "relu"
83
- silu = "silu"
84
- swiglu = "swiglu"
85
-
86
-
87
- class BlockType(StrEnum):
88
- sequential = "sequential"
89
- parallel = "parallel"
90
-
91
- llama = "llama"
92
- """
93
- A block similar to the sequential block with slightly different
94
- implementations of operations like attention to imitate the behavior of Llama.
95
- """
96
-
97
-
98
- class InitFnType(StrEnum):
99
- mitchell = "mitchell"
100
- """
101
- The strategy suggested to us by Mitchell Wortsman from UW.
102
- This uses a truncated normal distribution with an adaptive standard deviation that depends
103
- on the size of the weights as well as the depth of the layer.
104
- """
105
-
106
- normal = "normal"
107
- """
108
- All weights are initialized from the same normal distribution.
109
- """
110
-
111
- kaiming_normal = "kaiming_normal"
112
- """
113
- All weights are initialized with the Kaiming method from a normal distribution.
114
- Note this currently won't work with FSDP.
115
- """
116
-
117
- fan_in = "fan_in"
118
- """
119
- "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
120
- is the input dimensionality of the kernel.
121
- """
122
-
123
- full_megatron = "full_megatron"
124
- """
125
- This is what metaseq calls "full megatron init". It is the init used for Llama 2.
126
- """
127
-
128
-
129
- @dataclass
130
- class ModelConfig():
131
- """
132
- LLaDA (model) configuration.
133
- """
134
-
135
- # Note that the defaults for these attributes are equivalent to the base GPT2 model.
136
-
137
- d_model: int = 768
138
- """
139
- The hidden size of the model.
140
- """
141
-
142
- n_heads: int = 12
143
- """
144
- The number of self-attention heads.
145
- """
146
-
147
- n_kv_heads: Optional[int] = None
148
- """
149
- The number of heads to use for keys and values. Defaults to `n_heads`.
150
- Set this to ``None`` or ``n_heads`` for normal multi-head attention.
151
- Set this to 1 for multi-query attention.
152
- Set it to some in-between value for Llama2-style grouped query attention.
153
- """
154
-
155
- n_layers: int = 12
156
- """
157
- The number of layers/blocks.
158
- """
159
-
160
- mlp_ratio: int = 4
161
- """
162
- The ratio of the inner MLP dimensionality to ``d_model``.
163
- This is only used when ``mlp_hidden_size`` is not set.
164
- """
165
-
166
- mlp_hidden_size: Optional[int] = None
167
- """
168
- Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
169
- """
170
-
171
- activation_type: ActivationType = ActivationType.swiglu
172
- """
173
- The activation function to use within the MLP layers.
174
- """
175
-
176
- block_type: BlockType = BlockType.sequential
177
- """
178
- The transformer block implementation.
179
- """
180
-
181
- block_group_size: int = 1
182
- """
183
- The number of blocks to group together into a single parent block.
184
- This has no affect on the number of parameters in the model and is only used to wrap groups
185
- of blocks together with a single FSDP wrapper during training.
186
- """
187
-
188
- alibi: bool = False
189
- """
190
- If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
191
- """
192
-
193
- alibi_bias_max: float = 8.0
194
- """
195
- Maximum absolute value of ALiBi bias.
196
- """
197
-
198
- rope: bool = False
199
- """
200
- Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
201
- """
202
-
203
- rope_full_precision: bool = True
204
- """
205
- If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
206
- apply RoPE at the precision of the input.
207
- """
208
-
209
- flash_attention: bool = False
210
- """
211
- If ``True``, use ``FlashAttention``.
212
- """
213
-
214
- attention_dropout: float = 0.1
215
- """
216
- The dropout probability within the attention modules.
217
- """
218
-
219
- multi_query_attention: Optional[bool] = None
220
- """
221
- Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
222
- and is more efficient during inference.
223
- """
224
-
225
- attention_layer_norm: bool = False
226
- """
227
- Apply layer norm to the keys and queries within the attention mechanism.
228
- This can help stabilize training.
229
- """
230
-
231
- residual_dropout: float = 0.1
232
- """
233
- The dropout probability for the MLP and attention output within each block.
234
- """
235
-
236
- embedding_dropout: float = 0.1
237
- """
238
- The dropout probability for embeddings.
239
- """
240
-
241
- input_emb_norm: bool = False
242
- """
243
- An input hidden_states norm implementation by gemmma.
244
- """
245
-
246
- layer_norm_type: LayerNormType = LayerNormType.default
247
- """
248
- The layernorm implementation to use.
249
- """
250
-
251
- layer_norm_with_affine: bool = True
252
- """
253
- Whether to include bias and weight parameters for the layer norms.
254
- This only affects layer norms that are immediately followed by a linear layer in the forward pass,
255
- so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
256
- to ``False``.
257
- """
258
-
259
- rms_norm_eps: float = 1e-05
260
- """
261
- The rms layernorm eps param.
262
- """
263
-
264
- attention_layer_norm_with_affine: bool = True
265
- """
266
- Toggle affine transform for the QK norms.
267
- """
268
-
269
- max_sequence_length: int = 1024
270
- """
271
- The maximum input sequence length supported by the model.
272
- """
273
-
274
- rope_theta: float = 10000.0
275
- """
276
- The rope base param.
277
- """
278
-
279
- include_qkv_bias: Optional[bool] = False
280
- """
281
- Whether or not to include bias parameters in qkv linear layers.
282
- """
283
-
284
- include_bias: bool = False
285
- """
286
- Whether or not to include bias parameters in linear layers.
287
- In PaLM, they got rid of all bias terms because they found that large
288
- models tend to have near 0 bias terms anyway.
289
- """
290
-
291
- bias_for_layer_norm: Optional[bool] = None
292
- """
293
- Whether or not to include bias parameters in layer norm.
294
- This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
295
- layer norm.
296
- When this is None (the default), it inherits the setting from include_bias.
297
- """
298
-
299
- scale_logits: bool = False
300
- """
301
- If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
302
- """
303
-
304
- vocab_size: int = 50257
305
- """
306
- Vocabulary size of the model.
307
- """
308
-
309
- embedding_size: Optional[int] = 50304
310
- """
311
- The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
312
- to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
313
- next multiple of 128 that's greater than ``vocab_size`` can improve throughput
314
- substantially.
315
- """
316
-
317
- weight_tying: bool = True
318
- """
319
- Whether to tie output linear weights to the input embedding.
320
- """
321
-
322
- eos_token_id: int = 50256
323
- """
324
- The ID of the end-of-sentence special token.
325
- """
326
-
327
- pad_token_id: int = 50256
328
- """
329
- The ID of the token to use for padding. Defaults to the ID of the EOS token.
330
- """
331
-
332
- mask_token_id: Optional[int] = 50256
333
- """
334
- The ID of the token to use for mask token. Defaults to the ID of the EOS token.
335
- """
336
-
337
- init_device: Optional[str] = None
338
- """
339
- The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
340
- """
341
-
342
- init_fn: InitFnType = InitFnType.normal
343
- """
344
- The weight initialization strategy.
345
- """
346
-
347
- init_std: float = 0.02
348
- """
349
- The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
350
- as "normal".
351
- """
352
-
353
- init_cutoff_factor: Optional[float] = None
354
- """
355
- A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
356
- as "normal". Setting this to None means values are not cutoff.
357
- """
358
-
359
- precision: Optional[str] = None
360
- """
361
- Precision used to train/evaluate with. You shouldn't set this directly.
362
- See :data:`TrainConfig.precision` instead.
363
- """
364
-
365
- @property
366
- def effective_n_kv_heads(self) -> int:
367
- if self.n_kv_heads is None:
368
- if self.multi_query_attention is True:
369
- return 1
370
- else:
371
- return self.n_heads
372
- else:
373
- if self.multi_query_attention is None:
374
- return self.n_kv_heads
375
- if self.multi_query_attention:
376
- n_kv_heads_should_be = 1
377
- else:
378
- n_kv_heads_should_be = self.n_heads
379
- if self.n_kv_heads == n_kv_heads_should_be:
380
- return n_kv_heads_should_be
381
- else:
382
- raise Exception(
383
- "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
384
- )
385
-
386
- class ActivationCheckpointingStrategy(StrEnum):
387
- whole_layer = "whole_layer"
388
- """
389
- Checkpoint every transformer layer.
390
- """
391
-
392
- one_in_two = "one_in_two"
393
- """
394
- Checkpoint one in two transformer layers.
395
- """
396
-
397
- one_in_three = "one_in_three"
398
- """
399
- Checkpoint one in three transformer layers.
400
- """
401
-
402
- one_in_four = "one_in_four"
403
- """
404
- Checkpoint one in four transformer layers.
405
- """
406
-
407
- two_in_three = "two_in_three"
408
- """
409
- Checkpoint two out of every three transformer layers.
410
- """
411
-
412
- three_in_four = "three_in_four"
413
- """
414
- Checkpoint three out of four of every transformer layers.
415
- """
416
-
417
- four_in_five = "four_in_five"
418
- """
419
- Checkpoint four out of five of every transformer layers.
420
- """
421
-
422
- nine_in_ten = "nine_in_ten"
423
- """
424
- Checkpoint nine out of ten of every transformer layers.
425
- """
426
-
427
- fine_grained = "fine_grained"
428
- """
429
- Focus checkpointing on where it is cheap to recompute and saves most memory.
430
- """
431
-
432
-
433
- class LLaDAConfig(PretrainedConfig):
434
- model_type = "llada"
435
- keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
436
-
437
- def __init__(self, use_cache: bool = False, **kwargs):
438
- model_config = ModelConfig()
439
- all_kwargs = model_config.__dict__
440
- all_kwargs.update(kwargs)
441
- all_kwargs.update({"use_cache": use_cache})
442
- all_kwargs.update(
443
- {
444
- "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
445
- }
446
- )
447
- super().__init__(**all_kwargs)
448
-
449
- @property
450
- def num_attention_heads(self):
451
- return self.n_heads
452
-
453
- @property
454
- def num_hidden_layers(self):
455
- return self.n_layers
456
-
457
- @property
458
- def hidden_size(self):
459
- return self.d_model
460
-
461
-
462
- # Register the config class so that it is available for transformer pipelines, auto-loading etc.
463
- AutoConfig.register("llada", LLaDAConfig)