Midstream commited on
Commit
ab7792a
·
verified ·
1 Parent(s): 49403ff

Add smol-difussion-base checkpoint

Browse files
config.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_comment_activation": "MLP activation function aligned with ModernBert",
3
+ "_comment_attention": "Attention and RoPE settings for shorter context",
4
+ "_comment_bias": "Bias settings for Linear layers aligned with ModernBert",
5
+ "_comment_dropout": "Dropout rates aligned with ModernBert",
6
+ "_comment_initialization": "Initialization scheme aligned with ModernBert",
7
+ "_comment_misc": "Other settings for decoder-style model",
8
+ "_comment_normalization": "LayerNorm settings aligned with ModernBert",
9
+ "activation_type": "silu",
10
+ "alibi": false,
11
+ "alibi_bias_max": 8.0,
12
+ "architectures": [
13
+ "LLaDAModelLM"
14
+ ],
15
+ "attention_bias": false,
16
+ "attention_dropout": 0.0,
17
+ "attention_layer_norm": false,
18
+ "attention_layer_norm_with_affine": true,
19
+ "auto_map": {
20
+ "AutoConfig": "configuration_llada.LLaDAConfig",
21
+ "AutoModel": "modeling_llada.LLaDAModelLM"
22
+ },
23
+ "bias_for_layer_norm": false,
24
+ "block_group_size": 1,
25
+ "block_type": "llama",
26
+ "bos_token_id": 0,
27
+ "d_model": 576,
28
+ "dtype": "float32",
29
+ "embedding_dropout": 0.0,
30
+ "embedding_size": 50368,
31
+ "eos_token_id": 0,
32
+ "flash_attention": false,
33
+ "include_bias": false,
34
+ "include_qkv_bias": false,
35
+ "init_cutoff_factor": 2.0,
36
+ "init_device": "meta",
37
+ "init_fn": "full_megatron",
38
+ "init_std": 0.02,
39
+ "input_emb_norm": false,
40
+ "is_llama_config": true,
41
+ "layer_norm_type": "default",
42
+ "layer_norm_with_affine": true,
43
+ "loss_normalization": "masked_tokens",
44
+ "mask_token_id": 50256,
45
+ "max_position_embeddings": 8192,
46
+ "max_sequence_length": 2048,
47
+ "mlp_hidden_size": 1536,
48
+ "mlp_ratio": 4,
49
+ "model_type": "llada",
50
+ "multi_query_attention": null,
51
+ "n_heads": 9,
52
+ "n_kv_heads": 3,
53
+ "n_layers": 30,
54
+ "pad_token_id": 50283,
55
+ "precision": null,
56
+ "pretraining_tp": 1,
57
+ "residual_dropout": 0.0,
58
+ "rms_norm_eps": 1e-05,
59
+ "rope": true,
60
+ "rope_full_precision": true,
61
+ "rope_interleaved": false,
62
+ "rope_scaling": null,
63
+ "rope_theta": 100000.0,
64
+ "scale_logits": false,
65
+ "transformers_version": "4.57.1",
66
+ "use_cache": false,
67
+ "vocab_size": 50368,
68
+ "weight_tying": true
69
+ }
configuration_llada.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaDA configuration
3
+ """
4
+ from transformers import AutoConfig, PretrainedConfig
5
+
6
+ from enum import Enum
7
+ from os import PathLike
8
+ from typing import Union
9
+ from dataclasses import asdict, dataclass, field
10
+ from glob import glob
11
+ from pathlib import Path
12
+ from typing import (
13
+ Any,
14
+ Dict,
15
+ Iterable,
16
+ List,
17
+ Optional,
18
+ Tuple,
19
+ Type,
20
+ TypeVar,
21
+ Union,
22
+ cast,
23
+ )
24
+
25
+
26
+ __all__ = [
27
+ "ActivationType",
28
+ "ActivationCheckpointingStrategy",
29
+ "BlockType",
30
+ "LayerNormType",
31
+ "InitFnType",
32
+ "ModelConfig",
33
+ ]
34
+
35
+ PathOrStr = Union[str, PathLike]
36
+
37
+
38
+ class StrEnum(str, Enum):
39
+ """
40
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
41
+ We include this here for compatibility with older version of Python.
42
+ """
43
+
44
+ def __str__(self) -> str:
45
+ return self.value
46
+
47
+ def __repr__(self) -> str:
48
+ return f"'{str(self)}'"
49
+
50
+
51
+ class LayerNormType(StrEnum):
52
+ default = "default"
53
+ """
54
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
55
+ """
56
+
57
+ low_precision = "low_precision"
58
+ """
59
+ A low-precision version of the default LayerNorm.
60
+ """
61
+
62
+ rms = "rms"
63
+ """
64
+ An RMSNorm implementation. When using ``torch.compile`` this is
65
+ probably the fastest implementation.
66
+ """
67
+
68
+ gemma_rms = "gemma_rms"
69
+ """
70
+ An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
71
+ probably the fastest implementation.
72
+ """
73
+
74
+ amd_compatible = "amd_compatible"
75
+ """
76
+ LayerNorm implemented manually to work around an issue with ROCm.
77
+ """
78
+
79
+
80
+ class ActivationType(StrEnum):
81
+ gelu = "gelu"
82
+ relu = "relu"
83
+ silu = "silu"
84
+ swiglu = "swiglu"
85
+
86
+
87
+ class BlockType(StrEnum):
88
+ sequential = "sequential"
89
+ parallel = "parallel"
90
+
91
+ llama = "llama"
92
+ """
93
+ A block similar to the sequential block with slightly different
94
+ implementations of operations like attention to imitate the behavior of Llama.
95
+ """
96
+
97
+
98
+ class InitFnType(StrEnum):
99
+ mitchell = "mitchell"
100
+ """
101
+ The strategy suggested to us by Mitchell Wortsman from UW.
102
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
103
+ on the size of the weights as well as the depth of the layer.
104
+ """
105
+
106
+ normal = "normal"
107
+ """
108
+ All weights are initialized from the same normal distribution.
109
+ """
110
+
111
+ kaiming_normal = "kaiming_normal"
112
+ """
113
+ All weights are initialized with the Kaiming method from a normal distribution.
114
+ Note this currently won't work with FSDP.
115
+ """
116
+
117
+ fan_in = "fan_in"
118
+ """
119
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
120
+ is the input dimensionality of the kernel.
121
+ """
122
+
123
+ full_megatron = "full_megatron"
124
+ """
125
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
126
+ """
127
+
128
+
129
+ @dataclass
130
+ class ModelConfig():
131
+ """
132
+ LLaDA (model) configuration.
133
+ """
134
+
135
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
136
+
137
+ d_model: int = 768
138
+ """
139
+ The hidden size of the model.
140
+ """
141
+
142
+ n_heads: int = 12
143
+ """
144
+ The number of self-attention heads.
145
+ """
146
+
147
+ n_kv_heads: Optional[int] = None
148
+ """
149
+ The number of heads to use for keys and values. Defaults to `n_heads`.
150
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
151
+ Set this to 1 for multi-query attention.
152
+ Set it to some in-between value for Llama2-style grouped query attention.
153
+ """
154
+
155
+ n_layers: int = 12
156
+ """
157
+ The number of layers/blocks.
158
+ """
159
+
160
+ mlp_ratio: int = 4
161
+ """
162
+ The ratio of the inner MLP dimensionality to ``d_model``.
163
+ This is only used when ``mlp_hidden_size`` is not set.
164
+ """
165
+
166
+ mlp_hidden_size: Optional[int] = None
167
+ """
168
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
169
+ """
170
+
171
+ activation_type: ActivationType = ActivationType.swiglu
172
+ """
173
+ The activation function to use within the MLP layers.
174
+ """
175
+
176
+ block_type: BlockType = BlockType.sequential
177
+ """
178
+ The transformer block implementation.
179
+ """
180
+
181
+ block_group_size: int = 1
182
+ """
183
+ The number of blocks to group together into a single parent block.
184
+ This has no affect on the number of parameters in the model and is only used to wrap groups
185
+ of blocks together with a single FSDP wrapper during training.
186
+ """
187
+
188
+ alibi: bool = False
189
+ """
190
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
191
+ """
192
+
193
+ alibi_bias_max: float = 8.0
194
+ """
195
+ Maximum absolute value of ALiBi bias.
196
+ """
197
+
198
+ rope: bool = False
199
+ """
200
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
201
+ """
202
+
203
+ rope_full_precision: bool = True
204
+ """
205
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
206
+ apply RoPE at the precision of the input.
207
+ """
208
+
209
+ flash_attention: bool = False
210
+ """
211
+ If ``True``, use ``FlashAttention``.
212
+ """
213
+
214
+ attention_dropout: float = 0.1
215
+ """
216
+ The dropout probability within the attention modules.
217
+ """
218
+
219
+ multi_query_attention: Optional[bool] = None
220
+ """
221
+ Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
222
+ and is more efficient during inference.
223
+ """
224
+
225
+ attention_layer_norm: bool = False
226
+ """
227
+ Apply layer norm to the keys and queries within the attention mechanism.
228
+ This can help stabilize training.
229
+ """
230
+
231
+ residual_dropout: float = 0.1
232
+ """
233
+ The dropout probability for the MLP and attention output within each block.
234
+ """
235
+
236
+ embedding_dropout: float = 0.1
237
+ """
238
+ The dropout probability for embeddings.
239
+ """
240
+
241
+ input_emb_norm: bool = False
242
+ """
243
+ An input hidden_states norm implementation by gemmma.
244
+ """
245
+
246
+ layer_norm_type: LayerNormType = LayerNormType.default
247
+ """
248
+ The layernorm implementation to use.
249
+ """
250
+
251
+ layer_norm_with_affine: bool = True
252
+ """
253
+ Whether to include bias and weight parameters for the layer norms.
254
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
255
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
256
+ to ``False``.
257
+ """
258
+
259
+ rms_norm_eps: float = 1e-05
260
+ """
261
+ The rms layernorm eps param.
262
+ """
263
+
264
+ attention_layer_norm_with_affine: bool = True
265
+ """
266
+ Toggle affine transform for the QK norms.
267
+ """
268
+
269
+ max_sequence_length: int = 1024
270
+ """
271
+ The maximum input sequence length supported by the model.
272
+ """
273
+
274
+ rope_theta: float = 10000.0
275
+ """
276
+ The rope base param.
277
+ """
278
+
279
+ include_qkv_bias: Optional[bool] = False
280
+ """
281
+ Whether or not to include bias parameters in qkv linear layers.
282
+ """
283
+
284
+ include_bias: bool = False
285
+ """
286
+ Whether or not to include bias parameters in linear layers.
287
+ In PaLM, they got rid of all bias terms because they found that large
288
+ models tend to have near 0 bias terms anyway.
289
+ """
290
+
291
+ bias_for_layer_norm: Optional[bool] = None
292
+ """
293
+ Whether or not to include bias parameters in layer norm.
294
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
295
+ layer norm.
296
+ When this is None (the default), it inherits the setting from include_bias.
297
+ """
298
+
299
+ scale_logits: bool = False
300
+ """
301
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
302
+ """
303
+
304
+ vocab_size: int = 50257
305
+ """
306
+ Vocabulary size of the model.
307
+ """
308
+
309
+ embedding_size: Optional[int] = 50304
310
+ """
311
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
312
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
313
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
314
+ substantially.
315
+ """
316
+
317
+ loss_normalization: str = "masked_tokens"
318
+ """
319
+ The normalization method for the loss.
320
+ "masked_tokens": Normalize by the number of masked tokens (default, old behavior).
321
+ "total_tokens": Normalize by the total number of tokens in the batch (official behavior).
322
+ """
323
+
324
+ weight_tying: bool = True
325
+ """
326
+ Whether to tie output linear weights to the input embedding.
327
+ """
328
+
329
+ eos_token_id: int = 50256
330
+ """
331
+ The ID of the end-of-sentence special token.
332
+ """
333
+
334
+ pad_token_id: int = 50256
335
+ """
336
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
337
+ """
338
+
339
+ mask_token_id: Optional[int] = 50256
340
+ """
341
+ The ID of the token to use for mask token. Defaults to the ID of the EOS token.
342
+ """
343
+
344
+ init_device: Optional[str] = None
345
+ """
346
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
347
+ """
348
+
349
+ init_fn: InitFnType = InitFnType.normal
350
+ """
351
+ The weight initialization strategy.
352
+ """
353
+
354
+ init_std: float = 0.02
355
+ """
356
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
357
+ as "normal".
358
+ """
359
+
360
+ init_cutoff_factor: Optional[float] = None
361
+ """
362
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
363
+ as "normal". Setting this to None means values are not cutoff.
364
+ """
365
+
366
+ precision: Optional[str] = None
367
+ """
368
+ Precision used to train/evaluate with. You shouldn't set this directly.
369
+ See :data:`TrainConfig.precision` instead.
370
+ """
371
+
372
+ @property
373
+ def effective_n_kv_heads(self) -> int:
374
+ if self.n_kv_heads is None:
375
+ if self.multi_query_attention is True:
376
+ return 1
377
+ else:
378
+ return self.n_heads
379
+ else:
380
+ if self.multi_query_attention is None:
381
+ return self.n_kv_heads
382
+ if self.multi_query_attention:
383
+ n_kv_heads_should_be = 1
384
+ else:
385
+ n_kv_heads_should_be = self.n_heads
386
+ if self.n_kv_heads == n_kv_heads_should_be:
387
+ return n_kv_heads_should_be
388
+ else:
389
+ raise Exception(
390
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
391
+ )
392
+
393
+ class ActivationCheckpointingStrategy(StrEnum):
394
+ whole_layer = "whole_layer"
395
+ """
396
+ Checkpoint every transformer layer.
397
+ """
398
+
399
+ one_in_two = "one_in_two"
400
+ """
401
+ Checkpoint one in two transformer layers.
402
+ """
403
+
404
+ one_in_three = "one_in_three"
405
+ """
406
+ Checkpoint one in three transformer layers.
407
+ """
408
+
409
+ one_in_four = "one_in_four"
410
+ """
411
+ Checkpoint one in four transformer layers.
412
+ """
413
+
414
+ two_in_three = "two_in_three"
415
+ """
416
+ Checkpoint two out of every three transformer layers.
417
+ """
418
+
419
+ three_in_four = "three_in_four"
420
+ """
421
+ Checkpoint three out of four of every transformer layers.
422
+ """
423
+
424
+ four_in_five = "four_in_five"
425
+ """
426
+ Checkpoint four out of five of every transformer layers.
427
+ """
428
+
429
+ nine_in_ten = "nine_in_ten"
430
+ """
431
+ Checkpoint nine out of ten of every transformer layers.
432
+ """
433
+
434
+ fine_grained = "fine_grained"
435
+ """
436
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
437
+ """
438
+
439
+
440
+ class LLaDAConfig(PretrainedConfig):
441
+ model_type = "llada"
442
+ keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
443
+
444
+ def __init__(self, use_cache: bool = False, **kwargs):
445
+ model_config = ModelConfig()
446
+ all_kwargs = model_config.__dict__
447
+ all_kwargs.update(kwargs)
448
+ all_kwargs.update({"use_cache": use_cache})
449
+ all_kwargs.update(
450
+ {
451
+ "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
452
+ }
453
+ )
454
+ super().__init__(**all_kwargs)
455
+
456
+ @property
457
+ def num_attention_heads(self):
458
+ return self.n_heads
459
+
460
+ @property
461
+ def num_hidden_layers(self):
462
+ return self.n_layers
463
+
464
+ @property
465
+ def hidden_size(self):
466
+ return self.d_model
467
+
468
+
469
+ # Register the config class so that it is available for transformer pipelines, auto-loading etc.
470
+ AutoConfig.register("llada", LLaDAConfig)
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "pad_token_id": 50283,
6
+ "transformers_version": "4.57.1",
7
+ "use_cache": false
8
+ }
modeling_llada.py ADDED
@@ -0,0 +1,1890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ import sys
6
+ from abc import abstractmethod
7
+ from collections import defaultdict
8
+ from functools import partial
9
+ from typing import (
10
+ Callable,
11
+ Dict,
12
+ Iterable,
13
+ List,
14
+ NamedTuple,
15
+ Optional,
16
+ Sequence,
17
+ Set,
18
+ Tuple,
19
+ cast,
20
+ )
21
+ from dataclasses import dataclass, fields
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.backends.cuda
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch import einsum
29
+ from transformers import PreTrainedModel
30
+ from transformers.modeling_outputs import CausalLMOutputWithPast
31
+ from transformers.models.auto import AutoModel
32
+ from transformers.cache_utils import Cache
33
+
34
+ from .configuration_llada import (
35
+ LLaDAConfig,
36
+ StrEnum,
37
+ InitFnType,
38
+ ActivationType,
39
+ BlockType,
40
+ LayerNormType,
41
+ ModelConfig,
42
+ ActivationCheckpointingStrategy,
43
+ )
44
+
45
+ if sys.version_info.minor > 8:
46
+ from collections.abc import MutableMapping
47
+ elif sys.version_info.minor == 8:
48
+ from typing import MutableMapping
49
+ else:
50
+ raise SystemExit("This script supports Python 3.8 or higher")
51
+
52
+ __all__ = [
53
+ "LayerNormBase",
54
+ "LayerNorm",
55
+ "RMSLayerNorm",
56
+ "GemmaRMSLayerNorm",
57
+ "RotaryEmbedding",
58
+ "Activation",
59
+ "GELU",
60
+ "ReLU",
61
+ "SwiGLU",
62
+ "LLaDABlock",
63
+ "LLaDASequentialBlock",
64
+ "LLaDAModel",
65
+ "LLaDAOutput",
66
+ "LLaDAGenerateOutput",
67
+ ]
68
+
69
+
70
+ log = logging.getLogger(__name__)
71
+
72
+
73
+ logger = logging.getLogger(__name__)
74
+ logger.setLevel(logging.INFO)
75
+
76
+
77
+ if not logger.hasHandlers():
78
+ handler = logging.StreamHandler()
79
+ formatter = logging.Formatter('[%(asctime)s][%(levelname)s][%(name)s] %(message)s')
80
+ handler.setFormatter(formatter)
81
+ logger.addHandler(handler)
82
+
83
+ class ModuleType(StrEnum):
84
+ in_module = "in"
85
+ out_module = "out"
86
+ emb = "emb"
87
+ final_out = "final_out"
88
+
89
+
90
+ def init_weights(
91
+ config: ModelConfig,
92
+ module: Union[nn.Linear, nn.Embedding],
93
+ d: Optional[int] = None,
94
+ layer_id: Optional[int] = None,
95
+ std_factor: float = 1.0,
96
+ type_of_module: Optional[ModuleType] = None,
97
+ ) -> None:
98
+ """
99
+ Initialize weights of a linear or embedding module.
100
+ :param config: The model config.
101
+ :param module: The linear or embedding submodule to initialize.
102
+ :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
103
+ for fused layers.
104
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
105
+ ``1 / sqrt(2 * (layer_id + 1))``.
106
+ """
107
+
108
+
109
+ d = d if d is not None else config.d_model
110
+ if config.init_fn == InitFnType.normal:
111
+ std = config.init_std * std_factor
112
+ if config.init_cutoff_factor is not None:
113
+ cutoff_value = config.init_cutoff_factor * std
114
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
115
+ else:
116
+ nn.init.normal_(module.weight, mean=0.0, std=std)
117
+ elif config.init_fn == InitFnType.mitchell:
118
+ std = std_factor / math.sqrt(d)
119
+ if layer_id is not None:
120
+ std = std / math.sqrt(2 * (layer_id + 1))
121
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
122
+ elif config.init_fn == InitFnType.kaiming_normal:
123
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
124
+ elif config.init_fn == InitFnType.fan_in:
125
+ std = std_factor / math.sqrt(d)
126
+ nn.init.normal_(module.weight, mean=0.0, std=std)
127
+ elif config.init_fn == InitFnType.full_megatron:
128
+
129
+ if type_of_module is None:
130
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
131
+
132
+ cutoff_factor = config.init_cutoff_factor
133
+ if cutoff_factor is None:
134
+ cutoff_factor = 3
135
+
136
+ if type_of_module == ModuleType.in_module:
137
+ # for att_proj (same as QKV), ff_proj
138
+ std = config.init_std
139
+ elif type_of_module == ModuleType.out_module:
140
+ # for attn_out, ff_out
141
+ std = config.init_std / math.sqrt(2.0 * config.n_layers)
142
+ elif type_of_module == ModuleType.emb:
143
+ # positional embeddings (wpe)
144
+ # token embeddings (wte)
145
+ std = config.init_std
146
+ elif type_of_module == ModuleType.final_out:
147
+ # final output (ff_out)
148
+ std = config.d_model**-0.5
149
+ else:
150
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
151
+ nn.init.trunc_normal_(
152
+ module.weight,
153
+ mean=0.0,
154
+ std=std,
155
+ a=-cutoff_factor * std,
156
+ b=cutoff_factor * std,
157
+ )
158
+ else:
159
+ raise NotImplementedError(config.init_fn)
160
+
161
+ if isinstance(module, nn.Linear):
162
+ if module.bias is not None:
163
+ nn.init.zeros_(module.bias)
164
+
165
+ if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
166
+ with torch.no_grad():
167
+ module.weight.div_(math.sqrt(2 * config.n_layers))
168
+
169
+
170
+ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
171
+ """
172
+ Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
173
+ is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
174
+ """
175
+ if check_neg_inf:
176
+ x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
177
+ if check_pos_inf:
178
+ x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
179
+
180
+
181
+ def activation_checkpoint_function(cfg: ModelConfig):
182
+ preserve_rng_state = (
183
+ (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
184
+ )
185
+ from torch.utils.checkpoint import checkpoint
186
+
187
+ return partial(
188
+ checkpoint,
189
+ preserve_rng_state=preserve_rng_state,
190
+ use_reentrant=False,
191
+ )
192
+
193
+
194
+ class BufferCache(dict, MutableMapping[str, torch.Tensor]):
195
+ """
196
+ Cache for attention biases and other things that would normally be stored as buffers.
197
+ We avoid using buffers because we've run into various issues doing so with FSDP.
198
+ In general it appears the way FSDP handles buffers is not well-defined.
199
+ It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
200
+ since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
201
+ NaNs when they're synchronized due to casting or some other issue.
202
+ """
203
+
204
+
205
+ def _non_meta_init_device(config: ModelConfig) -> torch.device:
206
+ if config.init_device is not None and config.init_device != "meta":
207
+ return torch.device(config.init_device)
208
+ else:
209
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
210
+
211
+
212
+ class Dropout(nn.Dropout):
213
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
214
+ if self.p == 0.0:
215
+ return input
216
+ else:
217
+ return F.dropout(input, self.p, self.training, self.inplace)
218
+
219
+
220
+ class LayerNormBase(nn.Module):
221
+ def __init__(
222
+ self,
223
+ config: ModelConfig,
224
+ *,
225
+ size: Optional[int] = None,
226
+ elementwise_affine: Optional[bool] = True,
227
+ eps: float = 1e-05,
228
+ ):
229
+ super().__init__()
230
+ self.config = config
231
+ self.eps = eps
232
+ self.normalized_shape = (size or config.d_model,)
233
+ if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
234
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
235
+ use_bias = self.config.bias_for_layer_norm
236
+ if use_bias is None:
237
+ use_bias = self.config.include_bias
238
+ if use_bias:
239
+ self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
240
+ else:
241
+ self.register_parameter("bias", None)
242
+ else:
243
+ self.register_parameter("bias", None)
244
+ self.register_parameter("weight", None)
245
+
246
+ @abstractmethod
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ raise NotImplementedError
249
+
250
+ @classmethod
251
+ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
252
+ if config.layer_norm_type == LayerNormType.default:
253
+ return LayerNorm(config, size=size, low_precision=False, **kwargs)
254
+ elif config.layer_norm_type == LayerNormType.low_precision:
255
+ return LayerNorm(config, size=size, low_precision=True, **kwargs)
256
+ elif config.layer_norm_type == LayerNormType.rms:
257
+ return RMSLayerNorm(config, size=size, **kwargs)
258
+ elif config.layer_norm_type == LayerNormType.gemma_rms:
259
+ return GemmaRMSLayerNorm(config, size=size, **kwargs)
260
+ else:
261
+ raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
262
+
263
+ def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
264
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
265
+ # `is_autocast_cpu_enabled()` for CPU autocast.
266
+ # See https://github.com/pytorch/pytorch/issues/110966.
267
+ if tensor.device.type == "cuda" and torch.is_autocast_enabled():
268
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
269
+ elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
270
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
271
+ else:
272
+ return tensor
273
+
274
+ def reset_parameters(self):
275
+ if self.weight is not None:
276
+ torch.nn.init.ones_(self.weight) # type: ignore
277
+ if self.bias is not None:
278
+ torch.nn.init.zeros_(self.bias) # type: ignore
279
+
280
+
281
+ class LayerNorm(LayerNormBase):
282
+ """
283
+ The default :class:`LayerNorm` implementation which can optionally run in low precision.
284
+ """
285
+
286
+ def __init__(
287
+ self,
288
+ config: ModelConfig,
289
+ size: Optional[int] = None,
290
+ low_precision: bool = False,
291
+ elementwise_affine: Optional[bool] = None,
292
+ eps: float = 1e-05,
293
+ ):
294
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
295
+ self.low_precision = low_precision
296
+
297
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
298
+ if self.low_precision:
299
+ module_device = x.device
300
+ downcast_x = self._cast_if_autocast_enabled(x)
301
+ downcast_weight = (
302
+ self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
303
+ )
304
+ downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
305
+ with torch.autocast(enabled=False, device_type=module_device.type):
306
+ return F.layer_norm(
307
+ downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
308
+ )
309
+ else:
310
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
311
+
312
+
313
+ class RMSLayerNorm(LayerNormBase):
314
+ """
315
+ RMS layer norm, a simplified :class:`LayerNorm` implementation
316
+ """
317
+
318
+ def __init__(
319
+ self,
320
+ config: ModelConfig,
321
+ size: Optional[int] = None,
322
+ elementwise_affine: Optional[bool] = None,
323
+ eps: float = 1e-5,
324
+ ):
325
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
326
+
327
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
328
+ with torch.autocast(enabled=False, device_type=x.device.type):
329
+ og_dtype = x.dtype
330
+ x = x.to(torch.float32)
331
+ variance = x.pow(2).mean(-1, keepdim=True)
332
+ x = x * torch.rsqrt(variance + self.eps)
333
+ x = x.to(og_dtype)
334
+
335
+ if self.weight is not None:
336
+ if self.bias is not None:
337
+ return self.weight * x + self.bias
338
+ else:
339
+ return self.weight * x
340
+ else:
341
+ return x
342
+
343
+
344
+ class GemmaRMSLayerNorm(LayerNormBase):
345
+ """
346
+ Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ config: ModelConfig,
352
+ size: Optional[int] = None,
353
+ elementwise_affine: Optional[bool] = None,
354
+ eps: float = 1e-5,
355
+ ):
356
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
357
+
358
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
359
+ with torch.autocast(enabled=False, device_type=x.device.type):
360
+ og_dtype = x.dtype
361
+ x = x.to(torch.float32)
362
+ variance = x.pow(2).mean(-1, keepdim=True)
363
+ x = x * torch.rsqrt(variance + self.eps)
364
+ x = x.to(og_dtype)
365
+
366
+ if self.weight is not None:
367
+ if self.bias is not None:
368
+ return x * (1 + self.weight) + self.bias
369
+ else:
370
+ return x * (1 + self.weight)
371
+ else:
372
+ return x
373
+
374
+
375
+ class RotaryEmbedding(nn.Module):
376
+ """
377
+ [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
378
+ """
379
+
380
+ def __init__(self, config: ModelConfig, cache: BufferCache):
381
+ super().__init__()
382
+ self.config = config
383
+ self.__cache = cache
384
+ # Warm up cache.
385
+ self.rope_theta = config.rope_theta
386
+ self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
387
+
388
+ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
389
+ if (
390
+ (pos_sin := self.__cache.get("rope_pos_sin")) is not None
391
+ and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
392
+ and pos_sin.shape[-2] >= seq_len
393
+ and pos_cos.shape[-2] >= seq_len
394
+ ):
395
+ if pos_sin.device != device:
396
+ pos_sin = pos_sin.to(device)
397
+ self.__cache["rope_pos_sin"] = pos_sin
398
+ if pos_cos.device != device:
399
+ pos_cos = pos_cos.to(device)
400
+ self.__cache["rope_pos_cos"] = pos_cos
401
+ return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
402
+
403
+ with torch.autocast(device.type, enabled=False):
404
+ dim = self.config.d_model // self.config.n_heads
405
+ inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
406
+ seq = torch.arange(seq_len, device=device, dtype=torch.float)
407
+ freqs = einsum("i , j -> i j", seq, inv_freq)
408
+ positions = torch.cat((freqs, freqs), dim=-1)
409
+ pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
410
+ self.__cache["rope_pos_sin"] = pos_sin
411
+ self.__cache["rope_pos_cos"] = pos_cos
412
+ return pos_sin, pos_cos
413
+
414
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
415
+ B, nh, T, hs = x.size()
416
+ x = x.view(B, nh, T, 2, hs // 2)
417
+ x1, x2 = x.unbind(dim=-2)
418
+ return torch.cat((-x2, x1), dim=-1)
419
+
420
+ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
421
+ return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
422
+
423
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
424
+ if self.config.rope_full_precision:
425
+ q_, k_ = q.float(), k.float()
426
+ else:
427
+ q_, k_ = q, k
428
+
429
+ with torch.autocast(q.device.type, enabled=False):
430
+ query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
431
+ pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
432
+ pos_sin = pos_sin.type_as(q_)
433
+ pos_cos = pos_cos.type_as(q_)
434
+ q_ = self.apply_rotary_pos_emb(
435
+ pos_sin[:, :, key_len - query_len : key_len, :],
436
+ pos_cos[:, :, key_len - query_len : key_len, :],
437
+ q_,
438
+ )
439
+ k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
440
+ return q_.type_as(q), k_.type_as(k)
441
+
442
+
443
+ class Activation(nn.Module):
444
+ def __init__(self, config: ModelConfig):
445
+ super().__init__()
446
+ self.config = config
447
+
448
+ @abstractmethod
449
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
450
+ raise NotImplementedError
451
+
452
+ @property
453
+ @abstractmethod
454
+ def output_multiplier(self) -> float:
455
+ raise NotImplementedError
456
+
457
+ @classmethod
458
+ def build(cls, config: ModelConfig) -> Activation:
459
+ if config.activation_type == ActivationType.gelu:
460
+ return cast(Activation, GELU(approximate="none"))
461
+ elif config.activation_type == ActivationType.relu:
462
+ return cast(Activation, ReLU(inplace=False))
463
+ elif config.activation_type == ActivationType.silu:
464
+ return cast(Activation, SiLU(inplace=False))
465
+ elif config.activation_type == ActivationType.swiglu:
466
+ return SwiGLU(config)
467
+ else:
468
+ raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
469
+
470
+
471
+ class GELU(nn.GELU):
472
+ @property
473
+ def output_multiplier(self) -> float:
474
+ return 1.0
475
+
476
+
477
+ class ReLU(nn.ReLU):
478
+ @property
479
+ def output_multiplier(self) -> float:
480
+ return 1.0
481
+
482
+ class SiLU(nn.SiLU):
483
+ @property
484
+ def output_multiplier(self) -> float:
485
+ return 1.0
486
+
487
+ class SwiGLU(Activation):
488
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
489
+ x, gate = x.chunk(2, dim=-1)
490
+ return F.silu(gate) * x
491
+
492
+ @property
493
+ def output_multiplier(self) -> float:
494
+ return 0.5
495
+
496
+
497
+ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
498
+ att_bias = torch.triu(
499
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
500
+ diagonal=1,
501
+ )
502
+ att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
503
+ return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
504
+
505
+
506
+ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
507
+ if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
508
+ if causal_bias.device != device:
509
+ causal_bias = causal_bias.to(device)
510
+ cache["causal_attention_bias"] = causal_bias
511
+ return causal_bias
512
+ with torch.autocast(device.type, enabled=False):
513
+ causal_bias = causal_attention_bias(seq_len, device)
514
+ cache["causal_attention_bias"] = causal_bias
515
+ return causal_bias
516
+
517
+
518
+ def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
519
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
520
+
521
+ # shape: (1, 1, seq_len, seq_len)
522
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
523
+ alibi_bias.abs_().mul_(-1)
524
+
525
+ # shape: (n_heads,)
526
+ m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
527
+ m.mul_(config.alibi_bias_max / config.n_heads)
528
+
529
+ # shape: (1, n_heads, seq_len, seq_len)
530
+ return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
531
+
532
+
533
+ class LLaDABlock(nn.Module):
534
+ """
535
+ A base class for transformer block implementations.
536
+ """
537
+
538
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
539
+ super().__init__()
540
+ self.layer_id = layer_id
541
+ self.config = config
542
+ self.hidden_size = (
543
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
544
+ )
545
+ self.__cache = cache
546
+ assert config.d_model % config.n_heads == 0
547
+
548
+ self._activation_checkpoint_fn = None
549
+
550
+ # Dropout.
551
+ self.dropout = Dropout(config.residual_dropout)
552
+
553
+ # Layer norms.
554
+ self.k_norm: Optional[LayerNormBase] = None
555
+ self.q_norm: Optional[LayerNormBase] = None
556
+ if config.attention_layer_norm:
557
+ self.k_norm = LayerNormBase.build(
558
+ config,
559
+ size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
560
+ elementwise_affine=config.attention_layer_norm_with_affine,
561
+ )
562
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
563
+
564
+ # Activation function.
565
+ self.act = Activation.build(config)
566
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
567
+
568
+ # Attention output projection.
569
+ self.attn_out = nn.Linear(
570
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
571
+ )
572
+
573
+ # Feed-forward output projection.
574
+ self.ff_out = nn.Linear(
575
+ int(self.act.output_multiplier * self.hidden_size),
576
+ config.d_model,
577
+ bias=config.include_bias,
578
+ device=config.init_device,
579
+ )
580
+ self.ff_out._is_residual = True # type: ignore
581
+
582
+ # Rotary embeddings.
583
+ if self.config.rope:
584
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
585
+
586
+ self.flash_attn_func = None
587
+ if config.flash_attention:
588
+ try:
589
+ from flash_attn import flash_attn_func # type: ignore
590
+
591
+ self.flash_attn_func = flash_attn_func
592
+ except ModuleNotFoundError:
593
+ pass
594
+
595
+ def reset_parameters(self):
596
+ if self.k_norm is not None:
597
+ self.k_norm.reset_parameters()
598
+ if self.q_norm is not None:
599
+ self.q_norm.reset_parameters()
600
+ init_weights(
601
+ self.config,
602
+ self.attn_out,
603
+ d=self.config.d_model,
604
+ layer_id=self.layer_id,
605
+ type_of_module=ModuleType.out_module,
606
+ )
607
+ init_weights(
608
+ self.config,
609
+ self.ff_out,
610
+ d=self.ff_out.in_features,
611
+ layer_id=self.layer_id,
612
+ type_of_module=ModuleType.out_module,
613
+ )
614
+
615
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
616
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
617
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
618
+ else:
619
+ self._activation_checkpoint_fn = None
620
+
621
+ @classmethod
622
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
623
+ target_dtype = input_dtype
624
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
625
+ # `is_autocast_cpu_enabled()` for CPU autocast.
626
+ # See https://github.com/pytorch/pytorch/issues/110966.
627
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
628
+ target_dtype = torch.get_autocast_gpu_dtype()
629
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
630
+ target_dtype = torch.get_autocast_cpu_dtype()
631
+ if bias.dtype != target_dtype:
632
+ bias = bias.to(target_dtype)
633
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
634
+ return bias
635
+
636
+ def _scaled_dot_product_attention(
637
+ self,
638
+ q: torch.Tensor,
639
+ k: torch.Tensor,
640
+ v: torch.Tensor,
641
+ attn_mask: Optional[torch.Tensor] = None,
642
+ dropout_p: float = 0.0,
643
+ is_causal: bool = False,
644
+ ) -> torch.Tensor:
645
+ """
646
+ Computes scaled dot product attention on query, key and value tensors, using an optional
647
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
648
+ """
649
+ if self.flash_attn_func is not None and attn_mask is None:
650
+ r = self.flash_attn_func(
651
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
652
+ )
653
+ return r.transpose(1, 2)
654
+ else:
655
+ # torch's sdpa doesn't support GQA, so we're doing this
656
+ assert k.size(1) == v.size(1)
657
+ num_kv_heads = k.size(1)
658
+ num_q_heads = q.size(1)
659
+ if num_q_heads != num_kv_heads:
660
+ assert num_q_heads % num_kv_heads == 0
661
+ k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
662
+ v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
663
+
664
+ # Modify: MDM set causal to False, and with no attn_mask.
665
+ return F.scaled_dot_product_attention(
666
+ q,
667
+ k,
668
+ v,
669
+ attn_mask=None,
670
+ dropout_p=dropout_p,
671
+ is_causal=False,
672
+ )
673
+
674
+ def attention(
675
+ self,
676
+ q: torch.Tensor,
677
+ k: torch.Tensor,
678
+ v: torch.Tensor,
679
+ attention_bias: Optional[torch.Tensor] = None,
680
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
681
+ use_cache: bool = False,
682
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
683
+ B, T, C = q.size() # batch size, sequence length, d_model
684
+ dtype = k.dtype
685
+
686
+ # Optionally apply layer norm to keys and queries.
687
+ if self.q_norm is not None and self.k_norm is not None:
688
+ q = self.q_norm(q).to(dtype=dtype)
689
+ k = self.k_norm(k).to(dtype=dtype)
690
+
691
+ # Move head forward to be next to the batch dim.
692
+ # shape: (B, nh, T, hs)
693
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
694
+ # shape: (B, n_kv_h, T, hs)
695
+ k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
696
+ # shape: (B, n_kv_h, T, hs)
697
+ v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
698
+
699
+ if layer_past is not None:
700
+ past_key, past_value = layer_past
701
+ k = torch.cat((past_key, k), dim=-2)
702
+ v = torch.cat((past_value, v), dim=-2)
703
+
704
+ present = (k, v) if use_cache else None
705
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
706
+
707
+ if self.config.rope:
708
+ # Apply rotary embeddings.
709
+ q, k = self.rotary_emb(q, k)
710
+
711
+ if attention_bias is not None:
712
+ # Resize and cast attention bias.
713
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
714
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
715
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
716
+ # cause the SDP attn function to produce NaNs.
717
+ attention_bias = self._cast_attn_bias(
718
+ attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
719
+ )
720
+
721
+ # Get the attention scores.
722
+ # shape: (B, nh, T, hs)
723
+ att = self._scaled_dot_product_attention(
724
+ q,
725
+ k,
726
+ v,
727
+ attn_mask=None,
728
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
729
+ is_causal=False,
730
+ )
731
+
732
+ # Re-assemble all head outputs side-by-side.
733
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
734
+
735
+ # Apply output projection.
736
+ return self.attn_out(att), present
737
+
738
+ @abstractmethod
739
+ def forward(
740
+ self,
741
+ x: torch.Tensor,
742
+ attention_bias: Optional[torch.FloatTensor] = None,
743
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
744
+ use_cache: bool = False,
745
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
746
+ raise NotImplementedError
747
+
748
+ @classmethod
749
+ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
750
+ if config.block_type == BlockType.sequential:
751
+ return LLaDASequentialBlock(layer_id, config, cache)
752
+ elif config.block_type == BlockType.llama:
753
+ return LLaDALlamaBlock(layer_id, config, cache)
754
+ else:
755
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
756
+
757
+
758
+ class LLaDASequentialBlock(LLaDABlock):
759
+ """
760
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
761
+ (plus another skip connection).
762
+ """
763
+
764
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
765
+ super().__init__(layer_id, config, cache)
766
+ # Layer norms.
767
+ self.attn_norm = LayerNorm.build(config)
768
+ self.ff_norm = LayerNorm.build(config)
769
+ # Attention input projection. Projects x -> (q, k, v)
770
+ head_dim = config.d_model // config.n_heads
771
+ self.fused_dims = (
772
+ config.d_model,
773
+ config.effective_n_kv_heads * head_dim,
774
+ config.effective_n_kv_heads * head_dim,
775
+ )
776
+ self.att_proj = nn.Linear(
777
+ config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
778
+ )
779
+ # Feed-forward input projection.
780
+ self.ff_proj = nn.Linear(
781
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
782
+ )
783
+
784
+ def reset_parameters(self):
785
+ super().reset_parameters()
786
+ self.attn_norm.reset_parameters()
787
+ self.ff_norm.reset_parameters()
788
+ # NOTE: the standard deviation for these weights does not depend on the layer.
789
+ init_weights(
790
+ self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
791
+ )
792
+ init_weights(
793
+ self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
794
+ )
795
+
796
+ def forward(
797
+ self,
798
+ x: torch.Tensor,
799
+ attention_bias: Optional[torch.Tensor] = None,
800
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
801
+ use_cache: bool = False,
802
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
803
+ # Get query, key, value projections.
804
+ # shape:
805
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
806
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
807
+ # k, v: (batch_size, seq_len, d_model // n_heads)
808
+ # - for group query attn q: (batch_size, seq_len, d_model)
809
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
810
+ if self._activation_checkpoint_fn is not None:
811
+ q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
812
+ self.fused_dims, dim=-1
813
+ )
814
+ else:
815
+ q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
816
+
817
+ # Get attention scores.
818
+ if self._activation_checkpoint_fn is not None:
819
+ att, cache = self._activation_checkpoint_fn( # type: ignore
820
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
821
+ )
822
+ else:
823
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
824
+
825
+ # Add attention scores.
826
+ # shape: (B, T, C)
827
+ x = x + self.dropout(att)
828
+
829
+ # Add feed-forward projection.
830
+ # shape: (batch_size, seq_len, d_model)
831
+ og_x = x
832
+ if self._activation_checkpoint_fn is not None:
833
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
834
+ else:
835
+ x = self.ff_norm(x)
836
+ x = self.ff_proj(x)
837
+ if self._activation_checkpoint_fn is not None:
838
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
839
+ else:
840
+ x = self.act(x)
841
+ x = self.ff_out(x)
842
+ x = self.dropout(x)
843
+ x = og_x + x
844
+
845
+ return x, cache
846
+
847
+
848
+ class LLaDALlamaBlock(LLaDABlock):
849
+ """
850
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
851
+ (plus another skip connection). This block is similar to `LLaDASequentialBlock`
852
+ but some operations have slightly different implementations to imitate the
853
+ behavior of Llama.
854
+ """
855
+
856
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
857
+ super().__init__(layer_id, config, cache)
858
+ # Layer norms.
859
+ self.attn_norm = LayerNorm.build(config)
860
+ self.ff_norm = LayerNorm.build(config)
861
+ self.__cache = cache
862
+
863
+ # Attention input projection. Projects x -> (q, k, v)
864
+ head_dim = config.d_model // config.n_heads
865
+ q_proj_out_dim = config.d_model
866
+ k_proj_out_dim = config.effective_n_kv_heads * head_dim
867
+ v_proj_out_dim = config.effective_n_kv_heads * head_dim
868
+ self.q_proj = nn.Linear(
869
+ config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
870
+ )
871
+ self.k_proj = nn.Linear(
872
+ config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
873
+ )
874
+ self.v_proj = nn.Linear(
875
+ config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
876
+ )
877
+
878
+ # Feed-forward input projection.
879
+ self.ff_proj = nn.Linear(
880
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
881
+ )
882
+ # new add
883
+ self.up_proj = nn.Linear(
884
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
885
+ )
886
+
887
+ #[MODIFIED] : 这里没有正确传递type_of_module
888
+ def reset_parameters(self):
889
+ super().reset_parameters() # This correctly initializes attn_out and ff_out as 'out_module'
890
+ self.attn_norm.reset_parameters()
891
+ self.ff_norm.reset_parameters()
892
+
893
+ # Correctly initialize all input projections with type_of_module
894
+ init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module)
895
+ init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module)
896
+ init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module)
897
+ init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module)
898
+ init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module)
899
+
900
+ def forward(
901
+ self,
902
+ x: torch.Tensor,
903
+ attention_bias: Optional[torch.Tensor] = None,
904
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
905
+ use_cache: bool = False,
906
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
907
+ # Get query, key, value projections.
908
+ # shape:
909
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
910
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
911
+ # k, v: (batch_size, seq_len, d_model // n_heads)
912
+ # - for group query attn q: (batch_size, seq_len, d_model)
913
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
914
+ x_normed = self.attn_norm(x)
915
+ q = self.q_proj(x_normed)
916
+ k = self.k_proj(x_normed)
917
+ v = self.v_proj(x_normed)
918
+
919
+ # Get attention scores.
920
+ if self._activation_checkpoint_fn is not None:
921
+ att, cache = self._activation_checkpoint_fn( # type: ignore
922
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
923
+ )
924
+ else:
925
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
926
+
927
+ # Add attention scores.
928
+ # shape: (B, T, C)
929
+ x = x + self.dropout(att)
930
+
931
+ # Add feed-forward projection.
932
+ # shape: (batch_size, seq_len, d_model)
933
+ og_x = x
934
+ if self._activation_checkpoint_fn is not None:
935
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
936
+ else:
937
+ x = self.ff_norm(x)
938
+ x, x_up = self.ff_proj(x), self.up_proj(x) # new add
939
+ if self._activation_checkpoint_fn is not None:
940
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
941
+ else:
942
+ x = self.act(x)
943
+ x = x * x_up # new add
944
+ x = self.ff_out(x)
945
+ x = self.dropout(x)
946
+ x = og_x + x
947
+
948
+ return x, cache
949
+
950
+
951
+ class LLaDAOutput(NamedTuple):
952
+ logits: torch.FloatTensor
953
+ """
954
+ A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
955
+ for the next token *before* normalization via (log) softmax.
956
+ """
957
+
958
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
959
+ """
960
+ Attention keys and values from each block.
961
+ """
962
+
963
+ hidden_states: Optional[Tuple[torch.Tensor]]
964
+ """
965
+ Hidden states from each block.
966
+ """
967
+
968
+
969
+ class LLaDAGenerateOutput(NamedTuple):
970
+ token_ids: torch.LongTensor
971
+ """
972
+ The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
973
+ These do *not* include the original input IDs.
974
+ """
975
+
976
+ scores: torch.FloatTensor
977
+ """
978
+ The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
979
+ """
980
+
981
+
982
+ class LLaDABlockGroup(nn.ModuleList):
983
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
984
+ super().__init__(modules)
985
+ self.config = config
986
+ self.layer_offset = layer_offset
987
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
988
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
989
+
990
+ def forward(
991
+ self,
992
+ x: torch.Tensor,
993
+ attention_bias: Optional[torch.FloatTensor] = None,
994
+ layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
995
+ use_cache: bool = False,
996
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
997
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
998
+ for block_idx, block in enumerate(self):
999
+ layer_past = None if layers_past is None else layers_past[block_idx]
1000
+ block_idx += self.layer_offset
1001
+ if (
1002
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1003
+ or (
1004
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1005
+ and block_idx % 2 == 0
1006
+ )
1007
+ or (
1008
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1009
+ and block_idx % 3 == 0
1010
+ )
1011
+ or (
1012
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1013
+ and block_idx % 4 == 0
1014
+ )
1015
+ ):
1016
+ # shape: (batch_size, seq_len, d_model)
1017
+ x, cache = self._activation_checkpoint_fn( # type: ignore
1018
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1019
+ )
1020
+ else:
1021
+ # shape: (batch_size, seq_len, d_model)
1022
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1023
+ if attn_key_values is not None:
1024
+ assert cache is not None
1025
+ attn_key_values.append(cache)
1026
+ return x, attn_key_values
1027
+
1028
+ def reset_parameters(self):
1029
+ for block in self:
1030
+ block.reset_parameters()
1031
+
1032
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1033
+ self.activation_checkpointing_strategy = strategy
1034
+ for block in self:
1035
+ block.set_activation_checkpointing(strategy)
1036
+
1037
+
1038
+ class LLaDAModel(nn.Module):
1039
+ def __init__(self, config: ModelConfig, init_params: bool = True):
1040
+ super().__init__()
1041
+ self.config = config
1042
+ self.__cache = BufferCache()
1043
+
1044
+ # Validate config.
1045
+ if self.config.alibi and self.config.flash_attention:
1046
+ raise Exception("ALiBi is currently not supported with FlashAttention")
1047
+
1048
+ if self.config.alibi and self.config.rope:
1049
+ raise Exception("ALiBi and RoPE are mutually exclusive")
1050
+
1051
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
1052
+ if self.config.embedding_size < self.config.vocab_size:
1053
+ raise Exception("embedding size should be at least as big as vocab size")
1054
+ elif self.config.embedding_size % 128 != 0:
1055
+ import warnings
1056
+
1057
+ warnings.warn(
1058
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1059
+ )
1060
+
1061
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1062
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1063
+
1064
+ if not (
1065
+ 0 < self.config.block_group_size <= self.config.n_layers
1066
+ and self.config.n_layers % self.config.block_group_size == 0
1067
+ ):
1068
+ raise Exception("n layers must be divisible by block group size")
1069
+
1070
+ torch.backends.cuda.enable_flash_sdp(True)
1071
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
1072
+
1073
+ self.transformer = nn.ModuleDict(
1074
+ dict(
1075
+ wte=nn.Embedding(
1076
+ config.embedding_size or config.vocab_size, config.d_model, device=config.init_device,padding_idx=config.pad_token_id
1077
+ ),
1078
+ emb_drop=Dropout(config.embedding_dropout),
1079
+ ln_f=LayerNorm.build(config),
1080
+ )
1081
+ )
1082
+
1083
+ blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
1084
+ if self.config.block_group_size > 1:
1085
+ block_groups = [
1086
+ LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
1087
+ for i in range(0, config.n_layers, config.block_group_size)
1088
+ ]
1089
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
1090
+ else:
1091
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
1092
+
1093
+ if not (self.config.alibi or self.config.rope):
1094
+ self.transformer.update(
1095
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1096
+ )
1097
+ if not config.weight_tying:
1098
+ self.transformer.update(
1099
+ {
1100
+ "ff_out": nn.Linear(
1101
+ config.d_model,
1102
+ config.embedding_size or config.vocab_size,
1103
+ bias=config.include_bias,
1104
+ device=config.init_device,
1105
+ )
1106
+ }
1107
+ )
1108
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1109
+ if init_params and self.config.init_device != "meta":
1110
+ self.reset_parameters()
1111
+ self.__num_fwd_flops: Optional[int] = None
1112
+
1113
+ # Warm up cache.
1114
+ if self.config.alibi:
1115
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1116
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1117
+
1118
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1119
+ self.activation_checkpointing_strategy = strategy
1120
+ if self.config.block_group_size != 1:
1121
+ for block_group in self.transformer.block_groups:
1122
+ block_group.set_activation_checkpointing(strategy)
1123
+ else:
1124
+ for block in self.transformer.blocks:
1125
+ block.set_activation_checkpointing(strategy)
1126
+
1127
+ @property
1128
+ def device(self) -> torch.device:
1129
+ device: torch.device = self.transformer.wte.weight.device # type: ignore
1130
+ if device.type == "meta":
1131
+ return _non_meta_init_device(self.config)
1132
+ else:
1133
+ return device
1134
+
1135
+ def reset_parameters(self):
1136
+ log.info("Initializing model parameters...")
1137
+ # Top-level embeddings / linear layers.
1138
+ init_weights(
1139
+ self.config,
1140
+ self.transformer.wte, # type: ignore
1141
+ std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
1142
+ type_of_module=ModuleType.emb,
1143
+ )
1144
+ if hasattr(self.transformer, "wpe"):
1145
+ init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
1146
+
1147
+ # Top-level layer norm.
1148
+ self.transformer.ln_f.reset_parameters() # type: ignore
1149
+
1150
+ # Output weights.
1151
+ if hasattr(self.transformer, "ff_out"):
1152
+ init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
1153
+
1154
+ # Let the blocks handle themselves.
1155
+ if self.config.block_group_size == 1:
1156
+ for block in self.transformer.blocks:
1157
+ block.reset_parameters()
1158
+ else:
1159
+ for block_group in self.transformer.block_groups:
1160
+ block_group.reset_parameters()
1161
+
1162
+ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1163
+ if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
1164
+ -1
1165
+ ] >= seq_len:
1166
+ if alibi_bias.device != device:
1167
+ alibi_bias = alibi_bias.to(device)
1168
+ self.__cache["alibi_attention_bias"] = alibi_bias
1169
+ return alibi_bias
1170
+ with torch.autocast(device.type, enabled=False):
1171
+ alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1172
+ self.__cache["alibi_attention_bias"] = alibi_bias
1173
+ return alibi_bias
1174
+
1175
+ def forward(
1176
+ self,
1177
+ input_ids: torch.LongTensor,
1178
+ input_embeddings: Optional[torch.FloatTensor] = None,
1179
+ attention_mask: Optional[torch.Tensor] = None,
1180
+ attention_bias: Optional[torch.Tensor] = None,
1181
+ past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1182
+ use_cache: bool = False,
1183
+ last_logits_only: bool = False,
1184
+ output_hidden_states: Optional[bool] = None,
1185
+ ) -> LLaDAOutput:
1186
+ """
1187
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1188
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1189
+ embeddings. When provided, it is treated as the output of the input embedding layer.
1190
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1191
+ which input IDs are masked. A `1` value in the mask means that
1192
+ the corresponding input ID should *not* be ignored. A `0` means
1193
+ that the corresponding input ID is masked.
1194
+
1195
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1196
+ library.
1197
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1198
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1199
+ to introduce causal or other biases.
1200
+
1201
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1202
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
1203
+ element in the sequence.
1204
+
1205
+ If the tensor is a float tensor, it will just be added to the attention
1206
+ scores before the softmax.
1207
+
1208
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1209
+ :param past_key_values: Pre-computed keys and values for each attention block.
1210
+ Can be used to speed up sequential decoding. The `input_ids` which have
1211
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
1212
+ :param use_cache: If `True`, return key and value tensors for each block.
1213
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1214
+ This can speed up decoding when you only care about the next token.
1215
+ """
1216
+ # Add Basic MDM Model config check
1217
+ assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
1218
+ assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
1219
+ assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM."
1220
+
1221
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1222
+
1223
+ if past_key_values:
1224
+ assert len(past_key_values) == self.config.n_layers
1225
+
1226
+ batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
1227
+ if past_key_values is None:
1228
+ past_length = 0
1229
+ else:
1230
+ past_length = past_key_values[0][0].size(-2)
1231
+
1232
+ # Get embeddings of input.
1233
+ # shape: (batch_size, seq_len, d_model)
1234
+
1235
+ x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
1236
+
1237
+ if self.config.input_emb_norm:
1238
+ x = x * (self.config.d_model**0.5)
1239
+
1240
+ if not (self.config.alibi or self.config.rope):
1241
+ # Get positional embeddings.
1242
+ # shape: (1, seq_len)
1243
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1244
+ # shape: (1, seq_len, d_model)
1245
+ pos_emb = self.transformer.wpe(pos) # type: ignore
1246
+ x = pos_emb + x
1247
+
1248
+ # Add input + positional embeddings and apply dropout.
1249
+ # shape: (batch_size, seq_len, d_model)
1250
+ x = self.transformer.emb_drop(x) # type: ignore
1251
+
1252
+ # Transform the attention mask into what the blocks expect.
1253
+ if attention_mask is not None and 0.0 in attention_mask:
1254
+ # shape: (batch_size, 1, 1, seq_len)
1255
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1256
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1257
+ else:
1258
+ attention_mask = None
1259
+
1260
+ # Merge attention mask with attention bias.
1261
+ if (
1262
+ attention_bias is not None
1263
+ or attention_mask is not None
1264
+ or self.config.alibi
1265
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1266
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1267
+ # scores correctly.
1268
+ or past_key_values is not None
1269
+ ):
1270
+ if attention_bias is None and self.config.alibi:
1271
+ attention_bias = get_causal_attention_bias(
1272
+ self.__cache, past_length + seq_len, x.device
1273
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1274
+ elif attention_bias is None:
1275
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1276
+ elif attention_bias.dtype in (torch.int8, torch.bool):
1277
+ attention_bias = attention_bias.to(dtype=torch.float)
1278
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1279
+
1280
+ # Transform to the right shape and data type.
1281
+ mask_len = seq_len
1282
+ if attention_mask is not None:
1283
+ mask_len = attention_mask.shape[-1]
1284
+ elif past_key_values is not None:
1285
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
1286
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1287
+
1288
+ # Add in the masking bias.
1289
+ if attention_mask is not None:
1290
+ attention_bias = attention_bias + attention_mask
1291
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1292
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1293
+ # it can produce NaNs.
1294
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1295
+
1296
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1297
+
1298
+ # decoder layers
1299
+ all_hidden_states = []
1300
+
1301
+ # Apply blocks one-by-one.
1302
+ if self.config.block_group_size == 1:
1303
+ for block_idx, block in enumerate(self.transformer.blocks):
1304
+ if output_hidden_states:
1305
+ # add hidden states
1306
+ all_hidden_states.append(x)
1307
+
1308
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
1309
+ if (
1310
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1311
+ or (
1312
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1313
+ and block_idx % 2 == 0
1314
+ )
1315
+ or (
1316
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1317
+ and block_idx % 3 == 0
1318
+ )
1319
+ or (
1320
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1321
+ and block_idx % 4 == 0
1322
+ )
1323
+ ):
1324
+ # shape: (batch_size, seq_len, d_model)
1325
+ x, cache = self._activation_checkpoint_fn(
1326
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1327
+ )
1328
+ else:
1329
+ # shape: (batch_size, seq_len, d_model)
1330
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1331
+ if attn_key_values is not None:
1332
+ assert cache is not None
1333
+ attn_key_values.append(cache)
1334
+ else:
1335
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
1336
+ if output_hidden_states:
1337
+ # add hidden states
1338
+ all_hidden_states.append(x)
1339
+
1340
+ layers_past = (
1341
+ None
1342
+ if past_key_values is None
1343
+ else past_key_values[
1344
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1345
+ ]
1346
+ )
1347
+ x, cache = block_group(
1348
+ x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
1349
+ )
1350
+ if attn_key_values is not None:
1351
+ assert cache is not None
1352
+ attn_key_values.extend(cache)
1353
+
1354
+ if last_logits_only:
1355
+ # shape: (batch_size, 1, d_model)
1356
+ x = x[:, -1, :].unsqueeze(1)
1357
+
1358
+ # Apply final layer norm.
1359
+ # shape: (batch_size, seq_len or 1, d_model)
1360
+ x = self.transformer.ln_f(x) # type: ignore
1361
+ if output_hidden_states:
1362
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
1363
+ all_hidden_states.append(x)
1364
+
1365
+ # Get logits.
1366
+ # shape: (batch_size, seq_len or 1, vocab_size)
1367
+ if self.config.weight_tying:
1368
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1369
+ else:
1370
+ logits = self.transformer.ff_out(x) # type: ignore
1371
+ if self.config.scale_logits:
1372
+ logits.mul_(1 / math.sqrt(self.config.d_model))
1373
+
1374
+ return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
1375
+
1376
+
1377
+ def create_model_config_from_pretrained_config(config: LLaDAConfig):
1378
+ """
1379
+ Utility function
1380
+ """
1381
+
1382
+ kwargs = {}
1383
+ for field in fields(ModelConfig):
1384
+ kwargs[field.name] = getattr(config, field.name)
1385
+
1386
+ model_config = ModelConfig(**kwargs)
1387
+ return model_config
1388
+
1389
+
1390
+ from transformers.modeling_outputs import ModelOutput
1391
+
1392
+ from transformers.loss.loss_utils import fixed_cross_entropy
1393
+
1394
+ def ForMaskedLMLoss(
1395
+ logits: torch.Tensor,
1396
+ labels: torch.Tensor,
1397
+ vocab_size: int,
1398
+ num_items_in_batch: Optional[torch.Tensor] = None,
1399
+ ignore_index: int = -100,
1400
+ # 新增参数,用于接收 per-token 的权重
1401
+ per_token_weights: Optional[torch.Tensor] = None,
1402
+ loss_normalization: str = "masked_tokens",
1403
+ **kwargs,
1404
+ ):
1405
+ """
1406
+ 计算Masked Language Model的损失。
1407
+ 支持基于重要性采样的 per-token 加权。
1408
+ """
1409
+
1410
+ logits = logits.float()
1411
+ # 如果没有提供权重,则使用原始的、更高效的计算方式
1412
+
1413
+
1414
+ if per_token_weights is None:
1415
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1416
+ # Flatten the tokens
1417
+ logits = logits.view(-1, vocab_size)
1418
+ labels = labels.view(-1)
1419
+ labels = labels.to(logits.device)
1420
+ # 使用原始的 fixed_cross_entropy
1421
+ loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
1422
+ return loss
1423
+
1424
+ # --- 如果提供了权重,则执行以下加权逻辑 ---
1425
+ # 1. 计算 per-token 的 loss,不进行 reduction
1426
+ # logits: (batch, seq_len, vocab_size) 或 (total_tokens, vocab_size)
1427
+ # labels: (batch, seq_len) 或 (total_tokens,)
1428
+ per_token_loss = F.cross_entropy(
1429
+ logits.view(-1, vocab_size),
1430
+ labels.view(-1),
1431
+ ignore_index=ignore_index,
1432
+ reduction='none' # 这是关键!
1433
+ )
1434
+
1435
+ # 2. 确保权重张量与 per_token_loss 的形状匹配
1436
+ # per_token_weights 应该已经被塑造成 (batch * seq_len,) 或 (total_tokens,) 的形状
1437
+ weights = per_token_weights.reshape(-1).to(per_token_loss.device)
1438
+
1439
+ # 检查形状是否匹配
1440
+ if per_token_loss.shape != weights.shape:
1441
+ raise ValueError(
1442
+ f"Shape mismatch between per_token_loss ({per_token_loss.shape}) and weights ({weights.shape})."
1443
+ "Please ensure per_token_weights are correctly expanded."
1444
+ )
1445
+
1446
+ # 3. 应用权重
1447
+ weighted_loss = per_token_loss * weights
1448
+
1449
+ # 4. 手动进行 reduction,模拟原始的 fixed_cross_entropy 行为
1450
+ # 即,对所有有效的(非 ignore_index)token 的加权损失求和,然后除以有效 token 的数量
1451
+ # 这种归一化方式可以确保不同大小的batch和不同数量的mask token下的loss尺度相对稳定。
1452
+
1453
+ # 获取有效token的数量
1454
+ if num_items_in_batch is None:
1455
+ # 如果没提供,就自己算
1456
+ num_valid_tokens = (labels.view(-1) != ignore_index).sum()
1457
+ else:
1458
+ num_valid_tokens = num_items_in_batch
1459
+
1460
+ # 避免除以零
1461
+ if torch.is_tensor(num_valid_tokens):
1462
+ num_valid_tokens = num_valid_tokens.to(weighted_loss.device)
1463
+
1464
+ # 求和并归一化
1465
+ total_weighted_loss = weighted_loss.sum()
1466
+
1467
+ if loss_normalization == "total_tokens":
1468
+ # Normalize by total number of tokens (official implementation behavior)
1469
+ num_total_tokens = labels.numel()
1470
+ final_loss = total_weighted_loss / num_total_tokens
1471
+ else:
1472
+ # Normalize by number of valid (masked) tokens (old behavior)
1473
+ if num_valid_tokens > 0:
1474
+ final_loss = total_weighted_loss / num_valid_tokens
1475
+ else:
1476
+ final_loss = torch.tensor(0.0, device=logits.device)
1477
+
1478
+ return final_loss
1479
+
1480
+
1481
+
1482
+
1483
+ @dataclass
1484
+ class CausalLMOutputWithPastAndMLMProb(ModelOutput):
1485
+ loss: Optional[torch.FloatTensor] = None
1486
+ logits: torch.FloatTensor = None
1487
+ past_key_values: Optional[List[torch.FloatTensor]] = None
1488
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1489
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1490
+ current_mlm_prob: Optional[torch.FloatTensor] = None
1491
+
1492
+ class LLaDAModelLM(PreTrainedModel):
1493
+ """
1494
+ Extremely barebones HF model wrapper.
1495
+ """
1496
+
1497
+ config_class = LLaDAConfig
1498
+ base_model_prefix = "model"
1499
+ _no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
1500
+
1501
+ def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
1502
+ super().__init__(config)
1503
+
1504
+ if not model:
1505
+ model_config = create_model_config_from_pretrained_config(config)
1506
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1507
+ model_config.init_device = "cpu"
1508
+ self.model = LLaDAModel(model_config, init_params=init_params)
1509
+ else:
1510
+ self.model = model
1511
+
1512
+ def forward(
1513
+ self,
1514
+ input_ids: torch.LongTensor = None,
1515
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1516
+ attention_mask: Optional[torch.Tensor] = None,
1517
+ attention_bias: Optional[torch.Tensor] = None,
1518
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1519
+ labels: Optional[torch.LongTensor] = None,
1520
+ use_cache: Optional[bool] = None,
1521
+ output_attentions: Optional[bool] = None,
1522
+ output_hidden_states: Optional[bool] = None,
1523
+ return_dict: Optional[bool] = None,
1524
+ cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
1525
+ current_mlm_prob: Optional[torch.Tensor] = None,
1526
+ **kwargs
1527
+ ) -> Union[Tuple, CausalLMOutputWithPastAndMLMProb]:
1528
+ if use_cache is None:
1529
+ use_cache = self.config.use_cache
1530
+
1531
+ if output_attentions:
1532
+ raise ValueError("output_attentions is not yet supported in LLaDA")
1533
+
1534
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1535
+
1536
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1537
+ outputs = self.model.forward(
1538
+ input_ids=input_ids,
1539
+ input_embeddings=inputs_embeds,
1540
+ attention_mask=attention_mask,
1541
+ attention_bias=attention_bias,
1542
+ past_key_values=past_key_values,
1543
+ use_cache=use_cache,
1544
+ output_hidden_states=output_hidden_states,
1545
+ )
1546
+
1547
+ logits = outputs.logits
1548
+ hidden_states = outputs.hidden_states
1549
+
1550
+ loss = None
1551
+ # [rangehow]: edited, add 重要性采样
1552
+ if labels is not None:
1553
+
1554
+ per_token_mlm_weights = None
1555
+ # 如果传入了 current_mlm_prob,则计算 per-token 权重
1556
+ if current_mlm_prob is not None:
1557
+ # 加一个很小的 epsilon 防止除以 0
1558
+ t = current_mlm_prob.to(logits.device)
1559
+ weights_per_sentence = 1.0 / (t + 1e-8)
1560
+
1561
+ seq_len = logits.shape[1]
1562
+
1563
+ per_token_mlm_weights = weights_per_sentence.unsqueeze(1).expand(-1, seq_len)
1564
+
1565
+
1566
+
1567
+ loss = ForMaskedLMLoss(
1568
+ logits,
1569
+ labels,
1570
+ vocab_size=self.config.vocab_size,
1571
+ per_token_weights=per_token_mlm_weights,
1572
+ loss_normalization=getattr(self.config, "loss_normalization", "masked_tokens"),
1573
+ **kwargs
1574
+ )
1575
+
1576
+
1577
+ if not return_dict:
1578
+ output = (logits,) + outputs[1:]
1579
+ return (loss,) + output if loss is not None else output
1580
+
1581
+
1582
+ # [rangehow]: edited
1583
+ return CausalLMOutputWithPastAndMLMProb(
1584
+ loss=loss,
1585
+ logits=logits,
1586
+ past_key_values=outputs.attn_key_values,
1587
+ hidden_states=hidden_states,
1588
+ current_mlm_prob=current_mlm_prob.mean() if current_mlm_prob is not None else None,
1589
+ )
1590
+
1591
+ def can_generate(self) -> bool:
1592
+ return True
1593
+
1594
+ def prepare_inputs_for_generation(
1595
+ self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
1596
+ ):
1597
+ if past_key_values:
1598
+ # This is because we want the model to only process the last generated token.
1599
+ input_ids = input_ids[:, -1:]
1600
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
1601
+
1602
+ model_inputs.update(kwargs)
1603
+ model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
1604
+ return model_inputs
1605
+
1606
+ # TODO: these are required to make the implementation complete.
1607
+ # def resize_position_embeddings(self, new_num_position_embeddings: int):
1608
+ # pass
1609
+ #
1610
+ # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
1611
+ # pass
1612
+ #
1613
+ # def _reorder_cache(self, past_key_values, beam_idx):
1614
+ # pass
1615
+
1616
+ def get_input_embeddings(self) -> torch.nn.Module:
1617
+ return self.model.transformer.wte
1618
+
1619
+ def set_input_embeddings(self, value: torch.nn.Module):
1620
+ self.model.transformer.wte = value
1621
+
1622
+ def get_output_embeddings(self):
1623
+ if self.config.weight_tying:
1624
+ return self.model.transformer.wte
1625
+ else:
1626
+ return self.model.transformer.ff_out
1627
+
1628
+ def set_output_embeddings(self, value: torch.nn.Module):
1629
+ if self.config.weight_tying:
1630
+ self.model.transformer.wte = value
1631
+ else:
1632
+ self.model.transformer.ff_out = value
1633
+
1634
+ def tie_weights(self):
1635
+ if self.config.weight_tying:
1636
+ self.model.transformer.ff_out = self.model.transformer.wte
1637
+
1638
+ # BUG 块推理还有问题
1639
+ @torch.inference_mode()
1640
+ def generate(
1641
+ self,
1642
+ input_ids: torch.LongTensor,
1643
+ mask_token_id: Optional[int],
1644
+ attention_mask: Optional[torch.Tensor] = None,
1645
+ max_new_tokens: int = 50,
1646
+ num_diffusion_steps: int = 10,
1647
+ temperature_mlm: float = 1.0,
1648
+ do_sample: bool = True,
1649
+ top_k: Optional[int] = None,
1650
+ top_p: Optional[float] = None,
1651
+ debug: bool = False,
1652
+ tokenizer = None,
1653
+ # --- 新增和修改的参数 ---
1654
+ block_size: int = 5, # 新增:定义块大小
1655
+ decode_top_k_positions: int = 2, # 修改:现在作用于块内
1656
+ **kwargs
1657
+ ) -> torch.LongTensor:
1658
+ """
1659
+ 自定义的扩散生成方法
1660
+
1661
+ Args:
1662
+ input_ids: 输入的token ids,形状为 (batch_size, seq_len)
1663
+ attention_mask: 注意力掩码,形状为 (batch_size, seq_len)
1664
+ max_new_tokens: 要生成的新token数量 (L)
1665
+ num_diffusion_steps: 扩散迭代次数 (T)
1666
+ temperature_mlm: MLM采样的温度参数
1667
+ do_sample: 是否使用采样,False则使用贪心解码
1668
+ top_k: top-k采样参数
1669
+ top_p: top-p采样参数
1670
+ mask_token_id: mask token的id,如果为None则尝试自动获取
1671
+ debug: 是否启用调试模式,输出每步迭代的详细信息
1672
+ tokenizer: 用于将token id转换为文本的tokenizer(可选)
1673
+
1674
+ Returns:
1675
+ 生成的完整序列,形状为 (batch_size, original_seq_len + max_new_tokens)
1676
+ """
1677
+ batch_size, original_seq_len = input_ids.shape
1678
+ device = input_ids.device
1679
+
1680
+ # 1. 在输入后填充L个mask token
1681
+ mask_tokens = torch.full((batch_size, max_new_tokens), mask_token_id,
1682
+ dtype=input_ids.dtype, device=device)
1683
+ extended_input_ids = torch.cat([input_ids, mask_tokens], dim=1)
1684
+
1685
+ # 扩展attention mask
1686
+ if attention_mask is None:
1687
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1688
+ extended_attention_mask = torch.cat([
1689
+ attention_mask,
1690
+ torch.ones((batch_size, max_new_tokens), dtype=torch.bool, device=device)
1691
+ ], dim=1)
1692
+
1693
+ current_sequence = extended_input_ids.clone()
1694
+
1695
+ if debug:
1696
+ logger.info("=" * 80)
1697
+ logger.info("🚀 开始分块扩散生成过程")
1698
+ logger.info(f"📊 参数: max_new_tokens={max_new_tokens}, block_size={block_size}, num_diffusion_steps_per_block={num_diffusion_steps}")
1699
+ initial_text = self._tokens_to_text(current_sequence[0], tokenizer, mask_token_id)
1700
+ logger.info(f"📝 初始序列 (Batch 0): {initial_text}")
1701
+ logger.info("=" * 80)
1702
+
1703
+ # 2. 【核心改造】外层循环,按块(block)进行
1704
+ num_blocks = math.ceil(max_new_tokens / block_size)
1705
+ for block_idx in range(num_blocks):
1706
+ # 计算当前块在完整序列中的绝对索引起始和结束位置
1707
+ block_start_abs = original_seq_len + block_idx * block_size
1708
+ block_end_abs = min(block_start_abs + block_size, original_seq_len + max_new_tokens)
1709
+
1710
+ # 如果块大小为0(例如处理最后一个不满的块时计算错误),则跳过
1711
+ if block_start_abs >= block_end_abs:
1712
+ continue
1713
+
1714
+ if debug:
1715
+ logger.info(f"\n🧱 === 开始处理 Block {block_idx + 1}/{num_blocks} (位置: {block_start_abs} to {block_end_abs-1}) ===")
1716
+
1717
+ # 3. 内层循环,对当前块进行迭代优化
1718
+ for step in range(num_diffusion_steps):
1719
+ if debug:
1720
+ logger.info(f" 🔄 --- 内部迭代 {step + 1}/{num_diffusion_steps} ---")
1721
+
1722
+ # a. 【全局上下文】前向传播始终在完整序列上进行
1723
+ outputs = self.forward(
1724
+ input_ids=current_sequence,
1725
+ attention_mask=extended_attention_mask,
1726
+ return_dict=True
1727
+ )
1728
+ mlm_logits = outputs.logits # (batch_size, total_seq_len, vocab_size)
1729
+
1730
+ # b. 【聚焦当前块】只获取当前块的 logits
1731
+ block_logits = mlm_logits[:, block_start_abs:block_end_abs, :]
1732
+
1733
+ # c. 【聚焦当前块】获取当前块的 tokens,用于判断哪些位置还是 MASK
1734
+ current_block_tokens = current_sequence[:, block_start_abs:block_end_abs]
1735
+ is_mask_in_block = (current_block_tokens == mask_token_id)
1736
+
1737
+ # 如果当前块已经没有 MASK,可以提前结束此块的迭代
1738
+ if not is_mask_in_block.any():
1739
+ if debug: logger.info(" [信息] 当前块已填满,提前进入下一块。")
1740
+ break
1741
+
1742
+ # d. 【块内生成】从块的 logits 中生成候选 token
1743
+ if do_sample:
1744
+ # (此处采样逻辑与原版相同,但作用域是 block_logits)
1745
+ probs = torch.softmax(block_logits / temperature_mlm, dim=-1)
1746
+ block_candidate_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(probs.shape[0], probs.shape[1])
1747
+ else:
1748
+ block_candidate_tokens = torch.argmax(block_logits, dim=-1)
1749
+
1750
+ # e. 【块内更新】只解码块内置信度最高的 top-k 个 MASK 位置
1751
+ # 计算块内所有位置的置信度 (即模型预测的token的概率)
1752
+ probs = torch.softmax(block_logits, dim=-1)
1753
+ confidence_scores, _ = torch.max(probs, dim=-1) # Shape: (batch_size, current_block_len)
1754
+
1755
+ # 关键:将被填充过的位置的置信度设为-1,确保它们不被 topk 选中
1756
+ masked_confidence_scores = torch.where(
1757
+ is_mask_in_block,
1758
+ confidence_scores,
1759
+ -1.0
1760
+ )
1761
+
1762
+ # 决定这次要更新多少个token
1763
+ num_masks_in_block = is_mask_in_block.sum(dim=1).min().item() # 取batch中最小的mask数,保证安全
1764
+ k = min(decode_top_k_positions, num_masks_in_block)
1765
+
1766
+ if k <= 0:
1767
+ if debug: logger.info(" [信息] 无可用MASK位置更新,跳过此迭代。")
1768
+ continue
1769
+
1770
+ # 找到置信度最高的 k 个 MASK 位置的索引
1771
+ _, top_k_indices_in_block = torch.topk(masked_confidence_scores, k=k, dim=1)
1772
+
1773
+ # 创建只在这些 top-k 位置为 True 的更新掩码
1774
+ block_update_mask = torch.zeros_like(confidence_scores, dtype=torch.bool, device=device)
1775
+ block_update_mask.scatter_(1, top_k_indices_in_block, True)
1776
+
1777
+ # 安全校验:确保只更新原先是 MASK 的位置
1778
+ block_update_mask = block_update_mask & is_mask_in_block
1779
+
1780
+ if debug:
1781
+ logger.info(f" [解码] 计划更新 {block_update_mask.sum().item()} 个位置。")
1782
+
1783
+ # f. 【更新序列】使用掩码,将候选 token 更新到当前块
1784
+ updated_block_tokens = torch.where(
1785
+ block_update_mask,
1786
+ block_candidate_tokens,
1787
+ current_block_tokens
1788
+ )
1789
+
1790
+ # 将更新后的块放回完整序列中
1791
+ prev_sequence_for_debug = current_sequence.clone()
1792
+ current_sequence[:, block_start_abs:block_end_abs] = updated_block_tokens
1793
+
1794
+ if debug:
1795
+ self._debug_block_step_changes(
1796
+ block_idx + 1,
1797
+ step + 1,
1798
+ prev_sequence_for_debug,
1799
+ current_sequence,
1800
+ block_start_abs,
1801
+ block_end_abs,
1802
+ tokenizer,
1803
+ mask_token_id
1804
+ )
1805
+
1806
+ if debug:
1807
+ logger.info("\n" + "=" * 80)
1808
+ logger.info("🎉 分块扩散生成完成!")
1809
+ for batch_idx in range(batch_size):
1810
+ final_text = self._tokens_to_text(current_sequence[batch_idx], tokenizer, mask_token_id)
1811
+ logger.info(f"📝 Batch {batch_idx} 最终序列: {final_text}")
1812
+ logger.info("=" * 80)
1813
+
1814
+ return current_sequence
1815
+
1816
+ # --- 辅助调试函数 ---
1817
+
1818
+ # 新增一个用于分块调试的函数
1819
+ def _debug_block_step_changes(self, block_num, step_num, prev_seq, curr_seq, block_start, block_end, tokenizer, mask_token_id):
1820
+ prev_block = prev_seq[:, block_start:block_end]
1821
+ curr_block = curr_seq[:, block_start:block_end]
1822
+
1823
+ changes = prev_block != curr_block
1824
+ if not changes.any():
1825
+ logger.info(f" [结果] Block {block_num} Step {step_num}: 无变化。")
1826
+ return
1827
+
1828
+ for batch_idx in range(curr_seq.shape[0]):
1829
+ changed_indices = torch.where(changes[batch_idx])[0]
1830
+ if not len(changed_indices): continue
1831
+
1832
+ logger.info(f" [结果] B{batch_idx} | 发生变化的位置: {[idx.item() + block_start for idx in changed_indices]}")
1833
+
1834
+ # 只展示当前块的变化细节
1835
+ for idx in changed_indices:
1836
+ prev_token_id = prev_block[batch_idx, idx].item()
1837
+ curr_token_id = curr_block[batch_idx, idx].item()
1838
+ prev_text = self._token_to_text(prev_token_id, tokenizer, mask_token_id)
1839
+ curr_text = self._token_to_text(curr_token_id, tokenizer, mask_token_id)
1840
+ logger.info(f" - Pos {idx.item() + block_start}: {prev_text} -> {curr_text}")
1841
+
1842
+ # 展示当前整个序列的状态
1843
+ if curr_seq.shape[0] == 1: # 如果batch_size为1,直接显示
1844
+ full_text = self._tokens_to_text(curr_seq[batch_idx], tokenizer, mask_token_id)
1845
+ logger.info(f" [序列] {full_text}")
1846
+
1847
+
1848
+ # 旧的调试函数可以保留或删除,这里我保留并重命名以示区别
1849
+ def _debug_full_step_changes(self, *args, **kwargs):
1850
+ # ... (这个函数逻辑可以保持不变,但现在可能不太常用了)
1851
+ pass
1852
+
1853
+ def _tokens_to_text(self, tokens: torch.Tensor, tokenizer, mask_token_id: int) -> str:
1854
+ # ... (这个函数逻辑保持不变)
1855
+ if tokenizer is None:
1856
+ token_strs = []
1857
+ for token_id in tokens.tolist():
1858
+ if token_id == mask_token_id:
1859
+ token_strs.append("[MASK]")
1860
+ else:
1861
+ token_strs.append(f"<{token_id}>")
1862
+ return " ".join(token_strs)
1863
+ else:
1864
+ try:
1865
+ temp_tokens = tokens.clone()
1866
+ temp_tokens[temp_tokens == mask_token_id] = tokenizer.mask_token_id if hasattr(tokenizer, 'mask_token_id') else -1
1867
+ text = tokenizer.decode(temp_tokens, skip_special_tokens=False)
1868
+ # return text.replace(tokenizer.decode([-1]), "[MASK]")
1869
+ return text
1870
+ except Exception as e:
1871
+
1872
+ logger.warning(f"Tokenizer解码失败: {e}")
1873
+ import pdb
1874
+ pdb.set_trace()
1875
+ return self._tokens_to_text(tokens, None, mask_token_id)
1876
+
1877
+ def _token_to_text(self, token_id: int, tokenizer, mask_token_id: int) -> str:
1878
+ # ... (这个函数逻辑保持不变)
1879
+ if token_id == mask_token_id:
1880
+ return "[MASK]"
1881
+ if tokenizer is None:
1882
+ return f"<{token_id}>"
1883
+ else:
1884
+ try:
1885
+ return f"'{tokenizer.decode([token_id], skip_special_tokens=False)}'"
1886
+ except:
1887
+ return f"<{token_id}>"
1888
+
1889
+ # Register the model so that it is available for transformer pipelines, auto-loading, etc.
1890
+ # AutoModel.register(LLaDAConfig, LLaDAModelLM)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1246d096518b82ec939353ff553b5152d15ea6296bdd04ea8f0f4af5178a1424
3
+ size 540955107
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": true,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "|||IP_ADDRESS|||",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": false
10
+ },
11
+ "1": {
12
+ "content": "<|padding|>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "50254": {
20
+ "content": " ",
21
+ "lstrip": false,
22
+ "normalized": true,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": false
26
+ },
27
+ "50255": {
28
+ "content": " ",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": false
34
+ },
35
+ "50256": {
36
+ "content": " ",
37
+ "lstrip": false,
38
+ "normalized": true,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": false
42
+ },
43
+ "50257": {
44
+ "content": " ",
45
+ "lstrip": false,
46
+ "normalized": true,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": false
50
+ },
51
+ "50258": {
52
+ "content": " ",
53
+ "lstrip": false,
54
+ "normalized": true,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": false
58
+ },
59
+ "50259": {
60
+ "content": " ",
61
+ "lstrip": false,
62
+ "normalized": true,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": false
66
+ },
67
+ "50260": {
68
+ "content": " ",
69
+ "lstrip": false,
70
+ "normalized": true,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": false
74
+ },
75
+ "50261": {
76
+ "content": " ",
77
+ "lstrip": false,
78
+ "normalized": true,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "50262": {
84
+ "content": " ",
85
+ "lstrip": false,
86
+ "normalized": true,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "50263": {
92
+ "content": " ",
93
+ "lstrip": false,
94
+ "normalized": true,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "50264": {
100
+ "content": " ",
101
+ "lstrip": false,
102
+ "normalized": true,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "50265": {
108
+ "content": " ",
109
+ "lstrip": false,
110
+ "normalized": true,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "50266": {
116
+ "content": " ",
117
+ "lstrip": false,
118
+ "normalized": true,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": false
122
+ },
123
+ "50267": {
124
+ "content": " ",
125
+ "lstrip": false,
126
+ "normalized": true,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": false
130
+ },
131
+ "50268": {
132
+ "content": " ",
133
+ "lstrip": false,
134
+ "normalized": true,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": false
138
+ },
139
+ "50269": {
140
+ "content": " ",
141
+ "lstrip": false,
142
+ "normalized": true,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": false
146
+ },
147
+ "50270": {
148
+ "content": " ",
149
+ "lstrip": false,
150
+ "normalized": true,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": false
154
+ },
155
+ "50271": {
156
+ "content": " ",
157
+ "lstrip": false,
158
+ "normalized": true,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": false
162
+ },
163
+ "50272": {
164
+ "content": " ",
165
+ "lstrip": false,
166
+ "normalized": true,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": false
170
+ },
171
+ "50273": {
172
+ "content": " ",
173
+ "lstrip": false,
174
+ "normalized": true,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": false
178
+ },
179
+ "50274": {
180
+ "content": " ",
181
+ "lstrip": false,
182
+ "normalized": true,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": false
186
+ },
187
+ "50275": {
188
+ "content": " ",
189
+ "lstrip": false,
190
+ "normalized": true,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": false
194
+ },
195
+ "50276": {
196
+ "content": " ",
197
+ "lstrip": false,
198
+ "normalized": true,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": false
202
+ },
203
+ "50277": {
204
+ "content": "|||EMAIL_ADDRESS|||",
205
+ "lstrip": false,
206
+ "normalized": true,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": false
210
+ },
211
+ "50278": {
212
+ "content": "|||PHONE_NUMBER|||",
213
+ "lstrip": false,
214
+ "normalized": true,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": false
218
+ },
219
+ "50279": {
220
+ "content": "<|endoftext|>",
221
+ "lstrip": false,
222
+ "normalized": false,
223
+ "rstrip": false,
224
+ "single_word": false,
225
+ "special": true
226
+ },
227
+ "50280": {
228
+ "content": "[UNK]",
229
+ "lstrip": false,
230
+ "normalized": false,
231
+ "rstrip": false,
232
+ "single_word": false,
233
+ "special": true
234
+ },
235
+ "50281": {
236
+ "content": "[CLS]",
237
+ "lstrip": false,
238
+ "normalized": false,
239
+ "rstrip": false,
240
+ "single_word": false,
241
+ "special": true
242
+ },
243
+ "50282": {
244
+ "content": "[SEP]",
245
+ "lstrip": false,
246
+ "normalized": false,
247
+ "rstrip": false,
248
+ "single_word": false,
249
+ "special": true
250
+ },
251
+ "50283": {
252
+ "content": "[PAD]",
253
+ "lstrip": false,
254
+ "normalized": false,
255
+ "rstrip": false,
256
+ "single_word": false,
257
+ "special": true
258
+ },
259
+ "50284": {
260
+ "content": "[MASK]",
261
+ "lstrip": true,
262
+ "normalized": false,
263
+ "rstrip": false,
264
+ "single_word": false,
265
+ "special": true
266
+ },
267
+ "50285": {
268
+ "content": "[unused0]",
269
+ "lstrip": false,
270
+ "normalized": true,
271
+ "rstrip": false,
272
+ "single_word": false,
273
+ "special": false
274
+ },
275
+ "50286": {
276
+ "content": "[unused1]",
277
+ "lstrip": false,
278
+ "normalized": true,
279
+ "rstrip": false,
280
+ "single_word": false,
281
+ "special": false
282
+ },
283
+ "50287": {
284
+ "content": "[unused2]",
285
+ "lstrip": false,
286
+ "normalized": true,
287
+ "rstrip": false,
288
+ "single_word": false,
289
+ "special": false
290
+ },
291
+ "50288": {
292
+ "content": "[unused3]",
293
+ "lstrip": false,
294
+ "normalized": true,
295
+ "rstrip": false,
296
+ "single_word": false,
297
+ "special": false
298
+ },
299
+ "50289": {
300
+ "content": "[unused4]",
301
+ "lstrip": false,
302
+ "normalized": true,
303
+ "rstrip": false,
304
+ "single_word": false,
305
+ "special": false
306
+ },
307
+ "50290": {
308
+ "content": "[unused5]",
309
+ "lstrip": false,
310
+ "normalized": true,
311
+ "rstrip": false,
312
+ "single_word": false,
313
+ "special": false
314
+ },
315
+ "50291": {
316
+ "content": "[unused6]",
317
+ "lstrip": false,
318
+ "normalized": true,
319
+ "rstrip": false,
320
+ "single_word": false,
321
+ "special": false
322
+ },
323
+ "50292": {
324
+ "content": "[unused7]",
325
+ "lstrip": false,
326
+ "normalized": true,
327
+ "rstrip": false,
328
+ "single_word": false,
329
+ "special": false
330
+ },
331
+ "50293": {
332
+ "content": "[unused8]",
333
+ "lstrip": false,
334
+ "normalized": true,
335
+ "rstrip": false,
336
+ "single_word": false,
337
+ "special": false
338
+ },
339
+ "50294": {
340
+ "content": "[unused9]",
341
+ "lstrip": false,
342
+ "normalized": true,
343
+ "rstrip": false,
344
+ "single_word": false,
345
+ "special": false
346
+ },
347
+ "50295": {
348
+ "content": "[unused10]",
349
+ "lstrip": false,
350
+ "normalized": true,
351
+ "rstrip": false,
352
+ "single_word": false,
353
+ "special": false
354
+ },
355
+ "50296": {
356
+ "content": "[unused11]",
357
+ "lstrip": false,
358
+ "normalized": true,
359
+ "rstrip": false,
360
+ "single_word": false,
361
+ "special": false
362
+ },
363
+ "50297": {
364
+ "content": "[unused12]",
365
+ "lstrip": false,
366
+ "normalized": true,
367
+ "rstrip": false,
368
+ "single_word": false,
369
+ "special": false
370
+ },
371
+ "50298": {
372
+ "content": "[unused13]",
373
+ "lstrip": false,
374
+ "normalized": true,
375
+ "rstrip": false,
376
+ "single_word": false,
377
+ "special": false
378
+ },
379
+ "50299": {
380
+ "content": "[unused14]",
381
+ "lstrip": false,
382
+ "normalized": true,
383
+ "rstrip": false,
384
+ "single_word": false,
385
+ "special": false
386
+ },
387
+ "50300": {
388
+ "content": "[unused15]",
389
+ "lstrip": false,
390
+ "normalized": true,
391
+ "rstrip": false,
392
+ "single_word": false,
393
+ "special": false
394
+ },
395
+ "50301": {
396
+ "content": "[unused16]",
397
+ "lstrip": false,
398
+ "normalized": true,
399
+ "rstrip": false,
400
+ "single_word": false,
401
+ "special": false
402
+ },
403
+ "50302": {
404
+ "content": "[unused17]",
405
+ "lstrip": false,
406
+ "normalized": true,
407
+ "rstrip": false,
408
+ "single_word": false,
409
+ "special": false
410
+ },
411
+ "50303": {
412
+ "content": "[unused18]",
413
+ "lstrip": false,
414
+ "normalized": true,
415
+ "rstrip": false,
416
+ "single_word": false,
417
+ "special": false
418
+ },
419
+ "50304": {
420
+ "content": "[unused19]",
421
+ "lstrip": false,
422
+ "normalized": true,
423
+ "rstrip": false,
424
+ "single_word": false,
425
+ "special": false
426
+ },
427
+ "50305": {
428
+ "content": "[unused20]",
429
+ "lstrip": false,
430
+ "normalized": true,
431
+ "rstrip": false,
432
+ "single_word": false,
433
+ "special": false
434
+ },
435
+ "50306": {
436
+ "content": "[unused21]",
437
+ "lstrip": false,
438
+ "normalized": true,
439
+ "rstrip": false,
440
+ "single_word": false,
441
+ "special": false
442
+ },
443
+ "50307": {
444
+ "content": "[unused22]",
445
+ "lstrip": false,
446
+ "normalized": true,
447
+ "rstrip": false,
448
+ "single_word": false,
449
+ "special": false
450
+ },
451
+ "50308": {
452
+ "content": "[unused23]",
453
+ "lstrip": false,
454
+ "normalized": true,
455
+ "rstrip": false,
456
+ "single_word": false,
457
+ "special": false
458
+ },
459
+ "50309": {
460
+ "content": "[unused24]",
461
+ "lstrip": false,
462
+ "normalized": true,
463
+ "rstrip": false,
464
+ "single_word": false,
465
+ "special": false
466
+ },
467
+ "50310": {
468
+ "content": "[unused25]",
469
+ "lstrip": false,
470
+ "normalized": true,
471
+ "rstrip": false,
472
+ "single_word": false,
473
+ "special": false
474
+ },
475
+ "50311": {
476
+ "content": "[unused26]",
477
+ "lstrip": false,
478
+ "normalized": true,
479
+ "rstrip": false,
480
+ "single_word": false,
481
+ "special": false
482
+ },
483
+ "50312": {
484
+ "content": "[unused27]",
485
+ "lstrip": false,
486
+ "normalized": true,
487
+ "rstrip": false,
488
+ "single_word": false,
489
+ "special": false
490
+ },
491
+ "50313": {
492
+ "content": "[unused28]",
493
+ "lstrip": false,
494
+ "normalized": true,
495
+ "rstrip": false,
496
+ "single_word": false,
497
+ "special": false
498
+ },
499
+ "50314": {
500
+ "content": "[unused29]",
501
+ "lstrip": false,
502
+ "normalized": true,
503
+ "rstrip": false,
504
+ "single_word": false,
505
+ "special": false
506
+ },
507
+ "50315": {
508
+ "content": "[unused30]",
509
+ "lstrip": false,
510
+ "normalized": true,
511
+ "rstrip": false,
512
+ "single_word": false,
513
+ "special": false
514
+ },
515
+ "50316": {
516
+ "content": "[unused31]",
517
+ "lstrip": false,
518
+ "normalized": true,
519
+ "rstrip": false,
520
+ "single_word": false,
521
+ "special": false
522
+ },
523
+ "50317": {
524
+ "content": "[unused32]",
525
+ "lstrip": false,
526
+ "normalized": true,
527
+ "rstrip": false,
528
+ "single_word": false,
529
+ "special": false
530
+ },
531
+ "50318": {
532
+ "content": "[unused33]",
533
+ "lstrip": false,
534
+ "normalized": true,
535
+ "rstrip": false,
536
+ "single_word": false,
537
+ "special": false
538
+ },
539
+ "50319": {
540
+ "content": "[unused34]",
541
+ "lstrip": false,
542
+ "normalized": true,
543
+ "rstrip": false,
544
+ "single_word": false,
545
+ "special": false
546
+ },
547
+ "50320": {
548
+ "content": "[unused35]",
549
+ "lstrip": false,
550
+ "normalized": true,
551
+ "rstrip": false,
552
+ "single_word": false,
553
+ "special": false
554
+ },
555
+ "50321": {
556
+ "content": "[unused36]",
557
+ "lstrip": false,
558
+ "normalized": true,
559
+ "rstrip": false,
560
+ "single_word": false,
561
+ "special": false
562
+ },
563
+ "50322": {
564
+ "content": "[unused37]",
565
+ "lstrip": false,
566
+ "normalized": true,
567
+ "rstrip": false,
568
+ "single_word": false,
569
+ "special": false
570
+ },
571
+ "50323": {
572
+ "content": "[unused38]",
573
+ "lstrip": false,
574
+ "normalized": true,
575
+ "rstrip": false,
576
+ "single_word": false,
577
+ "special": false
578
+ },
579
+ "50324": {
580
+ "content": "[unused39]",
581
+ "lstrip": false,
582
+ "normalized": true,
583
+ "rstrip": false,
584
+ "single_word": false,
585
+ "special": false
586
+ },
587
+ "50325": {
588
+ "content": "[unused40]",
589
+ "lstrip": false,
590
+ "normalized": true,
591
+ "rstrip": false,
592
+ "single_word": false,
593
+ "special": false
594
+ },
595
+ "50326": {
596
+ "content": "[unused41]",
597
+ "lstrip": false,
598
+ "normalized": true,
599
+ "rstrip": false,
600
+ "single_word": false,
601
+ "special": false
602
+ },
603
+ "50327": {
604
+ "content": "[unused42]",
605
+ "lstrip": false,
606
+ "normalized": true,
607
+ "rstrip": false,
608
+ "single_word": false,
609
+ "special": false
610
+ },
611
+ "50328": {
612
+ "content": "[unused43]",
613
+ "lstrip": false,
614
+ "normalized": true,
615
+ "rstrip": false,
616
+ "single_word": false,
617
+ "special": false
618
+ },
619
+ "50329": {
620
+ "content": "[unused44]",
621
+ "lstrip": false,
622
+ "normalized": true,
623
+ "rstrip": false,
624
+ "single_word": false,
625
+ "special": false
626
+ },
627
+ "50330": {
628
+ "content": "[unused45]",
629
+ "lstrip": false,
630
+ "normalized": true,
631
+ "rstrip": false,
632
+ "single_word": false,
633
+ "special": false
634
+ },
635
+ "50331": {
636
+ "content": "[unused46]",
637
+ "lstrip": false,
638
+ "normalized": true,
639
+ "rstrip": false,
640
+ "single_word": false,
641
+ "special": false
642
+ },
643
+ "50332": {
644
+ "content": "[unused47]",
645
+ "lstrip": false,
646
+ "normalized": true,
647
+ "rstrip": false,
648
+ "single_word": false,
649
+ "special": false
650
+ },
651
+ "50333": {
652
+ "content": "[unused48]",
653
+ "lstrip": false,
654
+ "normalized": true,
655
+ "rstrip": false,
656
+ "single_word": false,
657
+ "special": false
658
+ },
659
+ "50334": {
660
+ "content": "[unused49]",
661
+ "lstrip": false,
662
+ "normalized": true,
663
+ "rstrip": false,
664
+ "single_word": false,
665
+ "special": false
666
+ },
667
+ "50335": {
668
+ "content": "[unused50]",
669
+ "lstrip": false,
670
+ "normalized": true,
671
+ "rstrip": false,
672
+ "single_word": false,
673
+ "special": false
674
+ },
675
+ "50336": {
676
+ "content": "[unused51]",
677
+ "lstrip": false,
678
+ "normalized": true,
679
+ "rstrip": false,
680
+ "single_word": false,
681
+ "special": false
682
+ },
683
+ "50337": {
684
+ "content": "[unused52]",
685
+ "lstrip": false,
686
+ "normalized": true,
687
+ "rstrip": false,
688
+ "single_word": false,
689
+ "special": false
690
+ },
691
+ "50338": {
692
+ "content": "[unused53]",
693
+ "lstrip": false,
694
+ "normalized": true,
695
+ "rstrip": false,
696
+ "single_word": false,
697
+ "special": false
698
+ },
699
+ "50339": {
700
+ "content": "[unused54]",
701
+ "lstrip": false,
702
+ "normalized": true,
703
+ "rstrip": false,
704
+ "single_word": false,
705
+ "special": false
706
+ },
707
+ "50340": {
708
+ "content": "[unused55]",
709
+ "lstrip": false,
710
+ "normalized": true,
711
+ "rstrip": false,
712
+ "single_word": false,
713
+ "special": false
714
+ },
715
+ "50341": {
716
+ "content": "[unused56]",
717
+ "lstrip": false,
718
+ "normalized": true,
719
+ "rstrip": false,
720
+ "single_word": false,
721
+ "special": false
722
+ },
723
+ "50342": {
724
+ "content": "[unused57]",
725
+ "lstrip": false,
726
+ "normalized": true,
727
+ "rstrip": false,
728
+ "single_word": false,
729
+ "special": false
730
+ },
731
+ "50343": {
732
+ "content": "[unused58]",
733
+ "lstrip": false,
734
+ "normalized": true,
735
+ "rstrip": false,
736
+ "single_word": false,
737
+ "special": false
738
+ },
739
+ "50344": {
740
+ "content": "[unused59]",
741
+ "lstrip": false,
742
+ "normalized": true,
743
+ "rstrip": false,
744
+ "single_word": false,
745
+ "special": false
746
+ },
747
+ "50345": {
748
+ "content": "[unused60]",
749
+ "lstrip": false,
750
+ "normalized": true,
751
+ "rstrip": false,
752
+ "single_word": false,
753
+ "special": false
754
+ },
755
+ "50346": {
756
+ "content": "[unused61]",
757
+ "lstrip": false,
758
+ "normalized": true,
759
+ "rstrip": false,
760
+ "single_word": false,
761
+ "special": false
762
+ },
763
+ "50347": {
764
+ "content": "[unused62]",
765
+ "lstrip": false,
766
+ "normalized": true,
767
+ "rstrip": false,
768
+ "single_word": false,
769
+ "special": false
770
+ },
771
+ "50348": {
772
+ "content": "[unused63]",
773
+ "lstrip": false,
774
+ "normalized": true,
775
+ "rstrip": false,
776
+ "single_word": false,
777
+ "special": false
778
+ },
779
+ "50349": {
780
+ "content": "[unused64]",
781
+ "lstrip": false,
782
+ "normalized": true,
783
+ "rstrip": false,
784
+ "single_word": false,
785
+ "special": false
786
+ },
787
+ "50350": {
788
+ "content": "[unused65]",
789
+ "lstrip": false,
790
+ "normalized": true,
791
+ "rstrip": false,
792
+ "single_word": false,
793
+ "special": false
794
+ },
795
+ "50351": {
796
+ "content": "[unused66]",
797
+ "lstrip": false,
798
+ "normalized": true,
799
+ "rstrip": false,
800
+ "single_word": false,
801
+ "special": false
802
+ },
803
+ "50352": {
804
+ "content": "[unused67]",
805
+ "lstrip": false,
806
+ "normalized": true,
807
+ "rstrip": false,
808
+ "single_word": false,
809
+ "special": false
810
+ },
811
+ "50353": {
812
+ "content": "[unused68]",
813
+ "lstrip": false,
814
+ "normalized": true,
815
+ "rstrip": false,
816
+ "single_word": false,
817
+ "special": false
818
+ },
819
+ "50354": {
820
+ "content": "[unused69]",
821
+ "lstrip": false,
822
+ "normalized": true,
823
+ "rstrip": false,
824
+ "single_word": false,
825
+ "special": false
826
+ },
827
+ "50355": {
828
+ "content": "[unused70]",
829
+ "lstrip": false,
830
+ "normalized": true,
831
+ "rstrip": false,
832
+ "single_word": false,
833
+ "special": false
834
+ },
835
+ "50356": {
836
+ "content": "[unused71]",
837
+ "lstrip": false,
838
+ "normalized": true,
839
+ "rstrip": false,
840
+ "single_word": false,
841
+ "special": false
842
+ },
843
+ "50357": {
844
+ "content": "[unused72]",
845
+ "lstrip": false,
846
+ "normalized": true,
847
+ "rstrip": false,
848
+ "single_word": false,
849
+ "special": false
850
+ },
851
+ "50358": {
852
+ "content": "[unused73]",
853
+ "lstrip": false,
854
+ "normalized": true,
855
+ "rstrip": false,
856
+ "single_word": false,
857
+ "special": false
858
+ },
859
+ "50359": {
860
+ "content": "[unused74]",
861
+ "lstrip": false,
862
+ "normalized": true,
863
+ "rstrip": false,
864
+ "single_word": false,
865
+ "special": false
866
+ },
867
+ "50360": {
868
+ "content": "[unused75]",
869
+ "lstrip": false,
870
+ "normalized": true,
871
+ "rstrip": false,
872
+ "single_word": false,
873
+ "special": false
874
+ },
875
+ "50361": {
876
+ "content": "[unused76]",
877
+ "lstrip": false,
878
+ "normalized": true,
879
+ "rstrip": false,
880
+ "single_word": false,
881
+ "special": false
882
+ },
883
+ "50362": {
884
+ "content": "[unused77]",
885
+ "lstrip": false,
886
+ "normalized": true,
887
+ "rstrip": false,
888
+ "single_word": false,
889
+ "special": false
890
+ },
891
+ "50363": {
892
+ "content": "[unused78]",
893
+ "lstrip": false,
894
+ "normalized": true,
895
+ "rstrip": false,
896
+ "single_word": false,
897
+ "special": false
898
+ },
899
+ "50364": {
900
+ "content": "[unused79]",
901
+ "lstrip": false,
902
+ "normalized": true,
903
+ "rstrip": false,
904
+ "single_word": false,
905
+ "special": false
906
+ },
907
+ "50365": {
908
+ "content": "[unused80]",
909
+ "lstrip": false,
910
+ "normalized": true,
911
+ "rstrip": false,
912
+ "single_word": false,
913
+ "special": false
914
+ },
915
+ "50366": {
916
+ "content": "[unused81]",
917
+ "lstrip": false,
918
+ "normalized": true,
919
+ "rstrip": false,
920
+ "single_word": false,
921
+ "special": false
922
+ },
923
+ "50367": {
924
+ "content": "[unused82]",
925
+ "lstrip": false,
926
+ "normalized": true,
927
+ "rstrip": false,
928
+ "single_word": false,
929
+ "special": false
930
+ }
931
+ },
932
+ "clean_up_tokenization_spaces": true,
933
+ "cls_token": "[CLS]",
934
+ "extra_special_tokens": {},
935
+ "mask_token": "[MASK]",
936
+ "model_input_names": [
937
+ "input_ids",
938
+ "attention_mask"
939
+ ],
940
+ "model_max_length": 8192,
941
+ "pad_token": "[PAD]",
942
+ "sep_token": "[SEP]",
943
+ "tokenizer_class": "PreTrainedTokenizerFast",
944
+ "unk_token": "[UNK]"
945
+ }