Init
Browse files- README.md +75 -0
- config.json +71 -0
- configuration_llada_engram.py +477 -0
- example.py +172 -0
- model.safetensors +3 -0
- modeling_llada_engram.py +1895 -0
- tokenizer.json +0 -0
- tokenizer_config.json +8 -0
README.md
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusion Engram IME Demo
|
| 2 |
+
|
| 3 |
+
[English Version](#introduction)
|
| 4 |
+
|
| 5 |
+
本项目探索了一种基于扩散语言模型的输入法实现思路。它基于 [LLaDA](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) 的实现,并融合了 [Engram](https://github.com/deepseek-ai/Engram) 模块,期盼利用语言模型强大的上下文理解能力来提升长句输入的准确性与连贯性。
|
| 6 |
+
|
| 7 |
+
模型在 [Chinese Fineweb Edu Dataset V2.1](https://huggingface.co/datasets/opencsg/Fineweb-Edu-Chinese-V2.1) 上进行了训练,采用[虎码输入方案](https://www.tiger-code.com/)作为编码标准。
|
| 8 |
+
|
| 9 |
+
## 使用方式
|
| 10 |
+
|
| 11 |
+
本项目提供了简易的交互脚本 `example.py`,用于演示核心功能。
|
| 12 |
+
|
| 13 |
+
### 输入格式
|
| 14 |
+
在提示符下输入字符串。
|
| 15 |
+
- **汉字**: 请输入其虎码编码的前两位。
|
| 16 |
+
- **标点/大写字母/特殊符号**: 直接输入原文即可。符号可以输入半角版本。
|
| 17 |
+
- **小写字母**: 输入字母+空格。
|
| 18 |
+
- **混合输入**: 支持编码与原文混合输入。
|
| 19 |
+
|
| 20 |
+
部分实例见 `example.py` 末尾注释。
|
| 21 |
+
|
| 22 |
+
## 局限性
|
| 23 |
+
|
| 24 |
+
⚠️本项目可能包含⚠️:
|
| 25 |
+
- 没有认真处理和选择的训练语料
|
| 26 |
+
- 拍脑袋想出来的模型架构和超参数
|
| 27 |
+
- 低效的模型实现和推理代码
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## 杂谈
|
| 31 |
+
其实早在 ChatGPT 还没有横空出世、我还没有了解过 transformer 的时候,我就思考过能不能用深度学习模型做出更加强大的输入法。那时候我试着学习形码(最后并没有坚持下来),也常常幻想其他提升输入效率的手段。我曾想输入法也许可以更熟悉语言应当有的语法和语义、能以更高的概率组出合理的句子。当然它的用户界面可能与现在的输入法大相径庭、以输入句子甚至段落为核心,从而能利用起来上下文的信息。(当然做这事的人可能不少。)不过那时没有 vibe coding 帮忙,我的行动力不足以让我把这些想法变成现实。
|
| 32 |
+
|
| 33 |
+
随着 LLM 的发展,我愈发觉得输入效率很大程度上限制了人机交互的效率。了解到 diffusion language model 的思路后,我觉得这非常适合输入法的场景:模型需要根据完整的上下文来预测每个字(而在自回归模型上做 constrastive decoding 只能感知到上文),并且可以从易到难地逐步推断原文,甚至可以无缝地处理部分词由用户手动选择的情况。之所以选择形码,首要原因是形码不会有多音字的问题,数据处理简单一些,且每个编码上的字数分布更加均匀。不过最后训练下来,效果不太理想。
|
| 34 |
+
|
| 35 |
+
看到 Engram 的时候,我立刻开始重新尝试这个项目。Engram 所做的对 n-gram 查表,几乎就是完美地承担了“词库”的职责,Engram 模块应当可以大幅度减轻模型主干记忆词库的压力。训练下来,结果确实也比之前好不少。
|
| 36 |
+
|
| 37 |
+
这个项目离实际可用的输入法还有很大差距:最显然的当然是其没有一个合适的用户界面,要有合适的方式让用户修改候选结果、且能适应零点几秒的延迟;模型的训练数据在类型上很窄,尤其缺乏口语化或文学化的内容;模型的推理几乎没有优化过;模型具体应该做多大、超参数如何选择也没有认真实验过;等等。除此之外,怎么让模型利用已经上屏的部分作为上下文,以及能否针对性地再改造 Engram(和其他各个模块)使其更适合输入法场景,都是潜在的改进方向。
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## Introduction
|
| 42 |
+
|
| 43 |
+
This project explores an implementation idea for an Input Method Editor (IME) based on diffusion language models. It is built upon the implementation of [LLaDA](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) and incorporates the [Engram](https://github.com/deepseek-ai/Engram) module, aiming to leverage the powerful context understanding capabilities of language models to improve the accuracy and coherence of long sentence input.
|
| 44 |
+
|
| 45 |
+
The model is trained on the [Chinese Fineweb Edu Dataset V2.1](https://huggingface.co/datasets/opencsg/Fineweb-Edu-Chinese-V2.1) and uses the [Tiger Code (Huma) input scheme](https://www.tiger-code.com/) as the encoding standard.
|
| 46 |
+
|
| 47 |
+
## Usage
|
| 48 |
+
|
| 49 |
+
This project provides a simple interactive script `example.py` to demonstrate core functionalities.
|
| 50 |
+
|
| 51 |
+
### Input Format
|
| 52 |
+
Enter a string at the prompt.
|
| 53 |
+
- **Chinese Characters**: Please enter the first two characters of their Tiger Code encoding.
|
| 54 |
+
- **Punctuation/Uppercase Letters/Special Symbols**: Enter the original characters directly. Symbols can be entered in their half-width versions.
|
| 55 |
+
- **Lowercase Letters**: Enter the letter followed by a space.
|
| 56 |
+
- **Mixed Input**: Supports mixing encoding and original text input.
|
| 57 |
+
|
| 58 |
+
See comments at the end of `example.py` for some examples.
|
| 59 |
+
|
| 60 |
+
## Limitations
|
| 61 |
+
|
| 62 |
+
⚠️ This project may contain ⚠️:
|
| 63 |
+
- Training corpora that have not been carefully processed or selected
|
| 64 |
+
- Model architecture and hyperparameters conceived on a whim
|
| 65 |
+
- Inefficient model implementation and inference code
|
| 66 |
+
|
| 67 |
+
## Ramblings
|
| 68 |
+
|
| 69 |
+
Actually, long before ChatGPT took the world by storm and before I knew anything about transformers, I wondered if deep learning models could be used to create a more powerful IME. At that time, I tried learning shape-based input methods (though I didn't stick with it) and often fantasized about other means to improve input efficiency. I thought that an IME could perhaps be more familiar with the syntax and semantics that language should have, and group reasonable sentences with higher probability. Of course, its user interface might be vastly different from current IMEs, focusing on inputting sentences or even paragraphs to utilize context information. (Of course, many people might be doing this.) However, without "vibe coding" to help me then, my lack of execution prevented me from turning these ideas into reality.
|
| 70 |
+
|
| 71 |
+
With the development of LLMs, I increasingly feel that input efficiency largely limits the efficiency of human-computer interaction. After learning about the idea of diffusion language models, I felt this was very suitable for IME scenarios: the model needs to predict each character based on the complete context (whereas contrastive decoding on autoregressive models can only perceive the preceding text), and can infer the original text gradually from easy to difficult, and can even seamlessly handle cases where some words are manually selected by the user. The primary reason for choosing a shape-based code is that it avoids the problem of polyphones, simplifying data processing, and the character distribution for each code is more uniform. However, the initial training results were not ideal.
|
| 72 |
+
|
| 73 |
+
When I saw Engram, I immediately started retrying this project. Engram's n-gram lookup almost perfectly assumes the responsibility of a "lexicon". The Engram module should significantly reduce the pressure on the model backbone to memorize the lexicon. After training, the results are indeed much better than before.
|
| 74 |
+
|
| 75 |
+
This project is still far from being a practically usable IME: the most obvious gap is the lack of a suitable user interface that allows users to modify candidate results and adapt to a latency of a few tenths of a second; the training data is narrow in type, especially lacking colloquial or literary content; the model's inference is hardly optimized; no serious experiments have been done on how large the model should be or how to select hyperparameters; etc. Besides, how to let the model use the already entered text as context, and whether Engram (and other modules) can be specifically transformed to better suit IME scenarios, are potential directions for improvement.
|
config.json
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_type": "silu",
|
| 3 |
+
"alibi": false,
|
| 4 |
+
"alibi_bias_max": 8.0,
|
| 5 |
+
"architectures": [
|
| 6 |
+
"LLaDAModelLM"
|
| 7 |
+
],
|
| 8 |
+
"attention_dropout": 0.0,
|
| 9 |
+
"attention_layer_norm": false,
|
| 10 |
+
"attention_layer_norm_with_affine": true,
|
| 11 |
+
"auto_map": {
|
| 12 |
+
"AutoConfig": "configuration_llada_engram.LLaDAConfig",
|
| 13 |
+
"AutoModelForCausalLM": "modeling_llada_engram.LLaDAModelLM",
|
| 14 |
+
"AutoModel": "modeling_llada_engram.LLaDAModelLM"
|
| 15 |
+
},
|
| 16 |
+
"bias_for_layer_norm": false,
|
| 17 |
+
"block_group_size": 1,
|
| 18 |
+
"block_type": "llama",
|
| 19 |
+
"d_model": 768,
|
| 20 |
+
"embedding_dropout": 0.0,
|
| 21 |
+
"embedding_size": 7424,
|
| 22 |
+
"eos_token_id": 6624,
|
| 23 |
+
"flash_attention": false,
|
| 24 |
+
"include_bias": false,
|
| 25 |
+
"include_qkv_bias": false,
|
| 26 |
+
"init_cutoff_factor": null,
|
| 27 |
+
"init_device": "meta",
|
| 28 |
+
"init_fn": "mitchell",
|
| 29 |
+
"init_std": 0.02,
|
| 30 |
+
"input_emb_norm": false,
|
| 31 |
+
"layer_norm_type": "rms",
|
| 32 |
+
"layer_norm_with_affine": true,
|
| 33 |
+
"mask_token_id": 7186,
|
| 34 |
+
"max_sequence_length": 128,
|
| 35 |
+
"mlp_hidden_size": 1536,
|
| 36 |
+
"model_type": "llada",
|
| 37 |
+
"multi_query_attention": null,
|
| 38 |
+
"n_heads": 12,
|
| 39 |
+
"n_kv_heads": 12,
|
| 40 |
+
"n_layers": 14,
|
| 41 |
+
"pad_token_id": 6624,
|
| 42 |
+
"precision": "amp_bf16",
|
| 43 |
+
"residual_dropout": 0.0,
|
| 44 |
+
"rms_norm_eps": 1e-05,
|
| 45 |
+
"rope": true,
|
| 46 |
+
"rope_full_precision": true,
|
| 47 |
+
"rope_theta": 500000.0,
|
| 48 |
+
"scale_logits": false,
|
| 49 |
+
"transformers_version": "4.46.3",
|
| 50 |
+
"use_cache": false,
|
| 51 |
+
"vocab_size": 7397,
|
| 52 |
+
"weight_tying": false,
|
| 53 |
+
"engram_config": {
|
| 54 |
+
"tokenizer_name_or_path": "./tokenizer.json",
|
| 55 |
+
"engram_vocab_size": [
|
| 56 |
+
51200,
|
| 57 |
+
51200,
|
| 58 |
+
51200
|
| 59 |
+
],
|
| 60 |
+
"max_ngram_size": 4,
|
| 61 |
+
"n_embed_per_ngram": 256,
|
| 62 |
+
"n_head_per_ngram": 4,
|
| 63 |
+
"layer_ids": [
|
| 64 |
+
1,
|
| 65 |
+
7
|
| 66 |
+
],
|
| 67 |
+
"pad_id": 6629,
|
| 68 |
+
"seed": 42,
|
| 69 |
+
"kernel_size": 7
|
| 70 |
+
}
|
| 71 |
+
}
|
configuration_llada_engram.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLaDA configuration
|
| 3 |
+
"""
|
| 4 |
+
from transformers import AutoConfig, PretrainedConfig
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from os import PathLike
|
| 8 |
+
from typing import Union
|
| 9 |
+
from dataclasses import asdict, dataclass, field
|
| 10 |
+
from glob import glob
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import (
|
| 13 |
+
Any,
|
| 14 |
+
Dict,
|
| 15 |
+
Iterable,
|
| 16 |
+
List,
|
| 17 |
+
Optional,
|
| 18 |
+
Tuple,
|
| 19 |
+
Type,
|
| 20 |
+
TypeVar,
|
| 21 |
+
Union,
|
| 22 |
+
cast,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"ActivationType",
|
| 28 |
+
"ActivationCheckpointingStrategy",
|
| 29 |
+
"BlockType",
|
| 30 |
+
"LayerNormType",
|
| 31 |
+
"InitFnType",
|
| 32 |
+
"ModelConfig",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
PathOrStr = Union[str, PathLike]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class StrEnum(str, Enum):
|
| 39 |
+
"""
|
| 40 |
+
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
| 41 |
+
We include this here for compatibility with older version of Python.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __str__(self) -> str:
|
| 45 |
+
return self.value
|
| 46 |
+
|
| 47 |
+
def __repr__(self) -> str:
|
| 48 |
+
return f"'{str(self)}'"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LayerNormType(StrEnum):
|
| 52 |
+
default = "default"
|
| 53 |
+
"""
|
| 54 |
+
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
low_precision = "low_precision"
|
| 58 |
+
"""
|
| 59 |
+
A low-precision version of the default LayerNorm.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
rms = "rms"
|
| 63 |
+
"""
|
| 64 |
+
An RMSNorm implementation. When using ``torch.compile`` this is
|
| 65 |
+
probably the fastest implementation.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
gemma_rms = "gemma_rms"
|
| 69 |
+
"""
|
| 70 |
+
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
|
| 71 |
+
probably the fastest implementation.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
amd_compatible = "amd_compatible"
|
| 75 |
+
"""
|
| 76 |
+
LayerNorm implemented manually to work around an issue with ROCm.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ActivationType(StrEnum):
|
| 81 |
+
gelu = "gelu"
|
| 82 |
+
relu = "relu"
|
| 83 |
+
silu = "silu"
|
| 84 |
+
swiglu = "swiglu"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BlockType(StrEnum):
|
| 88 |
+
sequential = "sequential"
|
| 89 |
+
parallel = "parallel"
|
| 90 |
+
|
| 91 |
+
llama = "llama"
|
| 92 |
+
"""
|
| 93 |
+
A block similar to the sequential block with slightly different
|
| 94 |
+
implementations of operations like attention to imitate the behavior of Llama.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class InitFnType(StrEnum):
|
| 99 |
+
mitchell = "mitchell"
|
| 100 |
+
"""
|
| 101 |
+
The strategy suggested to us by Mitchell Wortsman from UW.
|
| 102 |
+
This uses a truncated normal distribution with an adaptive standard deviation that depends
|
| 103 |
+
on the size of the weights as well as the depth of the layer.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
normal = "normal"
|
| 107 |
+
"""
|
| 108 |
+
All weights are initialized from the same normal distribution.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
kaiming_normal = "kaiming_normal"
|
| 112 |
+
"""
|
| 113 |
+
All weights are initialized with the Kaiming method from a normal distribution.
|
| 114 |
+
Note this currently won't work with FSDP.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
fan_in = "fan_in"
|
| 118 |
+
"""
|
| 119 |
+
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
|
| 120 |
+
is the input dimensionality of the kernel.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
full_megatron = "full_megatron"
|
| 124 |
+
"""
|
| 125 |
+
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class EngramConfig:
|
| 131 |
+
tokenizer_name_or_path: str = "deepseek-ai/DeepSeek-V3"
|
| 132 |
+
engram_vocab_size: List[int] = field(default_factory=lambda: [129280*5, 129280*5])
|
| 133 |
+
max_ngram_size: int = 3
|
| 134 |
+
n_embed_per_ngram: int = 512
|
| 135 |
+
n_head_per_ngram: int = 8
|
| 136 |
+
layer_ids: List[int] = field(default_factory=lambda: [1, 15])
|
| 137 |
+
pad_id: int = 2
|
| 138 |
+
seed: int = 0
|
| 139 |
+
kernel_size: int = 7
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass
|
| 143 |
+
class ModelConfig():
|
| 144 |
+
"""
|
| 145 |
+
LLaDA (model) configuration.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
|
| 149 |
+
|
| 150 |
+
d_model: int = 768
|
| 151 |
+
"""
|
| 152 |
+
The hidden size of the model.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
n_heads: int = 12
|
| 156 |
+
"""
|
| 157 |
+
The number of self-attention heads.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
n_kv_heads: Optional[int] = None
|
| 161 |
+
"""
|
| 162 |
+
The number of heads to use for keys and values. Defaults to `n_heads`.
|
| 163 |
+
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
|
| 164 |
+
Set this to 1 for multi-query attention.
|
| 165 |
+
Set it to some in-between value for Llama2-style grouped query attention.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
n_layers: int = 12
|
| 169 |
+
"""
|
| 170 |
+
The number of layers/blocks.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
mlp_ratio: int = 4
|
| 174 |
+
"""
|
| 175 |
+
The ratio of the inner MLP dimensionality to ``d_model``.
|
| 176 |
+
This is only used when ``mlp_hidden_size`` is not set.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
mlp_hidden_size: Optional[int] = None
|
| 180 |
+
"""
|
| 181 |
+
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
activation_type: ActivationType = ActivationType.swiglu
|
| 185 |
+
"""
|
| 186 |
+
The activation function to use within the MLP layers.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
block_type: BlockType = BlockType.sequential
|
| 190 |
+
"""
|
| 191 |
+
The transformer block implementation.
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
block_group_size: int = 1
|
| 195 |
+
"""
|
| 196 |
+
The number of blocks to group together into a single parent block.
|
| 197 |
+
This has no affect on the number of parameters in the model and is only used to wrap groups
|
| 198 |
+
of blocks together with a single FSDP wrapper during training.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
alibi: bool = False
|
| 202 |
+
"""
|
| 203 |
+
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
alibi_bias_max: float = 8.0
|
| 207 |
+
"""
|
| 208 |
+
Maximum absolute value of ALiBi bias.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
rope: bool = False
|
| 212 |
+
"""
|
| 213 |
+
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
rope_full_precision: bool = True
|
| 217 |
+
"""
|
| 218 |
+
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
|
| 219 |
+
apply RoPE at the precision of the input.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
flash_attention: bool = False
|
| 223 |
+
"""
|
| 224 |
+
If ``True``, use ``FlashAttention``.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
attention_dropout: float = 0.1
|
| 228 |
+
"""
|
| 229 |
+
The dropout probability within the attention modules.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
multi_query_attention: Optional[bool] = None
|
| 233 |
+
"""
|
| 234 |
+
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
|
| 235 |
+
and is more efficient during inference.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
attention_layer_norm: bool = False
|
| 239 |
+
"""
|
| 240 |
+
Apply layer norm to the keys and queries within the attention mechanism.
|
| 241 |
+
This can help stabilize training.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
residual_dropout: float = 0.1
|
| 245 |
+
"""
|
| 246 |
+
The dropout probability for the MLP and attention output within each block.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
embedding_dropout: float = 0.1
|
| 250 |
+
"""
|
| 251 |
+
The dropout probability for embeddings.
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
input_emb_norm: bool = False
|
| 255 |
+
"""
|
| 256 |
+
An input hidden_states norm implementation by gemmma.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
layer_norm_type: LayerNormType = LayerNormType.default
|
| 260 |
+
"""
|
| 261 |
+
The layernorm implementation to use.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
layer_norm_with_affine: bool = True
|
| 265 |
+
"""
|
| 266 |
+
Whether to include bias and weight parameters for the layer norms.
|
| 267 |
+
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
|
| 268 |
+
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
|
| 269 |
+
to ``False``.
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
rms_norm_eps: float = 1e-05
|
| 273 |
+
"""
|
| 274 |
+
The rms layernorm eps param.
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
attention_layer_norm_with_affine: bool = True
|
| 278 |
+
"""
|
| 279 |
+
Toggle affine transform for the QK norms.
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
max_sequence_length: int = 1024
|
| 283 |
+
"""
|
| 284 |
+
The maximum input sequence length supported by the model.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
rope_theta: float = 10000.0
|
| 288 |
+
"""
|
| 289 |
+
The rope base param.
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
include_qkv_bias: Optional[bool] = False
|
| 293 |
+
"""
|
| 294 |
+
Whether or not to include bias parameters in qkv linear layers.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
include_bias: bool = False
|
| 298 |
+
"""
|
| 299 |
+
Whether or not to include bias parameters in linear layers.
|
| 300 |
+
In PaLM, they got rid of all bias terms because they found that large
|
| 301 |
+
models tend to have near 0 bias terms anyway.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
bias_for_layer_norm: Optional[bool] = None
|
| 305 |
+
"""
|
| 306 |
+
Whether or not to include bias parameters in layer norm.
|
| 307 |
+
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
|
| 308 |
+
layer norm.
|
| 309 |
+
When this is None (the default), it inherits the setting from include_bias.
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
scale_logits: bool = False
|
| 313 |
+
"""
|
| 314 |
+
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
vocab_size: int = 50257
|
| 318 |
+
"""
|
| 319 |
+
Vocabulary size of the model.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
embedding_size: Optional[int] = 50304
|
| 323 |
+
"""
|
| 324 |
+
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
|
| 325 |
+
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
|
| 326 |
+
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
|
| 327 |
+
substantially.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
weight_tying: bool = True
|
| 331 |
+
"""
|
| 332 |
+
Whether to tie output linear weights to the input embedding.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
eos_token_id: int = 50256
|
| 336 |
+
"""
|
| 337 |
+
The ID of the end-of-sentence special token.
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
pad_token_id: int = 50256
|
| 341 |
+
"""
|
| 342 |
+
The ID of the token to use for padding. Defaults to the ID of the EOS token.
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
mask_token_id: Optional[int] = 50256
|
| 346 |
+
"""
|
| 347 |
+
The ID of the token to use for mask token. Defaults to the ID of the EOS token.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
init_device: Optional[str] = None
|
| 351 |
+
"""
|
| 352 |
+
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
init_fn: InitFnType = InitFnType.normal
|
| 356 |
+
"""
|
| 357 |
+
The weight initialization strategy.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
init_std: float = 0.02
|
| 361 |
+
"""
|
| 362 |
+
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
|
| 363 |
+
as "normal".
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
init_cutoff_factor: Optional[float] = None
|
| 367 |
+
"""
|
| 368 |
+
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
|
| 369 |
+
as "normal". Setting this to None means values are not cutoff.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
precision: Optional[str] = None
|
| 373 |
+
"""
|
| 374 |
+
Precision used to train/evaluate with. You shouldn't set this directly.
|
| 375 |
+
See :data:`TrainConfig.precision` instead.
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
engram_config: Optional[EngramConfig] = None
|
| 379 |
+
|
| 380 |
+
@property
|
| 381 |
+
def effective_n_kv_heads(self) -> int:
|
| 382 |
+
if self.n_kv_heads is None:
|
| 383 |
+
if self.multi_query_attention is True:
|
| 384 |
+
return 1
|
| 385 |
+
else:
|
| 386 |
+
return self.n_heads
|
| 387 |
+
else:
|
| 388 |
+
if self.multi_query_attention is None:
|
| 389 |
+
return self.n_kv_heads
|
| 390 |
+
if self.multi_query_attention:
|
| 391 |
+
n_kv_heads_should_be = 1
|
| 392 |
+
else:
|
| 393 |
+
n_kv_heads_should_be = self.n_heads
|
| 394 |
+
if self.n_kv_heads == n_kv_heads_should_be:
|
| 395 |
+
return n_kv_heads_should_be
|
| 396 |
+
else:
|
| 397 |
+
raise Exception(
|
| 398 |
+
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
class ActivationCheckpointingStrategy(StrEnum):
|
| 402 |
+
whole_layer = "whole_layer"
|
| 403 |
+
"""
|
| 404 |
+
Checkpoint every transformer layer.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
one_in_two = "one_in_two"
|
| 408 |
+
"""
|
| 409 |
+
Checkpoint one in two transformer layers.
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
one_in_three = "one_in_three"
|
| 413 |
+
"""
|
| 414 |
+
Checkpoint one in three transformer layers.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
one_in_four = "one_in_four"
|
| 418 |
+
"""
|
| 419 |
+
Checkpoint one in four transformer layers.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
two_in_three = "two_in_three"
|
| 423 |
+
"""
|
| 424 |
+
Checkpoint two out of every three transformer layers.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
three_in_four = "three_in_four"
|
| 428 |
+
"""
|
| 429 |
+
Checkpoint three out of four of every transformer layers.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
four_in_five = "four_in_five"
|
| 433 |
+
"""
|
| 434 |
+
Checkpoint four out of five of every transformer layers.
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
nine_in_ten = "nine_in_ten"
|
| 438 |
+
"""
|
| 439 |
+
Checkpoint nine out of ten of every transformer layers.
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
fine_grained = "fine_grained"
|
| 443 |
+
"""
|
| 444 |
+
Focus checkpointing on where it is cheap to recompute and saves most memory.
|
| 445 |
+
"""
|
| 446 |
+
|
| 447 |
+
class LLaDAConfig(PretrainedConfig):
|
| 448 |
+
model_type = "llada"
|
| 449 |
+
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
|
| 450 |
+
|
| 451 |
+
def __init__(self, use_cache: bool = False, **kwargs):
|
| 452 |
+
model_config = ModelConfig()
|
| 453 |
+
all_kwargs = model_config.__dict__
|
| 454 |
+
all_kwargs.update(kwargs)
|
| 455 |
+
all_kwargs.update({"use_cache": use_cache})
|
| 456 |
+
all_kwargs.update(
|
| 457 |
+
{
|
| 458 |
+
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
|
| 459 |
+
}
|
| 460 |
+
)
|
| 461 |
+
super().__init__(**all_kwargs)
|
| 462 |
+
|
| 463 |
+
@property
|
| 464 |
+
def num_attention_heads(self):
|
| 465 |
+
return self.n_heads
|
| 466 |
+
|
| 467 |
+
@property
|
| 468 |
+
def num_hidden_layers(self):
|
| 469 |
+
return self.n_layers
|
| 470 |
+
|
| 471 |
+
@property
|
| 472 |
+
def hidden_size(self):
|
| 473 |
+
return self.d_model
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# Register the config class so that it is available for transformer pipelines, auto-loading etc.
|
| 477 |
+
AutoConfig.register("llada", LLaDAConfig)
|
example.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers import Tokenizer
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def process_string_into_pairs(input_str: str) -> list[str]:
|
| 10 |
+
result = []
|
| 11 |
+
i = 0
|
| 12 |
+
n = len(input_str)
|
| 13 |
+
|
| 14 |
+
while i < n:
|
| 15 |
+
char = input_str[i]
|
| 16 |
+
|
| 17 |
+
# 检查当前字符是否为小写字母
|
| 18 |
+
if "a" <= char <= "z":
|
| 19 |
+
# 检查是否有下一个字符,并且下一个字符也是小写字母(配对情况)
|
| 20 |
+
if i + 1 < n and "a" <= input_str[i + 1] <= "z":
|
| 21 |
+
result.append(char + input_str[i + 1])
|
| 22 |
+
i += 2 # 跳过两个字符
|
| 23 |
+
# 检查是否有下一个字符,并且下一个字符是空格(落单小写字母+空格 的特殊情况)
|
| 24 |
+
elif i + 1 < n and input_str[i + 1] == " ":
|
| 25 |
+
result.append(char)
|
| 26 |
+
i += 2 # 跳过当前字母和后面的空格
|
| 27 |
+
# 其他情况(落单小写字母,后面是其他字符或已到末尾)
|
| 28 |
+
else:
|
| 29 |
+
result.append(char)
|
| 30 |
+
i += 1 # 只跳过当前一个字符
|
| 31 |
+
# 如果当前字符不是小写字母
|
| 32 |
+
else:
|
| 33 |
+
result.append(char)
|
| 34 |
+
i += 1 # 只跳过当前一个字符
|
| 35 |
+
|
| 36 |
+
return result
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_mask_from_string(input_str: str, tokenizer) -> torch.Tensor:
|
| 40 |
+
pairs = process_string_into_pairs(input_str)
|
| 41 |
+
masks = [
|
| 42 |
+
f"<|mask_{pair}|>" if all(ord(i) < 128 for i in pair) else pair
|
| 43 |
+
for pair in pairs
|
| 44 |
+
]
|
| 45 |
+
mask_tensor = torch.tensor(
|
| 46 |
+
[tokenizer.token_to_id(mask) for mask in masks], dtype=torch.long
|
| 47 |
+
)
|
| 48 |
+
return mask_tensor
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def inference(model, input_str: str, tokenizer, device, threshold=0.9):
|
| 52 |
+
model.eval()
|
| 53 |
+
|
| 54 |
+
# Initialize NgramHashMapping
|
| 55 |
+
engram_cfg = model.config.engram_config
|
| 56 |
+
hash_mapping = None
|
| 57 |
+
if engram_cfg is not None:
|
| 58 |
+
from modeling_llada_engram import ModelConfig, EngramConfig, NgramHashMapping
|
| 59 |
+
from dataclasses import fields
|
| 60 |
+
# Prepare ModelConfig for NgramHashMapping
|
| 61 |
+
backbone_config_dict = model.config.to_dict()
|
| 62 |
+
# Filter out keys not in ModelConfig if necessary, but ModelConfig usually matches LLaDAConfig fields
|
| 63 |
+
backbone_config = ModelConfig(**{k: v for k, v in backbone_config_dict.items() if k in [f.name for f in fields(ModelConfig)]})
|
| 64 |
+
|
| 65 |
+
hash_mapping = NgramHashMapping(
|
| 66 |
+
engram_vocab_size = engram_cfg.get('engram_vocab_size', [129280*5, 129280*5]),
|
| 67 |
+
max_ngram_size = engram_cfg.get('max_ngram_size', 3),
|
| 68 |
+
n_embed_per_ngram = engram_cfg.get('n_embed_per_ngram', 512),
|
| 69 |
+
n_head_per_ngram = engram_cfg.get('n_head_per_ngram', 8),
|
| 70 |
+
layer_ids = engram_cfg.get('layer_ids', [1, 15]),
|
| 71 |
+
pad_id = engram_cfg.get('pad_id', 2),
|
| 72 |
+
seed = engram_cfg.get('seed', 0),
|
| 73 |
+
config = backbone_config,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
mask_tensor = get_mask_from_string(input_str, tokenizer).unsqueeze(0).to(device)
|
| 78 |
+
# is_masked = torch.ones(mask_tensor.shape, dtype=torch.bool, device=device)
|
| 79 |
+
is_masked = mask_tensor >= tokenizer.token_to_id("<|mask|>")
|
| 80 |
+
rounds = 0
|
| 81 |
+
while is_masked.any():
|
| 82 |
+
rounds += 1
|
| 83 |
+
|
| 84 |
+
output = model(input_ids=mask_tensor)[0]
|
| 85 |
+
# Logit to probability
|
| 86 |
+
output = torch.softmax(output, dim=-1)
|
| 87 |
+
unmasked_any = False
|
| 88 |
+
prob_info = []
|
| 89 |
+
|
| 90 |
+
most_certain_token = (0, 0, 0) # (probability, index, token_id)
|
| 91 |
+
# Check each token that still is_masked
|
| 92 |
+
for i in range(mask_tensor.shape[1]):
|
| 93 |
+
if is_masked[0, i]:
|
| 94 |
+
# Get the token with the highest probability
|
| 95 |
+
predicted_token = output[0, i].argmax().item()
|
| 96 |
+
prob_info.append(
|
| 97 |
+
f"{output[0, i, predicted_token].item():.2f} {tokenizer.id_to_token(predicted_token)}"
|
| 98 |
+
)
|
| 99 |
+
most_certain_token = max(
|
| 100 |
+
most_certain_token,
|
| 101 |
+
(output[0, i, predicted_token].item(), i, predicted_token)
|
| 102 |
+
)
|
| 103 |
+
# If the probability is above the threshold, replace the mask
|
| 104 |
+
if output[0, i, predicted_token].item() > threshold:
|
| 105 |
+
mask_tensor[0, i] = predicted_token
|
| 106 |
+
is_masked[0, i] = False
|
| 107 |
+
unmasked_any = True
|
| 108 |
+
else:
|
| 109 |
+
prob_info.append("")
|
| 110 |
+
if not unmasked_any:
|
| 111 |
+
# Unmask the most certain one
|
| 112 |
+
mask_tensor[0, most_certain_token[1]] = most_certain_token[2]
|
| 113 |
+
is_masked[0, most_certain_token[1]] = False
|
| 114 |
+
|
| 115 |
+
masked_str = "".join(
|
| 116 |
+
(
|
| 117 |
+
tokenizer.id_to_token(mask_tensor[0, i].item())
|
| 118 |
+
if not is_masked[0, i]
|
| 119 |
+
else tokenizer.id_to_token(mask_tensor[0, i].item())[7:-2]
|
| 120 |
+
)
|
| 121 |
+
for i in range(mask_tensor.shape[1])
|
| 122 |
+
)
|
| 123 |
+
print(masked_str)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 128 |
+
tokenizer = Tokenizer.from_file("tokenizer.json")
|
| 129 |
+
|
| 130 |
+
# Load from local directory using AutoModel
|
| 131 |
+
# Note: Ensure you have transformers installed and trust_remote_code=True
|
| 132 |
+
try:
|
| 133 |
+
from transformers import AutoModelForCausalLM
|
| 134 |
+
model = AutoModelForCausalLM.from_pretrained(".", trust_remote_code=True).to(device)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Failed to load with AutoModel: {e}")
|
| 137 |
+
print("Falling back to manual loading (if needed, but prefer AutoModel for validation)")
|
| 138 |
+
# Fallback code removed for clarity as we want to enforce AutoModel structure
|
| 139 |
+
raise e
|
| 140 |
+
|
| 141 |
+
# To bfloat16
|
| 142 |
+
model = model.to(torch.bfloat16) if device.type == "cuda" else model.float()
|
| 143 |
+
print("Loaded model. Parameters:", sum(p.numel() for p in model.parameters()))
|
| 144 |
+
|
| 145 |
+
threshold = 0.9
|
| 146 |
+
|
| 147 |
+
while True:
|
| 148 |
+
input_str = input("Enter a string to process: ")
|
| 149 |
+
inference(model, input_str, tokenizer, device, threshold=threshold)
|
| 150 |
+
print("") # 空行分隔
|
| 151 |
+
|
| 152 |
+
# Input example: nhkzotdgjvdmleunkmiekz。
|
| 153 |
+
# Output: 黄河是中华民族的母亲河。
|
| 154 |
+
|
| 155 |
+
# Input example: mdflswsyelfl,eyxxmdswsyelfl,raxxmdelfl,otfixdzhfnjrugfoirmbisunswsyelfl。zhldxxdgun“mdfl”uvelflqhnvxtmdunkmpbofvjcjnnmdunsoirpbucheel。
|
| 156 |
+
# Output: 大型语言模型,也称大语言模型,简称大模型,是一种基于人工神经网络的语言模型。其名称中的“大型”指模型具有庞大的参数量以及巨大的训练数据规模。
|
| 157 |
+
|
| 158 |
+
# Input example: hgzz(Go o g l e )otfiwjpmrnxjuchkaf,hdidjifngmrnsdoovsoggn.
|
| 159 |
+
# Output:
|
| 160 |
+
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城.
|
| 161 |
+
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
|
| 162 |
+
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
|
| 163 |
+
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
|
| 164 |
+
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
|
| 165 |
+
|
| 166 |
+
# Input example: jxvuygvbotghtusvwtvbdt。auwvvbotcbghwhtkshdl?
|
| 167 |
+
# Output:
|
| 168 |
+
# 天对地,雨对风。大陆对长空。山lj对ke树,赤日对ljeb。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨tq晚霞红。
|
| 169 |
+
# 天对地,雨对风。大陆对长空。山lj对杂树,赤日对苍eb。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨雷晚霞红。
|
| 170 |
+
# 天对地,雨对风。大陆对长空。山lj对杂树,赤日对苍穹。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨雷晚霞红。
|
| 171 |
+
# 天对地,雨对风。大陆对长空。山苍对杂树,赤日对苍穹。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨雷晚霞红。
|
| 172 |
+
# (Expected Output: 天对地,雨对风。大陆对长空。山花对海树,赤日对苍穹。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨霁晚霞红。)
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0114ad14a6671ade8155e31d930bb1c19779dab6af574d139d11678d7152270
|
| 3 |
+
size 700907000
|
modeling_llada_engram.py
ADDED
|
@@ -0,0 +1,1895 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import sys
|
| 6 |
+
from abc import abstractmethod
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import (
|
| 10 |
+
Callable,
|
| 11 |
+
Dict,
|
| 12 |
+
Iterable,
|
| 13 |
+
List,
|
| 14 |
+
NamedTuple,
|
| 15 |
+
Optional,
|
| 16 |
+
Sequence,
|
| 17 |
+
Set,
|
| 18 |
+
Tuple,
|
| 19 |
+
cast,
|
| 20 |
+
)
|
| 21 |
+
from dataclasses import fields
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.backends.cuda
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import torch.nn.utils.rnn as rnn_utils
|
| 29 |
+
from torch import einsum
|
| 30 |
+
from transformers import PreTrainedModel
|
| 31 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 32 |
+
from transformers.models.auto import AutoModel
|
| 33 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 34 |
+
from transformers.cache_utils import Cache
|
| 35 |
+
from sympy import isprime
|
| 36 |
+
import numpy as np
|
| 37 |
+
|
| 38 |
+
from configuration_llada_engram import (
|
| 39 |
+
EngramConfig,
|
| 40 |
+
LLaDAConfig,
|
| 41 |
+
StrEnum,
|
| 42 |
+
InitFnType,
|
| 43 |
+
ActivationType,
|
| 44 |
+
BlockType,
|
| 45 |
+
LayerNormType,
|
| 46 |
+
ModelConfig,
|
| 47 |
+
ActivationCheckpointingStrategy,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if sys.version_info.minor > 8:
|
| 51 |
+
from collections.abc import MutableMapping
|
| 52 |
+
elif sys.version_info.minor == 8:
|
| 53 |
+
from typing import MutableMapping
|
| 54 |
+
else:
|
| 55 |
+
raise SystemExit("This script supports Python 3.8 or higher")
|
| 56 |
+
|
| 57 |
+
__all__ = [
|
| 58 |
+
"LayerNormBase",
|
| 59 |
+
"LayerNorm",
|
| 60 |
+
"RMSLayerNorm",
|
| 61 |
+
"GemmaRMSLayerNorm",
|
| 62 |
+
"RotaryEmbedding",
|
| 63 |
+
"Activation",
|
| 64 |
+
"GELU",
|
| 65 |
+
"ReLU",
|
| 66 |
+
"SwiGLU",
|
| 67 |
+
"LLaDABlock",
|
| 68 |
+
"LLaDASequentialBlock",
|
| 69 |
+
"LLaDAModel",
|
| 70 |
+
"LLaDAOutput",
|
| 71 |
+
"LLaDAGenerateOutput",
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
log = logging.getLogger(__name__)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ModuleType(StrEnum):
|
| 79 |
+
in_module = "in"
|
| 80 |
+
out_module = "out"
|
| 81 |
+
emb = "emb"
|
| 82 |
+
final_out = "final_out"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def init_weights(
|
| 86 |
+
config: ModelConfig,
|
| 87 |
+
module: Union[nn.Linear, nn.Embedding],
|
| 88 |
+
d: Optional[int] = None,
|
| 89 |
+
layer_id: Optional[int] = None,
|
| 90 |
+
std_factor: float = 1.0,
|
| 91 |
+
type_of_module: Optional[ModuleType] = None,
|
| 92 |
+
) -> None:
|
| 93 |
+
"""
|
| 94 |
+
Initialize weights of a linear or embedding module.
|
| 95 |
+
:param config: The model config.
|
| 96 |
+
:param module: The linear or embedding submodule to initialize.
|
| 97 |
+
:param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
|
| 98 |
+
for fused layers.
|
| 99 |
+
:param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
|
| 100 |
+
``1 / sqrt(2 * (layer_id + 1))``.
|
| 101 |
+
"""
|
| 102 |
+
d = d if d is not None else config.d_model
|
| 103 |
+
if config.init_fn == InitFnType.normal:
|
| 104 |
+
std = config.init_std * std_factor
|
| 105 |
+
if config.init_cutoff_factor is not None:
|
| 106 |
+
cutoff_value = config.init_cutoff_factor * std
|
| 107 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
|
| 108 |
+
else:
|
| 109 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 110 |
+
elif config.init_fn == InitFnType.mitchell:
|
| 111 |
+
std = std_factor / math.sqrt(d)
|
| 112 |
+
if layer_id is not None:
|
| 113 |
+
std = std / math.sqrt(2 * (layer_id + 1))
|
| 114 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
| 115 |
+
elif config.init_fn == InitFnType.kaiming_normal:
|
| 116 |
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
| 117 |
+
elif config.init_fn == InitFnType.fan_in:
|
| 118 |
+
std = std_factor / math.sqrt(d)
|
| 119 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 120 |
+
elif config.init_fn == InitFnType.full_megatron:
|
| 121 |
+
if type_of_module is None:
|
| 122 |
+
raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
|
| 123 |
+
|
| 124 |
+
cutoff_factor = config.init_cutoff_factor
|
| 125 |
+
if cutoff_factor is None:
|
| 126 |
+
cutoff_factor = 3
|
| 127 |
+
|
| 128 |
+
if type_of_module == ModuleType.in_module:
|
| 129 |
+
# for att_proj (same as QKV), ff_proj
|
| 130 |
+
std = config.init_std
|
| 131 |
+
elif type_of_module == ModuleType.out_module:
|
| 132 |
+
# for attn_out, ff_out
|
| 133 |
+
std = config.init_std / math.sqrt(2.0 * config.n_layers)
|
| 134 |
+
elif type_of_module == ModuleType.emb:
|
| 135 |
+
# positional embeddings (wpe)
|
| 136 |
+
# token embeddings (wte)
|
| 137 |
+
std = config.init_std
|
| 138 |
+
elif type_of_module == ModuleType.final_out:
|
| 139 |
+
# final output (ff_out)
|
| 140 |
+
std = config.d_model**-0.5
|
| 141 |
+
else:
|
| 142 |
+
raise RuntimeError(f"Unknown module type '{type_of_module}'")
|
| 143 |
+
nn.init.trunc_normal_(
|
| 144 |
+
module.weight,
|
| 145 |
+
mean=0.0,
|
| 146 |
+
std=std,
|
| 147 |
+
a=-cutoff_factor * std,
|
| 148 |
+
b=cutoff_factor * std,
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
raise NotImplementedError(config.init_fn)
|
| 152 |
+
|
| 153 |
+
if isinstance(module, nn.Linear):
|
| 154 |
+
if module.bias is not None:
|
| 155 |
+
nn.init.zeros_(module.bias)
|
| 156 |
+
|
| 157 |
+
if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
module.weight.div_(math.sqrt(2 * config.n_layers))
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
| 163 |
+
"""
|
| 164 |
+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
| 165 |
+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
|
| 166 |
+
"""
|
| 167 |
+
if check_neg_inf:
|
| 168 |
+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
|
| 169 |
+
if check_pos_inf:
|
| 170 |
+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def activation_checkpoint_function(cfg: ModelConfig):
|
| 174 |
+
preserve_rng_state = (
|
| 175 |
+
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
|
| 176 |
+
)
|
| 177 |
+
from torch.utils.checkpoint import checkpoint
|
| 178 |
+
|
| 179 |
+
return partial(
|
| 180 |
+
checkpoint,
|
| 181 |
+
preserve_rng_state=preserve_rng_state,
|
| 182 |
+
use_reentrant=False,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
| 187 |
+
"""
|
| 188 |
+
Cache for attention biases and other things that would normally be stored as buffers.
|
| 189 |
+
We avoid using buffers because we've run into various issues doing so with FSDP.
|
| 190 |
+
In general it appears the way FSDP handles buffers is not well-defined.
|
| 191 |
+
It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
|
| 192 |
+
since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
|
| 193 |
+
NaNs when they're synchronized due to casting or some other issue.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _non_meta_init_device(config: ModelConfig) -> torch.device:
|
| 198 |
+
if config.init_device is not None and config.init_device != "meta":
|
| 199 |
+
return torch.device(config.init_device)
|
| 200 |
+
else:
|
| 201 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class Dropout(nn.Dropout):
|
| 205 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 206 |
+
if self.p == 0.0:
|
| 207 |
+
return input
|
| 208 |
+
else:
|
| 209 |
+
return F.dropout(input, self.p, self.training, self.inplace)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class LayerNormBase(nn.Module):
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
config: ModelConfig,
|
| 216 |
+
*,
|
| 217 |
+
size: Optional[int] = None,
|
| 218 |
+
elementwise_affine: Optional[bool] = True,
|
| 219 |
+
eps: float = 1e-05,
|
| 220 |
+
):
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.config = config
|
| 223 |
+
self.eps = eps
|
| 224 |
+
self.normalized_shape = (size or config.d_model,)
|
| 225 |
+
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
| 226 |
+
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
|
| 227 |
+
use_bias = self.config.bias_for_layer_norm
|
| 228 |
+
if use_bias is None:
|
| 229 |
+
use_bias = self.config.include_bias
|
| 230 |
+
if use_bias:
|
| 231 |
+
self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
|
| 232 |
+
else:
|
| 233 |
+
self.register_parameter("bias", None)
|
| 234 |
+
else:
|
| 235 |
+
self.register_parameter("bias", None)
|
| 236 |
+
self.register_parameter("weight", None)
|
| 237 |
+
|
| 238 |
+
@abstractmethod
|
| 239 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 240 |
+
raise NotImplementedError
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
|
| 244 |
+
if config.layer_norm_type == LayerNormType.default:
|
| 245 |
+
return LayerNorm(config, size=size, low_precision=False, **kwargs)
|
| 246 |
+
elif config.layer_norm_type == LayerNormType.low_precision:
|
| 247 |
+
return LayerNorm(config, size=size, low_precision=True, **kwargs)
|
| 248 |
+
elif config.layer_norm_type == LayerNormType.rms:
|
| 249 |
+
return RMSLayerNorm(config, size=size, **kwargs)
|
| 250 |
+
elif config.layer_norm_type == LayerNormType.gemma_rms:
|
| 251 |
+
return GemmaRMSLayerNorm(config, size=size, **kwargs)
|
| 252 |
+
else:
|
| 253 |
+
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
|
| 254 |
+
|
| 255 |
+
def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
| 256 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
| 257 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
| 258 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
| 259 |
+
if tensor.device.type == "cuda" and torch.is_autocast_enabled():
|
| 260 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
|
| 261 |
+
elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
| 262 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
|
| 263 |
+
else:
|
| 264 |
+
return tensor
|
| 265 |
+
|
| 266 |
+
def reset_parameters(self):
|
| 267 |
+
if self.weight is not None:
|
| 268 |
+
torch.nn.init.ones_(self.weight) # type: ignore
|
| 269 |
+
if self.bias is not None:
|
| 270 |
+
torch.nn.init.zeros_(self.bias) # type: ignore
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class LayerNorm(LayerNormBase):
|
| 274 |
+
"""
|
| 275 |
+
The default :class:`LayerNorm` implementation which can optionally run in low precision.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def __init__(
|
| 279 |
+
self,
|
| 280 |
+
config: ModelConfig,
|
| 281 |
+
size: Optional[int] = None,
|
| 282 |
+
low_precision: bool = False,
|
| 283 |
+
elementwise_affine: Optional[bool] = None,
|
| 284 |
+
eps: float = 1e-05,
|
| 285 |
+
):
|
| 286 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
| 287 |
+
self.low_precision = low_precision
|
| 288 |
+
|
| 289 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 290 |
+
if self.low_precision:
|
| 291 |
+
module_device = x.device
|
| 292 |
+
downcast_x = self._cast_if_autocast_enabled(x)
|
| 293 |
+
downcast_weight = (
|
| 294 |
+
self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 295 |
+
)
|
| 296 |
+
downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
| 297 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
| 298 |
+
return F.layer_norm(
|
| 299 |
+
downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
|
| 300 |
+
)
|
| 301 |
+
else:
|
| 302 |
+
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class RMSLayerNorm(LayerNormBase):
|
| 306 |
+
"""
|
| 307 |
+
RMS layer norm, a simplified :class:`LayerNorm` implementation
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
config: ModelConfig,
|
| 313 |
+
size: Optional[int] = None,
|
| 314 |
+
elementwise_affine: Optional[bool] = None,
|
| 315 |
+
eps: float = 1e-5,
|
| 316 |
+
):
|
| 317 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
| 318 |
+
|
| 319 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 320 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 321 |
+
og_dtype = x.dtype
|
| 322 |
+
x = x.to(torch.float32)
|
| 323 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 324 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 325 |
+
x = x.to(og_dtype)
|
| 326 |
+
|
| 327 |
+
if self.weight is not None:
|
| 328 |
+
if self.bias is not None:
|
| 329 |
+
return self.weight * x + self.bias
|
| 330 |
+
else:
|
| 331 |
+
return self.weight * x
|
| 332 |
+
else:
|
| 333 |
+
return x
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class GemmaRMSLayerNorm(LayerNormBase):
|
| 337 |
+
"""
|
| 338 |
+
Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
def __init__(
|
| 342 |
+
self,
|
| 343 |
+
config: ModelConfig,
|
| 344 |
+
size: Optional[int] = None,
|
| 345 |
+
elementwise_affine: Optional[bool] = None,
|
| 346 |
+
eps: float = 1e-5,
|
| 347 |
+
):
|
| 348 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
| 349 |
+
|
| 350 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 351 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 352 |
+
og_dtype = x.dtype
|
| 353 |
+
x = x.to(torch.float32)
|
| 354 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 355 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 356 |
+
x = x.to(og_dtype)
|
| 357 |
+
|
| 358 |
+
if self.weight is not None:
|
| 359 |
+
if self.bias is not None:
|
| 360 |
+
return x * (1 + self.weight) + self.bias
|
| 361 |
+
else:
|
| 362 |
+
return x * (1 + self.weight)
|
| 363 |
+
else:
|
| 364 |
+
return x
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class RotaryEmbedding(nn.Module):
|
| 368 |
+
"""
|
| 369 |
+
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
def __init__(self, config: ModelConfig, cache: BufferCache):
|
| 373 |
+
super().__init__()
|
| 374 |
+
self.config = config
|
| 375 |
+
self.__cache = cache
|
| 376 |
+
# Warm up cache.
|
| 377 |
+
self.rope_theta = config.rope_theta
|
| 378 |
+
self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
|
| 379 |
+
|
| 380 |
+
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 381 |
+
if (
|
| 382 |
+
(pos_sin := self.__cache.get("rope_pos_sin")) is not None
|
| 383 |
+
and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
|
| 384 |
+
and pos_sin.shape[-2] >= seq_len
|
| 385 |
+
and pos_cos.shape[-2] >= seq_len
|
| 386 |
+
):
|
| 387 |
+
if pos_sin.device != device:
|
| 388 |
+
pos_sin = pos_sin.to(device)
|
| 389 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 390 |
+
if pos_cos.device != device:
|
| 391 |
+
pos_cos = pos_cos.to(device)
|
| 392 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 393 |
+
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
|
| 394 |
+
|
| 395 |
+
with torch.autocast(device.type, enabled=False):
|
| 396 |
+
dim = self.config.d_model // self.config.n_heads
|
| 397 |
+
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
|
| 398 |
+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
| 399 |
+
freqs = einsum("i , j -> i j", seq, inv_freq)
|
| 400 |
+
positions = torch.cat((freqs, freqs), dim=-1)
|
| 401 |
+
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
|
| 402 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 403 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 404 |
+
return pos_sin, pos_cos
|
| 405 |
+
|
| 406 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
| 407 |
+
B, nh, T, hs = x.size()
|
| 408 |
+
x = x.view(B, nh, T, 2, hs // 2)
|
| 409 |
+
x1, x2 = x.unbind(dim=-2)
|
| 410 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 411 |
+
|
| 412 |
+
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 413 |
+
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
| 414 |
+
|
| 415 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 416 |
+
if self.config.rope_full_precision:
|
| 417 |
+
q_, k_ = q.float(), k.float()
|
| 418 |
+
else:
|
| 419 |
+
q_, k_ = q, k
|
| 420 |
+
|
| 421 |
+
with torch.autocast(q.device.type, enabled=False):
|
| 422 |
+
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
|
| 423 |
+
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
|
| 424 |
+
pos_sin = pos_sin.type_as(q_)
|
| 425 |
+
pos_cos = pos_cos.type_as(q_)
|
| 426 |
+
q_ = self.apply_rotary_pos_emb(
|
| 427 |
+
pos_sin[:, :, key_len - query_len : key_len, :],
|
| 428 |
+
pos_cos[:, :, key_len - query_len : key_len, :],
|
| 429 |
+
q_,
|
| 430 |
+
)
|
| 431 |
+
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
|
| 432 |
+
return q_.type_as(q), k_.type_as(k)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class Activation(nn.Module):
|
| 436 |
+
def __init__(self, config: ModelConfig):
|
| 437 |
+
super().__init__()
|
| 438 |
+
self.config = config
|
| 439 |
+
|
| 440 |
+
@abstractmethod
|
| 441 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 442 |
+
raise NotImplementedError
|
| 443 |
+
|
| 444 |
+
@property
|
| 445 |
+
@abstractmethod
|
| 446 |
+
def output_multiplier(self) -> float:
|
| 447 |
+
raise NotImplementedError
|
| 448 |
+
|
| 449 |
+
@classmethod
|
| 450 |
+
def build(cls, config: ModelConfig) -> Activation:
|
| 451 |
+
if config.activation_type == ActivationType.gelu:
|
| 452 |
+
return cast(Activation, GELU(approximate="none"))
|
| 453 |
+
elif config.activation_type == ActivationType.relu:
|
| 454 |
+
return cast(Activation, ReLU(inplace=False))
|
| 455 |
+
elif config.activation_type == ActivationType.silu:
|
| 456 |
+
return cast(Activation, SiLU(inplace=False))
|
| 457 |
+
elif config.activation_type == ActivationType.swiglu:
|
| 458 |
+
return SwiGLU(config)
|
| 459 |
+
else:
|
| 460 |
+
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class GELU(nn.GELU):
|
| 464 |
+
@property
|
| 465 |
+
def output_multiplier(self) -> float:
|
| 466 |
+
return 1.0
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class ReLU(nn.ReLU):
|
| 470 |
+
@property
|
| 471 |
+
def output_multiplier(self) -> float:
|
| 472 |
+
return 1.0
|
| 473 |
+
|
| 474 |
+
class SiLU(nn.SiLU):
|
| 475 |
+
@property
|
| 476 |
+
def output_multiplier(self) -> float:
|
| 477 |
+
return 1.0
|
| 478 |
+
|
| 479 |
+
class SwiGLU(Activation):
|
| 480 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 481 |
+
x, gate = x.chunk(2, dim=-1)
|
| 482 |
+
return F.silu(gate) * x
|
| 483 |
+
|
| 484 |
+
@property
|
| 485 |
+
def output_multiplier(self) -> float:
|
| 486 |
+
return 0.5
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
|
| 490 |
+
att_bias = torch.triu(
|
| 491 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
|
| 492 |
+
diagonal=1,
|
| 493 |
+
)
|
| 494 |
+
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
|
| 495 |
+
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 499 |
+
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
|
| 500 |
+
if causal_bias.device != device:
|
| 501 |
+
causal_bias = causal_bias.to(device)
|
| 502 |
+
cache["causal_attention_bias"] = causal_bias
|
| 503 |
+
return causal_bias
|
| 504 |
+
with torch.autocast(device.type, enabled=False):
|
| 505 |
+
causal_bias = causal_attention_bias(seq_len, device)
|
| 506 |
+
cache["causal_attention_bias"] = causal_bias
|
| 507 |
+
return causal_bias
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
|
| 511 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
|
| 512 |
+
|
| 513 |
+
# shape: (1, 1, seq_len, seq_len)
|
| 514 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
|
| 515 |
+
alibi_bias.abs_().mul_(-1)
|
| 516 |
+
|
| 517 |
+
# shape: (n_heads,)
|
| 518 |
+
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
|
| 519 |
+
m.mul_(config.alibi_bias_max / config.n_heads)
|
| 520 |
+
|
| 521 |
+
# shape: (1, n_heads, seq_len, seq_len)
|
| 522 |
+
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
|
| 523 |
+
|
| 524 |
+
class ShortConv(nn.Module):
|
| 525 |
+
def __init__(
|
| 526 |
+
self,
|
| 527 |
+
hidden_size: int,
|
| 528 |
+
kernel_size: int = 7, # 修改默认值为 7
|
| 529 |
+
dilation: int = 1,
|
| 530 |
+
norm_eps: float = 1e-5,
|
| 531 |
+
hc_mult: int = 1,
|
| 532 |
+
activation: bool = True,
|
| 533 |
+
):
|
| 534 |
+
super().__init__()
|
| 535 |
+
self.activation = activation
|
| 536 |
+
self.kernel_size = kernel_size
|
| 537 |
+
self.dilation = dilation
|
| 538 |
+
|
| 539 |
+
# 针对奇数核(如7)的 Same Padding 计算: (K-1)/2
|
| 540 |
+
# K=7 -> padding=3
|
| 541 |
+
self.padding = (kernel_size - 1) // 2 * dilation
|
| 542 |
+
|
| 543 |
+
# 标准卷积,这次我们可以直接用 PyTorch 的 padding,
|
| 544 |
+
# 因为奇数核的 same padding 是对称的,不需要手动 F.pad
|
| 545 |
+
self.conv = nn.Conv1d(
|
| 546 |
+
in_channels=hidden_size,
|
| 547 |
+
out_channels=hidden_size,
|
| 548 |
+
kernel_size=kernel_size,
|
| 549 |
+
groups=hidden_size,
|
| 550 |
+
bias=False,
|
| 551 |
+
padding=self.padding, # 直接设置 padding
|
| 552 |
+
dilation=dilation,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
self.norm = nn.RMSNorm(hidden_size, eps=norm_eps)
|
| 556 |
+
|
| 557 |
+
if self.activation:
|
| 558 |
+
self.act_fn = nn.SiLU()
|
| 559 |
+
|
| 560 |
+
self.reset_parameters()
|
| 561 |
+
|
| 562 |
+
def reset_parameters(self):
|
| 563 |
+
nn.init.zeros_(self.conv.weight)
|
| 564 |
+
|
| 565 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 566 |
+
# x: [B, L, D]
|
| 567 |
+
x_norm = self.norm(x)
|
| 568 |
+
|
| 569 |
+
# [B, L, D] -> [B, D, L]
|
| 570 |
+
x_bct = x_norm.transpose(1, 2)
|
| 571 |
+
|
| 572 |
+
# 卷积 (自动 padding)
|
| 573 |
+
y_bct = self.conv(x_bct)
|
| 574 |
+
|
| 575 |
+
if self.activation:
|
| 576 |
+
y_bct = self.act_fn(y_bct)
|
| 577 |
+
|
| 578 |
+
# [B, D, L] -> [B, L, D]
|
| 579 |
+
y = y_bct.transpose(1, 2).contiguous()
|
| 580 |
+
|
| 581 |
+
return y
|
| 582 |
+
|
| 583 |
+
def find_next_prime(start, seen_primes):
|
| 584 |
+
candidate = start + 1
|
| 585 |
+
while True:
|
| 586 |
+
if isprime(candidate) and candidate not in seen_primes:
|
| 587 |
+
return candidate
|
| 588 |
+
candidate += 1
|
| 589 |
+
|
| 590 |
+
class NgramHashMapping:
|
| 591 |
+
def __init__(
|
| 592 |
+
self,
|
| 593 |
+
engram_vocab_size,
|
| 594 |
+
max_ngram_size,
|
| 595 |
+
n_embed_per_ngram,
|
| 596 |
+
n_head_per_ngram,
|
| 597 |
+
layer_ids,
|
| 598 |
+
pad_id,
|
| 599 |
+
seed,
|
| 600 |
+
config: ModelConfig,
|
| 601 |
+
):
|
| 602 |
+
self.vocab_size_per_ngram = engram_vocab_size
|
| 603 |
+
self.max_ngram_size = max_ngram_size
|
| 604 |
+
self.n_embed_per_ngram = n_embed_per_ngram
|
| 605 |
+
self.n_head_per_ngram = n_head_per_ngram
|
| 606 |
+
self.pad_id = pad_id
|
| 607 |
+
self.layer_ids = layer_ids
|
| 608 |
+
|
| 609 |
+
self.tokenizer_vocab_size = config.vocab_size
|
| 610 |
+
|
| 611 |
+
max_long = np.iinfo(np.int64).max
|
| 612 |
+
M_max = int(max_long // self.tokenizer_vocab_size)
|
| 613 |
+
half_bound = max(1, M_max // 2)
|
| 614 |
+
PRIME_1 = 10007
|
| 615 |
+
|
| 616 |
+
self.layer_multipliers = {}
|
| 617 |
+
|
| 618 |
+
for layer_id in self.layer_ids:
|
| 619 |
+
base_seed = int(seed + PRIME_1 * int(layer_id))
|
| 620 |
+
g = np.random.default_rng(base_seed)
|
| 621 |
+
r = g.integers(
|
| 622 |
+
low=0,
|
| 623 |
+
high=half_bound,
|
| 624 |
+
size=(self.max_ngram_size,),
|
| 625 |
+
dtype=np.int64
|
| 626 |
+
)
|
| 627 |
+
multipliers = r * 2 + 1
|
| 628 |
+
self.layer_multipliers[layer_id] = multipliers
|
| 629 |
+
|
| 630 |
+
self.vocab_size_across_layers = self.calculate_vocab_size_across_layers()
|
| 631 |
+
|
| 632 |
+
def calculate_vocab_size_across_layers(self):
|
| 633 |
+
seen_primes = set()
|
| 634 |
+
vocab_size_across_layers = {}
|
| 635 |
+
|
| 636 |
+
for layer_id in self.layer_ids:
|
| 637 |
+
all_ngram_vocab_sizes = []
|
| 638 |
+
for ngram in range(2, self.max_ngram_size + 1):
|
| 639 |
+
current_ngram_heads_sizes = []
|
| 640 |
+
|
| 641 |
+
vocab_size = self.vocab_size_per_ngram[ngram - 2]
|
| 642 |
+
num_head = self.n_head_per_ngram
|
| 643 |
+
current_prime_search_start = vocab_size - 1
|
| 644 |
+
|
| 645 |
+
for _ in range(num_head):
|
| 646 |
+
found_prime = find_next_prime(
|
| 647 |
+
current_prime_search_start,
|
| 648 |
+
seen_primes
|
| 649 |
+
)
|
| 650 |
+
seen_primes.add(found_prime)
|
| 651 |
+
current_ngram_heads_sizes.append(found_prime)
|
| 652 |
+
current_prime_search_start = found_prime
|
| 653 |
+
|
| 654 |
+
all_ngram_vocab_sizes.append(current_ngram_heads_sizes)
|
| 655 |
+
vocab_size_across_layers[layer_id] = all_ngram_vocab_sizes
|
| 656 |
+
|
| 657 |
+
return vocab_size_across_layers
|
| 658 |
+
|
| 659 |
+
def _get_ngram_hashes(
|
| 660 |
+
self,
|
| 661 |
+
input_ids: np.ndarray,
|
| 662 |
+
layer_id: int,
|
| 663 |
+
) -> np.ndarray:
|
| 664 |
+
x = np.asarray(input_ids, dtype=np.int64)
|
| 665 |
+
B, T = x.shape
|
| 666 |
+
|
| 667 |
+
multipliers = self.layer_multipliers[layer_id]
|
| 668 |
+
|
| 669 |
+
def shift_k(k: int) -> np.ndarray:
|
| 670 |
+
if k == 0: return x
|
| 671 |
+
shifted = np.pad(x, ((0, 0), (k, 0)),
|
| 672 |
+
mode='constant', constant_values=self.pad_id)[:, :T]
|
| 673 |
+
return shifted
|
| 674 |
+
|
| 675 |
+
base_shifts = [shift_k(k) for k in range(self.max_ngram_size)]
|
| 676 |
+
|
| 677 |
+
all_hashes = []
|
| 678 |
+
|
| 679 |
+
for n in range(2, self.max_ngram_size + 1):
|
| 680 |
+
n_gram_index = n - 2
|
| 681 |
+
tokens = base_shifts[:n]
|
| 682 |
+
mix = (tokens[0] * multipliers[0])
|
| 683 |
+
for k in range(1, n):
|
| 684 |
+
mix = np.bitwise_xor(mix, tokens[k] * multipliers[k])
|
| 685 |
+
num_heads_for_this_ngram = self.n_head_per_ngram
|
| 686 |
+
head_vocab_sizes = self.vocab_size_across_layers[layer_id][n_gram_index]
|
| 687 |
+
|
| 688 |
+
for j in range(num_heads_for_this_ngram):
|
| 689 |
+
mod = int(head_vocab_sizes[j])
|
| 690 |
+
head_hash = mix % mod
|
| 691 |
+
all_hashes.append(head_hash.astype(np.int64, copy=False))
|
| 692 |
+
|
| 693 |
+
return np.stack(all_hashes, axis=2)
|
| 694 |
+
|
| 695 |
+
def hash(self, input_ids):
|
| 696 |
+
hash_ids_for_all_layers = {}
|
| 697 |
+
for layer_id in self.layer_ids:
|
| 698 |
+
hash_ids_for_all_layers[layer_id] = self._get_ngram_hashes(input_ids, layer_id=layer_id)
|
| 699 |
+
return hash_ids_for_all_layers
|
| 700 |
+
|
| 701 |
+
class TorchNgramHashMapping:
|
| 702 |
+
"""
|
| 703 |
+
在 GPU 上进行 n-gram 哈希计算的 Torch 实现。
|
| 704 |
+
由现有的 NgramHashMapping 提供 multipliers 与每 head 的素数模数组,
|
| 705 |
+
以确保与 numpy 版本一致的哈希结果与 head 排列顺序。
|
| 706 |
+
输出: dict[layer_id] -> (B, T, num_hash_heads) [long]
|
| 707 |
+
"""
|
| 708 |
+
def __init__(self, np_mapping: NgramHashMapping, device: torch.device):
|
| 709 |
+
self.layer_ids = list(np_mapping.layer_ids)
|
| 710 |
+
self.max_ngram_size = int(np_mapping.max_ngram_size)
|
| 711 |
+
self.n_head_per_ngram = int(np_mapping.n_head_per_ngram)
|
| 712 |
+
self.pad_id = int(np_mapping.pad_id)
|
| 713 |
+
|
| 714 |
+
# 每层 multipliers: (max_ngram_size,)
|
| 715 |
+
self._multipliers = {
|
| 716 |
+
lid: torch.tensor(np_mapping.layer_multipliers[lid], dtype=torch.long, device=device)
|
| 717 |
+
for lid in self.layer_ids
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
# 每层 mods: 列表,mods[n-2] = (n_head_per_ngram,)
|
| 721 |
+
self._mods = {}
|
| 722 |
+
for lid in self.layer_ids:
|
| 723 |
+
mods_per_n = []
|
| 724 |
+
for n in range(2, self.max_ngram_size + 1):
|
| 725 |
+
head_mods = np_mapping.vocab_size_across_layers[lid][n - 2]
|
| 726 |
+
mods_per_n.append(torch.tensor(head_mods, dtype=torch.long, device=device))
|
| 727 |
+
self._mods[lid] = mods_per_n
|
| 728 |
+
|
| 729 |
+
self.num_hash_heads = (self.max_ngram_size - 1) * self.n_head_per_ngram
|
| 730 |
+
|
| 731 |
+
def hash(self, input_ids: torch.Tensor) -> Dict[int, torch.Tensor]:
|
| 732 |
+
"""
|
| 733 |
+
input_ids: (B, T) long tensor on target device
|
| 734 |
+
return: {layer_id: (B, T, num_hash_heads) long}
|
| 735 |
+
"""
|
| 736 |
+
x = input_ids.to(torch.long)
|
| 737 |
+
B, T = x.shape
|
| 738 |
+
|
| 739 |
+
# 右移 k 位(左侧 pad): shifts[k] shape (B, T)
|
| 740 |
+
shifts = [x]
|
| 741 |
+
for k in range(1, self.max_ngram_size):
|
| 742 |
+
shifts.append(F.pad(x, (k, 0), value=self.pad_id)[:, :T])
|
| 743 |
+
|
| 744 |
+
out: Dict[int, torch.Tensor] = {}
|
| 745 |
+
for lid in self.layer_ids:
|
| 746 |
+
multipliers = self._multipliers[lid]
|
| 747 |
+
heads_per_layer = []
|
| 748 |
+
|
| 749 |
+
for n in range(2, self.max_ngram_size + 1):
|
| 750 |
+
mix = shifts[0] * multipliers[0]
|
| 751 |
+
for k in range(1, n):
|
| 752 |
+
mix = torch.bitwise_xor(mix, shifts[k] * multipliers[k])
|
| 753 |
+
|
| 754 |
+
mods = self._mods[lid][n - 2] # (H,)
|
| 755 |
+
# (B, T, 1) % (1, 1, H) -> (B, T, H)
|
| 756 |
+
head_hash = mix.unsqueeze(-1) % mods.view(1, 1, -1)
|
| 757 |
+
heads_per_layer.append(head_hash)
|
| 758 |
+
|
| 759 |
+
out[lid] = torch.cat(heads_per_layer, dim=-1)
|
| 760 |
+
|
| 761 |
+
return out
|
| 762 |
+
|
| 763 |
+
class MultiHeadEmbedding(nn.Module):
|
| 764 |
+
def __init__(self, list_of_N: List[int], D: int):
|
| 765 |
+
super().__init__()
|
| 766 |
+
self.num_heads = len(list_of_N)
|
| 767 |
+
self.embedding_dim = D
|
| 768 |
+
|
| 769 |
+
offsets = [0]
|
| 770 |
+
for n in list_of_N[:-1]:
|
| 771 |
+
offsets.append(offsets[-1] + n)
|
| 772 |
+
|
| 773 |
+
self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long))
|
| 774 |
+
|
| 775 |
+
total_N = sum(list_of_N)
|
| 776 |
+
self.embedding = nn.Embedding(num_embeddings=total_N, embedding_dim=D)
|
| 777 |
+
|
| 778 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 779 |
+
shifted_input_ids = input_ids + self.offsets
|
| 780 |
+
output = self.embedding(shifted_input_ids)
|
| 781 |
+
|
| 782 |
+
return output
|
| 783 |
+
|
| 784 |
+
class Engram(nn.Module):
|
| 785 |
+
def __init__(self, layer_id: int, config: ModelConfig):
|
| 786 |
+
super().__init__()
|
| 787 |
+
self.layer_id = layer_id
|
| 788 |
+
self.engram_cfg = config.engram_config
|
| 789 |
+
self.backbone_config = config
|
| 790 |
+
engram_cfg = self.engram_cfg
|
| 791 |
+
backbone_config = self.backbone_config
|
| 792 |
+
self.hash_mapping = NgramHashMapping(
|
| 793 |
+
engram_vocab_size = engram_cfg.engram_vocab_size,
|
| 794 |
+
max_ngram_size = engram_cfg.max_ngram_size,
|
| 795 |
+
n_embed_per_ngram = engram_cfg.n_embed_per_ngram,
|
| 796 |
+
n_head_per_ngram = engram_cfg.n_head_per_ngram,
|
| 797 |
+
layer_ids = engram_cfg.layer_ids,
|
| 798 |
+
pad_id = engram_cfg.pad_id,
|
| 799 |
+
seed = engram_cfg.seed,
|
| 800 |
+
config = backbone_config,
|
| 801 |
+
)
|
| 802 |
+
self.multi_head_embedding = MultiHeadEmbedding(
|
| 803 |
+
list_of_N = [x for y in self.hash_mapping.vocab_size_across_layers[self.layer_id] for x in y],
|
| 804 |
+
D = engram_cfg.n_embed_per_ngram // engram_cfg.n_head_per_ngram,
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
# 修改 ShortConv 调用
|
| 808 |
+
self.short_conv = ShortConv(
|
| 809 |
+
hidden_size = backbone_config.d_model,
|
| 810 |
+
kernel_size = engram_cfg.kernel_size,
|
| 811 |
+
dilation = engram_cfg.max_ngram_size,
|
| 812 |
+
hc_mult = 1 # 设为 1
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
engram_hidden_size = (engram_cfg.max_ngram_size-1) * engram_cfg.n_embed_per_ngram
|
| 816 |
+
|
| 817 |
+
# --- 修改点:不再使用 ModuleList,而是单个层 ---
|
| 818 |
+
self.value_proj = nn.Linear(engram_hidden_size, backbone_config.d_model)
|
| 819 |
+
|
| 820 |
+
# 只需要 1 个 Key Projection
|
| 821 |
+
self.key_proj = nn.Linear(engram_hidden_size, backbone_config.d_model)
|
| 822 |
+
|
| 823 |
+
# 只需要 1 组 Norm
|
| 824 |
+
self.norm_key = nn.RMSNorm(backbone_config.d_model)
|
| 825 |
+
self.norm_query = nn.RMSNorm(backbone_config.d_model)
|
| 826 |
+
|
| 827 |
+
self.reset_parameters()
|
| 828 |
+
# Torch 版哈希缓存(按设备惰性构建)
|
| 829 |
+
self._torch_hash_mapping: Optional[TorchNgramHashMapping] = None
|
| 830 |
+
self._torch_hash_device: Optional[torch.device] = None
|
| 831 |
+
|
| 832 |
+
def reset_parameters(self):
|
| 833 |
+
init_weights(
|
| 834 |
+
self.backbone_config,
|
| 835 |
+
self.multi_head_embedding.embedding,
|
| 836 |
+
type_of_module=ModuleType.emb,
|
| 837 |
+
)
|
| 838 |
+
init_weights(
|
| 839 |
+
self.backbone_config,
|
| 840 |
+
self.value_proj,
|
| 841 |
+
layer_id=self.layer_id,
|
| 842 |
+
type_of_module=ModuleType.in_module,
|
| 843 |
+
)
|
| 844 |
+
init_weights(
|
| 845 |
+
self.backbone_config,
|
| 846 |
+
self.key_proj,
|
| 847 |
+
layer_id=self.layer_id,
|
| 848 |
+
type_of_module=ModuleType.in_module,
|
| 849 |
+
)
|
| 850 |
+
self.short_conv.reset_parameters()
|
| 851 |
+
|
| 852 |
+
def forward(self, hidden_states, input_ids, engram_hash=None):
|
| 853 |
+
"""
|
| 854 |
+
hidden_states: [B, L, D] <-- 标准形状
|
| 855 |
+
input_ids: [B, L]
|
| 856 |
+
engram_hash: [B, L, NumHeads] (Optional)
|
| 857 |
+
"""
|
| 858 |
+
# 1. 查表 (不变)
|
| 859 |
+
if engram_hash is None:
|
| 860 |
+
# 优先使用 GPU 版哈希,避免 CPU<->GPU 往返
|
| 861 |
+
cur_dev = hidden_states.device
|
| 862 |
+
if self._torch_hash_mapping is None or self._torch_hash_device != cur_dev:
|
| 863 |
+
self._torch_hash_mapping = TorchNgramHashMapping(self.hash_mapping, device=cur_dev)
|
| 864 |
+
self._torch_hash_device = cur_dev
|
| 865 |
+
hash_input_ids = self._torch_hash_mapping.hash(input_ids)[self.layer_id]
|
| 866 |
+
else:
|
| 867 |
+
hash_input_ids = engram_hash
|
| 868 |
+
embeddings = self.multi_head_embedding(hash_input_ids).flatten(start_dim=-2)
|
| 869 |
+
|
| 870 |
+
# 2. 计算 Gate (不需要循环了)
|
| 871 |
+
# Key 部分
|
| 872 |
+
key = self.key_proj(embeddings)
|
| 873 |
+
normed_key = self.norm_key(key)
|
| 874 |
+
|
| 875 |
+
# Query 部分 (直接使用 hidden_states)
|
| 876 |
+
query = hidden_states
|
| 877 |
+
normed_query = self.norm_query(query)
|
| 878 |
+
|
| 879 |
+
# Gate 计算
|
| 880 |
+
# [B, L, D] * [B, L, D] -> sum(dim=-1) -> [B, L]
|
| 881 |
+
gate = (normed_key * normed_query).sum(dim=-1) / math.sqrt(self.backbone_config.d_model)
|
| 882 |
+
gate = gate.abs().clamp_min(1e-6).sqrt() * gate.sign()
|
| 883 |
+
gate = gate.sigmoid().unsqueeze(-1) # [B, L, 1]
|
| 884 |
+
|
| 885 |
+
# 3. 融合 Value
|
| 886 |
+
value = gate * self.value_proj(embeddings) # [B, L, 1] * [B, L, D] -> [B, L, D]
|
| 887 |
+
|
| 888 |
+
# 4. Short Conv
|
| 889 |
+
output = value + self.short_conv(value)
|
| 890 |
+
|
| 891 |
+
return output
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
class LLaDABlock(nn.Module):
|
| 895 |
+
"""
|
| 896 |
+
A base class for transformer block implementations.
|
| 897 |
+
"""
|
| 898 |
+
|
| 899 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 900 |
+
super().__init__()
|
| 901 |
+
self.layer_id = layer_id
|
| 902 |
+
self.config = config
|
| 903 |
+
self.hidden_size = (
|
| 904 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
| 905 |
+
)
|
| 906 |
+
self.__cache = cache
|
| 907 |
+
assert config.d_model % config.n_heads == 0
|
| 908 |
+
|
| 909 |
+
self.engram = None
|
| 910 |
+
if config.engram_config is not None and layer_id in config.engram_config.layer_ids:
|
| 911 |
+
self.engram = Engram(layer_id, config)
|
| 912 |
+
|
| 913 |
+
self._activation_checkpoint_fn = None
|
| 914 |
+
|
| 915 |
+
# Dropout.
|
| 916 |
+
self.dropout = Dropout(config.residual_dropout)
|
| 917 |
+
|
| 918 |
+
# Layer norms.
|
| 919 |
+
self.k_norm: Optional[LayerNormBase] = None
|
| 920 |
+
self.q_norm: Optional[LayerNormBase] = None
|
| 921 |
+
if config.attention_layer_norm:
|
| 922 |
+
self.k_norm = LayerNormBase.build(
|
| 923 |
+
config,
|
| 924 |
+
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
|
| 925 |
+
elementwise_affine=config.attention_layer_norm_with_affine,
|
| 926 |
+
)
|
| 927 |
+
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
|
| 928 |
+
|
| 929 |
+
# Activation function.
|
| 930 |
+
self.act = Activation.build(config)
|
| 931 |
+
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
| 932 |
+
|
| 933 |
+
# Attention output projection.
|
| 934 |
+
self.attn_out = nn.Linear(
|
| 935 |
+
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# Feed-forward output projection.
|
| 939 |
+
self.ff_out = nn.Linear(
|
| 940 |
+
int(self.act.output_multiplier * self.hidden_size),
|
| 941 |
+
config.d_model,
|
| 942 |
+
bias=config.include_bias,
|
| 943 |
+
device=config.init_device,
|
| 944 |
+
)
|
| 945 |
+
self.ff_out._is_residual = True # type: ignore
|
| 946 |
+
|
| 947 |
+
# Rotary embeddings.
|
| 948 |
+
if self.config.rope:
|
| 949 |
+
self.rotary_emb = RotaryEmbedding(config, self.__cache)
|
| 950 |
+
|
| 951 |
+
self.flash_attn_func = None
|
| 952 |
+
if config.flash_attention:
|
| 953 |
+
try:
|
| 954 |
+
from flash_attn import flash_attn_func # type: ignore
|
| 955 |
+
|
| 956 |
+
self.flash_attn_func = flash_attn_func
|
| 957 |
+
except ModuleNotFoundError:
|
| 958 |
+
pass
|
| 959 |
+
|
| 960 |
+
def reset_parameters(self):
|
| 961 |
+
if self.engram is not None:
|
| 962 |
+
self.engram.reset_parameters()
|
| 963 |
+
if self.k_norm is not None:
|
| 964 |
+
self.k_norm.reset_parameters()
|
| 965 |
+
if self.q_norm is not None:
|
| 966 |
+
self.q_norm.reset_parameters()
|
| 967 |
+
init_weights(
|
| 968 |
+
self.config,
|
| 969 |
+
self.attn_out,
|
| 970 |
+
d=self.config.d_model,
|
| 971 |
+
layer_id=self.layer_id,
|
| 972 |
+
type_of_module=ModuleType.out_module,
|
| 973 |
+
)
|
| 974 |
+
init_weights(
|
| 975 |
+
self.config,
|
| 976 |
+
self.ff_out,
|
| 977 |
+
d=self.ff_out.in_features,
|
| 978 |
+
layer_id=self.layer_id,
|
| 979 |
+
type_of_module=ModuleType.out_module,
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 983 |
+
if strategy == ActivationCheckpointingStrategy.fine_grained:
|
| 984 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
| 985 |
+
else:
|
| 986 |
+
self._activation_checkpoint_fn = None
|
| 987 |
+
|
| 988 |
+
@classmethod
|
| 989 |
+
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
|
| 990 |
+
target_dtype = input_dtype
|
| 991 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
| 992 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
| 993 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
| 994 |
+
if bias.device.type == "cuda" and torch.is_autocast_enabled():
|
| 995 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 996 |
+
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
| 997 |
+
target_dtype = torch.get_autocast_cpu_dtype()
|
| 998 |
+
if bias.dtype != target_dtype:
|
| 999 |
+
bias = bias.to(target_dtype)
|
| 1000 |
+
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
|
| 1001 |
+
return bias
|
| 1002 |
+
|
| 1003 |
+
def _scaled_dot_product_attention(
|
| 1004 |
+
self,
|
| 1005 |
+
q: torch.Tensor,
|
| 1006 |
+
k: torch.Tensor,
|
| 1007 |
+
v: torch.Tensor,
|
| 1008 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 1009 |
+
dropout_p: float = 0.0,
|
| 1010 |
+
is_causal: bool = False,
|
| 1011 |
+
) -> torch.Tensor:
|
| 1012 |
+
"""
|
| 1013 |
+
Computes scaled dot product attention on query, key and value tensors, using an optional
|
| 1014 |
+
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
|
| 1015 |
+
"""
|
| 1016 |
+
if self.flash_attn_func is not None and attn_mask is None:
|
| 1017 |
+
r = self.flash_attn_func(
|
| 1018 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
|
| 1019 |
+
)
|
| 1020 |
+
return r.transpose(1, 2)
|
| 1021 |
+
else:
|
| 1022 |
+
# torch's sdpa doesn't support GQA, so we're doing this
|
| 1023 |
+
assert k.size(1) == v.size(1)
|
| 1024 |
+
num_kv_heads = k.size(1)
|
| 1025 |
+
num_q_heads = q.size(1)
|
| 1026 |
+
if num_q_heads != num_kv_heads:
|
| 1027 |
+
assert num_q_heads % num_kv_heads == 0
|
| 1028 |
+
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 1029 |
+
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 1030 |
+
|
| 1031 |
+
# Modify: MDM set causal to False, and with no attn_mask.
|
| 1032 |
+
return F.scaled_dot_product_attention(
|
| 1033 |
+
q,
|
| 1034 |
+
k,
|
| 1035 |
+
v,
|
| 1036 |
+
attn_mask=None,
|
| 1037 |
+
dropout_p=dropout_p,
|
| 1038 |
+
is_causal=False,
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
def attention(
|
| 1042 |
+
self,
|
| 1043 |
+
q: torch.Tensor,
|
| 1044 |
+
k: torch.Tensor,
|
| 1045 |
+
v: torch.Tensor,
|
| 1046 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1047 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 1048 |
+
use_cache: bool = False,
|
| 1049 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 1050 |
+
B, T, C = q.size() # batch size, sequence length, d_model
|
| 1051 |
+
dtype = k.dtype
|
| 1052 |
+
|
| 1053 |
+
# Optionally apply layer norm to keys and queries.
|
| 1054 |
+
if self.q_norm is not None and self.k_norm is not None:
|
| 1055 |
+
q = self.q_norm(q).to(dtype=dtype)
|
| 1056 |
+
k = self.k_norm(k).to(dtype=dtype)
|
| 1057 |
+
|
| 1058 |
+
# Move head forward to be next to the batch dim.
|
| 1059 |
+
# shape: (B, nh, T, hs)
|
| 1060 |
+
q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
|
| 1061 |
+
# shape: (B, n_kv_h, T, hs)
|
| 1062 |
+
k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 1063 |
+
# shape: (B, n_kv_h, T, hs)
|
| 1064 |
+
v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 1065 |
+
|
| 1066 |
+
if layer_past is not None:
|
| 1067 |
+
past_key, past_value = layer_past
|
| 1068 |
+
k = torch.cat((past_key, k), dim=-2)
|
| 1069 |
+
v = torch.cat((past_value, v), dim=-2)
|
| 1070 |
+
|
| 1071 |
+
present = (k, v) if use_cache else None
|
| 1072 |
+
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
|
| 1073 |
+
|
| 1074 |
+
if self.config.rope:
|
| 1075 |
+
# Apply rotary embeddings.
|
| 1076 |
+
q, k = self.rotary_emb(q, k)
|
| 1077 |
+
|
| 1078 |
+
if attention_bias is not None:
|
| 1079 |
+
# Resize and cast attention bias.
|
| 1080 |
+
# The current dtype of the attention bias might not match the dtype that the SDP attn function will
|
| 1081 |
+
# run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
|
| 1082 |
+
# as down-casting the attention bias to the autocast precision will result in -infs, which will
|
| 1083 |
+
# cause the SDP attn function to produce NaNs.
|
| 1084 |
+
attention_bias = self._cast_attn_bias(
|
| 1085 |
+
attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
|
| 1086 |
+
)
|
| 1087 |
+
|
| 1088 |
+
# Get the attention scores.
|
| 1089 |
+
# shape: (B, nh, T, hs)
|
| 1090 |
+
att = self._scaled_dot_product_attention(
|
| 1091 |
+
q,
|
| 1092 |
+
k,
|
| 1093 |
+
v,
|
| 1094 |
+
attn_mask=None,
|
| 1095 |
+
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 1096 |
+
is_causal=False,
|
| 1097 |
+
)
|
| 1098 |
+
|
| 1099 |
+
# Re-assemble all head outputs side-by-side.
|
| 1100 |
+
att = att.transpose(1, 2).contiguous().view(B, T, C)
|
| 1101 |
+
|
| 1102 |
+
# Apply output projection.
|
| 1103 |
+
return self.attn_out(att), present
|
| 1104 |
+
|
| 1105 |
+
@abstractmethod
|
| 1106 |
+
def forward(
|
| 1107 |
+
self,
|
| 1108 |
+
x: torch.Tensor,
|
| 1109 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1110 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
| 1111 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 1112 |
+
use_cache: bool = False,
|
| 1113 |
+
engram_hash: Optional[torch.Tensor] = None,
|
| 1114 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 1115 |
+
raise NotImplementedError
|
| 1116 |
+
|
| 1117 |
+
@classmethod
|
| 1118 |
+
def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
|
| 1119 |
+
if config.block_type == BlockType.sequential:
|
| 1120 |
+
return LLaDASequentialBlock(layer_id, config, cache)
|
| 1121 |
+
elif config.block_type == BlockType.llama:
|
| 1122 |
+
return LLaDALlamaBlock(layer_id, config, cache)
|
| 1123 |
+
else:
|
| 1124 |
+
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
class LLaDASequentialBlock(LLaDABlock):
|
| 1128 |
+
"""
|
| 1129 |
+
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
| 1130 |
+
(plus another skip connection).
|
| 1131 |
+
"""
|
| 1132 |
+
|
| 1133 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 1134 |
+
super().__init__(layer_id, config, cache)
|
| 1135 |
+
# Layer norms.
|
| 1136 |
+
self.attn_norm = LayerNorm.build(config)
|
| 1137 |
+
self.ff_norm = LayerNorm.build(config)
|
| 1138 |
+
# Attention input projection. Projects x -> (q, k, v)
|
| 1139 |
+
head_dim = config.d_model // config.n_heads
|
| 1140 |
+
self.fused_dims = (
|
| 1141 |
+
config.d_model,
|
| 1142 |
+
config.effective_n_kv_heads * head_dim,
|
| 1143 |
+
config.effective_n_kv_heads * head_dim,
|
| 1144 |
+
)
|
| 1145 |
+
self.att_proj = nn.Linear(
|
| 1146 |
+
config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 1147 |
+
)
|
| 1148 |
+
# Feed-forward input projection.
|
| 1149 |
+
self.ff_proj = nn.Linear(
|
| 1150 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
def reset_parameters(self):
|
| 1154 |
+
super().reset_parameters()
|
| 1155 |
+
self.attn_norm.reset_parameters()
|
| 1156 |
+
self.ff_norm.reset_parameters()
|
| 1157 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
| 1158 |
+
init_weights(
|
| 1159 |
+
self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 1160 |
+
)
|
| 1161 |
+
init_weights(
|
| 1162 |
+
self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
def forward(
|
| 1166 |
+
self,
|
| 1167 |
+
x: torch.Tensor,
|
| 1168 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1169 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1170 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 1171 |
+
use_cache: bool = False,
|
| 1172 |
+
engram_hash: Optional[torch.Tensor] = None,
|
| 1173 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 1174 |
+
if self.engram is not None:
|
| 1175 |
+
assert input_ids is not None
|
| 1176 |
+
x = x + self.engram(x, input_ids, engram_hash=engram_hash)
|
| 1177 |
+
|
| 1178 |
+
# Get query, key, value projections.
|
| 1179 |
+
# shape:
|
| 1180 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
| 1181 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
| 1182 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
| 1183 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
| 1184 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
| 1185 |
+
if self._activation_checkpoint_fn is not None:
|
| 1186 |
+
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
|
| 1187 |
+
self.fused_dims, dim=-1
|
| 1188 |
+
)
|
| 1189 |
+
else:
|
| 1190 |
+
q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
|
| 1191 |
+
|
| 1192 |
+
# Get attention scores.
|
| 1193 |
+
if self._activation_checkpoint_fn is not None:
|
| 1194 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 1195 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1196 |
+
)
|
| 1197 |
+
else:
|
| 1198 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1199 |
+
|
| 1200 |
+
# Add attention scores.
|
| 1201 |
+
# shape: (B, T, C)
|
| 1202 |
+
x = x + self.dropout(att)
|
| 1203 |
+
|
| 1204 |
+
# Add feed-forward projection.
|
| 1205 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1206 |
+
og_x = x
|
| 1207 |
+
if self._activation_checkpoint_fn is not None:
|
| 1208 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 1209 |
+
else:
|
| 1210 |
+
x = self.ff_norm(x)
|
| 1211 |
+
x = self.ff_proj(x)
|
| 1212 |
+
if self._activation_checkpoint_fn is not None:
|
| 1213 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
| 1214 |
+
else:
|
| 1215 |
+
x = self.act(x)
|
| 1216 |
+
x = self.ff_out(x)
|
| 1217 |
+
x = self.dropout(x)
|
| 1218 |
+
x = og_x + x
|
| 1219 |
+
|
| 1220 |
+
return x, cache
|
| 1221 |
+
|
| 1222 |
+
|
| 1223 |
+
class LLaDALlamaBlock(LLaDABlock):
|
| 1224 |
+
"""
|
| 1225 |
+
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
| 1226 |
+
(plus another skip connection). This block is similar to `LLaDASequentialBlock`
|
| 1227 |
+
but some operations have slightly different implementations to imitate the
|
| 1228 |
+
behavior of Llama.
|
| 1229 |
+
"""
|
| 1230 |
+
|
| 1231 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 1232 |
+
super().__init__(layer_id, config, cache)
|
| 1233 |
+
# Layer norms.
|
| 1234 |
+
self.attn_norm = LayerNorm.build(config)
|
| 1235 |
+
self.ff_norm = LayerNorm.build(config)
|
| 1236 |
+
self.__cache = cache
|
| 1237 |
+
|
| 1238 |
+
# Attention input projection. Projects x -> (q, k, v)
|
| 1239 |
+
head_dim = config.d_model // config.n_heads
|
| 1240 |
+
q_proj_out_dim = config.d_model
|
| 1241 |
+
k_proj_out_dim = config.effective_n_kv_heads * head_dim
|
| 1242 |
+
v_proj_out_dim = config.effective_n_kv_heads * head_dim
|
| 1243 |
+
self.q_proj = nn.Linear(
|
| 1244 |
+
config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 1245 |
+
)
|
| 1246 |
+
self.k_proj = nn.Linear(
|
| 1247 |
+
config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 1248 |
+
)
|
| 1249 |
+
self.v_proj = nn.Linear(
|
| 1250 |
+
config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
# Feed-forward input projection.
|
| 1254 |
+
self.ff_proj = nn.Linear(
|
| 1255 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 1256 |
+
)
|
| 1257 |
+
# new add
|
| 1258 |
+
self.up_proj = nn.Linear(
|
| 1259 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 1260 |
+
)
|
| 1261 |
+
|
| 1262 |
+
def reset_parameters(self):
|
| 1263 |
+
super().reset_parameters()
|
| 1264 |
+
self.attn_norm.reset_parameters()
|
| 1265 |
+
self.ff_norm.reset_parameters()
|
| 1266 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
| 1267 |
+
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
|
| 1268 |
+
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
|
| 1269 |
+
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
|
| 1270 |
+
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
|
| 1271 |
+
init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add
|
| 1272 |
+
|
| 1273 |
+
def forward(
|
| 1274 |
+
self,
|
| 1275 |
+
x: torch.Tensor,
|
| 1276 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1277 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1278 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 1279 |
+
use_cache: bool = False,
|
| 1280 |
+
engram_hash: Optional[torch.Tensor] = None,
|
| 1281 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 1282 |
+
if self.engram is not None:
|
| 1283 |
+
assert input_ids is not None
|
| 1284 |
+
x = x + self.engram(x, input_ids, engram_hash=engram_hash)
|
| 1285 |
+
|
| 1286 |
+
# Get query, key, value projections.
|
| 1287 |
+
# shape:
|
| 1288 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
| 1289 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
| 1290 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
| 1291 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
| 1292 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
| 1293 |
+
x_normed = self.attn_norm(x)
|
| 1294 |
+
q = self.q_proj(x_normed)
|
| 1295 |
+
k = self.k_proj(x_normed)
|
| 1296 |
+
v = self.v_proj(x_normed)
|
| 1297 |
+
|
| 1298 |
+
# Get attention scores.
|
| 1299 |
+
if self._activation_checkpoint_fn is not None:
|
| 1300 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 1301 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1302 |
+
)
|
| 1303 |
+
else:
|
| 1304 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1305 |
+
|
| 1306 |
+
# Add attention scores.
|
| 1307 |
+
# shape: (B, T, C)
|
| 1308 |
+
x = x + self.dropout(att)
|
| 1309 |
+
|
| 1310 |
+
# Add feed-forward projection.
|
| 1311 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1312 |
+
og_x = x
|
| 1313 |
+
if self._activation_checkpoint_fn is not None:
|
| 1314 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 1315 |
+
else:
|
| 1316 |
+
x = self.ff_norm(x)
|
| 1317 |
+
x, x_up = self.ff_proj(x), self.up_proj(x) # new add
|
| 1318 |
+
if self._activation_checkpoint_fn is not None:
|
| 1319 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
| 1320 |
+
else:
|
| 1321 |
+
x = self.act(x)
|
| 1322 |
+
x = x * x_up # new add
|
| 1323 |
+
x = self.ff_out(x)
|
| 1324 |
+
x = self.dropout(x)
|
| 1325 |
+
x = og_x + x
|
| 1326 |
+
|
| 1327 |
+
return x, cache
|
| 1328 |
+
|
| 1329 |
+
|
| 1330 |
+
class LLaDAOutput(NamedTuple):
|
| 1331 |
+
logits: torch.FloatTensor
|
| 1332 |
+
"""
|
| 1333 |
+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
|
| 1334 |
+
for the next token *before* normalization via (log) softmax.
|
| 1335 |
+
"""
|
| 1336 |
+
|
| 1337 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
|
| 1338 |
+
"""
|
| 1339 |
+
Attention keys and values from each block.
|
| 1340 |
+
"""
|
| 1341 |
+
|
| 1342 |
+
hidden_states: Optional[Tuple[torch.Tensor]]
|
| 1343 |
+
"""
|
| 1344 |
+
Hidden states from each block.
|
| 1345 |
+
"""
|
| 1346 |
+
|
| 1347 |
+
|
| 1348 |
+
class LLaDAGenerateOutput(NamedTuple):
|
| 1349 |
+
token_ids: torch.LongTensor
|
| 1350 |
+
"""
|
| 1351 |
+
The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
|
| 1352 |
+
These do *not* include the original input IDs.
|
| 1353 |
+
"""
|
| 1354 |
+
|
| 1355 |
+
scores: torch.FloatTensor
|
| 1356 |
+
"""
|
| 1357 |
+
The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
|
| 1358 |
+
"""
|
| 1359 |
+
|
| 1360 |
+
|
| 1361 |
+
class LLaDABlockGroup(nn.ModuleList):
|
| 1362 |
+
def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
|
| 1363 |
+
super().__init__(modules)
|
| 1364 |
+
self.config = config
|
| 1365 |
+
self.layer_offset = layer_offset
|
| 1366 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
| 1367 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
| 1368 |
+
|
| 1369 |
+
def forward(
|
| 1370 |
+
self,
|
| 1371 |
+
x: torch.Tensor,
|
| 1372 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1373 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
| 1374 |
+
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1375 |
+
use_cache: bool = False,
|
| 1376 |
+
engram_hashes: Optional[Dict[int, torch.Tensor]] = None,
|
| 1377 |
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
| 1378 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 1379 |
+
for block_idx, block in enumerate(self):
|
| 1380 |
+
layer_past = None if layers_past is None else layers_past[block_idx]
|
| 1381 |
+
block_idx += self.layer_offset
|
| 1382 |
+
if (
|
| 1383 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
| 1384 |
+
or (
|
| 1385 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
| 1386 |
+
and block_idx % 2 == 0
|
| 1387 |
+
)
|
| 1388 |
+
or (
|
| 1389 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
| 1390 |
+
and block_idx % 3 == 0
|
| 1391 |
+
)
|
| 1392 |
+
or (
|
| 1393 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
| 1394 |
+
and block_idx % 4 == 0
|
| 1395 |
+
)
|
| 1396 |
+
):
|
| 1397 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1398 |
+
x, cache = self._activation_checkpoint_fn( # type: ignore
|
| 1399 |
+
block, x, input_ids=input_ids, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache,
|
| 1400 |
+
engram_hash=None if engram_hashes is None else engram_hashes.get(block_idx)
|
| 1401 |
+
)
|
| 1402 |
+
else:
|
| 1403 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1404 |
+
x, cache = block(x, input_ids=input_ids, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache,
|
| 1405 |
+
engram_hash=None if engram_hashes is None else engram_hashes.get(block_idx))
|
| 1406 |
+
if attn_key_values is not None:
|
| 1407 |
+
assert cache is not None
|
| 1408 |
+
attn_key_values.append(cache)
|
| 1409 |
+
return x, attn_key_values
|
| 1410 |
+
|
| 1411 |
+
def reset_parameters(self):
|
| 1412 |
+
for block in self:
|
| 1413 |
+
block.reset_parameters()
|
| 1414 |
+
|
| 1415 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1416 |
+
self.activation_checkpointing_strategy = strategy
|
| 1417 |
+
for block in self:
|
| 1418 |
+
block.set_activation_checkpointing(strategy)
|
| 1419 |
+
|
| 1420 |
+
|
| 1421 |
+
class LLaDAModel(nn.Module):
|
| 1422 |
+
def __init__(self, config: ModelConfig, init_params: bool = True):
|
| 1423 |
+
super().__init__()
|
| 1424 |
+
self.config = config
|
| 1425 |
+
self.__cache = BufferCache()
|
| 1426 |
+
|
| 1427 |
+
# Validate config.
|
| 1428 |
+
if self.config.alibi and self.config.flash_attention:
|
| 1429 |
+
raise Exception("ALiBi is currently not supported with FlashAttention")
|
| 1430 |
+
|
| 1431 |
+
if self.config.alibi and self.config.rope:
|
| 1432 |
+
raise Exception("ALiBi and RoPE are mutually exclusive")
|
| 1433 |
+
|
| 1434 |
+
if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
|
| 1435 |
+
if self.config.embedding_size < self.config.vocab_size:
|
| 1436 |
+
raise Exception("embedding size should be at least as big as vocab size")
|
| 1437 |
+
elif self.config.embedding_size % 128 != 0:
|
| 1438 |
+
import warnings
|
| 1439 |
+
|
| 1440 |
+
warnings.warn(
|
| 1441 |
+
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
| 1442 |
+
)
|
| 1443 |
+
|
| 1444 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
| 1445 |
+
self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
|
| 1446 |
+
|
| 1447 |
+
if not (
|
| 1448 |
+
0 < self.config.block_group_size <= self.config.n_layers
|
| 1449 |
+
and self.config.n_layers % self.config.block_group_size == 0
|
| 1450 |
+
):
|
| 1451 |
+
raise Exception("n layers must be divisible by block group size")
|
| 1452 |
+
|
| 1453 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 1454 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
|
| 1455 |
+
|
| 1456 |
+
self.transformer = nn.ModuleDict(
|
| 1457 |
+
dict(
|
| 1458 |
+
wte=nn.Embedding(
|
| 1459 |
+
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
|
| 1460 |
+
),
|
| 1461 |
+
emb_drop=Dropout(config.embedding_dropout),
|
| 1462 |
+
ln_f=LayerNorm.build(config),
|
| 1463 |
+
)
|
| 1464 |
+
)
|
| 1465 |
+
|
| 1466 |
+
blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
|
| 1467 |
+
if self.config.block_group_size > 1:
|
| 1468 |
+
block_groups = [
|
| 1469 |
+
LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
|
| 1470 |
+
for i in range(0, config.n_layers, config.block_group_size)
|
| 1471 |
+
]
|
| 1472 |
+
self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
|
| 1473 |
+
else:
|
| 1474 |
+
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
| 1475 |
+
|
| 1476 |
+
if not (self.config.alibi or self.config.rope):
|
| 1477 |
+
self.transformer.update(
|
| 1478 |
+
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
|
| 1479 |
+
)
|
| 1480 |
+
if not config.weight_tying:
|
| 1481 |
+
self.transformer.update(
|
| 1482 |
+
{
|
| 1483 |
+
"ff_out": nn.Linear(
|
| 1484 |
+
config.d_model,
|
| 1485 |
+
config.embedding_size or config.vocab_size,
|
| 1486 |
+
bias=config.include_bias,
|
| 1487 |
+
device=config.init_device,
|
| 1488 |
+
)
|
| 1489 |
+
}
|
| 1490 |
+
)
|
| 1491 |
+
# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
|
| 1492 |
+
if init_params and self.config.init_device != "meta":
|
| 1493 |
+
self.reset_parameters()
|
| 1494 |
+
self.__num_fwd_flops: Optional[int] = None
|
| 1495 |
+
|
| 1496 |
+
# Warm up cache.
|
| 1497 |
+
if self.config.alibi:
|
| 1498 |
+
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
|
| 1499 |
+
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
|
| 1500 |
+
|
| 1501 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1502 |
+
self.activation_checkpointing_strategy = strategy
|
| 1503 |
+
if self.config.block_group_size != 1:
|
| 1504 |
+
for block_group in self.transformer.block_groups:
|
| 1505 |
+
block_group.set_activation_checkpointing(strategy)
|
| 1506 |
+
else:
|
| 1507 |
+
for block in self.transformer.blocks:
|
| 1508 |
+
block.set_activation_checkpointing(strategy)
|
| 1509 |
+
|
| 1510 |
+
@property
|
| 1511 |
+
def device(self) -> torch.device:
|
| 1512 |
+
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
| 1513 |
+
if device.type == "meta":
|
| 1514 |
+
return _non_meta_init_device(self.config)
|
| 1515 |
+
else:
|
| 1516 |
+
return device
|
| 1517 |
+
|
| 1518 |
+
def reset_parameters(self):
|
| 1519 |
+
log.info("Initializing model parameters...")
|
| 1520 |
+
# Top-level embeddings / linear layers.
|
| 1521 |
+
init_weights(
|
| 1522 |
+
self.config,
|
| 1523 |
+
self.transformer.wte, # type: ignore
|
| 1524 |
+
std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
|
| 1525 |
+
type_of_module=ModuleType.emb,
|
| 1526 |
+
)
|
| 1527 |
+
if hasattr(self.transformer, "wpe"):
|
| 1528 |
+
init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
|
| 1529 |
+
|
| 1530 |
+
# Top-level layer norm.
|
| 1531 |
+
self.transformer.ln_f.reset_parameters() # type: ignore
|
| 1532 |
+
|
| 1533 |
+
# Output weights.
|
| 1534 |
+
if hasattr(self.transformer, "ff_out"):
|
| 1535 |
+
init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
|
| 1536 |
+
|
| 1537 |
+
# Let the blocks handle themselves.
|
| 1538 |
+
if self.config.block_group_size == 1:
|
| 1539 |
+
for block in self.transformer.blocks:
|
| 1540 |
+
block.reset_parameters()
|
| 1541 |
+
else:
|
| 1542 |
+
for block_group in self.transformer.block_groups:
|
| 1543 |
+
block_group.reset_parameters()
|
| 1544 |
+
|
| 1545 |
+
def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 1546 |
+
if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
|
| 1547 |
+
-1
|
| 1548 |
+
] >= seq_len:
|
| 1549 |
+
if alibi_bias.device != device:
|
| 1550 |
+
alibi_bias = alibi_bias.to(device)
|
| 1551 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1552 |
+
return alibi_bias
|
| 1553 |
+
with torch.autocast(device.type, enabled=False):
|
| 1554 |
+
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
| 1555 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1556 |
+
return alibi_bias
|
| 1557 |
+
|
| 1558 |
+
def forward(
|
| 1559 |
+
self,
|
| 1560 |
+
input_ids: torch.LongTensor,
|
| 1561 |
+
input_embeddings: Optional[torch.FloatTensor] = None,
|
| 1562 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1563 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1564 |
+
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1565 |
+
use_cache: bool = False,
|
| 1566 |
+
last_logits_only: bool = False,
|
| 1567 |
+
output_hidden_states: Optional[bool] = None,
|
| 1568 |
+
engram_hashes: Optional[Dict[int, torch.Tensor]] = None,
|
| 1569 |
+
) -> LLaDAOutput:
|
| 1570 |
+
"""
|
| 1571 |
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
| 1572 |
+
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
|
| 1573 |
+
embeddings. When provided, it is treated as the output of the input embedding layer.
|
| 1574 |
+
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
|
| 1575 |
+
which input IDs are masked. A `1` value in the mask means that
|
| 1576 |
+
the corresponding input ID should *not* be ignored. A `0` means
|
| 1577 |
+
that the corresponding input ID is masked.
|
| 1578 |
+
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
|
| 1579 |
+
library.
|
| 1580 |
+
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
|
| 1581 |
+
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
|
| 1582 |
+
to introduce causal or other biases.
|
| 1583 |
+
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
|
| 1584 |
+
indicates that the i-th element in the sequence is allowed to attend to the j-th
|
| 1585 |
+
element in the sequence.
|
| 1586 |
+
If the tensor is a float tensor, it will just be added to the attention
|
| 1587 |
+
scores before the softmax.
|
| 1588 |
+
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
|
| 1589 |
+
:param past_key_values: Pre-computed keys and values for each attention block.
|
| 1590 |
+
Can be used to speed up sequential decoding. The `input_ids` which have
|
| 1591 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
| 1592 |
+
:param use_cache: If `True`, return key and value tensors for each block.
|
| 1593 |
+
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
|
| 1594 |
+
This can speed up decoding when you only care about the next token.
|
| 1595 |
+
"""
|
| 1596 |
+
# Add Basic MDM Model config check
|
| 1597 |
+
assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
|
| 1598 |
+
assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
|
| 1599 |
+
assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM."
|
| 1600 |
+
|
| 1601 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 1602 |
+
|
| 1603 |
+
if past_key_values:
|
| 1604 |
+
assert len(past_key_values) == self.config.n_layers
|
| 1605 |
+
|
| 1606 |
+
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
|
| 1607 |
+
if past_key_values is None:
|
| 1608 |
+
past_length = 0
|
| 1609 |
+
else:
|
| 1610 |
+
past_length = past_key_values[0][0].size(-2)
|
| 1611 |
+
|
| 1612 |
+
# Get embeddings of input.
|
| 1613 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1614 |
+
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
| 1615 |
+
|
| 1616 |
+
if self.config.input_emb_norm:
|
| 1617 |
+
x = x * (self.config.d_model**0.5)
|
| 1618 |
+
|
| 1619 |
+
if not (self.config.alibi or self.config.rope):
|
| 1620 |
+
# Get positional embeddings.
|
| 1621 |
+
# shape: (1, seq_len)
|
| 1622 |
+
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
|
| 1623 |
+
# shape: (1, seq_len, d_model)
|
| 1624 |
+
pos_emb = self.transformer.wpe(pos) # type: ignore
|
| 1625 |
+
x = pos_emb + x
|
| 1626 |
+
|
| 1627 |
+
# Add input + positional embeddings and apply dropout.
|
| 1628 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1629 |
+
x = self.transformer.emb_drop(x) # type: ignore
|
| 1630 |
+
|
| 1631 |
+
# Transform the attention mask into what the blocks expect.
|
| 1632 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1633 |
+
# shape: (batch_size, 1, 1, seq_len)
|
| 1634 |
+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
|
| 1635 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
| 1636 |
+
else:
|
| 1637 |
+
attention_mask = None
|
| 1638 |
+
|
| 1639 |
+
# Merge attention mask with attention bias.
|
| 1640 |
+
if (
|
| 1641 |
+
attention_bias is not None
|
| 1642 |
+
or attention_mask is not None
|
| 1643 |
+
or self.config.alibi
|
| 1644 |
+
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
|
| 1645 |
+
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
|
| 1646 |
+
# scores correctly.
|
| 1647 |
+
or past_key_values is not None
|
| 1648 |
+
):
|
| 1649 |
+
if attention_bias is None and self.config.alibi:
|
| 1650 |
+
attention_bias = get_causal_attention_bias(
|
| 1651 |
+
self.__cache, past_length + seq_len, x.device
|
| 1652 |
+
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
| 1653 |
+
elif attention_bias is None:
|
| 1654 |
+
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
|
| 1655 |
+
elif attention_bias.dtype in (torch.int8, torch.bool):
|
| 1656 |
+
attention_bias = attention_bias.to(dtype=torch.float)
|
| 1657 |
+
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|
| 1658 |
+
|
| 1659 |
+
# Transform to the right shape and data type.
|
| 1660 |
+
mask_len = seq_len
|
| 1661 |
+
if attention_mask is not None:
|
| 1662 |
+
mask_len = attention_mask.shape[-1]
|
| 1663 |
+
elif past_key_values is not None:
|
| 1664 |
+
mask_len = past_key_values[0][0].shape[-2] + seq_len
|
| 1665 |
+
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
|
| 1666 |
+
|
| 1667 |
+
# Add in the masking bias.
|
| 1668 |
+
if attention_mask is not None:
|
| 1669 |
+
attention_bias = attention_bias + attention_mask
|
| 1670 |
+
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
|
| 1671 |
+
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
|
| 1672 |
+
# it can produce NaNs.
|
| 1673 |
+
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
|
| 1674 |
+
|
| 1675 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 1676 |
+
|
| 1677 |
+
# decoder layers
|
| 1678 |
+
all_hidden_states = []
|
| 1679 |
+
|
| 1680 |
+
# Apply blocks one-by-one.
|
| 1681 |
+
if self.config.block_group_size == 1:
|
| 1682 |
+
for block_idx, block in enumerate(self.transformer.blocks):
|
| 1683 |
+
if output_hidden_states:
|
| 1684 |
+
# add hidden states
|
| 1685 |
+
all_hidden_states.append(x)
|
| 1686 |
+
|
| 1687 |
+
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
| 1688 |
+
if (
|
| 1689 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
| 1690 |
+
or (
|
| 1691 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
| 1692 |
+
and block_idx % 2 == 0
|
| 1693 |
+
)
|
| 1694 |
+
or (
|
| 1695 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
| 1696 |
+
and block_idx % 3 == 0
|
| 1697 |
+
)
|
| 1698 |
+
or (
|
| 1699 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
| 1700 |
+
and block_idx % 4 == 0
|
| 1701 |
+
)
|
| 1702 |
+
):
|
| 1703 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1704 |
+
x, cache = self._activation_checkpoint_fn(
|
| 1705 |
+
block, x, input_ids=input_ids, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache,
|
| 1706 |
+
engram_hash=None if engram_hashes is None else engram_hashes.get(block_idx)
|
| 1707 |
+
)
|
| 1708 |
+
else:
|
| 1709 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1710 |
+
x, cache = block(x, input_ids=input_ids, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache,
|
| 1711 |
+
engram_hash=None if engram_hashes is None else engram_hashes.get(block_idx))
|
| 1712 |
+
if attn_key_values is not None:
|
| 1713 |
+
assert cache is not None
|
| 1714 |
+
attn_key_values.append(cache)
|
| 1715 |
+
else:
|
| 1716 |
+
for group_idx, block_group in enumerate(self.transformer.block_groups):
|
| 1717 |
+
if output_hidden_states:
|
| 1718 |
+
# add hidden states
|
| 1719 |
+
all_hidden_states.append(x)
|
| 1720 |
+
|
| 1721 |
+
layers_past = (
|
| 1722 |
+
None
|
| 1723 |
+
if past_key_values is None
|
| 1724 |
+
else past_key_values[
|
| 1725 |
+
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
|
| 1726 |
+
]
|
| 1727 |
+
)
|
| 1728 |
+
x, cache = block_group(
|
| 1729 |
+
x, input_ids=input_ids, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache,
|
| 1730 |
+
engram_hashes=engram_hashes
|
| 1731 |
+
)
|
| 1732 |
+
if attn_key_values is not None:
|
| 1733 |
+
assert cache is not None
|
| 1734 |
+
attn_key_values.extend(cache)
|
| 1735 |
+
|
| 1736 |
+
if last_logits_only:
|
| 1737 |
+
# shape: (batch_size, 1, d_model)
|
| 1738 |
+
x = x[:, -1, :].unsqueeze(1)
|
| 1739 |
+
|
| 1740 |
+
# Apply final layer norm.
|
| 1741 |
+
# shape: (batch_size, seq_len or 1, d_model)
|
| 1742 |
+
x = self.transformer.ln_f(x) # type: ignore
|
| 1743 |
+
if output_hidden_states:
|
| 1744 |
+
# add final hidden state post-final-layernorm, following HuggingFace's convention
|
| 1745 |
+
all_hidden_states.append(x)
|
| 1746 |
+
|
| 1747 |
+
# Get logits.
|
| 1748 |
+
# shape: (batch_size, seq_len or 1, vocab_size)
|
| 1749 |
+
if self.config.weight_tying:
|
| 1750 |
+
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
| 1751 |
+
else:
|
| 1752 |
+
logits = self.transformer.ff_out(x) # type: ignore
|
| 1753 |
+
if self.config.scale_logits:
|
| 1754 |
+
logits.mul_(1 / math.sqrt(self.config.d_model))
|
| 1755 |
+
|
| 1756 |
+
return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
|
| 1757 |
+
|
| 1758 |
+
|
| 1759 |
+
def create_model_config_from_pretrained_config(config: LLaDAConfig):
|
| 1760 |
+
"""
|
| 1761 |
+
Utility function
|
| 1762 |
+
"""
|
| 1763 |
+
|
| 1764 |
+
kwargs = {}
|
| 1765 |
+
for field in fields(ModelConfig):
|
| 1766 |
+
val = getattr(config, field.name, None)
|
| 1767 |
+
if field.name == "engram_config" and isinstance(val, dict):
|
| 1768 |
+
val = EngramConfig(**val)
|
| 1769 |
+
kwargs[field.name] = val
|
| 1770 |
+
|
| 1771 |
+
model_config = ModelConfig(**kwargs)
|
| 1772 |
+
return model_config
|
| 1773 |
+
|
| 1774 |
+
|
| 1775 |
+
class LLaDAModelLM(PreTrainedModel):
|
| 1776 |
+
"""
|
| 1777 |
+
Extremely barebones HF model wrapper.
|
| 1778 |
+
"""
|
| 1779 |
+
|
| 1780 |
+
config_class = LLaDAConfig
|
| 1781 |
+
base_model_prefix = "model"
|
| 1782 |
+
_no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
|
| 1783 |
+
|
| 1784 |
+
def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
|
| 1785 |
+
super().__init__(config)
|
| 1786 |
+
|
| 1787 |
+
if not model:
|
| 1788 |
+
model_config = create_model_config_from_pretrained_config(config)
|
| 1789 |
+
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
| 1790 |
+
model_config.init_device = "cpu"
|
| 1791 |
+
self.model = LLaDAModel(model_config, init_params=init_params)
|
| 1792 |
+
else:
|
| 1793 |
+
self.model = model
|
| 1794 |
+
|
| 1795 |
+
def forward(
|
| 1796 |
+
self,
|
| 1797 |
+
input_ids: torch.LongTensor = None,
|
| 1798 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1799 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1800 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1801 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1802 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1803 |
+
use_cache: Optional[bool] = None,
|
| 1804 |
+
output_attentions: Optional[bool] = None,
|
| 1805 |
+
output_hidden_states: Optional[bool] = None,
|
| 1806 |
+
return_dict: Optional[bool] = None,
|
| 1807 |
+
cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
|
| 1808 |
+
engram_hashes: Optional[Dict[int, torch.Tensor]] = None,
|
| 1809 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1810 |
+
if use_cache is None:
|
| 1811 |
+
use_cache = self.config.use_cache
|
| 1812 |
+
|
| 1813 |
+
if output_attentions:
|
| 1814 |
+
raise ValueError("output_attentions is not yet supported in LLaDA")
|
| 1815 |
+
|
| 1816 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1817 |
+
|
| 1818 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1819 |
+
outputs = self.model.forward(
|
| 1820 |
+
input_ids=input_ids,
|
| 1821 |
+
input_embeddings=inputs_embeds,
|
| 1822 |
+
attention_mask=attention_mask,
|
| 1823 |
+
attention_bias=attention_bias,
|
| 1824 |
+
past_key_values=past_key_values,
|
| 1825 |
+
use_cache=use_cache,
|
| 1826 |
+
output_hidden_states=output_hidden_states,
|
| 1827 |
+
engram_hashes=engram_hashes,
|
| 1828 |
+
)
|
| 1829 |
+
|
| 1830 |
+
logits = outputs.logits
|
| 1831 |
+
hidden_states = outputs.hidden_states
|
| 1832 |
+
|
| 1833 |
+
loss = None
|
| 1834 |
+
if labels is not None:
|
| 1835 |
+
import warnings
|
| 1836 |
+
warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
|
| 1837 |
+
if not return_dict:
|
| 1838 |
+
output = (logits,) + outputs[1:]
|
| 1839 |
+
return (loss,) + output if loss is not None else output
|
| 1840 |
+
|
| 1841 |
+
return CausalLMOutputWithPast(
|
| 1842 |
+
logits=logits,
|
| 1843 |
+
past_key_values=outputs.attn_key_values,
|
| 1844 |
+
hidden_states=hidden_states,
|
| 1845 |
+
)
|
| 1846 |
+
|
| 1847 |
+
def can_generate(self) -> bool:
|
| 1848 |
+
return True
|
| 1849 |
+
|
| 1850 |
+
def prepare_inputs_for_generation(
|
| 1851 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
| 1852 |
+
):
|
| 1853 |
+
if past_key_values:
|
| 1854 |
+
# This is because we want the model to only process the last generated token.
|
| 1855 |
+
input_ids = input_ids[:, -1:]
|
| 1856 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 1857 |
+
|
| 1858 |
+
model_inputs.update(kwargs)
|
| 1859 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
| 1860 |
+
return model_inputs
|
| 1861 |
+
|
| 1862 |
+
# TODO: these are required to make the implementation complete.
|
| 1863 |
+
# def resize_position_embeddings(self, new_num_position_embeddings: int):
|
| 1864 |
+
# pass
|
| 1865 |
+
#
|
| 1866 |
+
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
| 1867 |
+
# pass
|
| 1868 |
+
#
|
| 1869 |
+
# def _reorder_cache(self, past_key_values, beam_idx):
|
| 1870 |
+
# pass
|
| 1871 |
+
|
| 1872 |
+
def get_input_embeddings(self) -> torch.nn.Module:
|
| 1873 |
+
return self.model.transformer.wte
|
| 1874 |
+
|
| 1875 |
+
def set_input_embeddings(self, value: torch.nn.Module):
|
| 1876 |
+
self.model.transformer.wte = value
|
| 1877 |
+
|
| 1878 |
+
def get_output_embeddings(self):
|
| 1879 |
+
if self.config.weight_tying:
|
| 1880 |
+
return self.model.transformer.wte
|
| 1881 |
+
else:
|
| 1882 |
+
return self.model.transformer.ff_out
|
| 1883 |
+
|
| 1884 |
+
def set_output_embeddings(self, value: torch.nn.Module):
|
| 1885 |
+
if self.config.weight_tying:
|
| 1886 |
+
self.model.transformer.wte = value
|
| 1887 |
+
else:
|
| 1888 |
+
self.model.transformer.ff_out = value
|
| 1889 |
+
|
| 1890 |
+
def tie_weights(self):
|
| 1891 |
+
if self.config.weight_tying:
|
| 1892 |
+
self.model.transformer.ff_out = self.model.transformer.wte
|
| 1893 |
+
|
| 1894 |
+
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
| 1895 |
+
AutoModel.register(LLaDAConfig, LLaDAModelLM)
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 3 |
+
"tokenizer_file": "tokenizer.json",
|
| 4 |
+
"bos_token": null,
|
| 5 |
+
"eos_token": null,
|
| 6 |
+
"pad_token": 6629,
|
| 7 |
+
"unk_token": 6630
|
| 8 |
+
}
|