Add files using upload-large-folder tool
Browse files- .gitattributes +3 -35
- README.md +40 -0
- checkpoints/best_model.pt +3 -0
- configs/model.yaml +43 -0
- data/vocab/char_vocab.json +139 -0
- inference.py +67 -0
- llm/__init__.py +5 -0
- llm/__pycache__/__init__.cpython-312.pyc +0 -0
- llm/data/__init__.py +1 -0
- llm/data/__pycache__/__init__.cpython-312.pyc +0 -0
- llm/data/__pycache__/tokenizer.cpython-312.pyc +0 -0
- llm/data/collate.py +167 -0
- llm/data/dataset.py +164 -0
- llm/data/tokenizer.py +126 -0
- llm/inference/__init__.py +5 -0
- llm/inference/__pycache__/__init__.cpython-312.pyc +0 -0
- llm/inference/__pycache__/generate.cpython-312.pyc +0 -0
- llm/inference/generate.py +179 -0
- llm/model/__init__.py +1 -0
- llm/model/__pycache__/__init__.cpython-312.pyc +0 -0
- llm/model/__pycache__/attention.cpython-312.pyc +0 -0
- llm/model/__pycache__/block.cpython-312.pyc +0 -0
- llm/model/__pycache__/embedding.cpython-312.pyc +0 -0
- llm/model/__pycache__/ffn.cpython-312.pyc +0 -0
- llm/model/__pycache__/norm.cpython-312.pyc +0 -0
- llm/model/__pycache__/rope.cpython-312.pyc +0 -0
- llm/model/__pycache__/transformer.cpython-312.pyc +0 -0
- llm/model/attention.py +435 -0
- llm/model/block.py +163 -0
- llm/model/embedding.py +35 -0
- llm/model/ffn.py +139 -0
- llm/model/norm.py +132 -0
- llm/model/rope.py +162 -0
- llm/model/transformer.py +280 -0
- llm/training/__init__.py +15 -0
- llm/training/loss.py +91 -0
- llm/training/metrics.py +175 -0
- llm/training/optim.py +223 -0
- llm/training/trainer.py +294 -0
- llm/utils/__init__.py +9 -0
- llm/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- llm/utils/__pycache__/checkpoint.cpython-312.pyc +0 -0
- llm/utils/__pycache__/init.cpython-312.pyc +0 -0
- llm/utils/checkpoint.py +39 -0
- llm/utils/config.py +25 -0
- llm/utils/init.py +213 -0
- llm/utils/logging.py +18 -0
- llm/utils/seed.py +14 -0
- requirements.txt +4 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,3 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.
|
| 3 |
-
*.
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
pipeline_tag: text-generation
|
| 7 |
+
tags:
|
| 8 |
+
- transformer
|
| 9 |
+
- character-level
|
| 10 |
+
- custom-code
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# nextShakespeare
|
| 14 |
+
|
| 15 |
+
`nextShakespeare` is a decoder-only, character-level Transformer language model
|
| 16 |
+
trained on Tiny Shakespeare style text. This repo uses custom PyTorch code
|
| 17 |
+
rather than `transformers` native model classes.
|
| 18 |
+
|
| 19 |
+
## Model assets
|
| 20 |
+
|
| 21 |
+
- Weights: `checkpoints/best_model.pt`
|
| 22 |
+
- Config: `configs/model.yaml`
|
| 23 |
+
- Vocabulary: `data/vocab/char_vocab.json`
|
| 24 |
+
- Runtime code: `llm/`
|
| 25 |
+
|
| 26 |
+
## Quickstart
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
git clone https://huggingface.co/<your-username>/nextShakespeare
|
| 30 |
+
cd nextShakespeare
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
python inference.py --prompt "First Citizen:\n" --max_length 200 --temperature 0.8
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Notes
|
| 36 |
+
|
| 37 |
+
- This is a custom-code checkpoint. It is not directly loadable via
|
| 38 |
+
`AutoModel.from_pretrained()` yet.
|
| 39 |
+
- For a web demo, see the companion Space:
|
| 40 |
+
`https://huggingface.co/spaces/<your-username>/manshu-init`
|
checkpoints/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8bbbcb6c73219ad04334d069ca56b0362282ec43b14772e95ae77539b5589357
|
| 3 |
+
size 1248083791
|
configs/model.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# Model Architecture
|
| 3 |
+
# ============================================================
|
| 4 |
+
|
| 5 |
+
num_hidden_layers: 12 # Transformer 层数
|
| 6 |
+
hidden_size: 768 # 隐藏层维度
|
| 7 |
+
num_attention_heads: 12 # 注意力头数
|
| 8 |
+
num_key_value_heads: 4 # Key/Value 头数(GQA)
|
| 9 |
+
intermediate_size: 3072 # FFN 中间层维度
|
| 10 |
+
|
| 11 |
+
# ----------------------------
|
| 12 |
+
# Context / Position Encoding
|
| 13 |
+
# ----------------------------
|
| 14 |
+
|
| 15 |
+
max_position_embeddings: 2048 # 最大上下文长度
|
| 16 |
+
rope_theta: 10000 # RoPE 位置编码参数
|
| 17 |
+
|
| 18 |
+
# ----------------------------
|
| 19 |
+
# Attention Optimization
|
| 20 |
+
# ----------------------------
|
| 21 |
+
|
| 22 |
+
sliding_window: 1024 # 滑动窗口注意力大小
|
| 23 |
+
sliding_window_overlap: true # 是否允许窗口重叠
|
| 24 |
+
# 注意:当前所有层都使用滑动窗口
|
| 25 |
+
|
| 26 |
+
# ----------------------------
|
| 27 |
+
# Normalization
|
| 28 |
+
# ----------------------------
|
| 29 |
+
|
| 30 |
+
rms_norm_eps: 1e-5 # RMSNorm 数值稳定项
|
| 31 |
+
|
| 32 |
+
# ----------------------------
|
| 33 |
+
# Embedding
|
| 34 |
+
# ----------------------------
|
| 35 |
+
|
| 36 |
+
tie_word_embeddings: true # 是否绑定输入输出词嵌入
|
| 37 |
+
|
| 38 |
+
# ----------------------------
|
| 39 |
+
# Initialization
|
| 40 |
+
# ----------------------------
|
| 41 |
+
|
| 42 |
+
init_weights: true # 是否启用权重初始化
|
| 43 |
+
init_std: 0.02 # 权重初始化标准差
|
data/vocab/char_vocab.json
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"char_to_id": {
|
| 3 |
+
"<unk>": 0,
|
| 4 |
+
" ": 1,
|
| 5 |
+
"e": 2,
|
| 6 |
+
"t": 3,
|
| 7 |
+
"o": 4,
|
| 8 |
+
"a": 5,
|
| 9 |
+
"h": 6,
|
| 10 |
+
"s": 7,
|
| 11 |
+
"r": 8,
|
| 12 |
+
"n": 9,
|
| 13 |
+
"i": 10,
|
| 14 |
+
"\n": 11,
|
| 15 |
+
"l": 12,
|
| 16 |
+
"d": 13,
|
| 17 |
+
"u": 14,
|
| 18 |
+
"m": 15,
|
| 19 |
+
"y": 16,
|
| 20 |
+
",": 17,
|
| 21 |
+
"w": 18,
|
| 22 |
+
"f": 19,
|
| 23 |
+
"c": 20,
|
| 24 |
+
"g": 21,
|
| 25 |
+
"I": 22,
|
| 26 |
+
"b": 23,
|
| 27 |
+
"p": 24,
|
| 28 |
+
":": 25,
|
| 29 |
+
".": 26,
|
| 30 |
+
"A": 27,
|
| 31 |
+
"v": 28,
|
| 32 |
+
"k": 29,
|
| 33 |
+
"T": 30,
|
| 34 |
+
"'": 31,
|
| 35 |
+
"E": 32,
|
| 36 |
+
"O": 33,
|
| 37 |
+
"N": 34,
|
| 38 |
+
"R": 35,
|
| 39 |
+
"S": 36,
|
| 40 |
+
"L": 37,
|
| 41 |
+
"C": 38,
|
| 42 |
+
";": 39,
|
| 43 |
+
"W": 40,
|
| 44 |
+
"U": 41,
|
| 45 |
+
"H": 42,
|
| 46 |
+
"M": 43,
|
| 47 |
+
"B": 44,
|
| 48 |
+
"?": 45,
|
| 49 |
+
"G": 46,
|
| 50 |
+
"!": 47,
|
| 51 |
+
"D": 48,
|
| 52 |
+
"-": 49,
|
| 53 |
+
"F": 50,
|
| 54 |
+
"Y": 51,
|
| 55 |
+
"P": 52,
|
| 56 |
+
"K": 53,
|
| 57 |
+
"V": 54,
|
| 58 |
+
"j": 55,
|
| 59 |
+
"q": 56,
|
| 60 |
+
"x": 57,
|
| 61 |
+
"z": 58,
|
| 62 |
+
"J": 59,
|
| 63 |
+
"Q": 60,
|
| 64 |
+
"Z": 61,
|
| 65 |
+
"X": 62,
|
| 66 |
+
"3": 63,
|
| 67 |
+
"&": 64,
|
| 68 |
+
"$": 65
|
| 69 |
+
},
|
| 70 |
+
"id_to_char": {
|
| 71 |
+
"0": "<unk>",
|
| 72 |
+
"1": " ",
|
| 73 |
+
"2": "e",
|
| 74 |
+
"3": "t",
|
| 75 |
+
"4": "o",
|
| 76 |
+
"5": "a",
|
| 77 |
+
"6": "h",
|
| 78 |
+
"7": "s",
|
| 79 |
+
"8": "r",
|
| 80 |
+
"9": "n",
|
| 81 |
+
"10": "i",
|
| 82 |
+
"11": "\n",
|
| 83 |
+
"12": "l",
|
| 84 |
+
"13": "d",
|
| 85 |
+
"14": "u",
|
| 86 |
+
"15": "m",
|
| 87 |
+
"16": "y",
|
| 88 |
+
"17": ",",
|
| 89 |
+
"18": "w",
|
| 90 |
+
"19": "f",
|
| 91 |
+
"20": "c",
|
| 92 |
+
"21": "g",
|
| 93 |
+
"22": "I",
|
| 94 |
+
"23": "b",
|
| 95 |
+
"24": "p",
|
| 96 |
+
"25": ":",
|
| 97 |
+
"26": ".",
|
| 98 |
+
"27": "A",
|
| 99 |
+
"28": "v",
|
| 100 |
+
"29": "k",
|
| 101 |
+
"30": "T",
|
| 102 |
+
"31": "'",
|
| 103 |
+
"32": "E",
|
| 104 |
+
"33": "O",
|
| 105 |
+
"34": "N",
|
| 106 |
+
"35": "R",
|
| 107 |
+
"36": "S",
|
| 108 |
+
"37": "L",
|
| 109 |
+
"38": "C",
|
| 110 |
+
"39": ";",
|
| 111 |
+
"40": "W",
|
| 112 |
+
"41": "U",
|
| 113 |
+
"42": "H",
|
| 114 |
+
"43": "M",
|
| 115 |
+
"44": "B",
|
| 116 |
+
"45": "?",
|
| 117 |
+
"46": "G",
|
| 118 |
+
"47": "!",
|
| 119 |
+
"48": "D",
|
| 120 |
+
"49": "-",
|
| 121 |
+
"50": "F",
|
| 122 |
+
"51": "Y",
|
| 123 |
+
"52": "P",
|
| 124 |
+
"53": "K",
|
| 125 |
+
"54": "V",
|
| 126 |
+
"55": "j",
|
| 127 |
+
"56": "q",
|
| 128 |
+
"57": "x",
|
| 129 |
+
"58": "z",
|
| 130 |
+
"59": "J",
|
| 131 |
+
"60": "Q",
|
| 132 |
+
"61": "Z",
|
| 133 |
+
"62": "X",
|
| 134 |
+
"63": "3",
|
| 135 |
+
"64": "&",
|
| 136 |
+
"65": "$"
|
| 137 |
+
},
|
| 138 |
+
"vocab_size": 66
|
| 139 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import yaml
|
| 6 |
+
|
| 7 |
+
from llm.data.tokenizer import CharTokenizer
|
| 8 |
+
from llm.inference.generate import greedy_decode, sample_decode
|
| 9 |
+
from llm.model.transformer import Transformer
|
| 10 |
+
from llm.utils.checkpoint import load_model_only
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_yaml(path: Path):
|
| 14 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 15 |
+
return yaml.safe_load(f)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--prompt", type=str, default="First Citizen:\\n")
|
| 21 |
+
parser.add_argument("--max_length", type=int, default=200)
|
| 22 |
+
parser.add_argument("--temperature", type=float, default=0.8)
|
| 23 |
+
parser.add_argument("--top_k", type=int, default=50)
|
| 24 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
| 25 |
+
parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt")
|
| 26 |
+
parser.add_argument("--config", type=str, default="configs/model.yaml")
|
| 27 |
+
parser.add_argument("--vocab", type=str, default="data/vocab/char_vocab.json")
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
|
| 32 |
+
model_config = load_yaml(Path(args.config))
|
| 33 |
+
tokenizer = CharTokenizer(vocab_path=args.vocab)
|
| 34 |
+
model_config["vocab_size"] = tokenizer.vocab_size
|
| 35 |
+
|
| 36 |
+
model = Transformer(model_config)
|
| 37 |
+
load_model_only(model, args.checkpoint)
|
| 38 |
+
model.to(device)
|
| 39 |
+
model.eval()
|
| 40 |
+
|
| 41 |
+
input_ids = tokenizer.encode(args.prompt)
|
| 42 |
+
if not input_ids:
|
| 43 |
+
input_ids = [0]
|
| 44 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
if args.temperature == 0:
|
| 48 |
+
generated_ids = greedy_decode(
|
| 49 |
+
model, input_ids, max_length=args.max_length, device=device
|
| 50 |
+
)
|
| 51 |
+
else:
|
| 52 |
+
generated_ids = sample_decode(
|
| 53 |
+
model,
|
| 54 |
+
input_ids,
|
| 55 |
+
max_length=args.max_length,
|
| 56 |
+
temperature=args.temperature,
|
| 57 |
+
top_k=args.top_k if args.top_k > 0 else None,
|
| 58 |
+
top_p=args.top_p if args.top_p > 0 else None,
|
| 59 |
+
device=device,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
text = tokenizer.decode(generated_ids[0])
|
| 63 |
+
print(text)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
main()
|
llm/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM from Manshu - 从零实现的大语言模型
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
__version__ = "0.1.0"
|
llm/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (278 Bytes). View file
|
|
|
llm/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""数据处理模块"""
|
llm/data/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (218 Bytes). View file
|
|
|
llm/data/__pycache__/tokenizer.cpython-312.pyc
ADDED
|
Binary file (5.23 kB). View file
|
|
|
llm/data/collate.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
数据整理函数:padding / batch
|
| 3 |
+
将多个样本组合成批次
|
| 4 |
+
处理不同长度的序列(padding)
|
| 5 |
+
转换为模型需要的张量格式
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# 2026-01-23
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def collate_fn(batch, pad_token_id=0):
|
| 16 |
+
"""
|
| 17 |
+
整理批次数据(支持 padding,但所有样本长度相同时直接堆叠)
|
| 18 |
+
|
| 19 |
+
参数:
|
| 20 |
+
batch: 批次数据列表,每个元素是 (input_ids, target_ids)
|
| 21 |
+
- input_ids: 输入序列,形状为 (seq_len,)
|
| 22 |
+
- target_ids: 目标序列,形状为 (seq_len,)
|
| 23 |
+
pad_token_id: padding token ID(默认: 0)
|
| 24 |
+
|
| 25 |
+
返回:
|
| 26 |
+
(input_ids_batch, target_ids_batch)
|
| 27 |
+
- input_ids_batch: 批次输入序列,形状为 (batch_size, max_seq_len)
|
| 28 |
+
- target_ids_batch: 批次目标序列,形状为 (batch_size, max_seq_len)
|
| 29 |
+
"""
|
| 30 |
+
# 分离 input_ids 和 target_ids
|
| 31 |
+
input_ids_list = [item[0] for item in batch]
|
| 32 |
+
target_ids_list = [item[1] for item in batch]
|
| 33 |
+
|
| 34 |
+
# 检查所有样本长度是否相同
|
| 35 |
+
input_lengths = [len(ids) for ids in input_ids_list]
|
| 36 |
+
target_lengths = [len(ids) for ids in target_ids_list]
|
| 37 |
+
|
| 38 |
+
all_same_length = (
|
| 39 |
+
len(set(input_lengths)) == 1 and
|
| 40 |
+
len(set(target_lengths)) == 1 and
|
| 41 |
+
input_lengths[0] == target_lengths[0]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if all_same_length:
|
| 45 |
+
# 所有样本长度相同,直接堆叠(高效,不需要 padding)
|
| 46 |
+
input_ids_batch = torch.stack(input_ids_list, dim=0) # (batch_size, seq_len)
|
| 47 |
+
target_ids_batch = torch.stack(target_ids_list, dim=0) # (batch_size, seq_len)
|
| 48 |
+
else:
|
| 49 |
+
# 样本长度不同,需要 padding
|
| 50 |
+
max_seq_len = max(max(input_lengths), max(target_lengths))
|
| 51 |
+
|
| 52 |
+
# Padding input_ids
|
| 53 |
+
padded_input_ids = []
|
| 54 |
+
for ids in input_ids_list:
|
| 55 |
+
pad_length = max_seq_len - len(ids)
|
| 56 |
+
if pad_length > 0:
|
| 57 |
+
padded = torch.cat([ids, torch.full((pad_length,), pad_token_id, dtype=ids.dtype)])
|
| 58 |
+
else:
|
| 59 |
+
padded = ids
|
| 60 |
+
padded_input_ids.append(padded)
|
| 61 |
+
|
| 62 |
+
# Padding target_ids
|
| 63 |
+
padded_target_ids = []
|
| 64 |
+
for ids in target_ids_list:
|
| 65 |
+
pad_length = max_seq_len - len(ids)
|
| 66 |
+
if pad_length > 0:
|
| 67 |
+
padded = torch.cat([ids, torch.full((pad_length,), pad_token_id, dtype=ids.dtype)])
|
| 68 |
+
else:
|
| 69 |
+
padded = ids
|
| 70 |
+
padded_target_ids.append(padded)
|
| 71 |
+
|
| 72 |
+
# 堆叠
|
| 73 |
+
input_ids_batch = torch.stack(padded_input_ids, dim=0) # (batch_size, max_seq_len)
|
| 74 |
+
target_ids_batch = torch.stack(padded_target_ids, dim=0) # (batch_size, max_seq_len)
|
| 75 |
+
|
| 76 |
+
return input_ids_batch, target_ids_batch
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
print("=" * 60)
|
| 81 |
+
print("数据整理函数测试")
|
| 82 |
+
print("=" * 60)
|
| 83 |
+
|
| 84 |
+
# 模拟批次数据
|
| 85 |
+
batch_size = 4
|
| 86 |
+
seq_len = 10
|
| 87 |
+
|
| 88 |
+
print("\n1. 创建模拟批次数据")
|
| 89 |
+
print(f" 批次大小: {batch_size}")
|
| 90 |
+
print(f" 序列长度: {seq_len}")
|
| 91 |
+
|
| 92 |
+
# 创建模拟数据
|
| 93 |
+
batch = []
|
| 94 |
+
for i in range(batch_size):
|
| 95 |
+
input_ids = torch.randint(0, 100, (seq_len,))
|
| 96 |
+
target_ids = torch.randint(0, 100, (seq_len,))
|
| 97 |
+
batch.append((input_ids, target_ids))
|
| 98 |
+
print(f" 样本 {i}: input_ids 形状={input_ids.shape}, target_ids 形状={target_ids.shape}")
|
| 99 |
+
|
| 100 |
+
# 测试 collate_fn
|
| 101 |
+
print("\n2. 测试 collate_fn")
|
| 102 |
+
input_ids_batch, target_ids_batch = collate_fn(batch)
|
| 103 |
+
|
| 104 |
+
print(f" 输入批次形状: {input_ids_batch.shape}")
|
| 105 |
+
print(f" 目标批次形状: {target_ids_batch.shape}")
|
| 106 |
+
print(f" 期望形状: ({batch_size}, {seq_len})")
|
| 107 |
+
|
| 108 |
+
# 验证形状
|
| 109 |
+
assert input_ids_batch.shape == (batch_size, seq_len), \
|
| 110 |
+
f"输入批次形状错误: {input_ids_batch.shape} != ({batch_size}, {seq_len})"
|
| 111 |
+
assert target_ids_batch.shape == (batch_size, seq_len), \
|
| 112 |
+
f"目标批次形状错误: {target_ids_batch.shape} != ({batch_size}, {seq_len})"
|
| 113 |
+
|
| 114 |
+
print(" 形状验证通过")
|
| 115 |
+
|
| 116 |
+
# 验证数据是否正确堆叠
|
| 117 |
+
print("\n3. 验证数据堆叠")
|
| 118 |
+
for i in range(batch_size):
|
| 119 |
+
input_match = torch.equal(input_ids_batch[i], batch[i][0])
|
| 120 |
+
target_match = torch.equal(target_ids_batch[i], batch[i][1])
|
| 121 |
+
print(f" 样本 {i}: input_ids 匹配={input_match}, target_ids 匹配={target_match}")
|
| 122 |
+
assert input_match and target_match, f"样本 {i} 数据不匹配"
|
| 123 |
+
|
| 124 |
+
print(" 数据验证通过")
|
| 125 |
+
|
| 126 |
+
# 测试不同序列长度(需要 padding)
|
| 127 |
+
print("\n4. 测试不同序列长度(需要 padding)")
|
| 128 |
+
batch_variable = [
|
| 129 |
+
(torch.randint(0, 100, (5,)), torch.randint(0, 100, (5,))),
|
| 130 |
+
(torch.randint(0, 100, (8,)), torch.randint(0, 100, (8,))),
|
| 131 |
+
(torch.randint(0, 100, (10,)), torch.randint(0, 100, (10,))),
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
print(" 样本长度: [5, 8, 10]")
|
| 135 |
+
input_batch_var, target_batch_var = collate_fn(batch_variable, pad_token_id=0)
|
| 136 |
+
|
| 137 |
+
print(f" 输入批次形状: {input_batch_var.shape}")
|
| 138 |
+
print(f" 目标批次形状: {target_batch_var.shape}")
|
| 139 |
+
print(f" 期望形状: (3, 10)")
|
| 140 |
+
|
| 141 |
+
assert input_batch_var.shape == (3, 10), \
|
| 142 |
+
f"输入批次形状错误: {input_batch_var.shape} != (3, 10)"
|
| 143 |
+
assert target_batch_var.shape == (3, 10), \
|
| 144 |
+
f"目标批次形状错误: {target_batch_var.shape} != (3, 10)"
|
| 145 |
+
|
| 146 |
+
# 验证 padding 是否正确
|
| 147 |
+
print("\n5. 验证 padding")
|
| 148 |
+
for i, (orig_input, orig_target) in enumerate(batch_variable):
|
| 149 |
+
orig_len = len(orig_input)
|
| 150 |
+
# 检查原始数据是否正确
|
| 151 |
+
assert torch.equal(input_batch_var[i, :orig_len], orig_input), \
|
| 152 |
+
f"样本 {i} 的 input_ids 数据不匹配"
|
| 153 |
+
assert torch.equal(target_batch_var[i, :orig_len], orig_target), \
|
| 154 |
+
f"样本 {i} 的 target_ids 数据不匹配"
|
| 155 |
+
# 检查 padding 是否正确(应该都是 pad_token_id)
|
| 156 |
+
if orig_len < 10:
|
| 157 |
+
assert torch.all(input_batch_var[i, orig_len:] == 0), \
|
| 158 |
+
f"样本 {i} 的 input_ids padding 不正确"
|
| 159 |
+
assert torch.all(target_batch_var[i, orig_len:] == 0), \
|
| 160 |
+
f"样本 {i} 的 target_ids padding 不正确"
|
| 161 |
+
print(f" 样本 {i}: 长度={orig_len}, padding 验证通过")
|
| 162 |
+
|
| 163 |
+
print(" ✓ Padding 验证通过")
|
| 164 |
+
|
| 165 |
+
print("\n" + "=" * 60)
|
| 166 |
+
print("所有测试完成!")
|
| 167 |
+
print("=" * 60)
|
llm/data/dataset.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""数据集:CharDataset"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CharDataset:
|
| 10 |
+
"""字符级数据集"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, data_path, tokenizer, context_window=128):
|
| 13 |
+
"""
|
| 14 |
+
初始化数据集
|
| 15 |
+
|
| 16 |
+
参数:
|
| 17 |
+
data_path: 数据文件路径(.npy 格式,包含 token IDs)
|
| 18 |
+
tokenizer: 分词器(用于获取 vocab_size 等信息,实际数据已预处理)
|
| 19 |
+
context_window: 训练时的块大小(序列长度)
|
| 20 |
+
"""
|
| 21 |
+
self.tokenizer = tokenizer
|
| 22 |
+
self.context_window = context_window
|
| 23 |
+
|
| 24 |
+
# 加载数据
|
| 25 |
+
data_path = Path(data_path)
|
| 26 |
+
if not data_path.exists():
|
| 27 |
+
raise FileNotFoundError(f"数据文件不存在: {data_path}")
|
| 28 |
+
|
| 29 |
+
# 加载 numpy 数组(一维数组,包含所有 token IDs)
|
| 30 |
+
self.data = np.load(str(data_path))
|
| 31 |
+
|
| 32 |
+
# 确保数据是一维数组
|
| 33 |
+
if self.data.ndim > 1:
|
| 34 |
+
self.data = self.data.flatten()
|
| 35 |
+
|
| 36 |
+
print(f"数据集加载完成:")
|
| 37 |
+
print(f" 数据文件: {data_path}")
|
| 38 |
+
print(f" 数据长度: {len(self.data):,} tokens")
|
| 39 |
+
print(f" 块大小: {context_window}")
|
| 40 |
+
print(f" 可用样本数: {len(self)}")
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
"""
|
| 44 |
+
返回数据集大小
|
| 45 |
+
|
| 46 |
+
返回:
|
| 47 |
+
可用样本数量(滑动窗口的数量)
|
| 48 |
+
"""
|
| 49 |
+
# 每个样本需要 block_size + 1 个 token
|
| 50 |
+
# 最后一个样本从 len(data) - block_size - 1 开始
|
| 51 |
+
# 所以总样本数是 len(data) - block_size
|
| 52 |
+
return len(self.data) - self.context_window
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, idx):
|
| 55 |
+
"""
|
| 56 |
+
获取单个样本
|
| 57 |
+
|
| 58 |
+
参数:
|
| 59 |
+
idx: 样本索引
|
| 60 |
+
|
| 61 |
+
返回:
|
| 62 |
+
input_ids: 输入序列,形状为 (block_size,)
|
| 63 |
+
target_ids: 目标序列(输入序列右移一位),形状为 (block_size,)
|
| 64 |
+
"""
|
| 65 |
+
# 边界检查
|
| 66 |
+
if idx < 0 or idx >= len(self):
|
| 67 |
+
raise IndexError(f"索引 {idx} 超出范围 [0, {len(self)})")
|
| 68 |
+
|
| 69 |
+
# 获取一个长度为 context_window + 1 的序列
|
| 70 |
+
# 例如:context_window=512,则取 513 个 token
|
| 71 |
+
chunk = self.data[idx:idx + self.context_window + 1]
|
| 72 |
+
|
| 73 |
+
# 前 context_window 个作为输入,后 context_window 个作为目标(右移一位)
|
| 74 |
+
# 例如:[0, 1, 2, ..., 255] -> [1, 2, 3, ..., 256]
|
| 75 |
+
input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
|
| 76 |
+
target_ids = torch.tensor(chunk[1:], dtype=torch.long)
|
| 77 |
+
|
| 78 |
+
return input_ids, target_ids
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
# 添加项目根目录到 Python 路径
|
| 83 |
+
from pathlib import Path
|
| 84 |
+
import sys
|
| 85 |
+
|
| 86 |
+
project_root = Path(__file__).parent.parent.parent
|
| 87 |
+
sys.path.insert(0, str(project_root))
|
| 88 |
+
|
| 89 |
+
from llm.data.tokenizer import CharTokenizer
|
| 90 |
+
|
| 91 |
+
print("=" * 60)
|
| 92 |
+
print("CharDataset 测试")
|
| 93 |
+
print("=" * 60)
|
| 94 |
+
|
| 95 |
+
# 测试参数
|
| 96 |
+
data_path = Path("data/processed/train.npy")
|
| 97 |
+
vocab_path = Path("data/vocab/char_vocab.json")
|
| 98 |
+
block_size = 256
|
| 99 |
+
|
| 100 |
+
# 检查文件是否存在
|
| 101 |
+
if not data_path.exists():
|
| 102 |
+
print(f"错误:数据文件不存在: {data_path}")
|
| 103 |
+
print("请先运行: python scripts/preprocess.py")
|
| 104 |
+
sys.exit(1)
|
| 105 |
+
|
| 106 |
+
if not vocab_path.exists():
|
| 107 |
+
print(f"错误:词汇表文件不存在: {vocab_path}")
|
| 108 |
+
print("请先运行: python scripts/preprocess.py")
|
| 109 |
+
sys.exit(1)
|
| 110 |
+
|
| 111 |
+
# 加载分词器
|
| 112 |
+
print("\n1. 加载分词器")
|
| 113 |
+
tokenizer = CharTokenizer(vocab_path=str(vocab_path))
|
| 114 |
+
print(f" 词汇表大小: {tokenizer.vocab_size}")
|
| 115 |
+
|
| 116 |
+
# 创建数据集
|
| 117 |
+
print("\n2. 创建数据集")
|
| 118 |
+
dataset = CharDataset(
|
| 119 |
+
data_path=str(data_path),
|
| 120 |
+
tokenizer=tokenizer,
|
| 121 |
+
block_size=block_size
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# 测试数据集长度
|
| 125 |
+
print(f"\n3. 数据集信息")
|
| 126 |
+
print(f" 数据集大小: {len(dataset):,} 个样本")
|
| 127 |
+
print(f" 每个样本输入长度: {block_size}")
|
| 128 |
+
print(f" 每个样本目标长度: {block_size}")
|
| 129 |
+
|
| 130 |
+
# 测试获取样本
|
| 131 |
+
print("\n4. 测试获取样本")
|
| 132 |
+
input_ids, target_ids = dataset[0]
|
| 133 |
+
print(f" 样本 0:")
|
| 134 |
+
print(f" 输入形状: {input_ids.shape}")
|
| 135 |
+
print(f" 目标形状: {target_ids.shape}")
|
| 136 |
+
print(f" 输入前10个 token IDs: {input_ids[:10].tolist()}")
|
| 137 |
+
print(f" 目标前10个 token IDs: {target_ids[:10].tolist()}")
|
| 138 |
+
|
| 139 |
+
# 验证目标是否正确(应该是输入右移一位)
|
| 140 |
+
print(f"\n5. 验证目标序列")
|
| 141 |
+
expected_target = input_ids[1:].tolist()
|
| 142 |
+
actual_target = target_ids[:-1].tolist()
|
| 143 |
+
is_correct = expected_target == actual_target
|
| 144 |
+
print(f" 目标序列是否正确(右移一位): {is_correct}")
|
| 145 |
+
if not is_correct:
|
| 146 |
+
print(f" 期望: {expected_target[:10]}")
|
| 147 |
+
print(f" 实际: {actual_target[:10]}")
|
| 148 |
+
|
| 149 |
+
# 测试解码
|
| 150 |
+
print(f"\n6. 测试解码")
|
| 151 |
+
input_text = tokenizer.decode(input_ids[:50])
|
| 152 |
+
target_text = tokenizer.decode(target_ids[:50])
|
| 153 |
+
print(f" 输入文本(前50个字符): {repr(input_text)}")
|
| 154 |
+
print(f" 目标文本(前50个字符): {repr(target_text)}")
|
| 155 |
+
|
| 156 |
+
# 测试多个样本
|
| 157 |
+
print(f"\n7. 测试多个样本")
|
| 158 |
+
for i in [0, 100, len(dataset) - 1]:
|
| 159 |
+
input_ids, target_ids = dataset[i]
|
| 160 |
+
print(f" 样本 {i}: 输入形状 {input_ids.shape}, 目标形状 {target_ids.shape}")
|
| 161 |
+
|
| 162 |
+
print("\n" + "=" * 60)
|
| 163 |
+
print("所有测试完成!")
|
| 164 |
+
print("=" * 60)
|
llm/data/tokenizer.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""分词器:CharTokenizer,从切割文本得到 token IDs"""
|
| 2 |
+
# 2026-01-22
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from collections import Counter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CharTokenizer:
|
| 10 |
+
"""字符级分词器,英文的"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, vocab_path=None):
|
| 13 |
+
"""
|
| 14 |
+
初始化分词器
|
| 15 |
+
|
| 16 |
+
参数:
|
| 17 |
+
vocab_path: 词汇表文件路径(JSON 格式)
|
| 18 |
+
"""
|
| 19 |
+
self.vocab_path = vocab_path
|
| 20 |
+
self.char_to_id = {} # 字符 -> ID 的映射
|
| 21 |
+
self.id_to_char = {} # ID -> 字符的映射
|
| 22 |
+
self.vocab_size = 0
|
| 23 |
+
|
| 24 |
+
if vocab_path and Path(vocab_path).exists():
|
| 25 |
+
self.load_vocab(vocab_path)
|
| 26 |
+
|
| 27 |
+
def build_vocab(self, texts):
|
| 28 |
+
"""
|
| 29 |
+
从文本构建词汇表
|
| 30 |
+
|
| 31 |
+
参数:
|
| 32 |
+
texts: 文本列表或单个文本字符串
|
| 33 |
+
"""
|
| 34 |
+
# 统一处理:如果是单个 str 类型的字符串,转为列表
|
| 35 |
+
if isinstance(texts, str):
|
| 36 |
+
texts = [texts]
|
| 37 |
+
|
| 38 |
+
# 合并所有字符串文本
|
| 39 |
+
all_chars = ''.join(texts)
|
| 40 |
+
|
| 41 |
+
# 统计字符(Python 自动按字符遍历,统计每个字符出现的次数)
|
| 42 |
+
char_counts = Counter(all_chars)
|
| 43 |
+
|
| 44 |
+
# 创建字符到 ID 的映射
|
| 45 |
+
self.char_to_id = {
|
| 46 |
+
'<unk>': 0, # 未知字符
|
| 47 |
+
# '<pad>': 1, # 填充,滑动窗口,不需要 '<pad>'
|
| 48 |
+
# '<bos>': 1, # 开始字符
|
| 49 |
+
# '<eos>': 2, # 结束字符
|
| 50 |
+
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# 按频率排序添加字符
|
| 54 |
+
sorted_chars = sorted(char_counts.items(), key=lambda x: x[1], reverse=True)
|
| 55 |
+
for char, count in sorted_chars:
|
| 56 |
+
if char not in self.char_to_id:
|
| 57 |
+
self.char_to_id[char] = len(self.char_to_id)
|
| 58 |
+
|
| 59 |
+
# 创建反向映射
|
| 60 |
+
self.id_to_char = {id: char for char, id in self.char_to_id.items()}
|
| 61 |
+
self.vocab_size = len(self.char_to_id)
|
| 62 |
+
|
| 63 |
+
def encode(self, text):
|
| 64 |
+
"""
|
| 65 |
+
编码:将文本切割成字符,然后转换为 ID
|
| 66 |
+
|
| 67 |
+
参数:
|
| 68 |
+
text: 输入文本字符串
|
| 69 |
+
|
| 70 |
+
返回:
|
| 71 |
+
token_ids: token ID 列表
|
| 72 |
+
"""
|
| 73 |
+
token_ids = []
|
| 74 |
+
# 遍历字符串,每个字符自动分离
|
| 75 |
+
for char in text:
|
| 76 |
+
# 查找字符对应的 ID,找不到则使用 <unk>
|
| 77 |
+
char_id = self.char_to_id.get(char, self.char_to_id.get('<unk>', 0))
|
| 78 |
+
token_ids.append(char_id)
|
| 79 |
+
return token_ids
|
| 80 |
+
|
| 81 |
+
def decode(self, token_ids):
|
| 82 |
+
"""
|
| 83 |
+
解码:将 ID 列表转换回文本
|
| 84 |
+
|
| 85 |
+
参数:
|
| 86 |
+
token_ids: token ID 列表或张量
|
| 87 |
+
|
| 88 |
+
返回:
|
| 89 |
+
text: 解码后的文本字符串
|
| 90 |
+
"""
|
| 91 |
+
# 如果是 PyTorch 张量,转换为列表
|
| 92 |
+
if hasattr(token_ids, 'tolist'):
|
| 93 |
+
token_ids = token_ids.tolist()
|
| 94 |
+
|
| 95 |
+
# 将每个 ID 转换为字符
|
| 96 |
+
chars = []
|
| 97 |
+
for id in token_ids:
|
| 98 |
+
char = self.id_to_char.get(id, '<unk>')
|
| 99 |
+
# 过滤特殊 token
|
| 100 |
+
if char not in ['<unk>']:
|
| 101 |
+
chars.append(char)
|
| 102 |
+
|
| 103 |
+
# 拼接成文本
|
| 104 |
+
return ''.join(chars)
|
| 105 |
+
|
| 106 |
+
def save_vocab(self, vocab_path):
|
| 107 |
+
"""保存词汇表到文件"""
|
| 108 |
+
vocab_data = {
|
| 109 |
+
'char_to_id': self.char_to_id,
|
| 110 |
+
'id_to_char': {str(k): v for k, v in self.id_to_char.items()},
|
| 111 |
+
'vocab_size': self.vocab_size
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
Path(vocab_path).parent.mkdir(parents=True, exist_ok=True)
|
| 115 |
+
|
| 116 |
+
with open(vocab_path, 'w', encoding='utf-8') as f:
|
| 117 |
+
json.dump(vocab_data, f, ensure_ascii=False, indent=2)
|
| 118 |
+
|
| 119 |
+
def load_vocab(self, vocab_path):
|
| 120 |
+
"""从文件加载词汇表"""
|
| 121 |
+
with open(vocab_path, 'r', encoding='utf-8') as f:
|
| 122 |
+
vocab_data = json.load(f)
|
| 123 |
+
|
| 124 |
+
self.char_to_id = vocab_data['char_to_id']
|
| 125 |
+
self.id_to_char = {int(k): v for k, v in vocab_data['id_to_char'].items()}
|
| 126 |
+
self.vocab_size = vocab_data['vocab_size']
|
llm/inference/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""推理模块:文本生成"""
|
| 2 |
+
|
| 3 |
+
from llm.inference.generate import greedy_decode, sample_decode
|
| 4 |
+
|
| 5 |
+
__all__ = ['greedy_decode', 'sample_decode']
|
llm/inference/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (359 Bytes). View file
|
|
|
llm/inference/__pycache__/generate.cpython-312.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
llm/inference/generate.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""文本生成:greedy / sampling"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from llm.model.attention import create_causal_mask
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def greedy_decode(model, input_ids, max_length=100, device="cpu", stop_token_ids=None):
|
| 10 |
+
"""
|
| 11 |
+
贪心解码:每次选择概率最高的 token
|
| 12 |
+
|
| 13 |
+
参数:
|
| 14 |
+
model: Transformer 模型
|
| 15 |
+
input_ids: 输入 token IDs,形状为 (batch_size, seq_len)
|
| 16 |
+
max_length: 最大生成长度(不包括输入长度)
|
| 17 |
+
device: 设备
|
| 18 |
+
stop_token_ids: 停止 token ID 列表,遇到这些 token 时提前停止(可选)
|
| 19 |
+
|
| 20 |
+
返回:
|
| 21 |
+
generated_ids: 生成的 token IDs,形状为 (batch_size, total_length)
|
| 22 |
+
"""
|
| 23 |
+
model.eval()
|
| 24 |
+
input_ids = input_ids.to(device)
|
| 25 |
+
generated_ids = input_ids.clone()
|
| 26 |
+
|
| 27 |
+
if stop_token_ids is None:
|
| 28 |
+
stop_token_ids = []
|
| 29 |
+
elif isinstance(stop_token_ids, int):
|
| 30 |
+
stop_token_ids = [stop_token_ids]
|
| 31 |
+
|
| 32 |
+
batch_size = generated_ids.size(0)
|
| 33 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 34 |
+
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
for step in range(max_length):
|
| 37 |
+
# 如果所有序列都已完成,提前退出
|
| 38 |
+
if finished.all():
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
# 获取当前序列长度
|
| 42 |
+
seq_len = generated_ids.size(1)
|
| 43 |
+
|
| 44 |
+
# 创建因果掩码
|
| 45 |
+
causal_mask = create_causal_mask(seq_len, device=device)
|
| 46 |
+
|
| 47 |
+
# 前向传播
|
| 48 |
+
logits = model(generated_ids, mask=causal_mask) # (batch_size, seq_len, vocab_size)
|
| 49 |
+
|
| 50 |
+
# 获取最后一个位置的 logits
|
| 51 |
+
next_token_logits = logits[:, -1, :] # (batch_size, vocab_size)
|
| 52 |
+
|
| 53 |
+
# 选择概率最高的 token(贪心)
|
| 54 |
+
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) # (batch_size, 1)
|
| 55 |
+
|
| 56 |
+
# 检查是否遇到停止 token
|
| 57 |
+
if stop_token_ids:
|
| 58 |
+
for stop_id in stop_token_ids:
|
| 59 |
+
finished = finished | (next_token_id.squeeze(-1) == stop_id)
|
| 60 |
+
|
| 61 |
+
# 拼接生成的 token
|
| 62 |
+
generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
|
| 63 |
+
|
| 64 |
+
return generated_ids
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def sample_decode(
|
| 68 |
+
model,
|
| 69 |
+
input_ids,
|
| 70 |
+
max_length=100,
|
| 71 |
+
temperature=1.0,
|
| 72 |
+
top_k=0,
|
| 73 |
+
top_p=0.0,
|
| 74 |
+
device="cpu",
|
| 75 |
+
stop_token_ids=None
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
采样解码:使用温度采样、top-k 采样和 top-p (nucleus) 采样
|
| 79 |
+
|
| 80 |
+
参数:
|
| 81 |
+
model: Transformer 模型
|
| 82 |
+
input_ids: 输入 token IDs,形状为 (batch_size, seq_len)
|
| 83 |
+
max_length: 最大生成长度(不包括输入长度)
|
| 84 |
+
temperature: 采样温度(越高越随机,越低越确定,0=贪心)
|
| 85 |
+
top_k: Top-K 采样(只从概率最高的 k 个 token 中采样,0=禁用)
|
| 86 |
+
top_p: Top-P (Nucleus) 采样(保留累积概率达到 p 的 token,0.0=禁用)
|
| 87 |
+
device: 设备
|
| 88 |
+
stop_token_ids: 停止 token ID 列表,遇到这些 token 时提前停止(可选)
|
| 89 |
+
|
| 90 |
+
返回:
|
| 91 |
+
generated_ids: 生成的 token IDs,形状为 (batch_size, total_length)
|
| 92 |
+
"""
|
| 93 |
+
model.eval()
|
| 94 |
+
input_ids = input_ids.to(device)
|
| 95 |
+
generated_ids = input_ids.clone()
|
| 96 |
+
|
| 97 |
+
if stop_token_ids is None:
|
| 98 |
+
stop_token_ids = []
|
| 99 |
+
elif isinstance(stop_token_ids, int):
|
| 100 |
+
stop_token_ids = [stop_token_ids]
|
| 101 |
+
|
| 102 |
+
batch_size = generated_ids.size(0)
|
| 103 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 104 |
+
|
| 105 |
+
# 如果温度为 0,使用贪心解码
|
| 106 |
+
if temperature == 0:
|
| 107 |
+
return greedy_decode(model, input_ids, max_length, device, stop_token_ids)
|
| 108 |
+
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
for step in range(max_length):
|
| 111 |
+
# 如果所有序列都已完成,提前退出
|
| 112 |
+
if finished.all():
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
# 获取当前序列长度
|
| 116 |
+
seq_len = generated_ids.size(1)
|
| 117 |
+
|
| 118 |
+
# 创建因果掩码
|
| 119 |
+
causal_mask = create_causal_mask(seq_len, device=device)
|
| 120 |
+
|
| 121 |
+
# 前向传播
|
| 122 |
+
logits = model(generated_ids, mask=causal_mask) # (batch_size, seq_len, vocab_size)
|
| 123 |
+
|
| 124 |
+
# 获取最后一个位置的 logits
|
| 125 |
+
next_token_logits = logits[:, -1, :] # (batch_size, vocab_size)
|
| 126 |
+
|
| 127 |
+
# 应用温度
|
| 128 |
+
if temperature != 1.0:
|
| 129 |
+
next_token_logits = next_token_logits / temperature
|
| 130 |
+
|
| 131 |
+
# Top-K 采样
|
| 132 |
+
if top_k is not None and top_k > 0:
|
| 133 |
+
# 获取 top-k 的值和索引
|
| 134 |
+
top_k_logits, top_k_indices = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)), dim=-1)
|
| 135 |
+
|
| 136 |
+
# 创建掩码,将非 top-k 的位置设为负无穷
|
| 137 |
+
mask = torch.full_like(next_token_logits, float('-inf'))
|
| 138 |
+
mask.scatter_(-1, top_k_indices, top_k_logits)
|
| 139 |
+
next_token_logits = mask
|
| 140 |
+
|
| 141 |
+
# Top-P (Nucleus) 采样
|
| 142 |
+
if top_p is not None and top_p > 0.0:
|
| 143 |
+
# 先计算概率分布
|
| 144 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 145 |
+
|
| 146 |
+
# 按概率降序排序
|
| 147 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
| 148 |
+
|
| 149 |
+
# 计算累积概率
|
| 150 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 151 |
+
|
| 152 |
+
# 找到第一个累积概率超过 top_p 的位置
|
| 153 |
+
# 保留该位置及之前的所有 token(至少保留第一个 token)
|
| 154 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 155 |
+
# 通过移位确保至少保留第一个 token
|
| 156 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 157 |
+
sorted_indices_to_remove[..., 0] = False
|
| 158 |
+
|
| 159 |
+
# 将排序后的掩码映射回原始索引
|
| 160 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 161 |
+
-1, sorted_indices, sorted_indices_to_remove
|
| 162 |
+
)
|
| 163 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
| 164 |
+
|
| 165 |
+
# 应用 softmax 得到概率分布
|
| 166 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 167 |
+
|
| 168 |
+
# 从概率分布中采样
|
| 169 |
+
next_token_id = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
| 170 |
+
|
| 171 |
+
# 检查是否遇到停止 token
|
| 172 |
+
if stop_token_ids:
|
| 173 |
+
for stop_id in stop_token_ids:
|
| 174 |
+
finished = finished | (next_token_id.squeeze(-1) == stop_id)
|
| 175 |
+
|
| 176 |
+
# 拼接生成的 token
|
| 177 |
+
generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
|
| 178 |
+
|
| 179 |
+
return generated_ids
|
llm/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""模型组件模块"""
|
llm/model/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (219 Bytes). View file
|
|
|
llm/model/__pycache__/attention.cpython-312.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
llm/model/__pycache__/block.cpython-312.pyc
ADDED
|
Binary file (6.39 kB). View file
|
|
|
llm/model/__pycache__/embedding.cpython-312.pyc
ADDED
|
Binary file (1.45 kB). View file
|
|
|
llm/model/__pycache__/ffn.cpython-312.pyc
ADDED
|
Binary file (5.47 kB). View file
|
|
|
llm/model/__pycache__/norm.cpython-312.pyc
ADDED
|
Binary file (7.19 kB). View file
|
|
|
llm/model/__pycache__/rope.cpython-312.pyc
ADDED
|
Binary file (5.92 kB). View file
|
|
|
llm/model/__pycache__/transformer.cpython-312.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
llm/model/attention.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""注意力机制:MHA / GQA / 滑动窗口注意力"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def scaled_dot_product_attention(
|
| 10 |
+
q, k, v, mask=None, sliding_window=None, sliding_window_overlap=True
|
| 11 |
+
):
|
| 12 |
+
"""
|
| 13 |
+
标量点积注意力(基础函数)
|
| 14 |
+
|
| 15 |
+
参数:
|
| 16 |
+
q: Query,形状为 (batch_size, num_heads, seq_len, head_dim)
|
| 17 |
+
k: Key,形状为 (batch_size, num_heads, seq_len, head_dim)
|
| 18 |
+
v: Value,形状为 (batch_size, num_heads, seq_len, head_dim)
|
| 19 |
+
mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
|
| 20 |
+
True 表示可以关注,False 表示不能关注
|
| 21 |
+
sliding_window: 滑动窗口大小(可选,如果提供则应用滑动窗口注意力)
|
| 22 |
+
sliding_window_overlap: 是否允许滑动窗口重叠
|
| 23 |
+
|
| 24 |
+
返回:
|
| 25 |
+
output: 注意力输出,形状为 (batch_size, num_heads, seq_len, head_dim)
|
| 26 |
+
attn_weights: 注意力权重,形状为 (batch_size, num_heads, seq_len, seq_len)
|
| 27 |
+
"""
|
| 28 |
+
batch_size, num_heads, seq_len, head_dim = q.shape
|
| 29 |
+
|
| 30 |
+
# 计算注意力分数:Q @ K^T
|
| 31 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
|
| 32 |
+
|
| 33 |
+
# 先应用滑动窗口掩码(如果有)
|
| 34 |
+
# 注意:滑动窗口掩码已经包含了因果掩码,所以如果使用滑动窗口,通常不需要额外的 mask
|
| 35 |
+
if sliding_window is not None:
|
| 36 |
+
scores = apply_sliding_window_mask(
|
| 37 |
+
scores, sliding_window, seq_len, sliding_window_overlap
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# 应用普通掩码:mask 中为 False 的位置设为 -inf
|
| 41 |
+
# mask 常见只有:(L, L)(所有 batch / head 共用,常用于 causal mask);(B, L, L)(每个 batch 和 head 独立,常用于 padding mask)
|
| 42 |
+
if mask is not None:
|
| 43 |
+
# 如果 mask 是 2D,扩展为 4D 以匹配 scores 的形状
|
| 44 |
+
if mask.dim() == 2:
|
| 45 |
+
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
|
| 46 |
+
elif mask.dim() == 3:
|
| 47 |
+
mask = mask.unsqueeze(1) # (batch_size, 1, seq_len, seq_len)
|
| 48 |
+
# mask 为 False 的位置设为 -inf,softmax 后变为 0
|
| 49 |
+
scores = scores.masked_fill(~mask, float("-inf"))
|
| 50 |
+
|
| 51 |
+
# Softmax:转换为概率分布
|
| 52 |
+
attn_weights = torch.softmax(scores, dim=-1)
|
| 53 |
+
|
| 54 |
+
# 加权求和:attn_weights @ V
|
| 55 |
+
output = torch.matmul(attn_weights, v)
|
| 56 |
+
|
| 57 |
+
return output, attn_weights
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def create_causal_mask(seq_len, device="cpu"):
|
| 61 |
+
"""
|
| 62 |
+
创建因果掩码(下三角矩阵)
|
| 63 |
+
|
| 64 |
+
参数:
|
| 65 |
+
seq_len: 序列长度
|
| 66 |
+
device: 设备
|
| 67 |
+
|
| 68 |
+
返回:
|
| 69 |
+
因果掩码,形状为 (seq_len, seq_len),下三角为 True
|
| 70 |
+
"""
|
| 71 |
+
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
|
| 72 |
+
return mask.bool()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def apply_sliding_window_mask(scores, window_size, seq_len, overlap=True):
|
| 76 |
+
"""
|
| 77 |
+
应用滑动窗口掩码
|
| 78 |
+
|
| 79 |
+
参数:
|
| 80 |
+
scores: 注意力分数,形状为 (batch_size, num_heads, seq_len, seq_len)
|
| 81 |
+
window_size: 窗口大小
|
| 82 |
+
seq_len: 序列长度
|
| 83 |
+
overlap: 是否允许窗口重叠(如果为 True,使用对称窗口;如果为 False,使用非对称窗口)
|
| 84 |
+
|
| 85 |
+
返回:
|
| 86 |
+
应用掩码后的分数
|
| 87 |
+
"""
|
| 88 |
+
batch_size, num_heads, seq_len, _ = scores.shape
|
| 89 |
+
|
| 90 |
+
# 创建滑动窗口掩码
|
| 91 |
+
# 位置 i 可以关注的位置范围
|
| 92 |
+
window_mask = torch.zeros(seq_len, seq_len, device=scores.device, dtype=torch.bool)
|
| 93 |
+
|
| 94 |
+
if overlap:
|
| 95 |
+
# 对称窗口:位置 i 可以关注 [max(0, i-window_size//2), min(seq_len, i+window_size//2+1)]
|
| 96 |
+
for i in range(seq_len):
|
| 97 |
+
start = max(0, i - window_size // 2)
|
| 98 |
+
end = min(seq_len, i + window_size // 2 + 1)
|
| 99 |
+
window_mask[i, start:end] = True
|
| 100 |
+
else:
|
| 101 |
+
# 非对称窗口:位置 i 可以关注 [max(0, i-window_size+1), i+1]
|
| 102 |
+
for i in range(seq_len):
|
| 103 |
+
start = max(0, i - window_size + 1)
|
| 104 |
+
end = i + 1
|
| 105 |
+
window_mask[i, start:end] = True
|
| 106 |
+
|
| 107 |
+
# 结合因果掩码(下三角):既要满足因果性,又要满足窗口限制
|
| 108 |
+
causal_mask = torch.tril(
|
| 109 |
+
torch.ones(seq_len, seq_len, device=scores.device, dtype=torch.bool)
|
| 110 |
+
)
|
| 111 |
+
combined_mask = causal_mask & window_mask
|
| 112 |
+
|
| 113 |
+
# 扩展维度以匹配 scores 的形状
|
| 114 |
+
combined_mask = combined_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
|
| 115 |
+
|
| 116 |
+
# 应用掩码:不能关注的位置设为 -inf
|
| 117 |
+
scores = scores.masked_fill(~combined_mask, float("-inf"))
|
| 118 |
+
|
| 119 |
+
return scores
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class MultiHeadAttention(nn.Module):
|
| 123 |
+
"""
|
| 124 |
+
多头注意力(MHA)
|
| 125 |
+
|
| 126 |
+
所有 Query、Key、Value 头都是独立的
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, hidden_size, num_heads):
|
| 130 |
+
"""
|
| 131 |
+
初始化多头注意力
|
| 132 |
+
|
| 133 |
+
参数:
|
| 134 |
+
hidden_size: 隐藏层维度
|
| 135 |
+
num_heads: 注意力头数
|
| 136 |
+
"""
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.hidden_size = hidden_size
|
| 139 |
+
self.num_heads = num_heads
|
| 140 |
+
self.head_dim = hidden_size // num_heads
|
| 141 |
+
|
| 142 |
+
# 确保 hidden_size 能被 num_heads 整除
|
| 143 |
+
assert hidden_size % num_heads == 0, (
|
| 144 |
+
f"hidden_size ({hidden_size}) 必须能被 num_heads ({num_heads}) 整除"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Query、Key、Value 投影,各自创建了一个线性层
|
| 148 |
+
self.q_proj = nn.Linear(hidden_size, hidden_size)
|
| 149 |
+
self.k_proj = nn.Linear(hidden_size, hidden_size)
|
| 150 |
+
self.v_proj = nn.Linear(hidden_size, hidden_size)
|
| 151 |
+
|
| 152 |
+
# 输出投影,创建了一个线性层
|
| 153 |
+
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
| 154 |
+
|
| 155 |
+
def forward(
|
| 156 |
+
self, x, mask=None, rope=None, sliding_window=None, sliding_window_overlap=True
|
| 157 |
+
):
|
| 158 |
+
"""
|
| 159 |
+
前向传播
|
| 160 |
+
|
| 161 |
+
参数:
|
| 162 |
+
x: 输入张量,形状为 (batch_size, seq_len, hidden_size)
|
| 163 |
+
mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
|
| 164 |
+
rope: RoPE 位置编码模块(可选)
|
| 165 |
+
sliding_window: 滑动窗口大小(可选,如果提供则应用滑动窗口注意力)
|
| 166 |
+
sliding_window_overlap: 是否允许滑动窗口重叠
|
| 167 |
+
|
| 168 |
+
返回:
|
| 169 |
+
输出张量,形状为 (batch_size, seq_len, hidden_size)
|
| 170 |
+
"""
|
| 171 |
+
batch_size, seq_len, hidden_size = x.shape
|
| 172 |
+
|
| 173 |
+
# 1. 通过投影层得到 Q、K、V
|
| 174 |
+
q = self.q_proj(x) # (batch_size, seq_len, hidden_size)
|
| 175 |
+
k = self.k_proj(x)
|
| 176 |
+
v = self.v_proj(x)
|
| 177 |
+
|
| 178 |
+
# 2. 重塑为多头形式拆分成多个头
|
| 179 |
+
# (batch_size, seq_len, hidden_size) -> (batch_size, seq_len, num_heads, head_dim)
|
| 180 |
+
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 181 |
+
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 182 |
+
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 183 |
+
|
| 184 |
+
# 3. 转置以便进行注意力计算
|
| 185 |
+
# (batch_size, num_heads, seq_len, head_dim)
|
| 186 |
+
# (B, L, H, D_h) -> (B, H, L, D_h)
|
| 187 |
+
q = q.transpose(1, 2)
|
| 188 |
+
k = k.transpose(1, 2)
|
| 189 |
+
v = v.transpose(1, 2)
|
| 190 |
+
|
| 191 |
+
# 4. 应用 RoPE 位置编码(如果有)
|
| 192 |
+
if rope is not None:
|
| 193 |
+
q, k = rope(q, k)
|
| 194 |
+
|
| 195 |
+
# 5. 计算注意力
|
| 196 |
+
# attn_output: (B, H, L, D_h), 第 h 个 head 中,第 i 个 token,从所有 token 的 Value 中,按注意力权重“加权融合”出来的向量
|
| 197 |
+
# attn_weights: (B, H, L, L),在 batch=b、head=h 中,第 i 个 token 对所有 token 的注意力权重
|
| 198 |
+
attn_output, attn_weights = scaled_dot_product_attention(
|
| 199 |
+
q, k, v, mask, sliding_window, sliding_window_overlap
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# 6. 转置回来并拼接
|
| 203 |
+
# (batch_size, num_heads, seq_len, head_dim) -> (batch_size, seq_len, num_heads, head_dim)
|
| 204 |
+
attn_output = attn_output.transpose(1, 2)
|
| 205 |
+
|
| 206 |
+
# 7. 重塑为原始形状
|
| 207 |
+
# (batch_size, seq_len, num_heads, head_dim) -> (batch_size, seq_len, hidden_size)
|
| 208 |
+
attn_output = attn_output.contiguous().view(batch_size, seq_len, hidden_size)
|
| 209 |
+
|
| 210 |
+
# 8. 通过输出投影层
|
| 211 |
+
output = self.o_proj(attn_output)
|
| 212 |
+
|
| 213 |
+
return output
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class GroupedQueryAttention(nn.Module):
|
| 217 |
+
"""
|
| 218 |
+
分组查询注意力(GQA)
|
| 219 |
+
|
| 220 |
+
多个 Query 头共享一组 Key-Value 头,减少 KV Cache 的内存占用
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(self, hidden_size, num_heads, num_kv_heads):
|
| 224 |
+
"""
|
| 225 |
+
初始化 GQA
|
| 226 |
+
|
| 227 |
+
参数:
|
| 228 |
+
hidden_size: 隐藏层维度
|
| 229 |
+
num_heads: Query 头数
|
| 230 |
+
num_kv_heads: Key-Value 头数
|
| 231 |
+
"""
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.hidden_size = hidden_size
|
| 234 |
+
self.num_heads = num_heads
|
| 235 |
+
self.num_kv_heads = num_kv_heads
|
| 236 |
+
self.head_dim = hidden_size // num_heads
|
| 237 |
+
|
| 238 |
+
# 确保 hidden_size 能被 num_heads 整除
|
| 239 |
+
assert hidden_size % num_heads == 0, (
|
| 240 |
+
f"hidden_size ({hidden_size}) 必须能被 num_heads ({num_heads}) 整除"
|
| 241 |
+
)
|
| 242 |
+
# 确保 num_heads 能被 num_kv_heads 整除
|
| 243 |
+
assert num_heads % num_kv_heads == 0, (
|
| 244 |
+
f"num_heads ({num_heads}) 必须能被 num_kv_heads ({num_kv_heads}) 整除"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Query 投影:每个头都有独立的 Q
|
| 248 |
+
self.q_proj = nn.Linear(hidden_size, hidden_size)
|
| 249 |
+
|
| 250 |
+
# Key 和 Value 投影:多个 Query 头共享 KV
|
| 251 |
+
# D -> (num_kv_heads × head_dim)
|
| 252 |
+
# 之后(B, L, 256) -> (B, L, 4, 64)
|
| 253 |
+
kv_hidden_size = num_kv_heads * self.head_dim
|
| 254 |
+
self.k_proj = nn.Linear(hidden_size, kv_hidden_size)
|
| 255 |
+
self.v_proj = nn.Linear(hidden_size, kv_hidden_size)
|
| 256 |
+
|
| 257 |
+
# 输出投影
|
| 258 |
+
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
| 259 |
+
|
| 260 |
+
def forward(
|
| 261 |
+
self, x, mask=None, rope=None, sliding_window=None, sliding_window_overlap=True
|
| 262 |
+
):
|
| 263 |
+
"""
|
| 264 |
+
前向传播
|
| 265 |
+
|
| 266 |
+
参数:
|
| 267 |
+
x: 输入张量,形状为 (batch_size, seq_len, hidden_size)
|
| 268 |
+
mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
|
| 269 |
+
rope: RoPE 位置编码模块(可选)
|
| 270 |
+
sliding_window: 滑动窗口大小(可选,如果提供则应用滑动窗口注意力)
|
| 271 |
+
sliding_window_overlap: 是否允许滑动窗口重叠
|
| 272 |
+
|
| 273 |
+
返回:
|
| 274 |
+
输出张量,形状为 (batch_size, seq_len, hidden_size)
|
| 275 |
+
"""
|
| 276 |
+
batch_size, seq_len, hidden_size = x.shape
|
| 277 |
+
|
| 278 |
+
# 1. 通过投影层
|
| 279 |
+
q = self.q_proj(x) # (batch_size, seq_len, hidden_size)
|
| 280 |
+
k = self.k_proj(x) # (batch_size, seq_len, num_kv_heads * head_dim)
|
| 281 |
+
v = self.v_proj(x) # (batch_size, seq_len, num_kv_heads * head_dim)
|
| 282 |
+
|
| 283 |
+
# 2. 重塑 Q
|
| 284 |
+
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 285 |
+
q = q.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
|
| 286 |
+
|
| 287 |
+
# 3. 重塑 K 和 V
|
| 288 |
+
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
| 289 |
+
k = k.transpose(1, 2) # (batch_size, num_kv_heads, seq_len, head_dim)
|
| 290 |
+
|
| 291 |
+
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
| 292 |
+
v = v.transpose(1, 2) # (batch_size, num_kv_heads, seq_len, head_dim)
|
| 293 |
+
|
| 294 |
+
# 4. 应用 RoPE 位置编码(如果有)
|
| 295 |
+
if rope is not None:
|
| 296 |
+
q, k = rope(q, k)
|
| 297 |
+
|
| 298 |
+
# 5. 重复 K 和 V 以匹配 Q 的头数
|
| 299 |
+
# 例如:10 个 Q 头,2 个 KV 头,每个 KV 头需要重复 5 次
|
| 300 |
+
repeat_kv = self.num_heads // self.num_kv_heads
|
| 301 |
+
k = k.repeat_interleave(
|
| 302 |
+
repeat_kv, dim=1
|
| 303 |
+
) # (batch_size, num_heads, seq_len, head_dim)
|
| 304 |
+
v = v.repeat_interleave(
|
| 305 |
+
repeat_kv, dim=1
|
| 306 |
+
) # (batch_size, num_heads, seq_len, head_dim)
|
| 307 |
+
|
| 308 |
+
# 6. 计算注意力
|
| 309 |
+
# attn_output: (B, H, L, D_h), 第 h 个 head 中,第 i 个 token,从所有 token 的 Value 中,按注意力权重"加权融合"出来的向量
|
| 310 |
+
# attn_weights: (B, H, L, L),在 batch=b、head=h 中,第 i 个 token 对所有 token 的注意力权重
|
| 311 |
+
attn_output, attn_weights = scaled_dot_product_attention(
|
| 312 |
+
q, k, v, mask, sliding_window, sliding_window_overlap
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# 7. 转置并重塑
|
| 316 |
+
attn_output = attn_output.transpose(
|
| 317 |
+
1, 2
|
| 318 |
+
) # (batch_size, seq_len, num_heads, head_dim)
|
| 319 |
+
|
| 320 |
+
# (batch_size, seq_len, num_heads, head_dim) -> (batch_size, seq_len, hidden_size)
|
| 321 |
+
attn_output = attn_output.contiguous().view(batch_size, seq_len, hidden_size)
|
| 322 |
+
|
| 323 |
+
# 8. 输出投影
|
| 324 |
+
output = self.o_proj(attn_output)
|
| 325 |
+
|
| 326 |
+
return output
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
if __name__ == "__main__":
|
| 330 |
+
print("=" * 60)
|
| 331 |
+
print("注意力机制测试")
|
| 332 |
+
print("=" * 60)
|
| 333 |
+
|
| 334 |
+
# 测试参数(与 configs/model.yaml 一致)
|
| 335 |
+
hidden_size = 320
|
| 336 |
+
num_heads = 10
|
| 337 |
+
num_kv_heads = 2
|
| 338 |
+
head_dim = hidden_size // num_heads
|
| 339 |
+
batch_size = 2
|
| 340 |
+
seq_len = 10
|
| 341 |
+
|
| 342 |
+
print("\n1. 测试参数")
|
| 343 |
+
print(f" hidden_size: {hidden_size}")
|
| 344 |
+
print(f" num_heads: {num_heads}")
|
| 345 |
+
print(f" num_kv_heads: {num_kv_heads}")
|
| 346 |
+
print(f" head_dim: {head_dim}")
|
| 347 |
+
print(f" batch_size: {batch_size}")
|
| 348 |
+
print(f" seq_len: {seq_len}")
|
| 349 |
+
|
| 350 |
+
# 测试基础函数
|
| 351 |
+
print("\n2. 测试 scaled_dot_product_attention")
|
| 352 |
+
q = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 353 |
+
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 354 |
+
v = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 355 |
+
causal_mask = create_causal_mask(seq_len)
|
| 356 |
+
|
| 357 |
+
output, attn_weights = scaled_dot_product_attention(q, k, v, causal_mask)
|
| 358 |
+
print(f" 输入 Q 形状: {q.shape}")
|
| 359 |
+
print(f" 输出形状: {output.shape}")
|
| 360 |
+
print(f" 注意力权重形状: {attn_weights.shape}")
|
| 361 |
+
print(f" 注意力权重和(每行应为1): {attn_weights.sum(dim=-1)[0, 0, :5]}")
|
| 362 |
+
|
| 363 |
+
# 测试 MHA
|
| 364 |
+
print("\n3. 测试 MultiHeadAttention")
|
| 365 |
+
mha = MultiHeadAttention(hidden_size, num_heads)
|
| 366 |
+
x = torch.randn(batch_size, seq_len, hidden_size)
|
| 367 |
+
output_mha = mha(x, mask=causal_mask)
|
| 368 |
+
print(f" 输入形状: {x.shape}")
|
| 369 |
+
print(f" 输出形状: {output_mha.shape}")
|
| 370 |
+
print(f" 参数数量: {sum(p.numel() for p in mha.parameters())}")
|
| 371 |
+
|
| 372 |
+
# 测试 GQA
|
| 373 |
+
print("\n4. 测试 GroupedQueryAttention")
|
| 374 |
+
gqa = GroupedQueryAttention(hidden_size, num_heads, num_kv_heads)
|
| 375 |
+
output_gqa = gqa(x, mask=causal_mask)
|
| 376 |
+
print(f" 输入形状: {x.shape}")
|
| 377 |
+
print(f" 输出形状: {output_gqa.shape}")
|
| 378 |
+
print(f" 参数数量: {sum(p.numel() for p in gqa.parameters())}")
|
| 379 |
+
print(f" Q 投影参数: {gqa.q_proj.weight.shape}")
|
| 380 |
+
print(f" K 投影参数: {gqa.k_proj.weight.shape}")
|
| 381 |
+
print(f" V 投影参数: {gqa.v_proj.weight.shape}")
|
| 382 |
+
|
| 383 |
+
# 测试 GQA + RoPE
|
| 384 |
+
print("\n5. 测试 GroupedQueryAttention + RoPE")
|
| 385 |
+
from pathlib import Path
|
| 386 |
+
import sys
|
| 387 |
+
|
| 388 |
+
# 添加项目根目录到 Python 路径
|
| 389 |
+
project_root = Path(__file__).parent.parent.parent
|
| 390 |
+
sys.path.insert(0, str(project_root))
|
| 391 |
+
from llm.model.rope import RoPE
|
| 392 |
+
|
| 393 |
+
rope = RoPE(dim=head_dim, max_seq_len=1024, theta=10000.0)
|
| 394 |
+
output_gqa_rope = gqa(x, mask=causal_mask, rope=rope)
|
| 395 |
+
print(f" 输出形状: {output_gqa_rope.shape}")
|
| 396 |
+
print(f" 与无 RoPE 的输出不同: {not torch.allclose(output_gqa, output_gqa_rope)}")
|
| 397 |
+
|
| 398 |
+
# 测试滑动窗口掩码
|
| 399 |
+
print("\n6. 测试滑动窗口掩码")
|
| 400 |
+
window_size = 4
|
| 401 |
+
scores = torch.randn(batch_size, num_heads, seq_len, seq_len)
|
| 402 |
+
scores_windowed = apply_sliding_window_mask(
|
| 403 |
+
scores, window_size, seq_len, overlap=True
|
| 404 |
+
)
|
| 405 |
+
print(f" 窗口大小: {window_size}")
|
| 406 |
+
print(f" 原始分数形状: {scores.shape}")
|
| 407 |
+
print(f" 掩码后分数形状: {scores_windowed.shape}")
|
| 408 |
+
# 检查第一个样本第一个头的掩码
|
| 409 |
+
mask_check = scores_windowed[0, 0] != float("-inf")
|
| 410 |
+
print(f" 位置 0 可关注的位置数: {mask_check[0].sum().item()}")
|
| 411 |
+
print(f" 位置 5 可关注的位置数: {mask_check[5].sum().item()}")
|
| 412 |
+
|
| 413 |
+
# 测试 GQA + 滑动窗口
|
| 414 |
+
print("\n7. 测试 GroupedQueryAttention + 滑动窗口")
|
| 415 |
+
output_gqa_window = gqa(
|
| 416 |
+
x, mask=None, rope=rope, sliding_window=window_size, sliding_window_overlap=True
|
| 417 |
+
)
|
| 418 |
+
print(f" 输出形状: {output_gqa_window.shape}")
|
| 419 |
+
print(
|
| 420 |
+
f" 与无滑动窗口的输出不同: {not torch.allclose(output_gqa_rope, output_gqa_window)}"
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# 测试 MHA + 滑动窗口(统一后 MHA 也支持滑动窗口)
|
| 424 |
+
print("\n8. 测试 MultiHeadAttention + 滑动窗口")
|
| 425 |
+
output_mha_window = mha(
|
| 426 |
+
x, mask=None, rope=rope, sliding_window=window_size, sliding_window_overlap=True
|
| 427 |
+
)
|
| 428 |
+
print(f" 输出形状: {output_mha_window.shape}")
|
| 429 |
+
print(
|
| 430 |
+
f" 与无滑动窗口的输出不同: {not torch.allclose(output_mha, output_mha_window)}"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
print("\n" + "=" * 60)
|
| 434 |
+
print("所有测试完成!")
|
| 435 |
+
print("=" * 60)
|
llm/model/block.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer Block"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
project_root = Path(__file__).parent.parent.parent
|
| 11 |
+
sys.path.insert(0, str(project_root))
|
| 12 |
+
|
| 13 |
+
from llm.model.attention import GroupedQueryAttention
|
| 14 |
+
from llm.model.ffn import FFN
|
| 15 |
+
from llm.model.norm import RMSNorm
|
| 16 |
+
from llm.model.rope import RoPE
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TransformerBlock(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Transformer 层(Block)
|
| 22 |
+
|
| 23 |
+
结构:
|
| 24 |
+
1. 注意力层(带残差连接)
|
| 25 |
+
2. FFN 层(带残差连接)
|
| 26 |
+
|
| 27 |
+
每个子层都使用 Pre-Norm 结构:
|
| 28 |
+
- Pre-Norm: x_norm = norm(x), output = x + sublayer(x_norm)
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config):
|
| 32 |
+
"""
|
| 33 |
+
初始化 Transformer Block
|
| 34 |
+
|
| 35 |
+
参数:
|
| 36 |
+
config: 配置字典,包含模型参数
|
| 37 |
+
"""
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.config = config
|
| 40 |
+
|
| 41 |
+
hidden_size = config["hidden_size"]
|
| 42 |
+
num_heads = config["num_attention_heads"]
|
| 43 |
+
num_kv_heads = config.get("num_key_value_heads", num_heads)
|
| 44 |
+
intermediate_size = config["intermediate_size"]
|
| 45 |
+
rms_norm_eps = float(config.get("rms_norm_eps", 1e-5))
|
| 46 |
+
|
| 47 |
+
# 滑动窗口配置
|
| 48 |
+
self.sliding_window = config.get("sliding_window")
|
| 49 |
+
self.sliding_window_overlap = config.get("sliding_window_overlap", True)
|
| 50 |
+
|
| 51 |
+
# 注意力层(使用 GQA)
|
| 52 |
+
self.attn = GroupedQueryAttention(hidden_size, num_heads, num_kv_heads)
|
| 53 |
+
|
| 54 |
+
# 归一化层(Pre-Norm 结构)
|
| 55 |
+
self.attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
| 56 |
+
self.ffn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
| 57 |
+
|
| 58 |
+
# FFN 层
|
| 59 |
+
self.ffn = FFN(hidden_size, intermediate_size)
|
| 60 |
+
|
| 61 |
+
# RoPE 位置编码(如果配置中有)
|
| 62 |
+
max_position_embeddings = config.get("max_position_embeddings", 1024)
|
| 63 |
+
rope_theta = config.get("rope_theta", 10000.0)
|
| 64 |
+
head_dim = hidden_size // num_heads
|
| 65 |
+
self.rope = RoPE(head_dim, max_position_embeddings, rope_theta)
|
| 66 |
+
|
| 67 |
+
def forward(self, x, mask=None):
|
| 68 |
+
"""
|
| 69 |
+
前向传播
|
| 70 |
+
|
| 71 |
+
参数:
|
| 72 |
+
x: 输入张量,形状为 (batch_size, seq_len, hidden_size)
|
| 73 |
+
mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
|
| 74 |
+
|
| 75 |
+
返回:
|
| 76 |
+
输出张量,形状为 (batch_size, seq_len, hidden_size)
|
| 77 |
+
"""
|
| 78 |
+
# 1. 注意力层(Pre-Norm + 残差连接)
|
| 79 |
+
# Pre-Norm: 先归一化,再计算注意力
|
| 80 |
+
x_norm = self.attn_norm(x)
|
| 81 |
+
attn_output = self.attn(
|
| 82 |
+
x_norm,
|
| 83 |
+
mask=mask,
|
| 84 |
+
rope=self.rope,
|
| 85 |
+
sliding_window=self.sliding_window,
|
| 86 |
+
sliding_window_overlap=self.sliding_window_overlap,
|
| 87 |
+
)
|
| 88 |
+
# 残差连接
|
| 89 |
+
x = x + attn_output
|
| 90 |
+
|
| 91 |
+
# 2. FFN 层(Pre-Norm + 残差连接)
|
| 92 |
+
# Pre-Norm: 先归一化,再计算 FFN
|
| 93 |
+
x_norm = self.ffn_norm(x)
|
| 94 |
+
ffn_output = self.ffn(x_norm)
|
| 95 |
+
# 残差连接
|
| 96 |
+
x = x + ffn_output
|
| 97 |
+
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
print("=" * 60)
|
| 103 |
+
print("TransformerBlock 测试")
|
| 104 |
+
print("=" * 60)
|
| 105 |
+
|
| 106 |
+
# 测试参数(与 configs/model.yaml 一致)
|
| 107 |
+
config = {
|
| 108 |
+
"hidden_size": 320,
|
| 109 |
+
"num_attention_heads": 10,
|
| 110 |
+
"num_key_value_heads": 2,
|
| 111 |
+
"intermediate_size": 960,
|
| 112 |
+
"rms_norm_eps": 1e-5,
|
| 113 |
+
"max_position_embeddings": 1024,
|
| 114 |
+
"rope_theta": 10000.0,
|
| 115 |
+
"sliding_window": 256,
|
| 116 |
+
"sliding_window_overlap": True,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
batch_size = 2
|
| 120 |
+
seq_len = 64
|
| 121 |
+
hidden_size = config["hidden_size"]
|
| 122 |
+
|
| 123 |
+
print("\n1. 测试参数")
|
| 124 |
+
print(f" hidden_size: {hidden_size}")
|
| 125 |
+
print(f" num_attention_heads: {config['num_attention_heads']}")
|
| 126 |
+
print(f" num_key_value_heads: {config['num_key_value_heads']}")
|
| 127 |
+
print(f" intermediate_size: {config['intermediate_size']}")
|
| 128 |
+
print(f" batch_size: {batch_size}")
|
| 129 |
+
print(f" seq_len: {seq_len}")
|
| 130 |
+
|
| 131 |
+
# 创建 TransformerBlock
|
| 132 |
+
print("\n2. 创建 TransformerBlock")
|
| 133 |
+
block = TransformerBlock(config)
|
| 134 |
+
x = torch.randn(batch_size, seq_len, hidden_size)
|
| 135 |
+
|
| 136 |
+
print(f" 输入形状: {x.shape}")
|
| 137 |
+
output = block(x)
|
| 138 |
+
print(f" 输出形状: {output.shape}")
|
| 139 |
+
print(f" 形状匹配: {output.shape == x.shape}")
|
| 140 |
+
|
| 141 |
+
# 检查参数数量
|
| 142 |
+
total_params = sum(p.numel() for p in block.parameters())
|
| 143 |
+
print(f" 参数数量: {total_params:,}")
|
| 144 |
+
|
| 145 |
+
# 验证残差连接
|
| 146 |
+
print("\n3. 验证残差连接")
|
| 147 |
+
# 如果输入很小,输出应该接近输入(因为残差连接)
|
| 148 |
+
x_small = torch.randn(batch_size, seq_len, hidden_size) * 0.01
|
| 149 |
+
output_small = block(x_small)
|
| 150 |
+
diff = torch.abs(output_small - x_small).mean()
|
| 151 |
+
print(f" 小输入测试: 输入输出差异均值 = {diff.item():.6f}")
|
| 152 |
+
|
| 153 |
+
# 测试梯度
|
| 154 |
+
print("\n4. 测试梯度计算")
|
| 155 |
+
loss = output.sum()
|
| 156 |
+
loss.backward()
|
| 157 |
+
print(
|
| 158 |
+
f" 所有参数是否有梯度: {all(p.grad is not None for p in block.parameters())}"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
print("\n" + "=" * 60)
|
| 162 |
+
print("所有测试完成!")
|
| 163 |
+
print("=" * 60)
|
llm/model/embedding.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""嵌入层:Token Embedding 词嵌入"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TokenEmbedding(nn.Module):
|
| 8 |
+
"""词嵌入"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, vocab_size, hidden_size):
|
| 11 |
+
"""
|
| 12 |
+
初始化词嵌入层
|
| 13 |
+
|
| 14 |
+
参数:
|
| 15 |
+
vocab_size: 词汇表大小
|
| 16 |
+
hidden_size: 隐藏层维度
|
| 17 |
+
"""
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.vocab_size = vocab_size
|
| 20 |
+
self.hidden_size = hidden_size
|
| 21 |
+
|
| 22 |
+
# 词嵌入层:将 token ID 映射为向量
|
| 23 |
+
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
| 24 |
+
|
| 25 |
+
def forward(self, input_ids):
|
| 26 |
+
"""
|
| 27 |
+
前向传播
|
| 28 |
+
|
| 29 |
+
参数:
|
| 30 |
+
input_ids: Token ID 张量,形状为 (batch_size, seq_len)
|
| 31 |
+
|
| 32 |
+
返回:
|
| 33 |
+
嵌入向量,形状为 (batch_size, seq_len, hidden_size)
|
| 34 |
+
"""
|
| 35 |
+
return self.embedding(input_ids)
|
llm/model/ffn.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""前馈神经网络(FFN)"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def silu(x):
|
| 9 |
+
"""
|
| 10 |
+
SiLU 激活函数(Sigmoid Linear Unit)
|
| 11 |
+
|
| 12 |
+
公式: SiLU(x) = x * sigmoid(x)
|
| 13 |
+
|
| 14 |
+
参数:
|
| 15 |
+
x: 输入张量
|
| 16 |
+
|
| 17 |
+
返回:
|
| 18 |
+
激活后的张量
|
| 19 |
+
"""
|
| 20 |
+
return x * torch.sigmoid(x)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FFN(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
前馈神经网络(使用 SwiGLU 结构)
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, hidden_size, intermediate_size):
|
| 29 |
+
"""
|
| 30 |
+
初始化 FFN
|
| 31 |
+
|
| 32 |
+
参数:
|
| 33 |
+
hidden_size: 隐藏层维度
|
| 34 |
+
intermediate_size: 中间层维度
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.hidden_size = hidden_size
|
| 38 |
+
self.intermediate_size = intermediate_size
|
| 39 |
+
|
| 40 |
+
# FFN(x) = W2 · activation(W1 · x)
|
| 41 |
+
# FFN(x) = W_down ( activation(W_gate · x) ⊙ (W_up · x) )
|
| 42 |
+
# SwiGLU 结构需要三个投影层:
|
| 43 |
+
# gate_proj: 用于门控(gate),经过激活函数,生成“门控信号”,决定哪些特征应该被放大或抑制
|
| 44 |
+
# up_proj: 用于上投影(up),不经过激活函数,提供被门控调制的“原始特征”
|
| 45 |
+
# down_proj: 用于下投影(down),将中间层映射回 hidden_size
|
| 46 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
|
| 47 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size)
|
| 48 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size)
|
| 49 |
+
|
| 50 |
+
# 激活函数
|
| 51 |
+
self.activation = silu
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
"""
|
| 55 |
+
前向传播
|
| 56 |
+
|
| 57 |
+
使用 SwiGLU 结构:
|
| 58 |
+
- gate = activation(gate_proj(x))
|
| 59 |
+
- up = up_proj(x)
|
| 60 |
+
- output = down_proj(gate * up)
|
| 61 |
+
|
| 62 |
+
参数:
|
| 63 |
+
x: 输入张量,形状为 (batch_size, seq_len, hidden_size)
|
| 64 |
+
|
| 65 |
+
返回:
|
| 66 |
+
输出张量,形状为 (batch_size, seq_len, hidden_size)
|
| 67 |
+
"""
|
| 68 |
+
# SwiGLU 结构
|
| 69 |
+
# 1. 计算门控值(经过激活函数)
|
| 70 |
+
gate = self.activation(
|
| 71 |
+
self.gate_proj(x)
|
| 72 |
+
) # (batch_size, seq_len, intermediate_size)
|
| 73 |
+
|
| 74 |
+
# 2. 计算上投影值(不经过激活函数)
|
| 75 |
+
up = self.up_proj(x) # (batch_size, seq_len, intermediate_size)
|
| 76 |
+
|
| 77 |
+
# 3. 门控乘法:gate * up(逐元素相乘)
|
| 78 |
+
gate_up = gate * up # (batch_size, seq_len, intermediate_size)
|
| 79 |
+
|
| 80 |
+
# 4. 下投影回 hidden_size
|
| 81 |
+
output = self.down_proj(gate_up) # (batch_size, seq_len, hidden_size)
|
| 82 |
+
|
| 83 |
+
return output
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
print("=" * 60)
|
| 88 |
+
print("FFN 测试")
|
| 89 |
+
print("=" * 60)
|
| 90 |
+
|
| 91 |
+
# 测试参数(与 configs/model.yaml 一致)
|
| 92 |
+
hidden_size = 320
|
| 93 |
+
intermediate_size = 960
|
| 94 |
+
batch_size = 2
|
| 95 |
+
seq_len = 10
|
| 96 |
+
|
| 97 |
+
print("\n1. 测试参数")
|
| 98 |
+
print(f" hidden_size: {hidden_size}")
|
| 99 |
+
print(f" intermediate_size: {intermediate_size}")
|
| 100 |
+
print(f" batch_size: {batch_size}")
|
| 101 |
+
print(f" seq_len: {seq_len}")
|
| 102 |
+
|
| 103 |
+
# 测试 SiLU 激活函数
|
| 104 |
+
print("\n2. 测试 SiLU 激活函数")
|
| 105 |
+
x_test = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
|
| 106 |
+
silu_output = silu(x_test)
|
| 107 |
+
print(f" 输入: {x_test.tolist()}")
|
| 108 |
+
print(f" 输出: {silu_output.tolist()}")
|
| 109 |
+
print(f" SiLU(0) 应该接近 0: {abs(silu_output[2].item()) < 0.01}")
|
| 110 |
+
|
| 111 |
+
# 测试 FFN
|
| 112 |
+
print("\n3. 测试 FFN(SwiGLU 结构)")
|
| 113 |
+
ffn = FFN(hidden_size, intermediate_size)
|
| 114 |
+
x = torch.randn(batch_size, seq_len, hidden_size)
|
| 115 |
+
|
| 116 |
+
print(f" 输入形状: {x.shape}")
|
| 117 |
+
output = ffn(x)
|
| 118 |
+
print(f" 输出形状: {output.shape}")
|
| 119 |
+
print(f" 形状匹配: {output.shape == x.shape}")
|
| 120 |
+
|
| 121 |
+
# 检查参数数量
|
| 122 |
+
total_params = sum(p.numel() for p in ffn.parameters())
|
| 123 |
+
print(f" 参数数量: {total_params}")
|
| 124 |
+
print(f" gate_proj 参数: {ffn.gate_proj.weight.shape}")
|
| 125 |
+
print(f" up_proj 参数: {ffn.up_proj.weight.shape}")
|
| 126 |
+
print(f" down_proj 参数: {ffn.down_proj.weight.shape}")
|
| 127 |
+
|
| 128 |
+
# 验证 SwiGLU 结构
|
| 129 |
+
print("\n4. 验证 SwiGLU 结构")
|
| 130 |
+
gate = silu(ffn.gate_proj(x))
|
| 131 |
+
up = ffn.up_proj(x)
|
| 132 |
+
gate_up = gate * up
|
| 133 |
+
manual_output = ffn.down_proj(gate_up)
|
| 134 |
+
print(f" 手动计算输出形状: {manual_output.shape}")
|
| 135 |
+
print(f" 与 forward 输出一致: {torch.allclose(output, manual_output)}")
|
| 136 |
+
|
| 137 |
+
print("\n" + "=" * 60)
|
| 138 |
+
print("所有测试完成!")
|
| 139 |
+
print("=" * 60)
|
llm/model/norm.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""归一化层:RMSNorm"""
|
| 2 |
+
# 2026-01-22
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RMSNorm(nn.Module):
|
| 9 |
+
"""RMSNorm 归一化层"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, dim, eps=1e-5):
|
| 12 |
+
"""
|
| 13 |
+
初始化 RMSNorm
|
| 14 |
+
|
| 15 |
+
参数:
|
| 16 |
+
dim: 特征维度
|
| 17 |
+
eps: 数值稳定项,防止除以零
|
| 18 |
+
"""
|
| 19 |
+
super().__init__()
|
| 20 |
+
# 确保 eps 是浮点数(防止 YAML 解析为字符串)
|
| 21 |
+
self.eps = float(eps)
|
| 22 |
+
|
| 23 |
+
# 创建可学习的缩放参数,初始化为全1
|
| 24 |
+
# nn.Parameter 表示这是模型参数,会被优化器更新
|
| 25 |
+
# weight 就是 γ 参数
|
| 26 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
"""
|
| 30 |
+
前向传播
|
| 31 |
+
|
| 32 |
+
参数:
|
| 33 |
+
x: 输入张量,形状为 (batch_size, seq_len, dim) 或其他形状
|
| 34 |
+
|
| 35 |
+
返回:
|
| 36 |
+
归一化后的张量
|
| 37 |
+
"""
|
| 38 |
+
# 计算均方根(Root Mean Square)
|
| 39 |
+
# x.pow(2) 计算每个元素的平方
|
| 40 |
+
# mean(-1, keepdim=True) 在最后一个维度上求均值,保持维度
|
| 41 |
+
rms = torch.sqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps)
|
| 42 |
+
|
| 43 |
+
# 归一化:除以 RMS,然后乘以可学习的权重
|
| 44 |
+
return x / rms * self.weight
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
print("=" * 60)
|
| 49 |
+
print("RMSNorm 测试")
|
| 50 |
+
print("=" * 60)
|
| 51 |
+
|
| 52 |
+
# 1. 创建 RMSNorm 层
|
| 53 |
+
dim = 32
|
| 54 |
+
norm = RMSNorm(dim=dim, eps=1e-5)
|
| 55 |
+
print("\n1. 创建 RMSNorm 层")
|
| 56 |
+
print(f" 维度: {dim}")
|
| 57 |
+
print(f" eps: {norm.eps}")
|
| 58 |
+
print(f" 权重形状: {norm.weight.shape}")
|
| 59 |
+
print(f" 权重初始值(前5个): {norm.weight[:5]}")
|
| 60 |
+
|
| 61 |
+
# 2. 创建测试输入
|
| 62 |
+
batch_size = 2
|
| 63 |
+
seq_len = 10
|
| 64 |
+
x = torch.randn(batch_size, seq_len, dim)
|
| 65 |
+
print("\n2. 创建测试输入")
|
| 66 |
+
print(f" 输入形状: {x.shape}")
|
| 67 |
+
print(" 输入统计:")
|
| 68 |
+
print(f" - 均值: {x.mean().item():.4f}")
|
| 69 |
+
print(f" - 标准差: {x.std().item():.4f}")
|
| 70 |
+
print(f" - 最小值: {x.min().item():.4f}")
|
| 71 |
+
print(f" - 最大值: {x.max().item():.4f}")
|
| 72 |
+
|
| 73 |
+
# 3. 前向传播
|
| 74 |
+
output = norm(x)
|
| 75 |
+
print("\n3. 前向传播结果")
|
| 76 |
+
print(f" 输出形状: {output.shape}")
|
| 77 |
+
print(" 输出统计:")
|
| 78 |
+
print(f" - 均值: {output.mean().item():.4f}")
|
| 79 |
+
print(f" - 标准差: {output.std().item():.4f}")
|
| 80 |
+
|
| 81 |
+
# 4. 验证归一化效果
|
| 82 |
+
print("\n4. 验证归一化效果")
|
| 83 |
+
# 计算每个样本的 RMS(应该接近1,因为权重初始化为1)
|
| 84 |
+
rms_per_sample = torch.sqrt(torch.mean(output.pow(2), dim=-1))
|
| 85 |
+
print(" 每个样本的 RMS(归一化后):")
|
| 86 |
+
print(f" - 样本1: {rms_per_sample[0].mean().item():.4f}")
|
| 87 |
+
print(f" - 样本2: {rms_per_sample[1].mean().item():.4f}")
|
| 88 |
+
print(f" - 平均 RMS: {rms_per_sample.mean().item():.4f}")
|
| 89 |
+
|
| 90 |
+
# 5. 验证参数是否可学习
|
| 91 |
+
print("\n5. 验证参数可学习性")
|
| 92 |
+
print(f" 权重是否为 Parameter: {isinstance(norm.weight, nn.Parameter)}")
|
| 93 |
+
print(f" 权重是否需要梯度: {norm.weight.requires_grad}")
|
| 94 |
+
|
| 95 |
+
# 6. 测试梯度计算
|
| 96 |
+
print("\n6. 测试梯度计算")
|
| 97 |
+
loss = output.sum() # 简单的损失函数
|
| 98 |
+
loss.backward()
|
| 99 |
+
print(f" 权重梯度是否存在: {norm.weight.grad is not None}")
|
| 100 |
+
if norm.weight.grad is not None:
|
| 101 |
+
print(f" 权重梯度形状: {norm.weight.grad.shape}")
|
| 102 |
+
print(f" 权重梯度统计:")
|
| 103 |
+
print(f" - 均值: {norm.weight.grad.mean().item():.4f}")
|
| 104 |
+
print(f" - 标准差: {norm.weight.grad.std().item():.4f}")
|
| 105 |
+
|
| 106 |
+
# 7. 测试不同输入形状
|
| 107 |
+
print("\n7. 测试不同输入形状")
|
| 108 |
+
test_cases = [
|
| 109 |
+
(1, 5, dim), # 单个样本
|
| 110 |
+
(4, 20, dim), # 多个样本
|
| 111 |
+
(1, 1, dim), # 单个 token
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
for i, shape in enumerate(test_cases, 1):
|
| 115 |
+
x_test = torch.randn(*shape)
|
| 116 |
+
output_test = norm(x_test)
|
| 117 |
+
print(f" 测试 {i}: 输入形状 {shape} -> 输出形状 {output_test.shape} ✓")
|
| 118 |
+
|
| 119 |
+
# 8. 验证数值稳定性
|
| 120 |
+
print("\n8. 验证数值稳定性")
|
| 121 |
+
# 测试非常小的输入
|
| 122 |
+
x_small = torch.randn(1, 1, dim) * 1e-6
|
| 123 |
+
output_small = norm(x_small)
|
| 124 |
+
print(
|
| 125 |
+
f" 极小输入测试: 输入范围 [{x_small.min().item():.2e}, {x_small.max().item():.2e}]"
|
| 126 |
+
)
|
| 127 |
+
print(f" 输出是否包含 NaN: {torch.isnan(output_small).any().item()}")
|
| 128 |
+
print(f" 输出是否包含 Inf: {torch.isinf(output_small).any().item()}")
|
| 129 |
+
|
| 130 |
+
print("\n" + "=" * 60)
|
| 131 |
+
print("所有测试完成!")
|
| 132 |
+
print("=" * 60)
|
llm/model/rope.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RoPE 旋转位置编码"""
|
| 2 |
+
# 2026-01-22
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RoPE(nn.Module):
|
| 9 |
+
"""旋转位置编码(Rotary Position Embedding)"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, dim, max_seq_len=1024, theta=10000.0):
|
| 12 |
+
"""
|
| 13 |
+
初始化 RoPE
|
| 14 |
+
|
| 15 |
+
参数:
|
| 16 |
+
dim: 每个注意力头的维度,必须是偶数
|
| 17 |
+
max_seq_len: 模型能处理的最大序列长度
|
| 18 |
+
theta: 旋转频率参数
|
| 19 |
+
"""
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.dim = dim
|
| 22 |
+
self.max_seq_len = max_seq_len
|
| 23 |
+
self.theta = theta
|
| 24 |
+
|
| 25 |
+
# 确保 dim 是偶数
|
| 26 |
+
assert dim % 2 == 0, "dim 必须是偶数"
|
| 27 |
+
|
| 28 |
+
# 只依赖于 dim,是维度分量
|
| 29 |
+
# 预计算频率,每个维度对对应一个频率,2i/dim
|
| 30 |
+
# inv_freq[i] = 1 / (theta^(2i/dim))
|
| 31 |
+
# torch.arange(0, dim, 2) = [0, 2, 4, 6, ..., 30] # 16 个数
|
| 32 |
+
# 除以 dim: [0/32, 2/32, 4/32, ..., 30/32]
|
| 33 |
+
# = [0, 0.0625, 0.125, 0.1875, ..., 0.9375]
|
| 34 |
+
|
| 35 |
+
# theta 的幂次: [10000^0, 10000^0.0625, 10000^0.125, ...]
|
| 36 |
+
# = [1, 1.47, 2.15, 3.16, ...] # 逐渐增大
|
| 37 |
+
|
| 38 |
+
# 取倒数: [1/1, 1/1.47, 1/2.15, ...]
|
| 39 |
+
# = [1.0, 0.68, 0.46, 0.32, ...] # 逐渐减小
|
| 40 |
+
|
| 41 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 42 |
+
# register_buffer 注册为缓冲区,不参与梯度计算,但会随模型保存/加载,节省内存
|
| 43 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 44 |
+
|
| 45 |
+
def forward(self, q, k, positions=None):
|
| 46 |
+
"""
|
| 47 |
+
应用旋转位置编码到 Query 和 Key
|
| 48 |
+
|
| 49 |
+
参数:
|
| 50 |
+
q: Query 张量,形状为 (batch_size, num_heads, seq_len, head_dim)
|
| 51 |
+
k: Key 张量,形状为 (batch_size, num_heads, seq_len, head_dim)
|
| 52 |
+
positions: 位置索引,如果为 None 则使用序列位置
|
| 53 |
+
|
| 54 |
+
返回:
|
| 55 |
+
旋转后的 q 和 k
|
| 56 |
+
"""
|
| 57 |
+
# q[b, h, s, d]
|
| 58 |
+
# 第 b 个样本,第 h 个注意力头,序列里第 s 个 token,在该 head 下的第 d 个特征分量
|
| 59 |
+
batch_size, num_heads, seq_len, head_dim = q.shape
|
| 60 |
+
|
| 61 |
+
# 如果没有提供位置,使用序列位置 [0, 1, 2, ..., seq_len-1]
|
| 62 |
+
# 生成位置索引
|
| 63 |
+
if positions is None:
|
| 64 |
+
positions = torch.arange(seq_len, device=q.device)
|
| 65 |
+
|
| 66 |
+
# 计算角度矩阵 freqs,第 s 个位置,第 d 个维度对应的频率,s * inv_freq
|
| 67 |
+
freqs = torch.outer(positions.float(), self.inv_freq)
|
| 68 |
+
|
| 69 |
+
# 计算 cos 和 sin,旋转矩阵的参数
|
| 70 |
+
cos = torch.cos(freqs)
|
| 71 |
+
sin = torch.sin(freqs)
|
| 72 |
+
|
| 73 |
+
# 扩展维度以匹配 q 和 k 的形状
|
| 74 |
+
# (seq_len, head_dim // 2) -> (1, 1, seq_len, head_dim // 2)
|
| 75 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 76 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 77 |
+
|
| 78 |
+
# 将 q 和 k 分成两部分(实部和虚部)
|
| 79 |
+
# q: (batch_size, num_heads, seq_len, head_dim)
|
| 80 |
+
# 分成两部分: q1 和 q2,各 (batch_size, num_heads, seq_len, head_dim // 2)
|
| 81 |
+
q1, q2 = q.chunk(2, dim=-1)
|
| 82 |
+
k1, k2 = k.chunk(2, dim=-1)
|
| 83 |
+
|
| 84 |
+
# 应用旋转矩阵
|
| 85 |
+
# 旋转矩阵: [cos -sin] 作用在 [x1]
|
| 86 |
+
# [sin cos] [x2]
|
| 87 |
+
# 结果: [x1*cos - x2*sin]
|
| 88 |
+
# [x1*sin + x2*cos]
|
| 89 |
+
q_rot = torch.cat(
|
| 90 |
+
[
|
| 91 |
+
q1 * cos - q2 * sin, # 实部
|
| 92 |
+
q1 * sin + q2 * cos, # 虚部
|
| 93 |
+
],
|
| 94 |
+
dim=-1,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
k_rot = torch.cat(
|
| 98 |
+
[
|
| 99 |
+
k1 * cos - k2 * sin, # 实部
|
| 100 |
+
k1 * sin + k2 * cos, # 虚部
|
| 101 |
+
],
|
| 102 |
+
dim=-1,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return q_rot, k_rot
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
print("=" * 60)
|
| 110 |
+
print("RoPE 测试")
|
| 111 |
+
print("=" * 60)
|
| 112 |
+
|
| 113 |
+
# 创建 RoPE 层
|
| 114 |
+
head_dim = 32 # 必须是偶数
|
| 115 |
+
rope = RoPE(dim=head_dim, max_seq_len=1024, theta=10000.0)
|
| 116 |
+
|
| 117 |
+
print(f"\n1. 创建 RoPE 层")
|
| 118 |
+
print(f" 头维度: {head_dim}")
|
| 119 |
+
print(f" 最大序列长度: {rope.max_seq_len}")
|
| 120 |
+
print(f" theta: {rope.theta}")
|
| 121 |
+
print(f" 频率数量: {len(rope.inv_freq)}")
|
| 122 |
+
print(f" 前5个频率: {rope.inv_freq[:5]}")
|
| 123 |
+
|
| 124 |
+
# 创建测试输入
|
| 125 |
+
batch_size = 2
|
| 126 |
+
num_heads = 10
|
| 127 |
+
seq_len = 10
|
| 128 |
+
|
| 129 |
+
q = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 130 |
+
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 131 |
+
|
| 132 |
+
print(f"\n2. 创建测试输入")
|
| 133 |
+
print(f" Q 形状: {q.shape}")
|
| 134 |
+
print(f" K 形状: {k.shape}")
|
| 135 |
+
|
| 136 |
+
# 前向传播
|
| 137 |
+
q_rot, k_rot = rope(q, k)
|
| 138 |
+
|
| 139 |
+
print(f"\n3. 前向传播结果")
|
| 140 |
+
print(f" 旋转后 Q 形状: {q_rot.shape}")
|
| 141 |
+
print(f" 旋转后 K 形状: {k_rot.shape}")
|
| 142 |
+
print(f" 形状是否匹配: {q_rot.shape == q.shape and k_rot.shape == k.shape}")
|
| 143 |
+
|
| 144 |
+
# 验证旋转效果
|
| 145 |
+
print(f"\n4. 验证���转效果")
|
| 146 |
+
# 位置 0 和位置 1 的旋转角度应该不同
|
| 147 |
+
q_pos0 = q_rot[0, 0, 0, :] # 第一个样本,第一个头,位置0
|
| 148 |
+
q_pos1 = q_rot[0, 0, 1, :] # 第一个样本,第一个头,位置1
|
| 149 |
+
print(f" 位置 0 的 Q(前5个值): {q_pos0[:5]}")
|
| 150 |
+
print(f" 位置 1 的 Q(前5个值): {q_pos1[:5]}")
|
| 151 |
+
print(f" 位置不同,编码不同: {not torch.allclose(q_pos0, q_pos1)}")
|
| 152 |
+
|
| 153 |
+
# 测试不同位置
|
| 154 |
+
print(f"\n5. 测试不同位置")
|
| 155 |
+
positions = torch.tensor([0, 5, 10])
|
| 156 |
+
q_custom, k_custom = rope(q[:, :, :3, :], k[:, :, :3, :], positions=positions)
|
| 157 |
+
print(f" 自定义位置: {positions.tolist()}")
|
| 158 |
+
print(f" 输出形状: {q_custom.shape}")
|
| 159 |
+
|
| 160 |
+
print(f"\n" + "=" * 60)
|
| 161 |
+
print("所有测试完成!")
|
| 162 |
+
print("=" * 60)
|
llm/model/transformer.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Decoder-only Transformer 主模型"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
# 添加项目根目录到 Python 路径
|
| 11 |
+
project_root = Path(__file__).parent.parent.parent
|
| 12 |
+
sys.path.insert(0, str(project_root))
|
| 13 |
+
|
| 14 |
+
from llm.model.embedding import TokenEmbedding
|
| 15 |
+
from llm.model.block import TransformerBlock
|
| 16 |
+
from llm.model.norm import RMSNorm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Transformer(nn.Module):
|
| 20 |
+
"""完整的 Transformer 模型(Decoder-only)"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, config):
|
| 23 |
+
"""
|
| 24 |
+
初始化 Transformer 模型
|
| 25 |
+
|
| 26 |
+
参数:
|
| 27 |
+
config: 配置字典,包含模型参数
|
| 28 |
+
"""
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.config = config
|
| 31 |
+
|
| 32 |
+
vocab_size = config.get("vocab_size", 1000) # 需要从数据配置中获取
|
| 33 |
+
hidden_size = config["hidden_size"]
|
| 34 |
+
num_layers = config["num_hidden_layers"]
|
| 35 |
+
rms_norm_eps = float(config.get("rms_norm_eps", 1e-5))
|
| 36 |
+
tie_word_embeddings = config.get("tie_word_embeddings", True)
|
| 37 |
+
|
| 38 |
+
# 词嵌入层
|
| 39 |
+
self.embedding = TokenEmbedding(vocab_size, hidden_size)
|
| 40 |
+
|
| 41 |
+
# Transformer 层(多个 Block)
|
| 42 |
+
self.layers = nn.ModuleList(
|
| 43 |
+
[TransformerBlock(config) for _ in range(num_layers)]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# 最终归一化层
|
| 47 |
+
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
| 48 |
+
|
| 49 |
+
# 输出层(用于生成下一个 token 的概率)
|
| 50 |
+
# 如果 tie_word_embeddings=True,则共享输入和输出词嵌入权重
|
| 51 |
+
if tie_word_embeddings:
|
| 52 |
+
# 绑定输入和输出词嵌入(共享权重)
|
| 53 |
+
self.lm_head = None
|
| 54 |
+
# 注意:实际使用时,输出层使用 embedding.weight 的转置
|
| 55 |
+
else:
|
| 56 |
+
# 独立的输出层
|
| 57 |
+
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
| 58 |
+
|
| 59 |
+
# 权重初始化(默认启用)
|
| 60 |
+
init_std = config.get("init_std", 0.02)
|
| 61 |
+
init_weights_enabled = config.get("init_weights", True)
|
| 62 |
+
if init_weights_enabled:
|
| 63 |
+
from llm.utils.init import apply_llm_init
|
| 64 |
+
apply_llm_init(self, std=init_std, init_output_layer=True)
|
| 65 |
+
|
| 66 |
+
def forward(self, input_ids, mask=None):
|
| 67 |
+
"""
|
| 68 |
+
前向传播
|
| 69 |
+
|
| 70 |
+
参数:
|
| 71 |
+
input_ids: Token ID 张量,形状为 (batch_size, seq_len)
|
| 72 |
+
mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
|
| 73 |
+
|
| 74 |
+
返回:
|
| 75 |
+
logits: 下一个 token 的 logits,形状为 (batch_size, seq_len, vocab_size)
|
| 76 |
+
"""
|
| 77 |
+
# 1. 词嵌入
|
| 78 |
+
x = self.embedding(input_ids) # (batch_size, seq_len, hidden_size)
|
| 79 |
+
|
| 80 |
+
# 2. 通过所有 Transformer 层
|
| 81 |
+
for layer in self.layers:
|
| 82 |
+
x = layer(x, mask=mask) # (batch_size, seq_len, hidden_size)
|
| 83 |
+
|
| 84 |
+
# 3. 最终归一化
|
| 85 |
+
x = self.norm(x) # (batch_size, seq_len, hidden_size)
|
| 86 |
+
|
| 87 |
+
# 4. 输出层(生成下一个 token 的概率分布)
|
| 88 |
+
if self.lm_head is None:
|
| 89 |
+
# 使用共享的词嵌入权重(转置)
|
| 90 |
+
# embedding.weight: (vocab_size, hidden_size)
|
| 91 |
+
# 我们需要: (hidden_size, vocab_size) -> 使用转置
|
| 92 |
+
logits = F.linear(x, self.embedding.embedding.weight)
|
| 93 |
+
else:
|
| 94 |
+
# 使用独立的输出层
|
| 95 |
+
logits = self.lm_head(x) # (batch_size, seq_len, vocab_size)
|
| 96 |
+
|
| 97 |
+
return logits
|
| 98 |
+
|
| 99 |
+
def generate(
|
| 100 |
+
self, input_ids, max_length=100, temperature=1.0, top_k=None, top_p=None
|
| 101 |
+
):
|
| 102 |
+
"""
|
| 103 |
+
文本生成(自回归生成)
|
| 104 |
+
|
| 105 |
+
参数:
|
| 106 |
+
input_ids: 起始 token ID 张量,形状为 (batch_size, start_len)
|
| 107 |
+
max_length: 最大生成长度(包括输入)
|
| 108 |
+
temperature: 温度参数,控制生成的随机性
|
| 109 |
+
top_k: Top-K 采样(保留概率最大的 k 个 token)
|
| 110 |
+
top_p: Nucleus 采样(保留概率累积和达到 p 的 token)
|
| 111 |
+
|
| 112 |
+
返回:
|
| 113 |
+
generated_ids: 生成的 token ID 序列,形状为 (batch_size, generated_len)
|
| 114 |
+
"""
|
| 115 |
+
self.eval() # 设置为评估模式
|
| 116 |
+
generated_ids = input_ids.clone()
|
| 117 |
+
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
for _ in range(max_length - input_ids.shape[1]):
|
| 120 |
+
# 获取当前位置的输出 logits
|
| 121 |
+
logits = self.forward(
|
| 122 |
+
generated_ids
|
| 123 |
+
) # (batch_size, seq_len, vocab_size)
|
| 124 |
+
|
| 125 |
+
# 获取最后一个位置的 logits(下一步要预测的 token)
|
| 126 |
+
next_token_logits = (
|
| 127 |
+
logits[:, -1, :] / temperature
|
| 128 |
+
) # (batch_size, vocab_size)
|
| 129 |
+
|
| 130 |
+
# Top-K 采样
|
| 131 |
+
if top_k is not None:
|
| 132 |
+
# 只保留 top-k 个最大的 logits,其余设为 -inf
|
| 133 |
+
top_k_values, top_k_indices = torch.topk(next_token_logits, top_k)
|
| 134 |
+
next_token_logits_filtered = torch.full_like(
|
| 135 |
+
next_token_logits, float("-inf")
|
| 136 |
+
)
|
| 137 |
+
next_token_logits_filtered.scatter_(1, top_k_indices, top_k_values)
|
| 138 |
+
next_token_logits = next_token_logits_filtered
|
| 139 |
+
|
| 140 |
+
# Top-P (Nucleus) 采样
|
| 141 |
+
if top_p is not None:
|
| 142 |
+
# 按概率排序
|
| 143 |
+
sorted_logits, sorted_indices = torch.sort(
|
| 144 |
+
next_token_logits, descending=True
|
| 145 |
+
)
|
| 146 |
+
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
| 147 |
+
|
| 148 |
+
# 计算累积概率
|
| 149 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 150 |
+
|
| 151 |
+
# 找到第一个累积概率超过 top_p 的位置
|
| 152 |
+
# 移除该位置及之后的所有位置(至少保留一个 token)
|
| 153 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 154 |
+
# 通过移位确保至少保留第一个 token
|
| 155 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
| 156 |
+
..., :-1
|
| 157 |
+
].clone()
|
| 158 |
+
sorted_indices_to_remove[..., 0] = False
|
| 159 |
+
|
| 160 |
+
# 将排序后的掩码映射回原始索引
|
| 161 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 162 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 163 |
+
)
|
| 164 |
+
next_token_logits[indices_to_remove] = float("-inf")
|
| 165 |
+
|
| 166 |
+
# 计算概率并采样
|
| 167 |
+
probs = F.softmax(next_token_logits, dim=-1) # (batch_size, vocab_size)
|
| 168 |
+
next_token_id = torch.multinomial(
|
| 169 |
+
probs, num_samples=1
|
| 170 |
+
) # (batch_size, 1)
|
| 171 |
+
|
| 172 |
+
# 将新生成的 token 添加到序列中
|
| 173 |
+
generated_ids = torch.cat(
|
| 174 |
+
[generated_ids, next_token_id], dim=1
|
| 175 |
+
) # (batch_size, seq_len+1)
|
| 176 |
+
|
| 177 |
+
# 如果所有序列都生成了结束符,可以提前停止(这里简化处理,生成固定长度)
|
| 178 |
+
|
| 179 |
+
return generated_ids
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
print("=" * 60)
|
| 184 |
+
print("Transformer 模型测试")
|
| 185 |
+
print("=" * 60)
|
| 186 |
+
|
| 187 |
+
# 测试参数(与 configs/model.yaml 一致)
|
| 188 |
+
config = {
|
| 189 |
+
"vocab_size": 100, # 示例词汇表大小
|
| 190 |
+
"hidden_size": 320,
|
| 191 |
+
"num_hidden_layers": 10,
|
| 192 |
+
"num_attention_heads": 10,
|
| 193 |
+
"num_key_value_heads": 2,
|
| 194 |
+
"intermediate_size": 960,
|
| 195 |
+
"rms_norm_eps": 1e-5,
|
| 196 |
+
"max_position_embeddings": 1024,
|
| 197 |
+
"rope_theta": 10000.0,
|
| 198 |
+
"sliding_window": 256,
|
| 199 |
+
"sliding_window_overlap": True,
|
| 200 |
+
"tie_word_embeddings": True,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
batch_size = 2
|
| 204 |
+
seq_len = 10
|
| 205 |
+
vocab_size = config["vocab_size"]
|
| 206 |
+
|
| 207 |
+
print("\n1. 测试参数")
|
| 208 |
+
print(f" vocab_size: {vocab_size}")
|
| 209 |
+
print(f" hidden_size: {config['hidden_size']}")
|
| 210 |
+
print(f" num_hidden_layers: {config['num_hidden_layers']}")
|
| 211 |
+
print(f" batch_size: {batch_size}")
|
| 212 |
+
print(f" seq_len: {seq_len}")
|
| 213 |
+
|
| 214 |
+
# 创建 Transformer 模型
|
| 215 |
+
print("\n2. 创建 Transformer 模型")
|
| 216 |
+
model = Transformer(config)
|
| 217 |
+
|
| 218 |
+
# 创建测试输入
|
| 219 |
+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
|
| 220 |
+
|
| 221 |
+
print(f" 输入形状: {input_ids.shape}")
|
| 222 |
+
|
| 223 |
+
# 前向传播
|
| 224 |
+
output = model(input_ids)
|
| 225 |
+
print(f" 输出形状: {output.shape}")
|
| 226 |
+
print(f" 输出形状正确: {output.shape == (batch_size, seq_len, vocab_size)}")
|
| 227 |
+
|
| 228 |
+
# 检查参数数量
|
| 229 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 230 |
+
print(f" 参数数量: {total_params:,}")
|
| 231 |
+
|
| 232 |
+
# 验证词嵌入共享
|
| 233 |
+
print("\n3. 验证词嵌入共享")
|
| 234 |
+
if model.lm_head is None:
|
| 235 |
+
print(" 使用共享的词嵌入权重(tie_word_embeddings=True)")
|
| 236 |
+
print(f" embedding 参数形状: {model.embedding.embedding.weight.shape}")
|
| 237 |
+
else:
|
| 238 |
+
print(" 使用独立的输出层(tie_word_embeddings=False)")
|
| 239 |
+
print(f" lm_head 参数形状: {model.lm_head.weight.shape}")
|
| 240 |
+
|
| 241 |
+
# 测试梯度
|
| 242 |
+
print("\n4. 测试梯度计算")
|
| 243 |
+
# 创建简单的损失函数(交叉熵)
|
| 244 |
+
targets = torch.randint(0, vocab_size, (batch_size, seq_len))
|
| 245 |
+
loss = F.cross_entropy(output.view(-1, vocab_size), targets.view(-1))
|
| 246 |
+
loss.backward()
|
| 247 |
+
print(f" 损失值: {loss.item():.4f}")
|
| 248 |
+
print(
|
| 249 |
+
f" 所有参数是否有梯度: {all(p.grad is not None for p in model.parameters())}"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# 测试生成(简化版)
|
| 253 |
+
print("\n5. 测试文本生成(简化版)")
|
| 254 |
+
start_ids = torch.randint(0, vocab_size, (1, 5))
|
| 255 |
+
print(f" 起始序列长度: {start_ids.shape[1]}")
|
| 256 |
+
print(f" 起始 token IDs: {start_ids[0].tolist()}")
|
| 257 |
+
|
| 258 |
+
# ���加调试:查看第一次生成的 logits 分布
|
| 259 |
+
with torch.no_grad():
|
| 260 |
+
logits = model(start_ids)
|
| 261 |
+
next_logits = logits[0, -1, :]
|
| 262 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 263 |
+
top_probs, top_indices = torch.topk(probs, k=5)
|
| 264 |
+
print(" 第一次生成的前5个最可能 token:")
|
| 265 |
+
for i, (idx, prob) in enumerate(zip(top_indices, top_probs)):
|
| 266 |
+
print(f" Token {idx.item()}: {prob.item():.4f}")
|
| 267 |
+
|
| 268 |
+
# 使用 top_k 采样增加多样性(未训练模型建议使用)
|
| 269 |
+
generated = model.generate(
|
| 270 |
+
start_ids,
|
| 271 |
+
max_length=15,
|
| 272 |
+
temperature=1.5, # 提高温度增加随机性
|
| 273 |
+
top_k=10, # 只从前10个最可能的 token 中采样
|
| 274 |
+
)
|
| 275 |
+
print(f" 生成序列长度: {generated.shape[1]}")
|
| 276 |
+
print(f" 生成的 token IDs: {generated[0].tolist()}")
|
| 277 |
+
|
| 278 |
+
print("\n" + "=" * 60)
|
| 279 |
+
print("所有测试完成!")
|
| 280 |
+
print("=" * 60)
|
llm/training/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""训练模块"""
|
| 2 |
+
|
| 3 |
+
from llm.training.metrics import (
|
| 4 |
+
calculate_perplexity,
|
| 5 |
+
calculate_accuracy,
|
| 6 |
+
calculate_top_k_accuracy,
|
| 7 |
+
calculate_metrics,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"calculate_perplexity",
|
| 12 |
+
"calculate_accuracy",
|
| 13 |
+
"calculate_top_k_accuracy",
|
| 14 |
+
"calculate_metrics",
|
| 15 |
+
]
|
llm/training/loss.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""损失函数:交叉熵损失"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CrossEntropyLoss(nn.Module):
|
| 10 |
+
"""交叉熵损失"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, ignore_index=-1):
|
| 13 |
+
"""
|
| 14 |
+
初始化交叉熵损失
|
| 15 |
+
|
| 16 |
+
参数:
|
| 17 |
+
ignore_index: 要忽略的目标索引(通常用于 padding)
|
| 18 |
+
"""
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.ignore_index = ignore_index
|
| 21 |
+
|
| 22 |
+
def forward(self, logits, targets):
|
| 23 |
+
"""
|
| 24 |
+
计算交叉熵损失
|
| 25 |
+
|
| 26 |
+
参数:
|
| 27 |
+
logits: 模型输出,形状为 (batch_size, seq_len, vocab_size)
|
| 28 |
+
targets: 目标 token IDs,形状为 (batch_size, seq_len)
|
| 29 |
+
|
| 30 |
+
返回:
|
| 31 |
+
损失值(标量)
|
| 32 |
+
"""
|
| 33 |
+
# 重塑为 (batch_size * seq_len, vocab_size) 和 (batch_size * seq_len,)
|
| 34 |
+
# 这样可以将所有位置的预测和目标展平,统一计算损失
|
| 35 |
+
logits_flat = logits.view(
|
| 36 |
+
-1, logits.size(-1)
|
| 37 |
+
) # (batch_size * seq_len, vocab_size)
|
| 38 |
+
targets_flat = targets.view(-1) # (batch_size * seq_len,)
|
| 39 |
+
|
| 40 |
+
# 计算交叉熵损失
|
| 41 |
+
# F.cross_entropy 内部会先对 logits 应用 log_softmax,然后计算负对数似然
|
| 42 |
+
loss = F.cross_entropy(
|
| 43 |
+
logits_flat, targets_flat, ignore_index=self.ignore_index
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
return loss
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
print("=" * 60)
|
| 51 |
+
print("损失函数测试")
|
| 52 |
+
print("=" * 60)
|
| 53 |
+
|
| 54 |
+
# 测试交叉熵损失
|
| 55 |
+
print("\n1. 测试 CrossEntropyLoss")
|
| 56 |
+
criterion = CrossEntropyLoss()
|
| 57 |
+
|
| 58 |
+
batch_size = 2
|
| 59 |
+
seq_len = 10
|
| 60 |
+
vocab_size = 100
|
| 61 |
+
|
| 62 |
+
# 创建随机 logits 和 targets
|
| 63 |
+
# 注意:logits 需要梯度才能进行反向传播测试
|
| 64 |
+
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
|
| 65 |
+
targets = torch.randint(0, vocab_size, (batch_size, seq_len))
|
| 66 |
+
|
| 67 |
+
print(f" Logits 形状: {logits.shape}")
|
| 68 |
+
print(f" Targets 形状: {targets.shape}")
|
| 69 |
+
|
| 70 |
+
loss = criterion(logits, targets)
|
| 71 |
+
print(f" 损失值: {loss.item():.4f}")
|
| 72 |
+
|
| 73 |
+
# 验证损失是否为标量
|
| 74 |
+
print(f" 损失是否为标量: {loss.dim() == 0}")
|
| 75 |
+
|
| 76 |
+
# 测试梯度
|
| 77 |
+
loss.backward()
|
| 78 |
+
print(f" 损失可以反向传播: True")
|
| 79 |
+
print(f" Logits 梯度形状: {logits.grad.shape}")
|
| 80 |
+
|
| 81 |
+
# 测试 ignore_index
|
| 82 |
+
print("\n2. 测试 ignore_index")
|
| 83 |
+
criterion_ignore = CrossEntropyLoss(ignore_index=-1)
|
| 84 |
+
targets_with_ignore = targets.clone()
|
| 85 |
+
targets_with_ignore[0, 0] = -1 # 设置一个忽略的 token
|
| 86 |
+
loss_ignore = criterion_ignore(logits, targets_with_ignore)
|
| 87 |
+
print(f" 使用 ignore_index 的损失值: {loss_ignore.item():.4f}")
|
| 88 |
+
|
| 89 |
+
print("\n" + "=" * 60)
|
| 90 |
+
print("所有测试完成!")
|
| 91 |
+
print("=" * 60)
|
llm/training/metrics.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""评估指标:困惑度、准确率等"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def calculate_perplexity(loss):
|
| 9 |
+
"""
|
| 10 |
+
计算困惑度(Perplexity)
|
| 11 |
+
|
| 12 |
+
困惑度是语言模型常用的评估指标,表示模型对下一个 token 的不确定性。
|
| 13 |
+
困惑度越低,模型越好。
|
| 14 |
+
|
| 15 |
+
公式: PPL = exp(loss)
|
| 16 |
+
|
| 17 |
+
参数:
|
| 18 |
+
loss: 交叉熵损失值(标量或张量)
|
| 19 |
+
|
| 20 |
+
返回:
|
| 21 |
+
困惑度值
|
| 22 |
+
"""
|
| 23 |
+
return torch.exp(loss)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def calculate_accuracy(logits, targets, ignore_index=-1):
|
| 27 |
+
"""
|
| 28 |
+
计算准确率(Accuracy)
|
| 29 |
+
|
| 30 |
+
参数:
|
| 31 |
+
logits: 模型输出,形状为 (batch_size, seq_len, vocab_size)
|
| 32 |
+
targets: 目标 token IDs,形状为 (batch_size, seq_len)
|
| 33 |
+
ignore_index: 要忽略的目标索引(通常用于 padding)
|
| 34 |
+
|
| 35 |
+
返回:
|
| 36 |
+
准确率(标量,0-1 之间)
|
| 37 |
+
"""
|
| 38 |
+
# 获取预测的 token IDs
|
| 39 |
+
predictions = torch.argmax(logits, dim=-1) # (batch_size, seq_len)
|
| 40 |
+
|
| 41 |
+
# 计算匹配的数量(忽略 ignore_index)
|
| 42 |
+
if ignore_index >= 0:
|
| 43 |
+
# 创建掩码:忽略 ignore_index 的位置
|
| 44 |
+
mask = (targets != ignore_index)
|
| 45 |
+
# 计算匹配的数量(只考虑非忽略的位置)
|
| 46 |
+
matches = (predictions == targets) & mask
|
| 47 |
+
total = mask.sum()
|
| 48 |
+
else:
|
| 49 |
+
# 不忽略任何位置
|
| 50 |
+
matches = (predictions == targets)
|
| 51 |
+
total = targets.numel()
|
| 52 |
+
|
| 53 |
+
# 计算准确率
|
| 54 |
+
accuracy = matches.sum().float() / total.float() if total > 0 else 0.0
|
| 55 |
+
|
| 56 |
+
return accuracy.item() if isinstance(accuracy, torch.Tensor) else accuracy
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def calculate_top_k_accuracy(logits, targets, k=5, ignore_index=-1):
|
| 60 |
+
"""
|
| 61 |
+
计算 Top-K 准确率
|
| 62 |
+
|
| 63 |
+
参数:
|
| 64 |
+
logits: 模型输出,形状为 (batch_size, seq_len, vocab_size)
|
| 65 |
+
targets: 目标 token IDs,形状为 (batch_size, seq_len)
|
| 66 |
+
k: Top-K 值(默认: 5)
|
| 67 |
+
ignore_index: 要忽略的目标索引
|
| 68 |
+
|
| 69 |
+
返回:
|
| 70 |
+
Top-K 准确率(标量,0-1 之间)
|
| 71 |
+
"""
|
| 72 |
+
# 获取 top-k 个最可能的 token IDs
|
| 73 |
+
_, top_k_indices = torch.topk(logits, k, dim=-1) # (batch_size, seq_len, k)
|
| 74 |
+
|
| 75 |
+
# 扩展 targets 维度以匹配 top_k_indices
|
| 76 |
+
targets_expanded = targets.unsqueeze(-1).expand_as(top_k_indices) # (batch_size, seq_len, k)
|
| 77 |
+
|
| 78 |
+
# 检查目标是否在 top-k 中
|
| 79 |
+
matches = (top_k_indices == targets_expanded).any(dim=-1) # (batch_size, seq_len)
|
| 80 |
+
|
| 81 |
+
# 计算准确率(忽略 ignore_index)
|
| 82 |
+
if ignore_index >= 0:
|
| 83 |
+
mask = (targets != ignore_index)
|
| 84 |
+
matches = matches & mask
|
| 85 |
+
total = mask.sum()
|
| 86 |
+
else:
|
| 87 |
+
total = targets.numel()
|
| 88 |
+
|
| 89 |
+
top_k_accuracy = matches.sum().float() / total.float() if total > 0 else 0.0
|
| 90 |
+
|
| 91 |
+
return top_k_accuracy.item() if isinstance(top_k_accuracy, torch.Tensor) else top_k_accuracy
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def calculate_metrics(logits, targets, loss, ignore_index=-1):
|
| 95 |
+
"""
|
| 96 |
+
计算所有评估指标
|
| 97 |
+
|
| 98 |
+
参数:
|
| 99 |
+
logits: 模型输出,形状为 (batch_size, seq_len, vocab_size)
|
| 100 |
+
targets: 目标 token IDs,形状为 (batch_size, seq_len)
|
| 101 |
+
loss: 损失值(标量)
|
| 102 |
+
ignore_index: 要忽略的目标索引
|
| 103 |
+
|
| 104 |
+
返回:
|
| 105 |
+
包含所有指标的字典
|
| 106 |
+
"""
|
| 107 |
+
metrics = {}
|
| 108 |
+
|
| 109 |
+
# 困惑度
|
| 110 |
+
metrics["perplexity"] = calculate_perplexity(loss).item()
|
| 111 |
+
|
| 112 |
+
# 准确率
|
| 113 |
+
metrics["accuracy"] = calculate_accuracy(logits, targets, ignore_index=ignore_index)
|
| 114 |
+
|
| 115 |
+
# Top-5 准确率
|
| 116 |
+
metrics["top5_accuracy"] = calculate_top_k_accuracy(logits, targets, k=5, ignore_index=ignore_index)
|
| 117 |
+
|
| 118 |
+
# 损失
|
| 119 |
+
metrics["loss"] = loss.item() if isinstance(loss, torch.Tensor) else loss
|
| 120 |
+
|
| 121 |
+
return metrics
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
print("=" * 60)
|
| 126 |
+
print("评估指标测试")
|
| 127 |
+
print("=" * 60)
|
| 128 |
+
|
| 129 |
+
# 测试参数
|
| 130 |
+
batch_size = 4
|
| 131 |
+
seq_len = 10
|
| 132 |
+
vocab_size = 100
|
| 133 |
+
|
| 134 |
+
print(f"\n测试参数:")
|
| 135 |
+
print(f" batch_size: {batch_size}")
|
| 136 |
+
print(f" seq_len: {seq_len}")
|
| 137 |
+
print(f" vocab_size: {vocab_size}")
|
| 138 |
+
|
| 139 |
+
# 创建模拟数据
|
| 140 |
+
logits = torch.randn(batch_size, seq_len, vocab_size)
|
| 141 |
+
targets = torch.randint(0, vocab_size, (batch_size, seq_len))
|
| 142 |
+
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
|
| 143 |
+
|
| 144 |
+
print(f"\n1. 测试困惑度")
|
| 145 |
+
ppl = calculate_perplexity(loss)
|
| 146 |
+
print(f" 损失: {loss.item():.4f}")
|
| 147 |
+
print(f" 困惑度: {ppl.item():.4f}")
|
| 148 |
+
|
| 149 |
+
print(f"\n2. 测试准确率")
|
| 150 |
+
accuracy = calculate_accuracy(logits, targets)
|
| 151 |
+
print(f" 准确率: {accuracy:.4f} ({accuracy*100:.2f}%)")
|
| 152 |
+
|
| 153 |
+
print(f"\n3. 测试 Top-5 准确率")
|
| 154 |
+
top5_acc = calculate_top_k_accuracy(logits, targets, k=5)
|
| 155 |
+
print(f" Top-5 准确率: {top5_acc:.4f} ({top5_acc*100:.2f}%)")
|
| 156 |
+
|
| 157 |
+
print(f"\n4. 测试 ignore_index")
|
| 158 |
+
# 设置一些 ignore_index
|
| 159 |
+
targets_with_ignore = targets.clone()
|
| 160 |
+
targets_with_ignore[0, :3] = -1 # 前3个设为忽略
|
| 161 |
+
accuracy_ignore = calculate_accuracy(logits, targets_with_ignore, ignore_index=-1)
|
| 162 |
+
print(f" 使用 ignore_index 的准确率: {accuracy_ignore:.4f} ({accuracy_ignore*100:.2f}%)")
|
| 163 |
+
|
| 164 |
+
print(f"\n5. 测试综合指标")
|
| 165 |
+
all_metrics = calculate_metrics(logits, targets, loss)
|
| 166 |
+
print(" 所有指标:")
|
| 167 |
+
for key, value in all_metrics.items():
|
| 168 |
+
if isinstance(value, float):
|
| 169 |
+
print(f" {key}: {value:.4f}")
|
| 170 |
+
else:
|
| 171 |
+
print(f" {key}: {value}")
|
| 172 |
+
|
| 173 |
+
print("\n" + "=" * 60)
|
| 174 |
+
print("所有测试完成!")
|
| 175 |
+
print("=" * 60)
|
llm/training/optim.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""优化器:AdamW / 学习率调度"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_optimizer(model, config):
|
| 10 |
+
"""
|
| 11 |
+
获取优化器(AdamW)
|
| 12 |
+
|
| 13 |
+
参数:
|
| 14 |
+
model: 模型
|
| 15 |
+
config: 训练配置字典,应包含以下键:
|
| 16 |
+
- learning_rate: 学习率(默认: 1e-4)
|
| 17 |
+
- weight_decay: 权重衰减(默认: 0.01)
|
| 18 |
+
- beta1: Adam beta1(默认: 0.9)
|
| 19 |
+
- beta2: Adam beta2(默认: 0.999)
|
| 20 |
+
- eps: Adam epsilon(默认: 1e-8)
|
| 21 |
+
|
| 22 |
+
返回:
|
| 23 |
+
AdamW 优化器
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
lr = float(config.get("learning_rate", 1e-4))
|
| 27 |
+
weight_decay = float(config.get("weight_decay", 0.01))
|
| 28 |
+
beta1 = float(config.get("beta1", 0.9))
|
| 29 |
+
beta2 = float(config.get("beta2", 0.999))
|
| 30 |
+
eps = float(config.get("eps", 1e-8))
|
| 31 |
+
|
| 32 |
+
optimizer = optim.AdamW(
|
| 33 |
+
model.parameters(),
|
| 34 |
+
lr=lr,
|
| 35 |
+
weight_decay=weight_decay,
|
| 36 |
+
betas=(beta1, beta2),
|
| 37 |
+
eps=eps,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return optimizer
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_lr_scheduler(optimizer, config):
|
| 44 |
+
"""
|
| 45 |
+
获取学习率调度器(支持预热)
|
| 46 |
+
|
| 47 |
+
参数:
|
| 48 |
+
optimizer: 优化器
|
| 49 |
+
config: 训练配置字典,应包含以下键:
|
| 50 |
+
- lr_scheduler: 调度器类型,可选值:
|
| 51 |
+
- "cosine": 余弦退火(带预热)
|
| 52 |
+
- "linear": 线性衰减(带预热)
|
| 53 |
+
- "constant": 常数学习率
|
| 54 |
+
- warmup_steps: 预热步数(默认: 100)
|
| 55 |
+
- max_steps: 最大训练步数(默认: 10000)
|
| 56 |
+
|
| 57 |
+
返回:
|
| 58 |
+
学习率调度器(如果为 constant,返回 None)
|
| 59 |
+
"""
|
| 60 |
+
scheduler_type = config.get("lr_scheduler", "cosine")
|
| 61 |
+
max_steps = config.get("max_steps", 10000)
|
| 62 |
+
warmup_steps = config.get("warmup_steps", 100)
|
| 63 |
+
|
| 64 |
+
if scheduler_type == "cosine":
|
| 65 |
+
# 余弦退火调度(带预热)
|
| 66 |
+
def lr_lambda(step):
|
| 67 |
+
if step < warmup_steps:
|
| 68 |
+
# 预热阶段:线性增加从 0 到 1
|
| 69 |
+
return step / warmup_steps if warmup_steps > 0 else 1.0
|
| 70 |
+
else:
|
| 71 |
+
# 余弦退火阶段:从 1 衰减到 0
|
| 72 |
+
# 处理边界情况:如果 max_steps <= warmup_steps,直接返回最小值
|
| 73 |
+
if max_steps <= warmup_steps:
|
| 74 |
+
return 0.0
|
| 75 |
+
progress = (step - warmup_steps) / (max_steps - warmup_steps)
|
| 76 |
+
# 限制 progress 在 [0, 1] 范围内
|
| 77 |
+
progress = min(progress, 1.0)
|
| 78 |
+
# 余弦退火:0.5 * (1 + cos(π * progress))
|
| 79 |
+
return 0.5 * (1.0 + math.cos(progress * math.pi))
|
| 80 |
+
|
| 81 |
+
scheduler = LambdaLR(optimizer, lr_lambda)
|
| 82 |
+
|
| 83 |
+
elif scheduler_type == "linear":
|
| 84 |
+
# 线性衰减调度(带预热)
|
| 85 |
+
def lr_lambda(step):
|
| 86 |
+
if step < warmup_steps:
|
| 87 |
+
# 预热阶段:线性增加从 0 到 1
|
| 88 |
+
return step / warmup_steps if warmup_steps > 0 else 1.0
|
| 89 |
+
else:
|
| 90 |
+
# 线性衰减阶段:从 1 线性衰减到 0.1
|
| 91 |
+
# 处理边界情况:如果 max_steps <= warmup_steps,直接返回最小值
|
| 92 |
+
if max_steps <= warmup_steps:
|
| 93 |
+
return 0.1
|
| 94 |
+
progress = (step - warmup_steps) / (max_steps - warmup_steps)
|
| 95 |
+
# 限制 progress 在 [0, 1] 范围内
|
| 96 |
+
progress = min(progress, 1.0)
|
| 97 |
+
# 线性衰减:从 1.0 到 0.1
|
| 98 |
+
return 1.0 - 0.9 * progress
|
| 99 |
+
|
| 100 |
+
scheduler = LambdaLR(optimizer, lr_lambda)
|
| 101 |
+
|
| 102 |
+
elif scheduler_type == "constant":
|
| 103 |
+
# 常数学习率(无调度器)
|
| 104 |
+
scheduler = None
|
| 105 |
+
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"未知的学习率调度器类型: {scheduler_type}。"
|
| 109 |
+
f"支持的类型: cosine, linear, constant"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return scheduler
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
import sys
|
| 117 |
+
import io
|
| 118 |
+
|
| 119 |
+
# 设置输出编码为 UTF-8(Windows 兼容)
|
| 120 |
+
if sys.platform == "win32":
|
| 121 |
+
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
| 122 |
+
|
| 123 |
+
print("=" * 60)
|
| 124 |
+
print("优化器和学习率调度器测试")
|
| 125 |
+
print("=" * 60)
|
| 126 |
+
|
| 127 |
+
# 创建一个简单的模型用于测试
|
| 128 |
+
import torch.nn as nn
|
| 129 |
+
|
| 130 |
+
class SimpleModel(nn.Module):
|
| 131 |
+
def __init__(self):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.linear = nn.Linear(10, 1)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
return self.linear(x)
|
| 137 |
+
|
| 138 |
+
model = SimpleModel()
|
| 139 |
+
|
| 140 |
+
# 测试配置
|
| 141 |
+
config = {
|
| 142 |
+
"learning_rate": 1e-3,
|
| 143 |
+
"weight_decay": 0.01,
|
| 144 |
+
"beta1": 0.9,
|
| 145 |
+
"beta2": 0.999,
|
| 146 |
+
"eps": 1e-8,
|
| 147 |
+
"lr_scheduler": "cosine",
|
| 148 |
+
"warmup_steps": 10,
|
| 149 |
+
"max_steps": 100,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# 测试优化器
|
| 153 |
+
print("\n1. 测试优化器创���")
|
| 154 |
+
optimizer = get_optimizer(model, config)
|
| 155 |
+
print(f" 优化器类型: {type(optimizer).__name__}")
|
| 156 |
+
print(f" 学习率: {optimizer.param_groups[0]['lr']}")
|
| 157 |
+
print(f" 权重衰减: {optimizer.param_groups[0]['weight_decay']}")
|
| 158 |
+
print(f" Beta1: {optimizer.param_groups[0]['betas'][0]}")
|
| 159 |
+
print(f" Beta2: {optimizer.param_groups[0]['betas'][1]}")
|
| 160 |
+
|
| 161 |
+
# 测试学习率调度器(cosine)
|
| 162 |
+
print("\n2. 测试余弦退火学习率调度器(带预热)")
|
| 163 |
+
scheduler = get_lr_scheduler(optimizer, config)
|
| 164 |
+
print(f" 调度器类型: {type(scheduler).__name__}")
|
| 165 |
+
print(f" 初始学习率: {optimizer.param_groups[0]['lr']:.6f}")
|
| 166 |
+
|
| 167 |
+
# 模拟训练步骤,观察学习率变化
|
| 168 |
+
print("\n3. 模拟训练步骤,观察学习率变化(cosine)")
|
| 169 |
+
lrs = []
|
| 170 |
+
for step in range(0, 101, 10):
|
| 171 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 172 |
+
lrs.append(current_lr)
|
| 173 |
+
print(f" 步数 {step:3d}: 学习率 = {current_lr:.6f}")
|
| 174 |
+
# 模拟优化步骤(先 optimizer.step(),再 scheduler.step())
|
| 175 |
+
# 创建一个虚拟的损失并反向传播(仅用于测试)
|
| 176 |
+
dummy_loss = sum(p.sum() for p in model.parameters())
|
| 177 |
+
dummy_loss.backward()
|
| 178 |
+
optimizer.step()
|
| 179 |
+
optimizer.zero_grad()
|
| 180 |
+
if scheduler is not None:
|
| 181 |
+
scheduler.step()
|
| 182 |
+
|
| 183 |
+
# 测试线性调度器
|
| 184 |
+
print("\n4. 测试线性学习率调度器(带预热)")
|
| 185 |
+
config_linear = config.copy()
|
| 186 |
+
config_linear["lr_scheduler"] = "linear"
|
| 187 |
+
optimizer_linear = get_optimizer(model, config_linear)
|
| 188 |
+
scheduler_linear = get_lr_scheduler(optimizer_linear, config_linear)
|
| 189 |
+
print(f" 调度器类型: {type(scheduler_linear).__name__}")
|
| 190 |
+
|
| 191 |
+
print("\n5. 模拟训练步骤,观察学习率变化(linear)")
|
| 192 |
+
for step in range(0, 101, 10):
|
| 193 |
+
current_lr = optimizer_linear.param_groups[0]["lr"]
|
| 194 |
+
print(f" 步数 {step:3d}: 学习率 = {current_lr:.6f}")
|
| 195 |
+
# 模拟优化步骤(先 optimizer.step(),再 scheduler.step())
|
| 196 |
+
dummy_loss = sum(p.sum() for p in model.parameters())
|
| 197 |
+
dummy_loss.backward()
|
| 198 |
+
optimizer_linear.step()
|
| 199 |
+
optimizer_linear.zero_grad()
|
| 200 |
+
if scheduler_linear is not None:
|
| 201 |
+
scheduler_linear.step()
|
| 202 |
+
|
| 203 |
+
# 测试常数学习率
|
| 204 |
+
print("\n6. 测试常数学习率")
|
| 205 |
+
config_constant = config.copy()
|
| 206 |
+
config_constant["lr_scheduler"] = "constant"
|
| 207 |
+
optimizer_constant = get_optimizer(model, config_constant)
|
| 208 |
+
scheduler_constant = get_lr_scheduler(optimizer_constant, config_constant)
|
| 209 |
+
print(f" 调度器类型: {scheduler_constant}")
|
| 210 |
+
print(f" 学习率: {optimizer_constant.param_groups[0]['lr']:.6f}")
|
| 211 |
+
|
| 212 |
+
# 测试错误类型
|
| 213 |
+
print("\n7. 测试错误类型处理")
|
| 214 |
+
try:
|
| 215 |
+
config_error = config.copy()
|
| 216 |
+
config_error["lr_scheduler"] = "invalid"
|
| 217 |
+
scheduler_error = get_lr_scheduler(optimizer, config_error)
|
| 218 |
+
except ValueError as e:
|
| 219 |
+
print(f" 正确捕获错误: {e}")
|
| 220 |
+
|
| 221 |
+
print("\n" + "=" * 60)
|
| 222 |
+
print("所有测试完成!")
|
| 223 |
+
print("=" * 60)
|
llm/training/trainer.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""训练器:训练循环"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from collections import deque
|
| 7 |
+
|
| 8 |
+
from llm.training.loss import CrossEntropyLoss
|
| 9 |
+
from llm.training.optim import get_optimizer, get_lr_scheduler
|
| 10 |
+
from llm.utils.checkpoint import save_checkpoint
|
| 11 |
+
from llm.utils.logging import ProgressBar
|
| 12 |
+
from llm.model.attention import create_causal_mask
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Trainer:
|
| 16 |
+
"""训练器"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, model, train_loader, val_loader, config):
|
| 19 |
+
"""
|
| 20 |
+
初始化训练器
|
| 21 |
+
|
| 22 |
+
参数:
|
| 23 |
+
model: Transformer 模型
|
| 24 |
+
train_loader: 训练数据加载器
|
| 25 |
+
val_loader: 验证数据加载器
|
| 26 |
+
config: 训练配置字典
|
| 27 |
+
"""
|
| 28 |
+
self.model = model
|
| 29 |
+
self.train_loader = train_loader
|
| 30 |
+
self.val_loader = val_loader
|
| 31 |
+
self.config = config
|
| 32 |
+
|
| 33 |
+
# 设备
|
| 34 |
+
device_str = config.get("device", "cpu")
|
| 35 |
+
self.device = torch.device(device_str)
|
| 36 |
+
self.model.to(self.device)
|
| 37 |
+
|
| 38 |
+
# 损失函数
|
| 39 |
+
self.criterion = CrossEntropyLoss()
|
| 40 |
+
|
| 41 |
+
# 优化器
|
| 42 |
+
self.optimizer = get_optimizer(model, config)
|
| 43 |
+
|
| 44 |
+
# 学习率调度器
|
| 45 |
+
self.scheduler = get_lr_scheduler(self.optimizer, config)
|
| 46 |
+
|
| 47 |
+
# 训练状态
|
| 48 |
+
self.global_step = 0
|
| 49 |
+
self.current_epoch = 0
|
| 50 |
+
self.best_val_loss = float("inf")
|
| 51 |
+
|
| 52 |
+
# 梯度累积
|
| 53 |
+
self.gradient_accumulation_steps = config.get("gradient_accumulation_steps", 1)
|
| 54 |
+
|
| 55 |
+
# 检查点配置
|
| 56 |
+
self.save_steps = config.get("save_steps", 500)
|
| 57 |
+
self.eval_steps = config.get("eval_steps", 500)
|
| 58 |
+
self.save_total_limit = config.get("save_total_limit", 3)
|
| 59 |
+
self.checkpoint_dir = Path(config.get("checkpoint_dir", "checkpoints"))
|
| 60 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# 保存的检查点路径(用于限制数量)
|
| 63 |
+
# 注意:不使用 maxlen,手动管理队列大小以便删除文件
|
| 64 |
+
self.saved_checkpoints = deque()
|
| 65 |
+
self.best_checkpoint_path = None # 保存最佳模型的路径
|
| 66 |
+
|
| 67 |
+
def train_step(self, batch):
|
| 68 |
+
"""
|
| 69 |
+
单步训练
|
| 70 |
+
|
| 71 |
+
参数:
|
| 72 |
+
batch: 批次数据,格式为 (input_ids, target_ids)
|
| 73 |
+
- input_ids: 输入 token IDs,形状为 (batch_size, seq_len)
|
| 74 |
+
- target_ids: 目标 token IDs,形状为 (batch_size, seq_len)
|
| 75 |
+
|
| 76 |
+
返回:
|
| 77 |
+
损失值(标量)
|
| 78 |
+
"""
|
| 79 |
+
input_ids, target_ids = batch
|
| 80 |
+
input_ids = input_ids.to(self.device)
|
| 81 |
+
target_ids = target_ids.to(self.device)
|
| 82 |
+
|
| 83 |
+
# 创建因果掩码
|
| 84 |
+
seq_len = input_ids.size(1)
|
| 85 |
+
causal_mask = create_causal_mask(seq_len, device=self.device)
|
| 86 |
+
|
| 87 |
+
# 前向传播
|
| 88 |
+
logits = self.model(input_ids, mask=causal_mask)
|
| 89 |
+
|
| 90 |
+
# 计算损失
|
| 91 |
+
loss = self.criterion(logits, target_ids)
|
| 92 |
+
|
| 93 |
+
# 梯度累积:将损失除以累积步数
|
| 94 |
+
loss = loss / self.gradient_accumulation_steps
|
| 95 |
+
|
| 96 |
+
# 反向传播
|
| 97 |
+
loss.backward()
|
| 98 |
+
|
| 99 |
+
# 梯度累积:只在累积步数达到时才更新参数
|
| 100 |
+
if (self.global_step + 1) % self.gradient_accumulation_steps == 0:
|
| 101 |
+
# 梯度裁剪
|
| 102 |
+
max_grad_norm = self.config.get("max_grad_norm", 1.0)
|
| 103 |
+
torch.nn.utils.clip_grad_norm_(
|
| 104 |
+
self.model.parameters(), max_grad_norm
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# 更新参数
|
| 108 |
+
self.optimizer.step()
|
| 109 |
+
self.optimizer.zero_grad()
|
| 110 |
+
|
| 111 |
+
# 更新学习率
|
| 112 |
+
if self.scheduler is not None:
|
| 113 |
+
self.scheduler.step()
|
| 114 |
+
|
| 115 |
+
self.global_step += 1
|
| 116 |
+
|
| 117 |
+
return loss.item() * self.gradient_accumulation_steps # 返回原始损失值
|
| 118 |
+
|
| 119 |
+
def evaluate(self, max_batches=10):
|
| 120 |
+
"""
|
| 121 |
+
在验证集上评估模型
|
| 122 |
+
|
| 123 |
+
返回:
|
| 124 |
+
val_loss: 验证损失
|
| 125 |
+
"""
|
| 126 |
+
self.model.eval()
|
| 127 |
+
total_loss = 0.0
|
| 128 |
+
num_batches = 0
|
| 129 |
+
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
for batch in self.val_loader:
|
| 132 |
+
if max_batches and num_batches >= max_batches:
|
| 133 |
+
break # 限制评估批次数
|
| 134 |
+
input_ids, target_ids = batch
|
| 135 |
+
input_ids = input_ids.to(self.device)
|
| 136 |
+
target_ids = target_ids.to(self.device)
|
| 137 |
+
|
| 138 |
+
# 创建因果掩码
|
| 139 |
+
seq_len = input_ids.size(1)
|
| 140 |
+
causal_mask = create_causal_mask(seq_len, device=self.device)
|
| 141 |
+
|
| 142 |
+
# 前向传播
|
| 143 |
+
logits = self.model(input_ids, mask=causal_mask)
|
| 144 |
+
|
| 145 |
+
# 计算损失
|
| 146 |
+
loss = self.criterion(logits, target_ids)
|
| 147 |
+
total_loss += loss.item()
|
| 148 |
+
num_batches += 1
|
| 149 |
+
|
| 150 |
+
avg_loss = total_loss / num_batches if num_batches > 0 else float("inf")
|
| 151 |
+
self.model.train() # 恢复训练模式
|
| 152 |
+
|
| 153 |
+
return avg_loss
|
| 154 |
+
|
| 155 |
+
def save_checkpoint(self, is_best=False):
|
| 156 |
+
"""
|
| 157 |
+
保存检查点
|
| 158 |
+
|
| 159 |
+
参数:
|
| 160 |
+
is_best: 是否为最佳模型
|
| 161 |
+
"""
|
| 162 |
+
# 传入简单的 name,让 checkpoint.py 统一拼接 step
|
| 163 |
+
checkpoint_name = "best_model" if is_best else "checkpoint"
|
| 164 |
+
|
| 165 |
+
checkpoint_path = save_checkpoint(
|
| 166 |
+
model=self.model,
|
| 167 |
+
optimizer=self.optimizer,
|
| 168 |
+
epoch=self.current_epoch,
|
| 169 |
+
step=self.global_step,
|
| 170 |
+
loss=self.best_val_loss if is_best else None,
|
| 171 |
+
checkpoint_dir=self.checkpoint_dir,
|
| 172 |
+
name=checkpoint_name,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# 记录保存的检查点
|
| 176 |
+
if is_best:
|
| 177 |
+
# 删除旧的最佳模型(只保留最新的一个)
|
| 178 |
+
if self.best_checkpoint_path is not None and self.best_checkpoint_path.exists():
|
| 179 |
+
self.best_checkpoint_path.unlink()
|
| 180 |
+
self.best_checkpoint_path = checkpoint_path
|
| 181 |
+
else:
|
| 182 |
+
# 在添加新检查点前,如果队列已满,先删除最旧的文件
|
| 183 |
+
if len(self.saved_checkpoints) >= self.save_total_limit:
|
| 184 |
+
old_checkpoint = self.saved_checkpoints.popleft()
|
| 185 |
+
if old_checkpoint.exists():
|
| 186 |
+
old_checkpoint.unlink()
|
| 187 |
+
print(f"删除旧检查点: {old_checkpoint}")
|
| 188 |
+
|
| 189 |
+
# 添加新检查点到队列
|
| 190 |
+
self.saved_checkpoints.append(checkpoint_path)
|
| 191 |
+
|
| 192 |
+
return checkpoint_path
|
| 193 |
+
|
| 194 |
+
def train(self):
|
| 195 |
+
"""
|
| 196 |
+
训练循环
|
| 197 |
+
"""
|
| 198 |
+
num_epochs = self.config.get("num_epochs", 10)
|
| 199 |
+
max_steps = self.config.get("max_steps", None)
|
| 200 |
+
|
| 201 |
+
print("=" * 60)
|
| 202 |
+
print("开始训练")
|
| 203 |
+
print("=" * 60)
|
| 204 |
+
print(f"设备: {self.device}")
|
| 205 |
+
print(f"训练样本数: {len(self.train_loader.dataset)}")
|
| 206 |
+
print(f"验证样本数: {len(self.val_loader.dataset) if self.val_loader else 0}")
|
| 207 |
+
print(f"批次大小: {self.config.get('batch_size', 4)}")
|
| 208 |
+
print(f"梯度累积步数: {self.gradient_accumulation_steps}")
|
| 209 |
+
print(f"最大步数: {max_steps}")
|
| 210 |
+
print(f"训练轮数: {num_epochs}")
|
| 211 |
+
print("=" * 60)
|
| 212 |
+
|
| 213 |
+
for epoch in range(num_epochs):
|
| 214 |
+
self.current_epoch = epoch
|
| 215 |
+
self.model.train()
|
| 216 |
+
|
| 217 |
+
# 创建进度条
|
| 218 |
+
pbar = ProgressBar(
|
| 219 |
+
total=len(self.train_loader), desc=f"Epoch {epoch+1}/{num_epochs}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
epoch_loss = 0.0
|
| 223 |
+
num_batches = 0
|
| 224 |
+
|
| 225 |
+
for batch_idx, batch in enumerate(self.train_loader):
|
| 226 |
+
# 检查是否达到最大步数
|
| 227 |
+
if max_steps is not None and self.global_step >= max_steps:
|
| 228 |
+
print(f"\n达到最大步数 {max_steps},停止训练")
|
| 229 |
+
pbar.close()
|
| 230 |
+
return
|
| 231 |
+
|
| 232 |
+
# 训练一步
|
| 233 |
+
loss = self.train_step(batch)
|
| 234 |
+
epoch_loss += loss
|
| 235 |
+
num_batches += 1
|
| 236 |
+
|
| 237 |
+
# 更新进度条
|
| 238 |
+
pbar.update(1)
|
| 239 |
+
|
| 240 |
+
# 定期评估
|
| 241 |
+
saved_best_at_this_step = False
|
| 242 |
+
if self.global_step % self.eval_steps == 0 and self.val_loader:
|
| 243 |
+
val_loss = self.evaluate()
|
| 244 |
+
# 获取当前学习率
|
| 245 |
+
current_lr = self.optimizer.param_groups[0]["lr"]
|
| 246 |
+
|
| 247 |
+
print(
|
| 248 |
+
f"\n步数 {self.global_step}: "
|
| 249 |
+
f"训练损失={loss:.4f}, "
|
| 250 |
+
f"验证损失={val_loss:.4f}, "
|
| 251 |
+
f"学习率={current_lr:.2e}"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# 保存最佳模型
|
| 255 |
+
if val_loss < self.best_val_loss:
|
| 256 |
+
self.best_val_loss = val_loss
|
| 257 |
+
checkpoint_path = self.save_checkpoint(is_best=True)
|
| 258 |
+
print(f"保存最佳模型: {checkpoint_path}")
|
| 259 |
+
saved_best_at_this_step = True
|
| 260 |
+
|
| 261 |
+
# 定期保存检查点(如果在这个 step 已经保存了最佳模型,则跳过)
|
| 262 |
+
if self.global_step % self.save_steps == 0 and not saved_best_at_this_step:
|
| 263 |
+
checkpoint_path = self.save_checkpoint(is_best=False)
|
| 264 |
+
print(f"保存检查点: {checkpoint_path}")
|
| 265 |
+
|
| 266 |
+
pbar.close()
|
| 267 |
+
|
| 268 |
+
# 每个 epoch 结束时的评估
|
| 269 |
+
avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
|
| 270 |
+
print(f"\nEpoch {epoch+1}/{num_epochs} 完成:")
|
| 271 |
+
print(f" 平均训练损失: {avg_epoch_loss:.4f}")
|
| 272 |
+
|
| 273 |
+
if self.val_loader:
|
| 274 |
+
val_loss = self.evaluate()
|
| 275 |
+
print(f" 验证损失: {val_loss:.4f}")
|
| 276 |
+
|
| 277 |
+
print("\n" + "=" * 60)
|
| 278 |
+
print("训练完成!")
|
| 279 |
+
print("=" * 60)
|
| 280 |
+
print(f"最佳验证损失: {self.best_val_loss:.4f}")
|
| 281 |
+
|
| 282 |
+
# 显示最佳模型路径
|
| 283 |
+
if self.best_checkpoint_path is not None and self.best_checkpoint_path.exists():
|
| 284 |
+
print(f"\n最佳模型路径: {self.best_checkpoint_path}")
|
| 285 |
+
print(" (文件名包含 'best_model' 的检查点就是最佳模型)")
|
| 286 |
+
else:
|
| 287 |
+
print("\n警告: 未找到最佳模型检查点")
|
| 288 |
+
|
| 289 |
+
# 显示所有保存的检查点
|
| 290 |
+
if self.saved_checkpoints:
|
| 291 |
+
print(f"\n保存的检查点数量: {len(self.saved_checkpoints)}")
|
| 292 |
+
print("检查点列表:")
|
| 293 |
+
for i, cp_path in enumerate(self.saved_checkpoints, 1):
|
| 294 |
+
print(f" {i}. {cp_path}")
|
llm/utils/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""工具模块"""
|
| 2 |
+
|
| 3 |
+
from llm.utils.init import init_weights, init_weights_with_scaling, apply_llm_init
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"init_weights",
|
| 7 |
+
"init_weights_with_scaling",
|
| 8 |
+
"apply_llm_init",
|
| 9 |
+
]
|
llm/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (358 Bytes). View file
|
|
|
llm/utils/__pycache__/checkpoint.cpython-312.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
llm/utils/__pycache__/init.cpython-312.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
llm/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""检查点保存和加载"""
|
| 2 |
+
# 2026-01-23
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def save_checkpoint(model, optimizer, epoch, step, loss, checkpoint_dir, name='checkpoint'):
|
| 9 |
+
"""保存检查点"""
|
| 10 |
+
checkpoint_dir = Path(checkpoint_dir)
|
| 11 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
# 统一拼接 step,避免重复
|
| 14 |
+
checkpoint_path = checkpoint_dir / f"{name}_step_{step}.pt"
|
| 15 |
+
|
| 16 |
+
torch.save({
|
| 17 |
+
'epoch': epoch,
|
| 18 |
+
'step': step,
|
| 19 |
+
'model_state_dict': model.state_dict(),
|
| 20 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 21 |
+
'loss': loss,
|
| 22 |
+
}, checkpoint_path)
|
| 23 |
+
|
| 24 |
+
return checkpoint_path
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_checkpoint(model, optimizer, checkpoint_path):
|
| 28 |
+
"""加载检查点(用于恢复训练)"""
|
| 29 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 30 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 31 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 32 |
+
return checkpoint['epoch'], checkpoint['step'], checkpoint.get('loss', None)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_model_only(model, checkpoint_path):
|
| 36 |
+
"""只加载模型权重(用于推理,不需要优化器)"""
|
| 37 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 38 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 39 |
+
return checkpoint.get('epoch', 0), checkpoint.get('step', 0), checkpoint.get('loss', None)
|
llm/utils/config.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""配置加载:读取 YAML"""
|
| 2 |
+
|
| 3 |
+
import yaml
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_config(config_path):
|
| 8 |
+
"""加载配置文件"""
|
| 9 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 10 |
+
config = yaml.safe_load(f)
|
| 11 |
+
return config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_all_configs(config_dir='configs'):
|
| 15 |
+
"""加载所有配置文件"""
|
| 16 |
+
config_dir = Path(config_dir)
|
| 17 |
+
configs = {}
|
| 18 |
+
|
| 19 |
+
for config_file in ['model.yaml', 'train.yaml', 'data.yaml']:
|
| 20 |
+
config_path = config_dir / config_file
|
| 21 |
+
if config_path.exists():
|
| 22 |
+
name = config_file.replace('.yaml', '')
|
| 23 |
+
configs[name] = load_config(config_path)
|
| 24 |
+
|
| 25 |
+
return configs
|
llm/utils/init.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""权重初始化:LLM 模型权重初始化策略"""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def init_weights(module, std=0.02):
|
| 10 |
+
"""
|
| 11 |
+
初始化模型权重(适用于 LLM,参考 GPT/LLaMA)
|
| 12 |
+
|
| 13 |
+
参数:
|
| 14 |
+
module: PyTorch 模块
|
| 15 |
+
std: 正态分布的标准差(默认: 0.02)
|
| 16 |
+
|
| 17 |
+
初始化策略:
|
| 18 |
+
- nn.Embedding: 正态分布 N(0, std)
|
| 19 |
+
- nn.Linear: 正态分布 N(0, std),偏置初始化为 0
|
| 20 |
+
- RMSNorm: 可学习参数(scale)初始化为 1.0(RMSNorm 会自动处理)
|
| 21 |
+
"""
|
| 22 |
+
if isinstance(module, nn.Embedding):
|
| 23 |
+
# 词嵌入层:正态分布初始化
|
| 24 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 25 |
+
elif isinstance(module, nn.Linear):
|
| 26 |
+
# 线性层:权重正态分布初始化,偏置初始化为 0
|
| 27 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 28 |
+
if module.bias is not None:
|
| 29 |
+
nn.init.zeros_(module.bias)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def init_weights_with_scaling(module, hidden_size=None, std=0.02):
|
| 33 |
+
"""
|
| 34 |
+
初始化模型权重(带缩放,适用于输出层)
|
| 35 |
+
|
| 36 |
+
参数:
|
| 37 |
+
module: PyTorch 模块
|
| 38 |
+
hidden_size: 隐藏层维度(用于输出层缩放)
|
| 39 |
+
std: 基础标准差(默认: 0.02)
|
| 40 |
+
|
| 41 |
+
初始化策略:
|
| 42 |
+
- nn.Embedding: 正态分布 N(0, std)
|
| 43 |
+
- nn.Linear:
|
| 44 |
+
- 输出层(如果 hidden_size 提供): N(0, std / sqrt(hidden_size))
|
| 45 |
+
- 其他层: N(0, std)
|
| 46 |
+
- 偏置: 初始化为 0
|
| 47 |
+
"""
|
| 48 |
+
if isinstance(module, nn.Embedding):
|
| 49 |
+
# 词嵌入层:正态分布初始化
|
| 50 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 51 |
+
elif isinstance(module, nn.Linear):
|
| 52 |
+
# 线性层
|
| 53 |
+
if hidden_size is not None:
|
| 54 |
+
# 输出层:使用缩放的标准差
|
| 55 |
+
output_std = std / math.sqrt(hidden_size)
|
| 56 |
+
nn.init.normal_(module.weight, mean=0.0, std=output_std)
|
| 57 |
+
else:
|
| 58 |
+
# 普通线性层:标准正态分布初始化
|
| 59 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 60 |
+
|
| 61 |
+
# 偏置初始化为 0
|
| 62 |
+
if module.bias is not None:
|
| 63 |
+
nn.init.zeros_(module.bias)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def apply_llm_init(model, std=0.02, init_output_layer=True):
|
| 67 |
+
"""
|
| 68 |
+
对整个模型应用 LLM 权重初始化
|
| 69 |
+
|
| 70 |
+
参数:
|
| 71 |
+
model: Transformer 模型
|
| 72 |
+
std: 正态分布的标准差(默认: 0.02)
|
| 73 |
+
init_output_layer: 是否对输出层使用特殊初始化(默认: True)
|
| 74 |
+
|
| 75 |
+
返回:
|
| 76 |
+
初始化后的模型
|
| 77 |
+
"""
|
| 78 |
+
# 获取 hidden_size(用于输出层初始化)
|
| 79 |
+
hidden_size = None
|
| 80 |
+
if hasattr(model, "config"):
|
| 81 |
+
hidden_size = model.config.get("hidden_size")
|
| 82 |
+
elif hasattr(model, "hidden_size"):
|
| 83 |
+
hidden_size = model.hidden_size
|
| 84 |
+
|
| 85 |
+
# 遍历所有模块并初始化
|
| 86 |
+
for module in model.modules():
|
| 87 |
+
if isinstance(module, (nn.Embedding, nn.Linear)):
|
| 88 |
+
if init_output_layer and isinstance(module, nn.Linear):
|
| 89 |
+
# 检查是否是输出层(lm_head)
|
| 90 |
+
# 如果 tie_word_embeddings=True,lm_head 可能为 None,使用 embedding 的权重
|
| 91 |
+
if hasattr(model, "lm_head") and module is model.lm_head:
|
| 92 |
+
# 输出层:使用缩放初始化
|
| 93 |
+
init_weights_with_scaling(module, hidden_size=hidden_size, std=std)
|
| 94 |
+
else:
|
| 95 |
+
# 普通线性层
|
| 96 |
+
init_weights(module, std=std)
|
| 97 |
+
else:
|
| 98 |
+
# 标准初始化
|
| 99 |
+
init_weights(module, std=std)
|
| 100 |
+
|
| 101 |
+
return model
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
import sys
|
| 106 |
+
import io
|
| 107 |
+
|
| 108 |
+
# 设置输出编码为 UTF-8(Windows 兼容)
|
| 109 |
+
if sys.platform == "win32":
|
| 110 |
+
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
| 111 |
+
|
| 112 |
+
print("=" * 60)
|
| 113 |
+
print("权重初始化测试")
|
| 114 |
+
print("=" * 60)
|
| 115 |
+
|
| 116 |
+
# 测试参数
|
| 117 |
+
vocab_size = 100
|
| 118 |
+
hidden_size = 320
|
| 119 |
+
intermediate_size = 960
|
| 120 |
+
std = 0.02
|
| 121 |
+
|
| 122 |
+
print(f"\n测试参数:")
|
| 123 |
+
print(f" vocab_size: {vocab_size}")
|
| 124 |
+
print(f" hidden_size: {hidden_size}")
|
| 125 |
+
print(f" intermediate_size: {intermediate_size}")
|
| 126 |
+
print(f" std: {std}")
|
| 127 |
+
|
| 128 |
+
# 测试 Embedding 初始化
|
| 129 |
+
print("\n1. 测试 Embedding 初始化")
|
| 130 |
+
embedding = nn.Embedding(vocab_size, hidden_size)
|
| 131 |
+
init_weights(embedding, std=std)
|
| 132 |
+
weight_mean = embedding.weight.mean().item()
|
| 133 |
+
weight_std = embedding.weight.std().item()
|
| 134 |
+
print(f" Embedding 权重均值: {weight_mean:.6f} (应该接近 0)")
|
| 135 |
+
print(f" Embedding 权重标准差: {weight_std:.6f} (应该接近 {std})")
|
| 136 |
+
|
| 137 |
+
# 测试 Linear 初始化
|
| 138 |
+
print("\n2. 测试 Linear 初始化")
|
| 139 |
+
linear = nn.Linear(hidden_size, intermediate_size)
|
| 140 |
+
init_weights(linear, std=std)
|
| 141 |
+
weight_mean = linear.weight.mean().item()
|
| 142 |
+
weight_std = linear.weight.std().item()
|
| 143 |
+
bias_mean = linear.bias.mean().item() if linear.bias is not None else 0.0
|
| 144 |
+
print(f" Linear 权重均值: {weight_mean:.6f} (应该接近 0)")
|
| 145 |
+
print(f" Linear 权重标准差: {weight_std:.6f} (应该接近 {std})")
|
| 146 |
+
print(f" Linear 偏置均值: {bias_mean:.6f} (应该为 0)")
|
| 147 |
+
|
| 148 |
+
# 测试输出层初始化(带缩放)
|
| 149 |
+
print("\n3. 测试输出层初始化(带缩放)")
|
| 150 |
+
output_layer = nn.Linear(hidden_size, vocab_size, bias=False)
|
| 151 |
+
init_weights_with_scaling(output_layer, hidden_size=hidden_size, std=std)
|
| 152 |
+
weight_mean = output_layer.weight.mean().item()
|
| 153 |
+
weight_std = output_layer.weight.std().item()
|
| 154 |
+
expected_std = std / math.sqrt(hidden_size)
|
| 155 |
+
print(f" 输出层权重均值: {weight_mean:.6f} (应该接近 0)")
|
| 156 |
+
print(f" 输出层权重标准差: {weight_std:.6f}")
|
| 157 |
+
print(f" 期望标准差: {expected_std:.6f}")
|
| 158 |
+
|
| 159 |
+
# 测试完整模型初始化
|
| 160 |
+
print("\n4. 测试完整模型初始化")
|
| 161 |
+
from llm.model.transformer import Transformer
|
| 162 |
+
|
| 163 |
+
config = {
|
| 164 |
+
"vocab_size": vocab_size,
|
| 165 |
+
"hidden_size": hidden_size,
|
| 166 |
+
"num_hidden_layers": 2, # 使用较小的层数用于测试
|
| 167 |
+
"num_attention_heads": 10,
|
| 168 |
+
"num_key_value_heads": 2,
|
| 169 |
+
"intermediate_size": intermediate_size,
|
| 170 |
+
"rms_norm_eps": 1e-5,
|
| 171 |
+
"max_position_embeddings": 1024,
|
| 172 |
+
"rope_theta": 10000.0,
|
| 173 |
+
"sliding_window": 256,
|
| 174 |
+
"sliding_window_overlap": True,
|
| 175 |
+
"tie_word_embeddings": True,
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
model = Transformer(config)
|
| 179 |
+
|
| 180 |
+
# 记录初始化前的权重统计
|
| 181 |
+
print(" 初始化前的权重统计:")
|
| 182 |
+
for name, param in model.named_parameters():
|
| 183 |
+
if "embedding" in name or "weight" in name:
|
| 184 |
+
print(f" {name}: mean={param.data.mean().item():.6f}, std={param.data.std().item():.6f}")
|
| 185 |
+
|
| 186 |
+
# 应用初始化
|
| 187 |
+
apply_llm_init(model, std=std, init_output_layer=True)
|
| 188 |
+
|
| 189 |
+
# 记录初始化后的权重统计
|
| 190 |
+
print("\n 初始化后的权重统计:")
|
| 191 |
+
for name, param in model.named_parameters():
|
| 192 |
+
if "embedding" in name or ("weight" in name and "norm" not in name):
|
| 193 |
+
print(f" {name}: mean={param.data.mean().item():.6f}, std={param.data.std().item():.6f}")
|
| 194 |
+
|
| 195 |
+
# 验证初始化效果
|
| 196 |
+
print("\n5. 验证初始化效果")
|
| 197 |
+
embedding_weight = model.embedding.embedding.weight
|
| 198 |
+
print(f" Embedding 权重均值: {embedding_weight.mean().item():.6f}")
|
| 199 |
+
print(f" Embedding 权重标准差: {embedding_weight.std().item():.6f}")
|
| 200 |
+
|
| 201 |
+
# 检查第一个 Transformer Block 的线性层
|
| 202 |
+
first_block = model.layers[0]
|
| 203 |
+
attn_q_proj = first_block.attn.q_proj
|
| 204 |
+
print(f" Attention Q 投影权重均值: {attn_q_proj.weight.mean().item():.6f}")
|
| 205 |
+
print(f" Attention Q 投影权重标准差: {attn_q_proj.weight.std().item():.6f}")
|
| 206 |
+
|
| 207 |
+
ffn_gate_proj = first_block.ffn.gate_proj
|
| 208 |
+
print(f" FFN Gate 投影权重均值: {ffn_gate_proj.weight.mean().item():.6f}")
|
| 209 |
+
print(f" FFN Gate 投影权重标准差: {ffn_gate_proj.weight.std().item():.6f}")
|
| 210 |
+
|
| 211 |
+
print("\n" + "=" * 60)
|
| 212 |
+
print("所有测试完成!")
|
| 213 |
+
print("=" * 60)
|
llm/utils/logging.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""日志:tqdm / tensorboard"""
|
| 2 |
+
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ProgressBar:
|
| 7 |
+
"""进度条"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, total, desc=""):
|
| 10 |
+
self.pbar = tqdm(total=total, desc=desc)
|
| 11 |
+
|
| 12 |
+
def update(self, n=1):
|
| 13 |
+
"""更新进度"""
|
| 14 |
+
self.pbar.update(n)
|
| 15 |
+
|
| 16 |
+
def close(self):
|
| 17 |
+
"""关闭进度条"""
|
| 18 |
+
self.pbar.close()
|
llm/utils/seed.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""随机种子设置"""
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_seed(seed=42):
|
| 9 |
+
"""设置随机种子"""
|
| 10 |
+
random.seed(seed)
|
| 11 |
+
np.random.seed(seed)
|
| 12 |
+
torch.manual_seed(seed)
|
| 13 |
+
if torch.cuda.is_available():
|
| 14 |
+
torch.cuda.manual_seed_all(seed)
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
pyyaml>=6.0
|
| 3 |
+
numpy>=2.0.0
|
| 4 |
+
huggingface_hub>=0.30.0
|