Epsilonoid commited on
Commit
db56acd
·
verified ·
1 Parent(s): 1d8542d
README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusion Engram IME Demo
2
+
3
+ [English Version](#introduction)
4
+
5
+ 本项目探索了一种基于扩散语言模型的输入法实现思路。它基于 [LLaDA](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) 的实现,并融合了 [Engram](https://github.com/deepseek-ai/Engram) 模块,期盼利用语言模型强大的上下文理解能力来提升长句输入的准确性与连贯性。
6
+
7
+ 模型在 [Chinese Fineweb Edu Dataset V2.1](https://huggingface.co/datasets/opencsg/Fineweb-Edu-Chinese-V2.1) 上进行了训练,采用[虎码输入方案](https://www.tiger-code.com/)作为编码标准。
8
+
9
+ ## 使用方式
10
+
11
+ 本项目提供了简易的交互脚本 `example.py`,用于演示核心功能。
12
+
13
+ ### 输入格式
14
+ 在提示符下输入字符串。
15
+ - **汉字**: 请输入其虎码编码的前两位。
16
+ - **标点/大写字母/特殊符号**: 直接输入原文即可。符号可以输入半角版本。
17
+ - **小写字母**: 输入字母+空格。
18
+ - **混合输入**: 支持编码与原文混合输入。
19
+
20
+ 部分实例见 `example.py` 末尾注释。
21
+
22
+ ## 局限性
23
+
24
+ ⚠️本项目可能包含⚠️:
25
+ - 没有认真处理和选择的训练语料
26
+ - 拍脑袋想出来的模型架构和超参数
27
+ - 低效的模型实现和推理代码
28
+
29
+
30
+ ## 杂谈
31
+ 其实早在 ChatGPT 还没有横空出世、我还没有了解过 transformer 的时候,我就思考过能不能用深度学习模型做出更加强大的输入法。那时候我试着学习形码(最后并没有坚持下来),也常常幻想其他提升输入效率的手段。我曾想输入法也许可以更熟悉语言应当有的语法和语义、能以更高的概率组出合理的句子。当然它的用户界面可能与现在的输入法大相径庭、以输入句子甚至段落为核心,从而能利用起来上下文的信息。(当然做这事的人可能不少。)不过那时没有 vibe coding 帮忙,我的行动力不足以让我把这些想法变成现实。
32
+
33
+ 随着 LLM 的发展,我愈发觉得输入效率很大程度上限制了人机交互的效率。了解到 diffusion language model 的思路后,我觉得这非常适合输入法的场景:模型需要根据完整的上下文来预测每个字(而在自回归模型上做 constrastive decoding 只能感知到上文),并且可以从易到难地逐步推断原文,甚至可以无缝地处理部分词由用户手动选择的情况。之所以选择形码,首要原因是形码不会有多音字的问题,数据处理简单一些,且每个编码上的字数分布更加均匀。不过最后训练下来,效果不太理想。
34
+
35
+ 看到 Engram 的时候,我立刻开始重新尝试这个项目。Engram 所做的对 n-gram 查表,几乎就是完美地承担了“词库”的职责,Engram 模块应当可以大幅度减轻模型主干记忆词库的压力。训练下来,结果确实也比之前好不少。
36
+
37
+ 这个项目离实际可用的输入法还有很大差距:最显然的当然是其没有一个合适的用户界面,要有合适的方式让用户修改候选结果、且能适应零点几秒的延迟;模型的训练数据在类型上很窄,尤其缺乏口语化或文学化的内容;模型的推理几乎没有优化过;模型具体应该做多大、超参数如何选择也没有认真实验过;等等。除此之外,怎么让模型利用已经上屏的部分作为上下文,以及能否针对性地再改造 Engram(和其他各个模块)使其更适合输入法场景,都是潜在的改进方向。
38
+
39
+ ---
40
+
41
+ ## Introduction
42
+
43
+ This project explores an implementation idea for an Input Method Editor (IME) based on diffusion language models. It is built upon the implementation of [LLaDA](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) and incorporates the [Engram](https://github.com/deepseek-ai/Engram) module, aiming to leverage the powerful context understanding capabilities of language models to improve the accuracy and coherence of long sentence input.
44
+
45
+ The model is trained on the [Chinese Fineweb Edu Dataset V2.1](https://huggingface.co/datasets/opencsg/Fineweb-Edu-Chinese-V2.1) and uses the [Tiger Code (Huma) input scheme](https://www.tiger-code.com/) as the encoding standard.
46
+
47
+ ## Usage
48
+
49
+ This project provides a simple interactive script `example.py` to demonstrate core functionalities.
50
+
51
+ ### Input Format
52
+ Enter a string at the prompt.
53
+ - **Chinese Characters**: Please enter the first two characters of their Tiger Code encoding.
54
+ - **Punctuation/Uppercase Letters/Special Symbols**: Enter the original characters directly. Symbols can be entered in their half-width versions.
55
+ - **Lowercase Letters**: Enter the letter followed by a space.
56
+ - **Mixed Input**: Supports mixing encoding and original text input.
57
+
58
+ See comments at the end of `example.py` for some examples.
59
+
60
+ ## Limitations
61
+
62
+ ⚠️ This project may contain ⚠️:
63
+ - Training corpora that have not been carefully processed or selected
64
+ - Model architecture and hyperparameters conceived on a whim
65
+ - Inefficient model implementation and inference code
66
+
67
+ ## Ramblings
68
+
69
+ Actually, long before ChatGPT took the world by storm and before I knew anything about transformers, I wondered if deep learning models could be used to create a more powerful IME. At that time, I tried learning shape-based input methods (though I didn't stick with it) and often fantasized about other means to improve input efficiency. I thought that an IME could perhaps be more familiar with the syntax and semantics that language should have, and group reasonable sentences with higher probability. Of course, its user interface might be vastly different from current IMEs, focusing on inputting sentences or even paragraphs to utilize context information. (Of course, many people might be doing this.) However, without "vibe coding" to help me then, my lack of execution prevented me from turning these ideas into reality.
70
+
71
+ With the development of LLMs, I increasingly feel that input efficiency largely limits the efficiency of human-computer interaction. After learning about the idea of diffusion language models, I felt this was very suitable for IME scenarios: the model needs to predict each character based on the complete context (whereas contrastive decoding on autoregressive models can only perceive the preceding text), and can infer the original text gradually from easy to difficult, and can even seamlessly handle cases where some words are manually selected by the user. The primary reason for choosing a shape-based code is that it avoids the problem of polyphones, simplifying data processing, and the character distribution for each code is more uniform. However, the initial training results were not ideal.
72
+
73
+ When I saw Engram, I immediately started retrying this project. Engram's n-gram lookup almost perfectly assumes the responsibility of a "lexicon". The Engram module should significantly reduce the pressure on the model backbone to memorize the lexicon. After training, the results are indeed much better than before.
74
+
75
+ This project is still far from being a practically usable IME: the most obvious gap is the lack of a suitable user interface that allows users to modify candidate results and adapt to a latency of a few tenths of a second; the training data is narrow in type, especially lacking colloquial or literary content; the model's inference is hardly optimized; no serious experiments have been done on how large the model should be or how to select hyperparameters; etc. Besides, how to let the model use the already entered text as context, and whether Engram (and other modules) can be specifically transformed to better suit IME scenarios, are potential directions for improvement.
config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_type": "silu",
3
+ "alibi": false,
4
+ "alibi_bias_max": 8.0,
5
+ "architectures": [
6
+ "LLaDAModelLM"
7
+ ],
8
+ "attention_dropout": 0.0,
9
+ "attention_layer_norm": false,
10
+ "attention_layer_norm_with_affine": true,
11
+ "auto_map": {
12
+ "AutoConfig": "configuration_llada_engram.LLaDAConfig",
13
+ "AutoModelForCausalLM": "modeling_llada_engram.LLaDAModelLM",
14
+ "AutoModel": "modeling_llada_engram.LLaDAModelLM"
15
+ },
16
+ "bias_for_layer_norm": false,
17
+ "block_group_size": 1,
18
+ "block_type": "llama",
19
+ "d_model": 768,
20
+ "embedding_dropout": 0.0,
21
+ "embedding_size": 7424,
22
+ "eos_token_id": 6624,
23
+ "flash_attention": false,
24
+ "include_bias": false,
25
+ "include_qkv_bias": false,
26
+ "init_cutoff_factor": null,
27
+ "init_device": "meta",
28
+ "init_fn": "mitchell",
29
+ "init_std": 0.02,
30
+ "input_emb_norm": false,
31
+ "layer_norm_type": "rms",
32
+ "layer_norm_with_affine": true,
33
+ "mask_token_id": 7186,
34
+ "max_sequence_length": 128,
35
+ "mlp_hidden_size": 1536,
36
+ "model_type": "llada",
37
+ "multi_query_attention": null,
38
+ "n_heads": 12,
39
+ "n_kv_heads": 12,
40
+ "n_layers": 14,
41
+ "pad_token_id": 6624,
42
+ "precision": "amp_bf16",
43
+ "residual_dropout": 0.0,
44
+ "rms_norm_eps": 1e-05,
45
+ "rope": true,
46
+ "rope_full_precision": true,
47
+ "rope_theta": 500000.0,
48
+ "scale_logits": false,
49
+ "transformers_version": "4.46.3",
50
+ "use_cache": false,
51
+ "vocab_size": 7397,
52
+ "weight_tying": false,
53
+ "engram_config": {
54
+ "tokenizer_name_or_path": "./tokenizer.json",
55
+ "engram_vocab_size": [
56
+ 51200,
57
+ 51200,
58
+ 51200
59
+ ],
60
+ "max_ngram_size": 4,
61
+ "n_embed_per_ngram": 256,
62
+ "n_head_per_ngram": 4,
63
+ "layer_ids": [
64
+ 1,
65
+ 7
66
+ ],
67
+ "pad_id": 6629,
68
+ "seed": 42,
69
+ "kernel_size": 7
70
+ }
71
+ }
configuration_llada_engram.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaDA configuration
3
+ """
4
+ from transformers import AutoConfig, PretrainedConfig
5
+
6
+ from enum import Enum
7
+ from os import PathLike
8
+ from typing import Union
9
+ from dataclasses import asdict, dataclass, field
10
+ from glob import glob
11
+ from pathlib import Path
12
+ from typing import (
13
+ Any,
14
+ Dict,
15
+ Iterable,
16
+ List,
17
+ Optional,
18
+ Tuple,
19
+ Type,
20
+ TypeVar,
21
+ Union,
22
+ cast,
23
+ )
24
+
25
+
26
+ __all__ = [
27
+ "ActivationType",
28
+ "ActivationCheckpointingStrategy",
29
+ "BlockType",
30
+ "LayerNormType",
31
+ "InitFnType",
32
+ "ModelConfig",
33
+ ]
34
+
35
+ PathOrStr = Union[str, PathLike]
36
+
37
+
38
+ class StrEnum(str, Enum):
39
+ """
40
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
41
+ We include this here for compatibility with older version of Python.
42
+ """
43
+
44
+ def __str__(self) -> str:
45
+ return self.value
46
+
47
+ def __repr__(self) -> str:
48
+ return f"'{str(self)}'"
49
+
50
+
51
+ class LayerNormType(StrEnum):
52
+ default = "default"
53
+ """
54
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
55
+ """
56
+
57
+ low_precision = "low_precision"
58
+ """
59
+ A low-precision version of the default LayerNorm.
60
+ """
61
+
62
+ rms = "rms"
63
+ """
64
+ An RMSNorm implementation. When using ``torch.compile`` this is
65
+ probably the fastest implementation.
66
+ """
67
+
68
+ gemma_rms = "gemma_rms"
69
+ """
70
+ An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
71
+ probably the fastest implementation.
72
+ """
73
+
74
+ amd_compatible = "amd_compatible"
75
+ """
76
+ LayerNorm implemented manually to work around an issue with ROCm.
77
+ """
78
+
79
+
80
+ class ActivationType(StrEnum):
81
+ gelu = "gelu"
82
+ relu = "relu"
83
+ silu = "silu"
84
+ swiglu = "swiglu"
85
+
86
+
87
+ class BlockType(StrEnum):
88
+ sequential = "sequential"
89
+ parallel = "parallel"
90
+
91
+ llama = "llama"
92
+ """
93
+ A block similar to the sequential block with slightly different
94
+ implementations of operations like attention to imitate the behavior of Llama.
95
+ """
96
+
97
+
98
+ class InitFnType(StrEnum):
99
+ mitchell = "mitchell"
100
+ """
101
+ The strategy suggested to us by Mitchell Wortsman from UW.
102
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
103
+ on the size of the weights as well as the depth of the layer.
104
+ """
105
+
106
+ normal = "normal"
107
+ """
108
+ All weights are initialized from the same normal distribution.
109
+ """
110
+
111
+ kaiming_normal = "kaiming_normal"
112
+ """
113
+ All weights are initialized with the Kaiming method from a normal distribution.
114
+ Note this currently won't work with FSDP.
115
+ """
116
+
117
+ fan_in = "fan_in"
118
+ """
119
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
120
+ is the input dimensionality of the kernel.
121
+ """
122
+
123
+ full_megatron = "full_megatron"
124
+ """
125
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
126
+ """
127
+
128
+
129
+ @dataclass
130
+ class EngramConfig:
131
+ tokenizer_name_or_path: str = "deepseek-ai/DeepSeek-V3"
132
+ engram_vocab_size: List[int] = field(default_factory=lambda: [129280*5, 129280*5])
133
+ max_ngram_size: int = 3
134
+ n_embed_per_ngram: int = 512
135
+ n_head_per_ngram: int = 8
136
+ layer_ids: List[int] = field(default_factory=lambda: [1, 15])
137
+ pad_id: int = 2
138
+ seed: int = 0
139
+ kernel_size: int = 7
140
+
141
+
142
+ @dataclass
143
+ class ModelConfig():
144
+ """
145
+ LLaDA (model) configuration.
146
+ """
147
+
148
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
149
+
150
+ d_model: int = 768
151
+ """
152
+ The hidden size of the model.
153
+ """
154
+
155
+ n_heads: int = 12
156
+ """
157
+ The number of self-attention heads.
158
+ """
159
+
160
+ n_kv_heads: Optional[int] = None
161
+ """
162
+ The number of heads to use for keys and values. Defaults to `n_heads`.
163
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
164
+ Set this to 1 for multi-query attention.
165
+ Set it to some in-between value for Llama2-style grouped query attention.
166
+ """
167
+
168
+ n_layers: int = 12
169
+ """
170
+ The number of layers/blocks.
171
+ """
172
+
173
+ mlp_ratio: int = 4
174
+ """
175
+ The ratio of the inner MLP dimensionality to ``d_model``.
176
+ This is only used when ``mlp_hidden_size`` is not set.
177
+ """
178
+
179
+ mlp_hidden_size: Optional[int] = None
180
+ """
181
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
182
+ """
183
+
184
+ activation_type: ActivationType = ActivationType.swiglu
185
+ """
186
+ The activation function to use within the MLP layers.
187
+ """
188
+
189
+ block_type: BlockType = BlockType.sequential
190
+ """
191
+ The transformer block implementation.
192
+ """
193
+
194
+ block_group_size: int = 1
195
+ """
196
+ The number of blocks to group together into a single parent block.
197
+ This has no affect on the number of parameters in the model and is only used to wrap groups
198
+ of blocks together with a single FSDP wrapper during training.
199
+ """
200
+
201
+ alibi: bool = False
202
+ """
203
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
204
+ """
205
+
206
+ alibi_bias_max: float = 8.0
207
+ """
208
+ Maximum absolute value of ALiBi bias.
209
+ """
210
+
211
+ rope: bool = False
212
+ """
213
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
214
+ """
215
+
216
+ rope_full_precision: bool = True
217
+ """
218
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
219
+ apply RoPE at the precision of the input.
220
+ """
221
+
222
+ flash_attention: bool = False
223
+ """
224
+ If ``True``, use ``FlashAttention``.
225
+ """
226
+
227
+ attention_dropout: float = 0.1
228
+ """
229
+ The dropout probability within the attention modules.
230
+ """
231
+
232
+ multi_query_attention: Optional[bool] = None
233
+ """
234
+ Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
235
+ and is more efficient during inference.
236
+ """
237
+
238
+ attention_layer_norm: bool = False
239
+ """
240
+ Apply layer norm to the keys and queries within the attention mechanism.
241
+ This can help stabilize training.
242
+ """
243
+
244
+ residual_dropout: float = 0.1
245
+ """
246
+ The dropout probability for the MLP and attention output within each block.
247
+ """
248
+
249
+ embedding_dropout: float = 0.1
250
+ """
251
+ The dropout probability for embeddings.
252
+ """
253
+
254
+ input_emb_norm: bool = False
255
+ """
256
+ An input hidden_states norm implementation by gemmma.
257
+ """
258
+
259
+ layer_norm_type: LayerNormType = LayerNormType.default
260
+ """
261
+ The layernorm implementation to use.
262
+ """
263
+
264
+ layer_norm_with_affine: bool = True
265
+ """
266
+ Whether to include bias and weight parameters for the layer norms.
267
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
268
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
269
+ to ``False``.
270
+ """
271
+
272
+ rms_norm_eps: float = 1e-05
273
+ """
274
+ The rms layernorm eps param.
275
+ """
276
+
277
+ attention_layer_norm_with_affine: bool = True
278
+ """
279
+ Toggle affine transform for the QK norms.
280
+ """
281
+
282
+ max_sequence_length: int = 1024
283
+ """
284
+ The maximum input sequence length supported by the model.
285
+ """
286
+
287
+ rope_theta: float = 10000.0
288
+ """
289
+ The rope base param.
290
+ """
291
+
292
+ include_qkv_bias: Optional[bool] = False
293
+ """
294
+ Whether or not to include bias parameters in qkv linear layers.
295
+ """
296
+
297
+ include_bias: bool = False
298
+ """
299
+ Whether or not to include bias parameters in linear layers.
300
+ In PaLM, they got rid of all bias terms because they found that large
301
+ models tend to have near 0 bias terms anyway.
302
+ """
303
+
304
+ bias_for_layer_norm: Optional[bool] = None
305
+ """
306
+ Whether or not to include bias parameters in layer norm.
307
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
308
+ layer norm.
309
+ When this is None (the default), it inherits the setting from include_bias.
310
+ """
311
+
312
+ scale_logits: bool = False
313
+ """
314
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
315
+ """
316
+
317
+ vocab_size: int = 50257
318
+ """
319
+ Vocabulary size of the model.
320
+ """
321
+
322
+ embedding_size: Optional[int] = 50304
323
+ """
324
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
325
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
326
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
327
+ substantially.
328
+ """
329
+
330
+ weight_tying: bool = True
331
+ """
332
+ Whether to tie output linear weights to the input embedding.
333
+ """
334
+
335
+ eos_token_id: int = 50256
336
+ """
337
+ The ID of the end-of-sentence special token.
338
+ """
339
+
340
+ pad_token_id: int = 50256
341
+ """
342
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
343
+ """
344
+
345
+ mask_token_id: Optional[int] = 50256
346
+ """
347
+ The ID of the token to use for mask token. Defaults to the ID of the EOS token.
348
+ """
349
+
350
+ init_device: Optional[str] = None
351
+ """
352
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
353
+ """
354
+
355
+ init_fn: InitFnType = InitFnType.normal
356
+ """
357
+ The weight initialization strategy.
358
+ """
359
+
360
+ init_std: float = 0.02
361
+ """
362
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
363
+ as "normal".
364
+ """
365
+
366
+ init_cutoff_factor: Optional[float] = None
367
+ """
368
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
369
+ as "normal". Setting this to None means values are not cutoff.
370
+ """
371
+
372
+ precision: Optional[str] = None
373
+ """
374
+ Precision used to train/evaluate with. You shouldn't set this directly.
375
+ See :data:`TrainConfig.precision` instead.
376
+ """
377
+
378
+ engram_config: Optional[EngramConfig] = None
379
+
380
+ @property
381
+ def effective_n_kv_heads(self) -> int:
382
+ if self.n_kv_heads is None:
383
+ if self.multi_query_attention is True:
384
+ return 1
385
+ else:
386
+ return self.n_heads
387
+ else:
388
+ if self.multi_query_attention is None:
389
+ return self.n_kv_heads
390
+ if self.multi_query_attention:
391
+ n_kv_heads_should_be = 1
392
+ else:
393
+ n_kv_heads_should_be = self.n_heads
394
+ if self.n_kv_heads == n_kv_heads_should_be:
395
+ return n_kv_heads_should_be
396
+ else:
397
+ raise Exception(
398
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
399
+ )
400
+
401
+ class ActivationCheckpointingStrategy(StrEnum):
402
+ whole_layer = "whole_layer"
403
+ """
404
+ Checkpoint every transformer layer.
405
+ """
406
+
407
+ one_in_two = "one_in_two"
408
+ """
409
+ Checkpoint one in two transformer layers.
410
+ """
411
+
412
+ one_in_three = "one_in_three"
413
+ """
414
+ Checkpoint one in three transformer layers.
415
+ """
416
+
417
+ one_in_four = "one_in_four"
418
+ """
419
+ Checkpoint one in four transformer layers.
420
+ """
421
+
422
+ two_in_three = "two_in_three"
423
+ """
424
+ Checkpoint two out of every three transformer layers.
425
+ """
426
+
427
+ three_in_four = "three_in_four"
428
+ """
429
+ Checkpoint three out of four of every transformer layers.
430
+ """
431
+
432
+ four_in_five = "four_in_five"
433
+ """
434
+ Checkpoint four out of five of every transformer layers.
435
+ """
436
+
437
+ nine_in_ten = "nine_in_ten"
438
+ """
439
+ Checkpoint nine out of ten of every transformer layers.
440
+ """
441
+
442
+ fine_grained = "fine_grained"
443
+ """
444
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
445
+ """
446
+
447
+ class LLaDAConfig(PretrainedConfig):
448
+ model_type = "llada"
449
+ keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
450
+
451
+ def __init__(self, use_cache: bool = False, **kwargs):
452
+ model_config = ModelConfig()
453
+ all_kwargs = model_config.__dict__
454
+ all_kwargs.update(kwargs)
455
+ all_kwargs.update({"use_cache": use_cache})
456
+ all_kwargs.update(
457
+ {
458
+ "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
459
+ }
460
+ )
461
+ super().__init__(**all_kwargs)
462
+
463
+ @property
464
+ def num_attention_heads(self):
465
+ return self.n_heads
466
+
467
+ @property
468
+ def num_hidden_layers(self):
469
+ return self.n_layers
470
+
471
+ @property
472
+ def hidden_size(self):
473
+ return self.d_model
474
+
475
+
476
+ # Register the config class so that it is available for transformer pipelines, auto-loading etc.
477
+ AutoConfig.register("llada", LLaDAConfig)
example.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ import torch
3
+ import numpy as np
4
+ import time
5
+ import os
6
+ from datetime import datetime
7
+
8
+
9
+ def process_string_into_pairs(input_str: str) -> list[str]:
10
+ result = []
11
+ i = 0
12
+ n = len(input_str)
13
+
14
+ while i < n:
15
+ char = input_str[i]
16
+
17
+ # 检查当前字符是否为小写字母
18
+ if "a" <= char <= "z":
19
+ # 检查是否有下一个字符,并且下一个字符也是小写字母(配对情况)
20
+ if i + 1 < n and "a" <= input_str[i + 1] <= "z":
21
+ result.append(char + input_str[i + 1])
22
+ i += 2 # 跳过两个字符
23
+ # 检查是否有下一个字符,并且下一个字符是空格(落单小写字母+空格 的特殊情况)
24
+ elif i + 1 < n and input_str[i + 1] == " ":
25
+ result.append(char)
26
+ i += 2 # 跳过当前字母和后面的空格
27
+ # 其他情况(落单小写字母,后面是其他字符或已到末尾)
28
+ else:
29
+ result.append(char)
30
+ i += 1 # 只跳过当前一个字符
31
+ # 如果当前字符不是小写字母
32
+ else:
33
+ result.append(char)
34
+ i += 1 # 只跳过当前一个字符
35
+
36
+ return result
37
+
38
+
39
+ def get_mask_from_string(input_str: str, tokenizer) -> torch.Tensor:
40
+ pairs = process_string_into_pairs(input_str)
41
+ masks = [
42
+ f"<|mask_{pair}|>" if all(ord(i) < 128 for i in pair) else pair
43
+ for pair in pairs
44
+ ]
45
+ mask_tensor = torch.tensor(
46
+ [tokenizer.token_to_id(mask) for mask in masks], dtype=torch.long
47
+ )
48
+ return mask_tensor
49
+
50
+
51
+ def inference(model, input_str: str, tokenizer, device, threshold=0.9):
52
+ model.eval()
53
+
54
+ # Initialize NgramHashMapping
55
+ engram_cfg = model.config.engram_config
56
+ hash_mapping = None
57
+ if engram_cfg is not None:
58
+ from modeling_llada_engram import ModelConfig, EngramConfig, NgramHashMapping
59
+ from dataclasses import fields
60
+ # Prepare ModelConfig for NgramHashMapping
61
+ backbone_config_dict = model.config.to_dict()
62
+ # Filter out keys not in ModelConfig if necessary, but ModelConfig usually matches LLaDAConfig fields
63
+ backbone_config = ModelConfig(**{k: v for k, v in backbone_config_dict.items() if k in [f.name for f in fields(ModelConfig)]})
64
+
65
+ hash_mapping = NgramHashMapping(
66
+ engram_vocab_size = engram_cfg.get('engram_vocab_size', [129280*5, 129280*5]),
67
+ max_ngram_size = engram_cfg.get('max_ngram_size', 3),
68
+ n_embed_per_ngram = engram_cfg.get('n_embed_per_ngram', 512),
69
+ n_head_per_ngram = engram_cfg.get('n_head_per_ngram', 8),
70
+ layer_ids = engram_cfg.get('layer_ids', [1, 15]),
71
+ pad_id = engram_cfg.get('pad_id', 2),
72
+ seed = engram_cfg.get('seed', 0),
73
+ config = backbone_config,
74
+ )
75
+
76
+ with torch.no_grad():
77
+ mask_tensor = get_mask_from_string(input_str, tokenizer).unsqueeze(0).to(device)
78
+ # is_masked = torch.ones(mask_tensor.shape, dtype=torch.bool, device=device)
79
+ is_masked = mask_tensor >= tokenizer.token_to_id("<|mask|>")
80
+ rounds = 0
81
+ while is_masked.any():
82
+ rounds += 1
83
+
84
+ output = model(input_ids=mask_tensor)[0]
85
+ # Logit to probability
86
+ output = torch.softmax(output, dim=-1)
87
+ unmasked_any = False
88
+ prob_info = []
89
+
90
+ most_certain_token = (0, 0, 0) # (probability, index, token_id)
91
+ # Check each token that still is_masked
92
+ for i in range(mask_tensor.shape[1]):
93
+ if is_masked[0, i]:
94
+ # Get the token with the highest probability
95
+ predicted_token = output[0, i].argmax().item()
96
+ prob_info.append(
97
+ f"{output[0, i, predicted_token].item():.2f} {tokenizer.id_to_token(predicted_token)}"
98
+ )
99
+ most_certain_token = max(
100
+ most_certain_token,
101
+ (output[0, i, predicted_token].item(), i, predicted_token)
102
+ )
103
+ # If the probability is above the threshold, replace the mask
104
+ if output[0, i, predicted_token].item() > threshold:
105
+ mask_tensor[0, i] = predicted_token
106
+ is_masked[0, i] = False
107
+ unmasked_any = True
108
+ else:
109
+ prob_info.append("")
110
+ if not unmasked_any:
111
+ # Unmask the most certain one
112
+ mask_tensor[0, most_certain_token[1]] = most_certain_token[2]
113
+ is_masked[0, most_certain_token[1]] = False
114
+
115
+ masked_str = "".join(
116
+ (
117
+ tokenizer.id_to_token(mask_tensor[0, i].item())
118
+ if not is_masked[0, i]
119
+ else tokenizer.id_to_token(mask_tensor[0, i].item())[7:-2]
120
+ )
121
+ for i in range(mask_tensor.shape[1])
122
+ )
123
+ print(masked_str)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+ tokenizer = Tokenizer.from_file("tokenizer.json")
129
+
130
+ # Load from local directory using AutoModel
131
+ # Note: Ensure you have transformers installed and trust_remote_code=True
132
+ try:
133
+ from transformers import AutoModelForCausalLM
134
+ model = AutoModelForCausalLM.from_pretrained(".", trust_remote_code=True).to(device)
135
+ except Exception as e:
136
+ print(f"Failed to load with AutoModel: {e}")
137
+ print("Falling back to manual loading (if needed, but prefer AutoModel for validation)")
138
+ # Fallback code removed for clarity as we want to enforce AutoModel structure
139
+ raise e
140
+
141
+ # To bfloat16
142
+ model = model.to(torch.bfloat16) if device.type == "cuda" else model.float()
143
+ print("Loaded model. Parameters:", sum(p.numel() for p in model.parameters()))
144
+
145
+ threshold = 0.9
146
+
147
+ while True:
148
+ input_str = input("Enter a string to process: ")
149
+ inference(model, input_str, tokenizer, device, threshold=threshold)
150
+ print("") # 空行分隔
151
+
152
+ # Input example: nhkzotdgjvdmleunkmiekz。
153
+ # Output: 黄河是中华民族的母亲河。
154
+
155
+ # Input example: mdflswsyelfl,eyxxmdswsyelfl,raxxmdelfl,otfixdzhfnjrugfoirmbisunswsyelfl。zhldxxdgun“mdfl”uvelflqhnvxtmdunkmpbofvjcjnnmdunsoirpbucheel。
156
+ # Output: 大型语言模型,也称大语言模型,简称大模型,是一种基于人工神经网络的语言模型。其名称中的“大型”指模型具有庞大的参数量以及巨大的训练数据规模。
157
+
158
+ # Input example: hgzz(Go o g l e )otfiwjpmrnxjuchkaf,hdidjifngmrnsdoovsoggn.
159
+ # Output:
160
+ # 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城.
161
+ # 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
162
+ # 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
163
+ # 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
164
+ # 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
165
+
166
+ # Input example: jxvuygvbotghtusvwtvbdt。auwvvbotcbghwhtkshdl?
167
+ # Output:
168
+ # 天对地,雨对风。大陆对长空。山lj对ke树,赤日对ljeb。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨tq晚霞红。
169
+ # 天对地,雨对风。大陆对长空。山lj对杂树,赤日对苍eb。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨雷晚霞红。
170
+ # 天对地,雨对风。大陆对长空。山lj对杂树,赤日对苍穹。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨雷晚霞红。
171
+ # 天对地,雨对风。大陆对长空。山苍对杂树,赤日对苍穹。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨雷晚霞红。
172
+ # (Expected Output: 天对地,雨对风。大陆对长空。山花对海树,赤日对苍穹。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨霁晚霞红。)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0114ad14a6671ade8155e31d930bb1c19779dab6af574d139d11678d7152270
3
+ size 700907000
modeling_llada_engram.py ADDED
@@ -0,0 +1,1895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ import sys
6
+ from abc import abstractmethod
7
+ from collections import defaultdict
8
+ from functools import partial
9
+ from typing import (
10
+ Callable,
11
+ Dict,
12
+ Iterable,
13
+ List,
14
+ NamedTuple,
15
+ Optional,
16
+ Sequence,
17
+ Set,
18
+ Tuple,
19
+ cast,
20
+ )
21
+ from dataclasses import fields
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.backends.cuda
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torch.nn.utils.rnn as rnn_utils
29
+ from torch import einsum
30
+ from transformers import PreTrainedModel
31
+ from transformers.modeling_outputs import CausalLMOutputWithPast
32
+ from transformers.models.auto import AutoModel
33
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
34
+ from transformers.cache_utils import Cache
35
+ from sympy import isprime
36
+ import numpy as np
37
+
38
+ from configuration_llada_engram import (
39
+ EngramConfig,
40
+ LLaDAConfig,
41
+ StrEnum,
42
+ InitFnType,
43
+ ActivationType,
44
+ BlockType,
45
+ LayerNormType,
46
+ ModelConfig,
47
+ ActivationCheckpointingStrategy,
48
+ )
49
+
50
+ if sys.version_info.minor > 8:
51
+ from collections.abc import MutableMapping
52
+ elif sys.version_info.minor == 8:
53
+ from typing import MutableMapping
54
+ else:
55
+ raise SystemExit("This script supports Python 3.8 or higher")
56
+
57
+ __all__ = [
58
+ "LayerNormBase",
59
+ "LayerNorm",
60
+ "RMSLayerNorm",
61
+ "GemmaRMSLayerNorm",
62
+ "RotaryEmbedding",
63
+ "Activation",
64
+ "GELU",
65
+ "ReLU",
66
+ "SwiGLU",
67
+ "LLaDABlock",
68
+ "LLaDASequentialBlock",
69
+ "LLaDAModel",
70
+ "LLaDAOutput",
71
+ "LLaDAGenerateOutput",
72
+ ]
73
+
74
+
75
+ log = logging.getLogger(__name__)
76
+
77
+
78
+ class ModuleType(StrEnum):
79
+ in_module = "in"
80
+ out_module = "out"
81
+ emb = "emb"
82
+ final_out = "final_out"
83
+
84
+
85
+ def init_weights(
86
+ config: ModelConfig,
87
+ module: Union[nn.Linear, nn.Embedding],
88
+ d: Optional[int] = None,
89
+ layer_id: Optional[int] = None,
90
+ std_factor: float = 1.0,
91
+ type_of_module: Optional[ModuleType] = None,
92
+ ) -> None:
93
+ """
94
+ Initialize weights of a linear or embedding module.
95
+ :param config: The model config.
96
+ :param module: The linear or embedding submodule to initialize.
97
+ :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
98
+ for fused layers.
99
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
100
+ ``1 / sqrt(2 * (layer_id + 1))``.
101
+ """
102
+ d = d if d is not None else config.d_model
103
+ if config.init_fn == InitFnType.normal:
104
+ std = config.init_std * std_factor
105
+ if config.init_cutoff_factor is not None:
106
+ cutoff_value = config.init_cutoff_factor * std
107
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
108
+ else:
109
+ nn.init.normal_(module.weight, mean=0.0, std=std)
110
+ elif config.init_fn == InitFnType.mitchell:
111
+ std = std_factor / math.sqrt(d)
112
+ if layer_id is not None:
113
+ std = std / math.sqrt(2 * (layer_id + 1))
114
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
115
+ elif config.init_fn == InitFnType.kaiming_normal:
116
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
117
+ elif config.init_fn == InitFnType.fan_in:
118
+ std = std_factor / math.sqrt(d)
119
+ nn.init.normal_(module.weight, mean=0.0, std=std)
120
+ elif config.init_fn == InitFnType.full_megatron:
121
+ if type_of_module is None:
122
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
123
+
124
+ cutoff_factor = config.init_cutoff_factor
125
+ if cutoff_factor is None:
126
+ cutoff_factor = 3
127
+
128
+ if type_of_module == ModuleType.in_module:
129
+ # for att_proj (same as QKV), ff_proj
130
+ std = config.init_std
131
+ elif type_of_module == ModuleType.out_module:
132
+ # for attn_out, ff_out
133
+ std = config.init_std / math.sqrt(2.0 * config.n_layers)
134
+ elif type_of_module == ModuleType.emb:
135
+ # positional embeddings (wpe)
136
+ # token embeddings (wte)
137
+ std = config.init_std
138
+ elif type_of_module == ModuleType.final_out:
139
+ # final output (ff_out)
140
+ std = config.d_model**-0.5
141
+ else:
142
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
143
+ nn.init.trunc_normal_(
144
+ module.weight,
145
+ mean=0.0,
146
+ std=std,
147
+ a=-cutoff_factor * std,
148
+ b=cutoff_factor * std,
149
+ )
150
+ else:
151
+ raise NotImplementedError(config.init_fn)
152
+
153
+ if isinstance(module, nn.Linear):
154
+ if module.bias is not None:
155
+ nn.init.zeros_(module.bias)
156
+
157
+ if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
158
+ with torch.no_grad():
159
+ module.weight.div_(math.sqrt(2 * config.n_layers))
160
+
161
+
162
+ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
163
+ """
164
+ Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
165
+ is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
166
+ """
167
+ if check_neg_inf:
168
+ x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
169
+ if check_pos_inf:
170
+ x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
171
+
172
+
173
+ def activation_checkpoint_function(cfg: ModelConfig):
174
+ preserve_rng_state = (
175
+ (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
176
+ )
177
+ from torch.utils.checkpoint import checkpoint
178
+
179
+ return partial(
180
+ checkpoint,
181
+ preserve_rng_state=preserve_rng_state,
182
+ use_reentrant=False,
183
+ )
184
+
185
+
186
+ class BufferCache(dict, MutableMapping[str, torch.Tensor]):
187
+ """
188
+ Cache for attention biases and other things that would normally be stored as buffers.
189
+ We avoid using buffers because we've run into various issues doing so with FSDP.
190
+ In general it appears the way FSDP handles buffers is not well-defined.
191
+ It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
192
+ since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
193
+ NaNs when they're synchronized due to casting or some other issue.
194
+ """
195
+
196
+
197
+ def _non_meta_init_device(config: ModelConfig) -> torch.device:
198
+ if config.init_device is not None and config.init_device != "meta":
199
+ return torch.device(config.init_device)
200
+ else:
201
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
202
+
203
+
204
+ class Dropout(nn.Dropout):
205
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
206
+ if self.p == 0.0:
207
+ return input
208
+ else:
209
+ return F.dropout(input, self.p, self.training, self.inplace)
210
+
211
+
212
+ class LayerNormBase(nn.Module):
213
+ def __init__(
214
+ self,
215
+ config: ModelConfig,
216
+ *,
217
+ size: Optional[int] = None,
218
+ elementwise_affine: Optional[bool] = True,
219
+ eps: float = 1e-05,
220
+ ):
221
+ super().__init__()
222
+ self.config = config
223
+ self.eps = eps
224
+ self.normalized_shape = (size or config.d_model,)
225
+ if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
226
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
227
+ use_bias = self.config.bias_for_layer_norm
228
+ if use_bias is None:
229
+ use_bias = self.config.include_bias
230
+ if use_bias:
231
+ self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
232
+ else:
233
+ self.register_parameter("bias", None)
234
+ else:
235
+ self.register_parameter("bias", None)
236
+ self.register_parameter("weight", None)
237
+
238
+ @abstractmethod
239
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
240
+ raise NotImplementedError
241
+
242
+ @classmethod
243
+ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
244
+ if config.layer_norm_type == LayerNormType.default:
245
+ return LayerNorm(config, size=size, low_precision=False, **kwargs)
246
+ elif config.layer_norm_type == LayerNormType.low_precision:
247
+ return LayerNorm(config, size=size, low_precision=True, **kwargs)
248
+ elif config.layer_norm_type == LayerNormType.rms:
249
+ return RMSLayerNorm(config, size=size, **kwargs)
250
+ elif config.layer_norm_type == LayerNormType.gemma_rms:
251
+ return GemmaRMSLayerNorm(config, size=size, **kwargs)
252
+ else:
253
+ raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
254
+
255
+ def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
256
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
257
+ # `is_autocast_cpu_enabled()` for CPU autocast.
258
+ # See https://github.com/pytorch/pytorch/issues/110966.
259
+ if tensor.device.type == "cuda" and torch.is_autocast_enabled():
260
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
261
+ elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
262
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
263
+ else:
264
+ return tensor
265
+
266
+ def reset_parameters(self):
267
+ if self.weight is not None:
268
+ torch.nn.init.ones_(self.weight) # type: ignore
269
+ if self.bias is not None:
270
+ torch.nn.init.zeros_(self.bias) # type: ignore
271
+
272
+
273
+ class LayerNorm(LayerNormBase):
274
+ """
275
+ The default :class:`LayerNorm` implementation which can optionally run in low precision.
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ config: ModelConfig,
281
+ size: Optional[int] = None,
282
+ low_precision: bool = False,
283
+ elementwise_affine: Optional[bool] = None,
284
+ eps: float = 1e-05,
285
+ ):
286
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
287
+ self.low_precision = low_precision
288
+
289
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
290
+ if self.low_precision:
291
+ module_device = x.device
292
+ downcast_x = self._cast_if_autocast_enabled(x)
293
+ downcast_weight = (
294
+ self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
295
+ )
296
+ downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
297
+ with torch.autocast(enabled=False, device_type=module_device.type):
298
+ return F.layer_norm(
299
+ downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
300
+ )
301
+ else:
302
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
303
+
304
+
305
+ class RMSLayerNorm(LayerNormBase):
306
+ """
307
+ RMS layer norm, a simplified :class:`LayerNorm` implementation
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ config: ModelConfig,
313
+ size: Optional[int] = None,
314
+ elementwise_affine: Optional[bool] = None,
315
+ eps: float = 1e-5,
316
+ ):
317
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
318
+
319
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
320
+ with torch.autocast(enabled=False, device_type=x.device.type):
321
+ og_dtype = x.dtype
322
+ x = x.to(torch.float32)
323
+ variance = x.pow(2).mean(-1, keepdim=True)
324
+ x = x * torch.rsqrt(variance + self.eps)
325
+ x = x.to(og_dtype)
326
+
327
+ if self.weight is not None:
328
+ if self.bias is not None:
329
+ return self.weight * x + self.bias
330
+ else:
331
+ return self.weight * x
332
+ else:
333
+ return x
334
+
335
+
336
+ class GemmaRMSLayerNorm(LayerNormBase):
337
+ """
338
+ Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
339
+ """
340
+
341
+ def __init__(
342
+ self,
343
+ config: ModelConfig,
344
+ size: Optional[int] = None,
345
+ elementwise_affine: Optional[bool] = None,
346
+ eps: float = 1e-5,
347
+ ):
348
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
349
+
350
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
351
+ with torch.autocast(enabled=False, device_type=x.device.type):
352
+ og_dtype = x.dtype
353
+ x = x.to(torch.float32)
354
+ variance = x.pow(2).mean(-1, keepdim=True)
355
+ x = x * torch.rsqrt(variance + self.eps)
356
+ x = x.to(og_dtype)
357
+
358
+ if self.weight is not None:
359
+ if self.bias is not None:
360
+ return x * (1 + self.weight) + self.bias
361
+ else:
362
+ return x * (1 + self.weight)
363
+ else:
364
+ return x
365
+
366
+
367
+ class RotaryEmbedding(nn.Module):
368
+ """
369
+ [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
370
+ """
371
+
372
+ def __init__(self, config: ModelConfig, cache: BufferCache):
373
+ super().__init__()
374
+ self.config = config
375
+ self.__cache = cache
376
+ # Warm up cache.
377
+ self.rope_theta = config.rope_theta
378
+ self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
379
+
380
+ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
381
+ if (
382
+ (pos_sin := self.__cache.get("rope_pos_sin")) is not None
383
+ and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
384
+ and pos_sin.shape[-2] >= seq_len
385
+ and pos_cos.shape[-2] >= seq_len
386
+ ):
387
+ if pos_sin.device != device:
388
+ pos_sin = pos_sin.to(device)
389
+ self.__cache["rope_pos_sin"] = pos_sin
390
+ if pos_cos.device != device:
391
+ pos_cos = pos_cos.to(device)
392
+ self.__cache["rope_pos_cos"] = pos_cos
393
+ return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
394
+
395
+ with torch.autocast(device.type, enabled=False):
396
+ dim = self.config.d_model // self.config.n_heads
397
+ inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
398
+ seq = torch.arange(seq_len, device=device, dtype=torch.float)
399
+ freqs = einsum("i , j -> i j", seq, inv_freq)
400
+ positions = torch.cat((freqs, freqs), dim=-1)
401
+ pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
402
+ self.__cache["rope_pos_sin"] = pos_sin
403
+ self.__cache["rope_pos_cos"] = pos_cos
404
+ return pos_sin, pos_cos
405
+
406
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
407
+ B, nh, T, hs = x.size()
408
+ x = x.view(B, nh, T, 2, hs // 2)
409
+ x1, x2 = x.unbind(dim=-2)
410
+ return torch.cat((-x2, x1), dim=-1)
411
+
412
+ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
413
+ return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
414
+
415
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
416
+ if self.config.rope_full_precision:
417
+ q_, k_ = q.float(), k.float()
418
+ else:
419
+ q_, k_ = q, k
420
+
421
+ with torch.autocast(q.device.type, enabled=False):
422
+ query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
423
+ pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
424
+ pos_sin = pos_sin.type_as(q_)
425
+ pos_cos = pos_cos.type_as(q_)
426
+ q_ = self.apply_rotary_pos_emb(
427
+ pos_sin[:, :, key_len - query_len : key_len, :],
428
+ pos_cos[:, :, key_len - query_len : key_len, :],
429
+ q_,
430
+ )
431
+ k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
432
+ return q_.type_as(q), k_.type_as(k)
433
+
434
+
435
+ class Activation(nn.Module):
436
+ def __init__(self, config: ModelConfig):
437
+ super().__init__()
438
+ self.config = config
439
+
440
+ @abstractmethod
441
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
442
+ raise NotImplementedError
443
+
444
+ @property
445
+ @abstractmethod
446
+ def output_multiplier(self) -> float:
447
+ raise NotImplementedError
448
+
449
+ @classmethod
450
+ def build(cls, config: ModelConfig) -> Activation:
451
+ if config.activation_type == ActivationType.gelu:
452
+ return cast(Activation, GELU(approximate="none"))
453
+ elif config.activation_type == ActivationType.relu:
454
+ return cast(Activation, ReLU(inplace=False))
455
+ elif config.activation_type == ActivationType.silu:
456
+ return cast(Activation, SiLU(inplace=False))
457
+ elif config.activation_type == ActivationType.swiglu:
458
+ return SwiGLU(config)
459
+ else:
460
+ raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
461
+
462
+
463
+ class GELU(nn.GELU):
464
+ @property
465
+ def output_multiplier(self) -> float:
466
+ return 1.0
467
+
468
+
469
+ class ReLU(nn.ReLU):
470
+ @property
471
+ def output_multiplier(self) -> float:
472
+ return 1.0
473
+
474
+ class SiLU(nn.SiLU):
475
+ @property
476
+ def output_multiplier(self) -> float:
477
+ return 1.0
478
+
479
+ class SwiGLU(Activation):
480
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
481
+ x, gate = x.chunk(2, dim=-1)
482
+ return F.silu(gate) * x
483
+
484
+ @property
485
+ def output_multiplier(self) -> float:
486
+ return 0.5
487
+
488
+
489
+ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
490
+ att_bias = torch.triu(
491
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
492
+ diagonal=1,
493
+ )
494
+ att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
495
+ return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
496
+
497
+
498
+ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
499
+ if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
500
+ if causal_bias.device != device:
501
+ causal_bias = causal_bias.to(device)
502
+ cache["causal_attention_bias"] = causal_bias
503
+ return causal_bias
504
+ with torch.autocast(device.type, enabled=False):
505
+ causal_bias = causal_attention_bias(seq_len, device)
506
+ cache["causal_attention_bias"] = causal_bias
507
+ return causal_bias
508
+
509
+
510
+ def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
511
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
512
+
513
+ # shape: (1, 1, seq_len, seq_len)
514
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
515
+ alibi_bias.abs_().mul_(-1)
516
+
517
+ # shape: (n_heads,)
518
+ m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
519
+ m.mul_(config.alibi_bias_max / config.n_heads)
520
+
521
+ # shape: (1, n_heads, seq_len, seq_len)
522
+ return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
523
+
524
+ class ShortConv(nn.Module):
525
+ def __init__(
526
+ self,
527
+ hidden_size: int,
528
+ kernel_size: int = 7, # 修改默认值为 7
529
+ dilation: int = 1,
530
+ norm_eps: float = 1e-5,
531
+ hc_mult: int = 1,
532
+ activation: bool = True,
533
+ ):
534
+ super().__init__()
535
+ self.activation = activation
536
+ self.kernel_size = kernel_size
537
+ self.dilation = dilation
538
+
539
+ # 针对奇数核(如7)的 Same Padding 计算: (K-1)/2
540
+ # K=7 -> padding=3
541
+ self.padding = (kernel_size - 1) // 2 * dilation
542
+
543
+ # 标准卷积,这次我们可以直接用 PyTorch 的 padding,
544
+ # 因为奇数核的 same padding 是对称的,不需要手动 F.pad
545
+ self.conv = nn.Conv1d(
546
+ in_channels=hidden_size,
547
+ out_channels=hidden_size,
548
+ kernel_size=kernel_size,
549
+ groups=hidden_size,
550
+ bias=False,
551
+ padding=self.padding, # 直接设置 padding
552
+ dilation=dilation,
553
+ )
554
+
555
+ self.norm = nn.RMSNorm(hidden_size, eps=norm_eps)
556
+
557
+ if self.activation:
558
+ self.act_fn = nn.SiLU()
559
+
560
+ self.reset_parameters()
561
+
562
+ def reset_parameters(self):
563
+ nn.init.zeros_(self.conv.weight)
564
+
565
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
566
+ # x: [B, L, D]
567
+ x_norm = self.norm(x)
568
+
569
+ # [B, L, D] -> [B, D, L]
570
+ x_bct = x_norm.transpose(1, 2)
571
+
572
+ # 卷积 (自动 padding)
573
+ y_bct = self.conv(x_bct)
574
+
575
+ if self.activation:
576
+ y_bct = self.act_fn(y_bct)
577
+
578
+ # [B, D, L] -> [B, L, D]
579
+ y = y_bct.transpose(1, 2).contiguous()
580
+
581
+ return y
582
+
583
+ def find_next_prime(start, seen_primes):
584
+ candidate = start + 1
585
+ while True:
586
+ if isprime(candidate) and candidate not in seen_primes:
587
+ return candidate
588
+ candidate += 1
589
+
590
+ class NgramHashMapping:
591
+ def __init__(
592
+ self,
593
+ engram_vocab_size,
594
+ max_ngram_size,
595
+ n_embed_per_ngram,
596
+ n_head_per_ngram,
597
+ layer_ids,
598
+ pad_id,
599
+ seed,
600
+ config: ModelConfig,
601
+ ):
602
+ self.vocab_size_per_ngram = engram_vocab_size
603
+ self.max_ngram_size = max_ngram_size
604
+ self.n_embed_per_ngram = n_embed_per_ngram
605
+ self.n_head_per_ngram = n_head_per_ngram
606
+ self.pad_id = pad_id
607
+ self.layer_ids = layer_ids
608
+
609
+ self.tokenizer_vocab_size = config.vocab_size
610
+
611
+ max_long = np.iinfo(np.int64).max
612
+ M_max = int(max_long // self.tokenizer_vocab_size)
613
+ half_bound = max(1, M_max // 2)
614
+ PRIME_1 = 10007
615
+
616
+ self.layer_multipliers = {}
617
+
618
+ for layer_id in self.layer_ids:
619
+ base_seed = int(seed + PRIME_1 * int(layer_id))
620
+ g = np.random.default_rng(base_seed)
621
+ r = g.integers(
622
+ low=0,
623
+ high=half_bound,
624
+ size=(self.max_ngram_size,),
625
+ dtype=np.int64
626
+ )
627
+ multipliers = r * 2 + 1
628
+ self.layer_multipliers[layer_id] = multipliers
629
+
630
+ self.vocab_size_across_layers = self.calculate_vocab_size_across_layers()
631
+
632
+ def calculate_vocab_size_across_layers(self):
633
+ seen_primes = set()
634
+ vocab_size_across_layers = {}
635
+
636
+ for layer_id in self.layer_ids:
637
+ all_ngram_vocab_sizes = []
638
+ for ngram in range(2, self.max_ngram_size + 1):
639
+ current_ngram_heads_sizes = []
640
+
641
+ vocab_size = self.vocab_size_per_ngram[ngram - 2]
642
+ num_head = self.n_head_per_ngram
643
+ current_prime_search_start = vocab_size - 1
644
+
645
+ for _ in range(num_head):
646
+ found_prime = find_next_prime(
647
+ current_prime_search_start,
648
+ seen_primes
649
+ )
650
+ seen_primes.add(found_prime)
651
+ current_ngram_heads_sizes.append(found_prime)
652
+ current_prime_search_start = found_prime
653
+
654
+ all_ngram_vocab_sizes.append(current_ngram_heads_sizes)
655
+ vocab_size_across_layers[layer_id] = all_ngram_vocab_sizes
656
+
657
+ return vocab_size_across_layers
658
+
659
+ def _get_ngram_hashes(
660
+ self,
661
+ input_ids: np.ndarray,
662
+ layer_id: int,
663
+ ) -> np.ndarray:
664
+ x = np.asarray(input_ids, dtype=np.int64)
665
+ B, T = x.shape
666
+
667
+ multipliers = self.layer_multipliers[layer_id]
668
+
669
+ def shift_k(k: int) -> np.ndarray:
670
+ if k == 0: return x
671
+ shifted = np.pad(x, ((0, 0), (k, 0)),
672
+ mode='constant', constant_values=self.pad_id)[:, :T]
673
+ return shifted
674
+
675
+ base_shifts = [shift_k(k) for k in range(self.max_ngram_size)]
676
+
677
+ all_hashes = []
678
+
679
+ for n in range(2, self.max_ngram_size + 1):
680
+ n_gram_index = n - 2
681
+ tokens = base_shifts[:n]
682
+ mix = (tokens[0] * multipliers[0])
683
+ for k in range(1, n):
684
+ mix = np.bitwise_xor(mix, tokens[k] * multipliers[k])
685
+ num_heads_for_this_ngram = self.n_head_per_ngram
686
+ head_vocab_sizes = self.vocab_size_across_layers[layer_id][n_gram_index]
687
+
688
+ for j in range(num_heads_for_this_ngram):
689
+ mod = int(head_vocab_sizes[j])
690
+ head_hash = mix % mod
691
+ all_hashes.append(head_hash.astype(np.int64, copy=False))
692
+
693
+ return np.stack(all_hashes, axis=2)
694
+
695
+ def hash(self, input_ids):
696
+ hash_ids_for_all_layers = {}
697
+ for layer_id in self.layer_ids:
698
+ hash_ids_for_all_layers[layer_id] = self._get_ngram_hashes(input_ids, layer_id=layer_id)
699
+ return hash_ids_for_all_layers
700
+
701
+ class TorchNgramHashMapping:
702
+ """
703
+ 在 GPU 上进行 n-gram 哈希计算的 Torch 实现。
704
+ 由现有的 NgramHashMapping 提供 multipliers 与每 head 的素数模数组,
705
+ 以确保与 numpy 版本一致的哈希结果与 head 排列顺序。
706
+ 输出: dict[layer_id] -> (B, T, num_hash_heads) [long]
707
+ """
708
+ def __init__(self, np_mapping: NgramHashMapping, device: torch.device):
709
+ self.layer_ids = list(np_mapping.layer_ids)
710
+ self.max_ngram_size = int(np_mapping.max_ngram_size)
711
+ self.n_head_per_ngram = int(np_mapping.n_head_per_ngram)
712
+ self.pad_id = int(np_mapping.pad_id)
713
+
714
+ # 每层 multipliers: (max_ngram_size,)
715
+ self._multipliers = {
716
+ lid: torch.tensor(np_mapping.layer_multipliers[lid], dtype=torch.long, device=device)
717
+ for lid in self.layer_ids
718
+ }
719
+
720
+ # 每层 mods: 列表,mods[n-2] = (n_head_per_ngram,)
721
+ self._mods = {}
722
+ for lid in self.layer_ids:
723
+ mods_per_n = []
724
+ for n in range(2, self.max_ngram_size + 1):
725
+ head_mods = np_mapping.vocab_size_across_layers[lid][n - 2]
726
+ mods_per_n.append(torch.tensor(head_mods, dtype=torch.long, device=device))
727
+ self._mods[lid] = mods_per_n
728
+
729
+ self.num_hash_heads = (self.max_ngram_size - 1) * self.n_head_per_ngram
730
+
731
+ def hash(self, input_ids: torch.Tensor) -> Dict[int, torch.Tensor]:
732
+ """
733
+ input_ids: (B, T) long tensor on target device
734
+ return: {layer_id: (B, T, num_hash_heads) long}
735
+ """
736
+ x = input_ids.to(torch.long)
737
+ B, T = x.shape
738
+
739
+ # 右移 k 位(左侧 pad): shifts[k] shape (B, T)
740
+ shifts = [x]
741
+ for k in range(1, self.max_ngram_size):
742
+ shifts.append(F.pad(x, (k, 0), value=self.pad_id)[:, :T])
743
+
744
+ out: Dict[int, torch.Tensor] = {}
745
+ for lid in self.layer_ids:
746
+ multipliers = self._multipliers[lid]
747
+ heads_per_layer = []
748
+
749
+ for n in range(2, self.max_ngram_size + 1):
750
+ mix = shifts[0] * multipliers[0]
751
+ for k in range(1, n):
752
+ mix = torch.bitwise_xor(mix, shifts[k] * multipliers[k])
753
+
754
+ mods = self._mods[lid][n - 2] # (H,)
755
+ # (B, T, 1) % (1, 1, H) -> (B, T, H)
756
+ head_hash = mix.unsqueeze(-1) % mods.view(1, 1, -1)
757
+ heads_per_layer.append(head_hash)
758
+
759
+ out[lid] = torch.cat(heads_per_layer, dim=-1)
760
+
761
+ return out
762
+
763
+ class MultiHeadEmbedding(nn.Module):
764
+ def __init__(self, list_of_N: List[int], D: int):
765
+ super().__init__()
766
+ self.num_heads = len(list_of_N)
767
+ self.embedding_dim = D
768
+
769
+ offsets = [0]
770
+ for n in list_of_N[:-1]:
771
+ offsets.append(offsets[-1] + n)
772
+
773
+ self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long))
774
+
775
+ total_N = sum(list_of_N)
776
+ self.embedding = nn.Embedding(num_embeddings=total_N, embedding_dim=D)
777
+
778
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
779
+ shifted_input_ids = input_ids + self.offsets
780
+ output = self.embedding(shifted_input_ids)
781
+
782
+ return output
783
+
784
+ class Engram(nn.Module):
785
+ def __init__(self, layer_id: int, config: ModelConfig):
786
+ super().__init__()
787
+ self.layer_id = layer_id
788
+ self.engram_cfg = config.engram_config
789
+ self.backbone_config = config
790
+ engram_cfg = self.engram_cfg
791
+ backbone_config = self.backbone_config
792
+ self.hash_mapping = NgramHashMapping(
793
+ engram_vocab_size = engram_cfg.engram_vocab_size,
794
+ max_ngram_size = engram_cfg.max_ngram_size,
795
+ n_embed_per_ngram = engram_cfg.n_embed_per_ngram,
796
+ n_head_per_ngram = engram_cfg.n_head_per_ngram,
797
+ layer_ids = engram_cfg.layer_ids,
798
+ pad_id = engram_cfg.pad_id,
799
+ seed = engram_cfg.seed,
800
+ config = backbone_config,
801
+ )
802
+ self.multi_head_embedding = MultiHeadEmbedding(
803
+ list_of_N = [x for y in self.hash_mapping.vocab_size_across_layers[self.layer_id] for x in y],
804
+ D = engram_cfg.n_embed_per_ngram // engram_cfg.n_head_per_ngram,
805
+ )
806
+
807
+ # 修改 ShortConv 调用
808
+ self.short_conv = ShortConv(
809
+ hidden_size = backbone_config.d_model,
810
+ kernel_size = engram_cfg.kernel_size,
811
+ dilation = engram_cfg.max_ngram_size,
812
+ hc_mult = 1 # 设为 1
813
+ )
814
+
815
+ engram_hidden_size = (engram_cfg.max_ngram_size-1) * engram_cfg.n_embed_per_ngram
816
+
817
+ # --- 修改点:不再使用 ModuleList,而是单个层 ---
818
+ self.value_proj = nn.Linear(engram_hidden_size, backbone_config.d_model)
819
+
820
+ # 只需要 1 个 Key Projection
821
+ self.key_proj = nn.Linear(engram_hidden_size, backbone_config.d_model)
822
+
823
+ # 只需要 1 组 Norm
824
+ self.norm_key = nn.RMSNorm(backbone_config.d_model)
825
+ self.norm_query = nn.RMSNorm(backbone_config.d_model)
826
+
827
+ self.reset_parameters()
828
+ # Torch 版哈希缓存(按设备惰性构建)
829
+ self._torch_hash_mapping: Optional[TorchNgramHashMapping] = None
830
+ self._torch_hash_device: Optional[torch.device] = None
831
+
832
+ def reset_parameters(self):
833
+ init_weights(
834
+ self.backbone_config,
835
+ self.multi_head_embedding.embedding,
836
+ type_of_module=ModuleType.emb,
837
+ )
838
+ init_weights(
839
+ self.backbone_config,
840
+ self.value_proj,
841
+ layer_id=self.layer_id,
842
+ type_of_module=ModuleType.in_module,
843
+ )
844
+ init_weights(
845
+ self.backbone_config,
846
+ self.key_proj,
847
+ layer_id=self.layer_id,
848
+ type_of_module=ModuleType.in_module,
849
+ )
850
+ self.short_conv.reset_parameters()
851
+
852
+ def forward(self, hidden_states, input_ids, engram_hash=None):
853
+ """
854
+ hidden_states: [B, L, D] <-- 标准形状
855
+ input_ids: [B, L]
856
+ engram_hash: [B, L, NumHeads] (Optional)
857
+ """
858
+ # 1. 查表 (不变)
859
+ if engram_hash is None:
860
+ # 优先使用 GPU 版哈希,避免 CPU<->GPU 往返
861
+ cur_dev = hidden_states.device
862
+ if self._torch_hash_mapping is None or self._torch_hash_device != cur_dev:
863
+ self._torch_hash_mapping = TorchNgramHashMapping(self.hash_mapping, device=cur_dev)
864
+ self._torch_hash_device = cur_dev
865
+ hash_input_ids = self._torch_hash_mapping.hash(input_ids)[self.layer_id]
866
+ else:
867
+ hash_input_ids = engram_hash
868
+ embeddings = self.multi_head_embedding(hash_input_ids).flatten(start_dim=-2)
869
+
870
+ # 2. 计算 Gate (不需要循环了)
871
+ # Key 部分
872
+ key = self.key_proj(embeddings)
873
+ normed_key = self.norm_key(key)
874
+
875
+ # Query 部分 (直接使用 hidden_states)
876
+ query = hidden_states
877
+ normed_query = self.norm_query(query)
878
+
879
+ # Gate 计算
880
+ # [B, L, D] * [B, L, D] -> sum(dim=-1) -> [B, L]
881
+ gate = (normed_key * normed_query).sum(dim=-1) / math.sqrt(self.backbone_config.d_model)
882
+ gate = gate.abs().clamp_min(1e-6).sqrt() * gate.sign()
883
+ gate = gate.sigmoid().unsqueeze(-1) # [B, L, 1]
884
+
885
+ # 3. 融合 Value
886
+ value = gate * self.value_proj(embeddings) # [B, L, 1] * [B, L, D] -> [B, L, D]
887
+
888
+ # 4. Short Conv
889
+ output = value + self.short_conv(value)
890
+
891
+ return output
892
+
893
+
894
+ class LLaDABlock(nn.Module):
895
+ """
896
+ A base class for transformer block implementations.
897
+ """
898
+
899
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
900
+ super().__init__()
901
+ self.layer_id = layer_id
902
+ self.config = config
903
+ self.hidden_size = (
904
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
905
+ )
906
+ self.__cache = cache
907
+ assert config.d_model % config.n_heads == 0
908
+
909
+ self.engram = None
910
+ if config.engram_config is not None and layer_id in config.engram_config.layer_ids:
911
+ self.engram = Engram(layer_id, config)
912
+
913
+ self._activation_checkpoint_fn = None
914
+
915
+ # Dropout.
916
+ self.dropout = Dropout(config.residual_dropout)
917
+
918
+ # Layer norms.
919
+ self.k_norm: Optional[LayerNormBase] = None
920
+ self.q_norm: Optional[LayerNormBase] = None
921
+ if config.attention_layer_norm:
922
+ self.k_norm = LayerNormBase.build(
923
+ config,
924
+ size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
925
+ elementwise_affine=config.attention_layer_norm_with_affine,
926
+ )
927
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
928
+
929
+ # Activation function.
930
+ self.act = Activation.build(config)
931
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
932
+
933
+ # Attention output projection.
934
+ self.attn_out = nn.Linear(
935
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
936
+ )
937
+
938
+ # Feed-forward output projection.
939
+ self.ff_out = nn.Linear(
940
+ int(self.act.output_multiplier * self.hidden_size),
941
+ config.d_model,
942
+ bias=config.include_bias,
943
+ device=config.init_device,
944
+ )
945
+ self.ff_out._is_residual = True # type: ignore
946
+
947
+ # Rotary embeddings.
948
+ if self.config.rope:
949
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
950
+
951
+ self.flash_attn_func = None
952
+ if config.flash_attention:
953
+ try:
954
+ from flash_attn import flash_attn_func # type: ignore
955
+
956
+ self.flash_attn_func = flash_attn_func
957
+ except ModuleNotFoundError:
958
+ pass
959
+
960
+ def reset_parameters(self):
961
+ if self.engram is not None:
962
+ self.engram.reset_parameters()
963
+ if self.k_norm is not None:
964
+ self.k_norm.reset_parameters()
965
+ if self.q_norm is not None:
966
+ self.q_norm.reset_parameters()
967
+ init_weights(
968
+ self.config,
969
+ self.attn_out,
970
+ d=self.config.d_model,
971
+ layer_id=self.layer_id,
972
+ type_of_module=ModuleType.out_module,
973
+ )
974
+ init_weights(
975
+ self.config,
976
+ self.ff_out,
977
+ d=self.ff_out.in_features,
978
+ layer_id=self.layer_id,
979
+ type_of_module=ModuleType.out_module,
980
+ )
981
+
982
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
983
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
984
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
985
+ else:
986
+ self._activation_checkpoint_fn = None
987
+
988
+ @classmethod
989
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
990
+ target_dtype = input_dtype
991
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
992
+ # `is_autocast_cpu_enabled()` for CPU autocast.
993
+ # See https://github.com/pytorch/pytorch/issues/110966.
994
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
995
+ target_dtype = torch.get_autocast_gpu_dtype()
996
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
997
+ target_dtype = torch.get_autocast_cpu_dtype()
998
+ if bias.dtype != target_dtype:
999
+ bias = bias.to(target_dtype)
1000
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
1001
+ return bias
1002
+
1003
+ def _scaled_dot_product_attention(
1004
+ self,
1005
+ q: torch.Tensor,
1006
+ k: torch.Tensor,
1007
+ v: torch.Tensor,
1008
+ attn_mask: Optional[torch.Tensor] = None,
1009
+ dropout_p: float = 0.0,
1010
+ is_causal: bool = False,
1011
+ ) -> torch.Tensor:
1012
+ """
1013
+ Computes scaled dot product attention on query, key and value tensors, using an optional
1014
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
1015
+ """
1016
+ if self.flash_attn_func is not None and attn_mask is None:
1017
+ r = self.flash_attn_func(
1018
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
1019
+ )
1020
+ return r.transpose(1, 2)
1021
+ else:
1022
+ # torch's sdpa doesn't support GQA, so we're doing this
1023
+ assert k.size(1) == v.size(1)
1024
+ num_kv_heads = k.size(1)
1025
+ num_q_heads = q.size(1)
1026
+ if num_q_heads != num_kv_heads:
1027
+ assert num_q_heads % num_kv_heads == 0
1028
+ k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
1029
+ v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
1030
+
1031
+ # Modify: MDM set causal to False, and with no attn_mask.
1032
+ return F.scaled_dot_product_attention(
1033
+ q,
1034
+ k,
1035
+ v,
1036
+ attn_mask=None,
1037
+ dropout_p=dropout_p,
1038
+ is_causal=False,
1039
+ )
1040
+
1041
+ def attention(
1042
+ self,
1043
+ q: torch.Tensor,
1044
+ k: torch.Tensor,
1045
+ v: torch.Tensor,
1046
+ attention_bias: Optional[torch.Tensor] = None,
1047
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1048
+ use_cache: bool = False,
1049
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
1050
+ B, T, C = q.size() # batch size, sequence length, d_model
1051
+ dtype = k.dtype
1052
+
1053
+ # Optionally apply layer norm to keys and queries.
1054
+ if self.q_norm is not None and self.k_norm is not None:
1055
+ q = self.q_norm(q).to(dtype=dtype)
1056
+ k = self.k_norm(k).to(dtype=dtype)
1057
+
1058
+ # Move head forward to be next to the batch dim.
1059
+ # shape: (B, nh, T, hs)
1060
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
1061
+ # shape: (B, n_kv_h, T, hs)
1062
+ k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
1063
+ # shape: (B, n_kv_h, T, hs)
1064
+ v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
1065
+
1066
+ if layer_past is not None:
1067
+ past_key, past_value = layer_past
1068
+ k = torch.cat((past_key, k), dim=-2)
1069
+ v = torch.cat((past_value, v), dim=-2)
1070
+
1071
+ present = (k, v) if use_cache else None
1072
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
1073
+
1074
+ if self.config.rope:
1075
+ # Apply rotary embeddings.
1076
+ q, k = self.rotary_emb(q, k)
1077
+
1078
+ if attention_bias is not None:
1079
+ # Resize and cast attention bias.
1080
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
1081
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
1082
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
1083
+ # cause the SDP attn function to produce NaNs.
1084
+ attention_bias = self._cast_attn_bias(
1085
+ attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
1086
+ )
1087
+
1088
+ # Get the attention scores.
1089
+ # shape: (B, nh, T, hs)
1090
+ att = self._scaled_dot_product_attention(
1091
+ q,
1092
+ k,
1093
+ v,
1094
+ attn_mask=None,
1095
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
1096
+ is_causal=False,
1097
+ )
1098
+
1099
+ # Re-assemble all head outputs side-by-side.
1100
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
1101
+
1102
+ # Apply output projection.
1103
+ return self.attn_out(att), present
1104
+
1105
+ @abstractmethod
1106
+ def forward(
1107
+ self,
1108
+ x: torch.Tensor,
1109
+ input_ids: Optional[torch.LongTensor] = None,
1110
+ attention_bias: Optional[torch.FloatTensor] = None,
1111
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1112
+ use_cache: bool = False,
1113
+ engram_hash: Optional[torch.Tensor] = None,
1114
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
1115
+ raise NotImplementedError
1116
+
1117
+ @classmethod
1118
+ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
1119
+ if config.block_type == BlockType.sequential:
1120
+ return LLaDASequentialBlock(layer_id, config, cache)
1121
+ elif config.block_type == BlockType.llama:
1122
+ return LLaDALlamaBlock(layer_id, config, cache)
1123
+ else:
1124
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
1125
+
1126
+
1127
+ class LLaDASequentialBlock(LLaDABlock):
1128
+ """
1129
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
1130
+ (plus another skip connection).
1131
+ """
1132
+
1133
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
1134
+ super().__init__(layer_id, config, cache)
1135
+ # Layer norms.
1136
+ self.attn_norm = LayerNorm.build(config)
1137
+ self.ff_norm = LayerNorm.build(config)
1138
+ # Attention input projection. Projects x -> (q, k, v)
1139
+ head_dim = config.d_model // config.n_heads
1140
+ self.fused_dims = (
1141
+ config.d_model,
1142
+ config.effective_n_kv_heads * head_dim,
1143
+ config.effective_n_kv_heads * head_dim,
1144
+ )
1145
+ self.att_proj = nn.Linear(
1146
+ config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
1147
+ )
1148
+ # Feed-forward input projection.
1149
+ self.ff_proj = nn.Linear(
1150
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
1151
+ )
1152
+
1153
+ def reset_parameters(self):
1154
+ super().reset_parameters()
1155
+ self.attn_norm.reset_parameters()
1156
+ self.ff_norm.reset_parameters()
1157
+ # NOTE: the standard deviation for these weights does not depend on the layer.
1158
+ init_weights(
1159
+ self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
1160
+ )
1161
+ init_weights(
1162
+ self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
1163
+ )
1164
+
1165
+ def forward(
1166
+ self,
1167
+ x: torch.Tensor,
1168
+ input_ids: Optional[torch.LongTensor] = None,
1169
+ attention_bias: Optional[torch.Tensor] = None,
1170
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1171
+ use_cache: bool = False,
1172
+ engram_hash: Optional[torch.Tensor] = None,
1173
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
1174
+ if self.engram is not None:
1175
+ assert input_ids is not None
1176
+ x = x + self.engram(x, input_ids, engram_hash=engram_hash)
1177
+
1178
+ # Get query, key, value projections.
1179
+ # shape:
1180
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
1181
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
1182
+ # k, v: (batch_size, seq_len, d_model // n_heads)
1183
+ # - for group query attn q: (batch_size, seq_len, d_model)
1184
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
1185
+ if self._activation_checkpoint_fn is not None:
1186
+ q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
1187
+ self.fused_dims, dim=-1
1188
+ )
1189
+ else:
1190
+ q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
1191
+
1192
+ # Get attention scores.
1193
+ if self._activation_checkpoint_fn is not None:
1194
+ att, cache = self._activation_checkpoint_fn( # type: ignore
1195
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
1196
+ )
1197
+ else:
1198
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
1199
+
1200
+ # Add attention scores.
1201
+ # shape: (B, T, C)
1202
+ x = x + self.dropout(att)
1203
+
1204
+ # Add feed-forward projection.
1205
+ # shape: (batch_size, seq_len, d_model)
1206
+ og_x = x
1207
+ if self._activation_checkpoint_fn is not None:
1208
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
1209
+ else:
1210
+ x = self.ff_norm(x)
1211
+ x = self.ff_proj(x)
1212
+ if self._activation_checkpoint_fn is not None:
1213
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
1214
+ else:
1215
+ x = self.act(x)
1216
+ x = self.ff_out(x)
1217
+ x = self.dropout(x)
1218
+ x = og_x + x
1219
+
1220
+ return x, cache
1221
+
1222
+
1223
+ class LLaDALlamaBlock(LLaDABlock):
1224
+ """
1225
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
1226
+ (plus another skip connection). This block is similar to `LLaDASequentialBlock`
1227
+ but some operations have slightly different implementations to imitate the
1228
+ behavior of Llama.
1229
+ """
1230
+
1231
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
1232
+ super().__init__(layer_id, config, cache)
1233
+ # Layer norms.
1234
+ self.attn_norm = LayerNorm.build(config)
1235
+ self.ff_norm = LayerNorm.build(config)
1236
+ self.__cache = cache
1237
+
1238
+ # Attention input projection. Projects x -> (q, k, v)
1239
+ head_dim = config.d_model // config.n_heads
1240
+ q_proj_out_dim = config.d_model
1241
+ k_proj_out_dim = config.effective_n_kv_heads * head_dim
1242
+ v_proj_out_dim = config.effective_n_kv_heads * head_dim
1243
+ self.q_proj = nn.Linear(
1244
+ config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
1245
+ )
1246
+ self.k_proj = nn.Linear(
1247
+ config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
1248
+ )
1249
+ self.v_proj = nn.Linear(
1250
+ config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
1251
+ )
1252
+
1253
+ # Feed-forward input projection.
1254
+ self.ff_proj = nn.Linear(
1255
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
1256
+ )
1257
+ # new add
1258
+ self.up_proj = nn.Linear(
1259
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
1260
+ )
1261
+
1262
+ def reset_parameters(self):
1263
+ super().reset_parameters()
1264
+ self.attn_norm.reset_parameters()
1265
+ self.ff_norm.reset_parameters()
1266
+ # NOTE: the standard deviation for these weights does not depend on the layer.
1267
+ init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
1268
+ init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
1269
+ init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
1270
+ init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
1271
+ init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add
1272
+
1273
+ def forward(
1274
+ self,
1275
+ x: torch.Tensor,
1276
+ input_ids: Optional[torch.LongTensor] = None,
1277
+ attention_bias: Optional[torch.Tensor] = None,
1278
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1279
+ use_cache: bool = False,
1280
+ engram_hash: Optional[torch.Tensor] = None,
1281
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
1282
+ if self.engram is not None:
1283
+ assert input_ids is not None
1284
+ x = x + self.engram(x, input_ids, engram_hash=engram_hash)
1285
+
1286
+ # Get query, key, value projections.
1287
+ # shape:
1288
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
1289
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
1290
+ # k, v: (batch_size, seq_len, d_model // n_heads)
1291
+ # - for group query attn q: (batch_size, seq_len, d_model)
1292
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
1293
+ x_normed = self.attn_norm(x)
1294
+ q = self.q_proj(x_normed)
1295
+ k = self.k_proj(x_normed)
1296
+ v = self.v_proj(x_normed)
1297
+
1298
+ # Get attention scores.
1299
+ if self._activation_checkpoint_fn is not None:
1300
+ att, cache = self._activation_checkpoint_fn( # type: ignore
1301
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
1302
+ )
1303
+ else:
1304
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
1305
+
1306
+ # Add attention scores.
1307
+ # shape: (B, T, C)
1308
+ x = x + self.dropout(att)
1309
+
1310
+ # Add feed-forward projection.
1311
+ # shape: (batch_size, seq_len, d_model)
1312
+ og_x = x
1313
+ if self._activation_checkpoint_fn is not None:
1314
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
1315
+ else:
1316
+ x = self.ff_norm(x)
1317
+ x, x_up = self.ff_proj(x), self.up_proj(x) # new add
1318
+ if self._activation_checkpoint_fn is not None:
1319
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
1320
+ else:
1321
+ x = self.act(x)
1322
+ x = x * x_up # new add
1323
+ x = self.ff_out(x)
1324
+ x = self.dropout(x)
1325
+ x = og_x + x
1326
+
1327
+ return x, cache
1328
+
1329
+
1330
+ class LLaDAOutput(NamedTuple):
1331
+ logits: torch.FloatTensor
1332
+ """
1333
+ A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
1334
+ for the next token *before* normalization via (log) softmax.
1335
+ """
1336
+
1337
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
1338
+ """
1339
+ Attention keys and values from each block.
1340
+ """
1341
+
1342
+ hidden_states: Optional[Tuple[torch.Tensor]]
1343
+ """
1344
+ Hidden states from each block.
1345
+ """
1346
+
1347
+
1348
+ class LLaDAGenerateOutput(NamedTuple):
1349
+ token_ids: torch.LongTensor
1350
+ """
1351
+ The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
1352
+ These do *not* include the original input IDs.
1353
+ """
1354
+
1355
+ scores: torch.FloatTensor
1356
+ """
1357
+ The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
1358
+ """
1359
+
1360
+
1361
+ class LLaDABlockGroup(nn.ModuleList):
1362
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
1363
+ super().__init__(modules)
1364
+ self.config = config
1365
+ self.layer_offset = layer_offset
1366
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1367
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
1368
+
1369
+ def forward(
1370
+ self,
1371
+ x: torch.Tensor,
1372
+ input_ids: Optional[torch.LongTensor] = None,
1373
+ attention_bias: Optional[torch.FloatTensor] = None,
1374
+ layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
1375
+ use_cache: bool = False,
1376
+ engram_hashes: Optional[Dict[int, torch.Tensor]] = None,
1377
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
1378
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1379
+ for block_idx, block in enumerate(self):
1380
+ layer_past = None if layers_past is None else layers_past[block_idx]
1381
+ block_idx += self.layer_offset
1382
+ if (
1383
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1384
+ or (
1385
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1386
+ and block_idx % 2 == 0
1387
+ )
1388
+ or (
1389
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1390
+ and block_idx % 3 == 0
1391
+ )
1392
+ or (
1393
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1394
+ and block_idx % 4 == 0
1395
+ )
1396
+ ):
1397
+ # shape: (batch_size, seq_len, d_model)
1398
+ x, cache = self._activation_checkpoint_fn( # type: ignore
1399
+ block, x, input_ids=input_ids, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache,
1400
+ engram_hash=None if engram_hashes is None else engram_hashes.get(block_idx)
1401
+ )
1402
+ else:
1403
+ # shape: (batch_size, seq_len, d_model)
1404
+ x, cache = block(x, input_ids=input_ids, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache,
1405
+ engram_hash=None if engram_hashes is None else engram_hashes.get(block_idx))
1406
+ if attn_key_values is not None:
1407
+ assert cache is not None
1408
+ attn_key_values.append(cache)
1409
+ return x, attn_key_values
1410
+
1411
+ def reset_parameters(self):
1412
+ for block in self:
1413
+ block.reset_parameters()
1414
+
1415
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1416
+ self.activation_checkpointing_strategy = strategy
1417
+ for block in self:
1418
+ block.set_activation_checkpointing(strategy)
1419
+
1420
+
1421
+ class LLaDAModel(nn.Module):
1422
+ def __init__(self, config: ModelConfig, init_params: bool = True):
1423
+ super().__init__()
1424
+ self.config = config
1425
+ self.__cache = BufferCache()
1426
+
1427
+ # Validate config.
1428
+ if self.config.alibi and self.config.flash_attention:
1429
+ raise Exception("ALiBi is currently not supported with FlashAttention")
1430
+
1431
+ if self.config.alibi and self.config.rope:
1432
+ raise Exception("ALiBi and RoPE are mutually exclusive")
1433
+
1434
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
1435
+ if self.config.embedding_size < self.config.vocab_size:
1436
+ raise Exception("embedding size should be at least as big as vocab size")
1437
+ elif self.config.embedding_size % 128 != 0:
1438
+ import warnings
1439
+
1440
+ warnings.warn(
1441
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1442
+ )
1443
+
1444
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1445
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1446
+
1447
+ if not (
1448
+ 0 < self.config.block_group_size <= self.config.n_layers
1449
+ and self.config.n_layers % self.config.block_group_size == 0
1450
+ ):
1451
+ raise Exception("n layers must be divisible by block group size")
1452
+
1453
+ torch.backends.cuda.enable_flash_sdp(True)
1454
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
1455
+
1456
+ self.transformer = nn.ModuleDict(
1457
+ dict(
1458
+ wte=nn.Embedding(
1459
+ config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
1460
+ ),
1461
+ emb_drop=Dropout(config.embedding_dropout),
1462
+ ln_f=LayerNorm.build(config),
1463
+ )
1464
+ )
1465
+
1466
+ blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
1467
+ if self.config.block_group_size > 1:
1468
+ block_groups = [
1469
+ LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
1470
+ for i in range(0, config.n_layers, config.block_group_size)
1471
+ ]
1472
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
1473
+ else:
1474
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
1475
+
1476
+ if not (self.config.alibi or self.config.rope):
1477
+ self.transformer.update(
1478
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1479
+ )
1480
+ if not config.weight_tying:
1481
+ self.transformer.update(
1482
+ {
1483
+ "ff_out": nn.Linear(
1484
+ config.d_model,
1485
+ config.embedding_size or config.vocab_size,
1486
+ bias=config.include_bias,
1487
+ device=config.init_device,
1488
+ )
1489
+ }
1490
+ )
1491
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1492
+ if init_params and self.config.init_device != "meta":
1493
+ self.reset_parameters()
1494
+ self.__num_fwd_flops: Optional[int] = None
1495
+
1496
+ # Warm up cache.
1497
+ if self.config.alibi:
1498
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1499
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1500
+
1501
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1502
+ self.activation_checkpointing_strategy = strategy
1503
+ if self.config.block_group_size != 1:
1504
+ for block_group in self.transformer.block_groups:
1505
+ block_group.set_activation_checkpointing(strategy)
1506
+ else:
1507
+ for block in self.transformer.blocks:
1508
+ block.set_activation_checkpointing(strategy)
1509
+
1510
+ @property
1511
+ def device(self) -> torch.device:
1512
+ device: torch.device = self.transformer.wte.weight.device # type: ignore
1513
+ if device.type == "meta":
1514
+ return _non_meta_init_device(self.config)
1515
+ else:
1516
+ return device
1517
+
1518
+ def reset_parameters(self):
1519
+ log.info("Initializing model parameters...")
1520
+ # Top-level embeddings / linear layers.
1521
+ init_weights(
1522
+ self.config,
1523
+ self.transformer.wte, # type: ignore
1524
+ std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
1525
+ type_of_module=ModuleType.emb,
1526
+ )
1527
+ if hasattr(self.transformer, "wpe"):
1528
+ init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
1529
+
1530
+ # Top-level layer norm.
1531
+ self.transformer.ln_f.reset_parameters() # type: ignore
1532
+
1533
+ # Output weights.
1534
+ if hasattr(self.transformer, "ff_out"):
1535
+ init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
1536
+
1537
+ # Let the blocks handle themselves.
1538
+ if self.config.block_group_size == 1:
1539
+ for block in self.transformer.blocks:
1540
+ block.reset_parameters()
1541
+ else:
1542
+ for block_group in self.transformer.block_groups:
1543
+ block_group.reset_parameters()
1544
+
1545
+ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1546
+ if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
1547
+ -1
1548
+ ] >= seq_len:
1549
+ if alibi_bias.device != device:
1550
+ alibi_bias = alibi_bias.to(device)
1551
+ self.__cache["alibi_attention_bias"] = alibi_bias
1552
+ return alibi_bias
1553
+ with torch.autocast(device.type, enabled=False):
1554
+ alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1555
+ self.__cache["alibi_attention_bias"] = alibi_bias
1556
+ return alibi_bias
1557
+
1558
+ def forward(
1559
+ self,
1560
+ input_ids: torch.LongTensor,
1561
+ input_embeddings: Optional[torch.FloatTensor] = None,
1562
+ attention_mask: Optional[torch.Tensor] = None,
1563
+ attention_bias: Optional[torch.Tensor] = None,
1564
+ past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1565
+ use_cache: bool = False,
1566
+ last_logits_only: bool = False,
1567
+ output_hidden_states: Optional[bool] = None,
1568
+ engram_hashes: Optional[Dict[int, torch.Tensor]] = None,
1569
+ ) -> LLaDAOutput:
1570
+ """
1571
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1572
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1573
+ embeddings. When provided, it is treated as the output of the input embedding layer.
1574
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1575
+ which input IDs are masked. A `1` value in the mask means that
1576
+ the corresponding input ID should *not* be ignored. A `0` means
1577
+ that the corresponding input ID is masked.
1578
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1579
+ library.
1580
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1581
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1582
+ to introduce causal or other biases.
1583
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1584
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
1585
+ element in the sequence.
1586
+ If the tensor is a float tensor, it will just be added to the attention
1587
+ scores before the softmax.
1588
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1589
+ :param past_key_values: Pre-computed keys and values for each attention block.
1590
+ Can be used to speed up sequential decoding. The `input_ids` which have
1591
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
1592
+ :param use_cache: If `True`, return key and value tensors for each block.
1593
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1594
+ This can speed up decoding when you only care about the next token.
1595
+ """
1596
+ # Add Basic MDM Model config check
1597
+ assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
1598
+ assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
1599
+ assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM."
1600
+
1601
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1602
+
1603
+ if past_key_values:
1604
+ assert len(past_key_values) == self.config.n_layers
1605
+
1606
+ batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
1607
+ if past_key_values is None:
1608
+ past_length = 0
1609
+ else:
1610
+ past_length = past_key_values[0][0].size(-2)
1611
+
1612
+ # Get embeddings of input.
1613
+ # shape: (batch_size, seq_len, d_model)
1614
+ x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
1615
+
1616
+ if self.config.input_emb_norm:
1617
+ x = x * (self.config.d_model**0.5)
1618
+
1619
+ if not (self.config.alibi or self.config.rope):
1620
+ # Get positional embeddings.
1621
+ # shape: (1, seq_len)
1622
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1623
+ # shape: (1, seq_len, d_model)
1624
+ pos_emb = self.transformer.wpe(pos) # type: ignore
1625
+ x = pos_emb + x
1626
+
1627
+ # Add input + positional embeddings and apply dropout.
1628
+ # shape: (batch_size, seq_len, d_model)
1629
+ x = self.transformer.emb_drop(x) # type: ignore
1630
+
1631
+ # Transform the attention mask into what the blocks expect.
1632
+ if attention_mask is not None and 0.0 in attention_mask:
1633
+ # shape: (batch_size, 1, 1, seq_len)
1634
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1635
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1636
+ else:
1637
+ attention_mask = None
1638
+
1639
+ # Merge attention mask with attention bias.
1640
+ if (
1641
+ attention_bias is not None
1642
+ or attention_mask is not None
1643
+ or self.config.alibi
1644
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1645
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1646
+ # scores correctly.
1647
+ or past_key_values is not None
1648
+ ):
1649
+ if attention_bias is None and self.config.alibi:
1650
+ attention_bias = get_causal_attention_bias(
1651
+ self.__cache, past_length + seq_len, x.device
1652
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1653
+ elif attention_bias is None:
1654
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1655
+ elif attention_bias.dtype in (torch.int8, torch.bool):
1656
+ attention_bias = attention_bias.to(dtype=torch.float)
1657
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1658
+
1659
+ # Transform to the right shape and data type.
1660
+ mask_len = seq_len
1661
+ if attention_mask is not None:
1662
+ mask_len = attention_mask.shape[-1]
1663
+ elif past_key_values is not None:
1664
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
1665
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1666
+
1667
+ # Add in the masking bias.
1668
+ if attention_mask is not None:
1669
+ attention_bias = attention_bias + attention_mask
1670
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1671
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1672
+ # it can produce NaNs.
1673
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1674
+
1675
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1676
+
1677
+ # decoder layers
1678
+ all_hidden_states = []
1679
+
1680
+ # Apply blocks one-by-one.
1681
+ if self.config.block_group_size == 1:
1682
+ for block_idx, block in enumerate(self.transformer.blocks):
1683
+ if output_hidden_states:
1684
+ # add hidden states
1685
+ all_hidden_states.append(x)
1686
+
1687
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
1688
+ if (
1689
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1690
+ or (
1691
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1692
+ and block_idx % 2 == 0
1693
+ )
1694
+ or (
1695
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1696
+ and block_idx % 3 == 0
1697
+ )
1698
+ or (
1699
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1700
+ and block_idx % 4 == 0
1701
+ )
1702
+ ):
1703
+ # shape: (batch_size, seq_len, d_model)
1704
+ x, cache = self._activation_checkpoint_fn(
1705
+ block, x, input_ids=input_ids, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache,
1706
+ engram_hash=None if engram_hashes is None else engram_hashes.get(block_idx)
1707
+ )
1708
+ else:
1709
+ # shape: (batch_size, seq_len, d_model)
1710
+ x, cache = block(x, input_ids=input_ids, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache,
1711
+ engram_hash=None if engram_hashes is None else engram_hashes.get(block_idx))
1712
+ if attn_key_values is not None:
1713
+ assert cache is not None
1714
+ attn_key_values.append(cache)
1715
+ else:
1716
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
1717
+ if output_hidden_states:
1718
+ # add hidden states
1719
+ all_hidden_states.append(x)
1720
+
1721
+ layers_past = (
1722
+ None
1723
+ if past_key_values is None
1724
+ else past_key_values[
1725
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1726
+ ]
1727
+ )
1728
+ x, cache = block_group(
1729
+ x, input_ids=input_ids, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache,
1730
+ engram_hashes=engram_hashes
1731
+ )
1732
+ if attn_key_values is not None:
1733
+ assert cache is not None
1734
+ attn_key_values.extend(cache)
1735
+
1736
+ if last_logits_only:
1737
+ # shape: (batch_size, 1, d_model)
1738
+ x = x[:, -1, :].unsqueeze(1)
1739
+
1740
+ # Apply final layer norm.
1741
+ # shape: (batch_size, seq_len or 1, d_model)
1742
+ x = self.transformer.ln_f(x) # type: ignore
1743
+ if output_hidden_states:
1744
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
1745
+ all_hidden_states.append(x)
1746
+
1747
+ # Get logits.
1748
+ # shape: (batch_size, seq_len or 1, vocab_size)
1749
+ if self.config.weight_tying:
1750
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1751
+ else:
1752
+ logits = self.transformer.ff_out(x) # type: ignore
1753
+ if self.config.scale_logits:
1754
+ logits.mul_(1 / math.sqrt(self.config.d_model))
1755
+
1756
+ return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
1757
+
1758
+
1759
+ def create_model_config_from_pretrained_config(config: LLaDAConfig):
1760
+ """
1761
+ Utility function
1762
+ """
1763
+
1764
+ kwargs = {}
1765
+ for field in fields(ModelConfig):
1766
+ val = getattr(config, field.name, None)
1767
+ if field.name == "engram_config" and isinstance(val, dict):
1768
+ val = EngramConfig(**val)
1769
+ kwargs[field.name] = val
1770
+
1771
+ model_config = ModelConfig(**kwargs)
1772
+ return model_config
1773
+
1774
+
1775
+ class LLaDAModelLM(PreTrainedModel):
1776
+ """
1777
+ Extremely barebones HF model wrapper.
1778
+ """
1779
+
1780
+ config_class = LLaDAConfig
1781
+ base_model_prefix = "model"
1782
+ _no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
1783
+
1784
+ def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
1785
+ super().__init__(config)
1786
+
1787
+ if not model:
1788
+ model_config = create_model_config_from_pretrained_config(config)
1789
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1790
+ model_config.init_device = "cpu"
1791
+ self.model = LLaDAModel(model_config, init_params=init_params)
1792
+ else:
1793
+ self.model = model
1794
+
1795
+ def forward(
1796
+ self,
1797
+ input_ids: torch.LongTensor = None,
1798
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1799
+ attention_mask: Optional[torch.Tensor] = None,
1800
+ attention_bias: Optional[torch.Tensor] = None,
1801
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1802
+ labels: Optional[torch.LongTensor] = None,
1803
+ use_cache: Optional[bool] = None,
1804
+ output_attentions: Optional[bool] = None,
1805
+ output_hidden_states: Optional[bool] = None,
1806
+ return_dict: Optional[bool] = None,
1807
+ cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
1808
+ engram_hashes: Optional[Dict[int, torch.Tensor]] = None,
1809
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1810
+ if use_cache is None:
1811
+ use_cache = self.config.use_cache
1812
+
1813
+ if output_attentions:
1814
+ raise ValueError("output_attentions is not yet supported in LLaDA")
1815
+
1816
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1817
+
1818
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1819
+ outputs = self.model.forward(
1820
+ input_ids=input_ids,
1821
+ input_embeddings=inputs_embeds,
1822
+ attention_mask=attention_mask,
1823
+ attention_bias=attention_bias,
1824
+ past_key_values=past_key_values,
1825
+ use_cache=use_cache,
1826
+ output_hidden_states=output_hidden_states,
1827
+ engram_hashes=engram_hashes,
1828
+ )
1829
+
1830
+ logits = outputs.logits
1831
+ hidden_states = outputs.hidden_states
1832
+
1833
+ loss = None
1834
+ if labels is not None:
1835
+ import warnings
1836
+ warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
1837
+ if not return_dict:
1838
+ output = (logits,) + outputs[1:]
1839
+ return (loss,) + output if loss is not None else output
1840
+
1841
+ return CausalLMOutputWithPast(
1842
+ logits=logits,
1843
+ past_key_values=outputs.attn_key_values,
1844
+ hidden_states=hidden_states,
1845
+ )
1846
+
1847
+ def can_generate(self) -> bool:
1848
+ return True
1849
+
1850
+ def prepare_inputs_for_generation(
1851
+ self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
1852
+ ):
1853
+ if past_key_values:
1854
+ # This is because we want the model to only process the last generated token.
1855
+ input_ids = input_ids[:, -1:]
1856
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
1857
+
1858
+ model_inputs.update(kwargs)
1859
+ model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
1860
+ return model_inputs
1861
+
1862
+ # TODO: these are required to make the implementation complete.
1863
+ # def resize_position_embeddings(self, new_num_position_embeddings: int):
1864
+ # pass
1865
+ #
1866
+ # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
1867
+ # pass
1868
+ #
1869
+ # def _reorder_cache(self, past_key_values, beam_idx):
1870
+ # pass
1871
+
1872
+ def get_input_embeddings(self) -> torch.nn.Module:
1873
+ return self.model.transformer.wte
1874
+
1875
+ def set_input_embeddings(self, value: torch.nn.Module):
1876
+ self.model.transformer.wte = value
1877
+
1878
+ def get_output_embeddings(self):
1879
+ if self.config.weight_tying:
1880
+ return self.model.transformer.wte
1881
+ else:
1882
+ return self.model.transformer.ff_out
1883
+
1884
+ def set_output_embeddings(self, value: torch.nn.Module):
1885
+ if self.config.weight_tying:
1886
+ self.model.transformer.wte = value
1887
+ else:
1888
+ self.model.transformer.ff_out = value
1889
+
1890
+ def tie_weights(self):
1891
+ if self.config.weight_tying:
1892
+ self.model.transformer.ff_out = self.model.transformer.wte
1893
+
1894
+ # Register the model so that it is available for transformer pipelines, auto-loading, etc.
1895
+ AutoModel.register(LLaDAConfig, LLaDAModelLM)
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "tokenizer_file": "tokenizer.json",
4
+ "bos_token": null,
5
+ "eos_token": null,
6
+ "pad_token": 6629,
7
+ "unk_token": 6630
8
+ }