LiManshu commited on
Commit
bf6be45
·
verified ·
1 Parent(s): b7488e1

Add files using upload-large-folder tool

Browse files
Files changed (49) hide show
  1. .gitattributes +3 -35
  2. README.md +40 -0
  3. checkpoints/best_model.pt +3 -0
  4. configs/model.yaml +43 -0
  5. data/vocab/char_vocab.json +139 -0
  6. inference.py +67 -0
  7. llm/__init__.py +5 -0
  8. llm/__pycache__/__init__.cpython-312.pyc +0 -0
  9. llm/data/__init__.py +1 -0
  10. llm/data/__pycache__/__init__.cpython-312.pyc +0 -0
  11. llm/data/__pycache__/tokenizer.cpython-312.pyc +0 -0
  12. llm/data/collate.py +167 -0
  13. llm/data/dataset.py +164 -0
  14. llm/data/tokenizer.py +126 -0
  15. llm/inference/__init__.py +5 -0
  16. llm/inference/__pycache__/__init__.cpython-312.pyc +0 -0
  17. llm/inference/__pycache__/generate.cpython-312.pyc +0 -0
  18. llm/inference/generate.py +179 -0
  19. llm/model/__init__.py +1 -0
  20. llm/model/__pycache__/__init__.cpython-312.pyc +0 -0
  21. llm/model/__pycache__/attention.cpython-312.pyc +0 -0
  22. llm/model/__pycache__/block.cpython-312.pyc +0 -0
  23. llm/model/__pycache__/embedding.cpython-312.pyc +0 -0
  24. llm/model/__pycache__/ffn.cpython-312.pyc +0 -0
  25. llm/model/__pycache__/norm.cpython-312.pyc +0 -0
  26. llm/model/__pycache__/rope.cpython-312.pyc +0 -0
  27. llm/model/__pycache__/transformer.cpython-312.pyc +0 -0
  28. llm/model/attention.py +435 -0
  29. llm/model/block.py +163 -0
  30. llm/model/embedding.py +35 -0
  31. llm/model/ffn.py +139 -0
  32. llm/model/norm.py +132 -0
  33. llm/model/rope.py +162 -0
  34. llm/model/transformer.py +280 -0
  35. llm/training/__init__.py +15 -0
  36. llm/training/loss.py +91 -0
  37. llm/training/metrics.py +175 -0
  38. llm/training/optim.py +223 -0
  39. llm/training/trainer.py +294 -0
  40. llm/utils/__init__.py +9 -0
  41. llm/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  42. llm/utils/__pycache__/checkpoint.cpython-312.pyc +0 -0
  43. llm/utils/__pycache__/init.cpython-312.pyc +0 -0
  44. llm/utils/checkpoint.py +39 -0
  45. llm/utils/config.py +25 -0
  46. llm/utils/init.py +213 -0
  47. llm/utils/logging.py +18 -0
  48. llm/utils/seed.py +14 -0
  49. requirements.txt +4 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: pytorch
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - transformer
9
+ - character-level
10
+ - custom-code
11
+ ---
12
+
13
+ # nextShakespeare
14
+
15
+ `nextShakespeare` is a decoder-only, character-level Transformer language model
16
+ trained on Tiny Shakespeare style text. This repo uses custom PyTorch code
17
+ rather than `transformers` native model classes.
18
+
19
+ ## Model assets
20
+
21
+ - Weights: `checkpoints/best_model.pt`
22
+ - Config: `configs/model.yaml`
23
+ - Vocabulary: `data/vocab/char_vocab.json`
24
+ - Runtime code: `llm/`
25
+
26
+ ## Quickstart
27
+
28
+ ```bash
29
+ git clone https://huggingface.co/<your-username>/nextShakespeare
30
+ cd nextShakespeare
31
+ pip install -r requirements.txt
32
+ python inference.py --prompt "First Citizen:\n" --max_length 200 --temperature 0.8
33
+ ```
34
+
35
+ ## Notes
36
+
37
+ - This is a custom-code checkpoint. It is not directly loadable via
38
+ `AutoModel.from_pretrained()` yet.
39
+ - For a web demo, see the companion Space:
40
+ `https://huggingface.co/spaces/<your-username>/manshu-init`
checkpoints/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bbbcb6c73219ad04334d069ca56b0362282ec43b14772e95ae77539b5589357
3
+ size 1248083791
configs/model.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # Model Architecture
3
+ # ============================================================
4
+
5
+ num_hidden_layers: 12 # Transformer 层数
6
+ hidden_size: 768 # 隐藏层维度
7
+ num_attention_heads: 12 # 注意力头数
8
+ num_key_value_heads: 4 # Key/Value 头数(GQA)
9
+ intermediate_size: 3072 # FFN 中间层维度
10
+
11
+ # ----------------------------
12
+ # Context / Position Encoding
13
+ # ----------------------------
14
+
15
+ max_position_embeddings: 2048 # 最大上下文长度
16
+ rope_theta: 10000 # RoPE 位置编码参数
17
+
18
+ # ----------------------------
19
+ # Attention Optimization
20
+ # ----------------------------
21
+
22
+ sliding_window: 1024 # 滑动窗口注意力大小
23
+ sliding_window_overlap: true # 是否允许窗口重叠
24
+ # 注意:当前所有层都使用滑动窗口
25
+
26
+ # ----------------------------
27
+ # Normalization
28
+ # ----------------------------
29
+
30
+ rms_norm_eps: 1e-5 # RMSNorm 数值稳定项
31
+
32
+ # ----------------------------
33
+ # Embedding
34
+ # ----------------------------
35
+
36
+ tie_word_embeddings: true # 是否绑定输入输出词嵌入
37
+
38
+ # ----------------------------
39
+ # Initialization
40
+ # ----------------------------
41
+
42
+ init_weights: true # 是否启用权重初始化
43
+ init_std: 0.02 # 权重初始化标准差
data/vocab/char_vocab.json ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "char_to_id": {
3
+ "<unk>": 0,
4
+ " ": 1,
5
+ "e": 2,
6
+ "t": 3,
7
+ "o": 4,
8
+ "a": 5,
9
+ "h": 6,
10
+ "s": 7,
11
+ "r": 8,
12
+ "n": 9,
13
+ "i": 10,
14
+ "\n": 11,
15
+ "l": 12,
16
+ "d": 13,
17
+ "u": 14,
18
+ "m": 15,
19
+ "y": 16,
20
+ ",": 17,
21
+ "w": 18,
22
+ "f": 19,
23
+ "c": 20,
24
+ "g": 21,
25
+ "I": 22,
26
+ "b": 23,
27
+ "p": 24,
28
+ ":": 25,
29
+ ".": 26,
30
+ "A": 27,
31
+ "v": 28,
32
+ "k": 29,
33
+ "T": 30,
34
+ "'": 31,
35
+ "E": 32,
36
+ "O": 33,
37
+ "N": 34,
38
+ "R": 35,
39
+ "S": 36,
40
+ "L": 37,
41
+ "C": 38,
42
+ ";": 39,
43
+ "W": 40,
44
+ "U": 41,
45
+ "H": 42,
46
+ "M": 43,
47
+ "B": 44,
48
+ "?": 45,
49
+ "G": 46,
50
+ "!": 47,
51
+ "D": 48,
52
+ "-": 49,
53
+ "F": 50,
54
+ "Y": 51,
55
+ "P": 52,
56
+ "K": 53,
57
+ "V": 54,
58
+ "j": 55,
59
+ "q": 56,
60
+ "x": 57,
61
+ "z": 58,
62
+ "J": 59,
63
+ "Q": 60,
64
+ "Z": 61,
65
+ "X": 62,
66
+ "3": 63,
67
+ "&": 64,
68
+ "$": 65
69
+ },
70
+ "id_to_char": {
71
+ "0": "<unk>",
72
+ "1": " ",
73
+ "2": "e",
74
+ "3": "t",
75
+ "4": "o",
76
+ "5": "a",
77
+ "6": "h",
78
+ "7": "s",
79
+ "8": "r",
80
+ "9": "n",
81
+ "10": "i",
82
+ "11": "\n",
83
+ "12": "l",
84
+ "13": "d",
85
+ "14": "u",
86
+ "15": "m",
87
+ "16": "y",
88
+ "17": ",",
89
+ "18": "w",
90
+ "19": "f",
91
+ "20": "c",
92
+ "21": "g",
93
+ "22": "I",
94
+ "23": "b",
95
+ "24": "p",
96
+ "25": ":",
97
+ "26": ".",
98
+ "27": "A",
99
+ "28": "v",
100
+ "29": "k",
101
+ "30": "T",
102
+ "31": "'",
103
+ "32": "E",
104
+ "33": "O",
105
+ "34": "N",
106
+ "35": "R",
107
+ "36": "S",
108
+ "37": "L",
109
+ "38": "C",
110
+ "39": ";",
111
+ "40": "W",
112
+ "41": "U",
113
+ "42": "H",
114
+ "43": "M",
115
+ "44": "B",
116
+ "45": "?",
117
+ "46": "G",
118
+ "47": "!",
119
+ "48": "D",
120
+ "49": "-",
121
+ "50": "F",
122
+ "51": "Y",
123
+ "52": "P",
124
+ "53": "K",
125
+ "54": "V",
126
+ "55": "j",
127
+ "56": "q",
128
+ "57": "x",
129
+ "58": "z",
130
+ "59": "J",
131
+ "60": "Q",
132
+ "61": "Z",
133
+ "62": "X",
134
+ "63": "3",
135
+ "64": "&",
136
+ "65": "$"
137
+ },
138
+ "vocab_size": 66
139
+ }
inference.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import yaml
6
+
7
+ from llm.data.tokenizer import CharTokenizer
8
+ from llm.inference.generate import greedy_decode, sample_decode
9
+ from llm.model.transformer import Transformer
10
+ from llm.utils.checkpoint import load_model_only
11
+
12
+
13
+ def load_yaml(path: Path):
14
+ with open(path, "r", encoding="utf-8") as f:
15
+ return yaml.safe_load(f)
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--prompt", type=str, default="First Citizen:\\n")
21
+ parser.add_argument("--max_length", type=int, default=200)
22
+ parser.add_argument("--temperature", type=float, default=0.8)
23
+ parser.add_argument("--top_k", type=int, default=50)
24
+ parser.add_argument("--top_p", type=float, default=0.9)
25
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt")
26
+ parser.add_argument("--config", type=str, default="configs/model.yaml")
27
+ parser.add_argument("--vocab", type=str, default="data/vocab/char_vocab.json")
28
+ args = parser.parse_args()
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ model_config = load_yaml(Path(args.config))
33
+ tokenizer = CharTokenizer(vocab_path=args.vocab)
34
+ model_config["vocab_size"] = tokenizer.vocab_size
35
+
36
+ model = Transformer(model_config)
37
+ load_model_only(model, args.checkpoint)
38
+ model.to(device)
39
+ model.eval()
40
+
41
+ input_ids = tokenizer.encode(args.prompt)
42
+ if not input_ids:
43
+ input_ids = [0]
44
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
45
+
46
+ with torch.no_grad():
47
+ if args.temperature == 0:
48
+ generated_ids = greedy_decode(
49
+ model, input_ids, max_length=args.max_length, device=device
50
+ )
51
+ else:
52
+ generated_ids = sample_decode(
53
+ model,
54
+ input_ids,
55
+ max_length=args.max_length,
56
+ temperature=args.temperature,
57
+ top_k=args.top_k if args.top_k > 0 else None,
58
+ top_p=args.top_p if args.top_p > 0 else None,
59
+ device=device,
60
+ )
61
+
62
+ text = tokenizer.decode(generated_ids[0])
63
+ print(text)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
llm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ LLM from Manshu - 从零实现的大语言模型
3
+ """
4
+
5
+ __version__ = "0.1.0"
llm/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (278 Bytes). View file
 
llm/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """数据处理模块"""
llm/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (218 Bytes). View file
 
llm/data/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (5.23 kB). View file
 
llm/data/collate.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 数据整理函数:padding / batch
3
+ 将多个样本组合成批次
4
+ 处理不同长度的序列(padding)
5
+ 转换为模型需要的张量格式
6
+ """
7
+
8
+
9
+
10
+ # 2026-01-23
11
+
12
+ import torch
13
+
14
+
15
+ def collate_fn(batch, pad_token_id=0):
16
+ """
17
+ 整理批次数据(支持 padding,但所有样本长度相同时直接堆叠)
18
+
19
+ 参数:
20
+ batch: 批次数据列表,每个元素是 (input_ids, target_ids)
21
+ - input_ids: 输入序列,形状为 (seq_len,)
22
+ - target_ids: 目标序列,形状为 (seq_len,)
23
+ pad_token_id: padding token ID(默认: 0)
24
+
25
+ 返回:
26
+ (input_ids_batch, target_ids_batch)
27
+ - input_ids_batch: 批次输入序列,形状为 (batch_size, max_seq_len)
28
+ - target_ids_batch: 批次目标序列,形状为 (batch_size, max_seq_len)
29
+ """
30
+ # 分离 input_ids 和 target_ids
31
+ input_ids_list = [item[0] for item in batch]
32
+ target_ids_list = [item[1] for item in batch]
33
+
34
+ # 检查所有样本长度是否相同
35
+ input_lengths = [len(ids) for ids in input_ids_list]
36
+ target_lengths = [len(ids) for ids in target_ids_list]
37
+
38
+ all_same_length = (
39
+ len(set(input_lengths)) == 1 and
40
+ len(set(target_lengths)) == 1 and
41
+ input_lengths[0] == target_lengths[0]
42
+ )
43
+
44
+ if all_same_length:
45
+ # 所有样本长度相同,直接堆叠(高效,不需要 padding)
46
+ input_ids_batch = torch.stack(input_ids_list, dim=0) # (batch_size, seq_len)
47
+ target_ids_batch = torch.stack(target_ids_list, dim=0) # (batch_size, seq_len)
48
+ else:
49
+ # 样本长度不同,需要 padding
50
+ max_seq_len = max(max(input_lengths), max(target_lengths))
51
+
52
+ # Padding input_ids
53
+ padded_input_ids = []
54
+ for ids in input_ids_list:
55
+ pad_length = max_seq_len - len(ids)
56
+ if pad_length > 0:
57
+ padded = torch.cat([ids, torch.full((pad_length,), pad_token_id, dtype=ids.dtype)])
58
+ else:
59
+ padded = ids
60
+ padded_input_ids.append(padded)
61
+
62
+ # Padding target_ids
63
+ padded_target_ids = []
64
+ for ids in target_ids_list:
65
+ pad_length = max_seq_len - len(ids)
66
+ if pad_length > 0:
67
+ padded = torch.cat([ids, torch.full((pad_length,), pad_token_id, dtype=ids.dtype)])
68
+ else:
69
+ padded = ids
70
+ padded_target_ids.append(padded)
71
+
72
+ # 堆叠
73
+ input_ids_batch = torch.stack(padded_input_ids, dim=0) # (batch_size, max_seq_len)
74
+ target_ids_batch = torch.stack(padded_target_ids, dim=0) # (batch_size, max_seq_len)
75
+
76
+ return input_ids_batch, target_ids_batch
77
+
78
+
79
+ if __name__ == "__main__":
80
+ print("=" * 60)
81
+ print("数据整理函数测试")
82
+ print("=" * 60)
83
+
84
+ # 模拟批次数据
85
+ batch_size = 4
86
+ seq_len = 10
87
+
88
+ print("\n1. 创建模拟批次数据")
89
+ print(f" 批次大小: {batch_size}")
90
+ print(f" 序列长度: {seq_len}")
91
+
92
+ # 创建模拟数据
93
+ batch = []
94
+ for i in range(batch_size):
95
+ input_ids = torch.randint(0, 100, (seq_len,))
96
+ target_ids = torch.randint(0, 100, (seq_len,))
97
+ batch.append((input_ids, target_ids))
98
+ print(f" 样本 {i}: input_ids 形状={input_ids.shape}, target_ids 形状={target_ids.shape}")
99
+
100
+ # 测试 collate_fn
101
+ print("\n2. 测试 collate_fn")
102
+ input_ids_batch, target_ids_batch = collate_fn(batch)
103
+
104
+ print(f" 输入批次形状: {input_ids_batch.shape}")
105
+ print(f" 目标批次形状: {target_ids_batch.shape}")
106
+ print(f" 期望形状: ({batch_size}, {seq_len})")
107
+
108
+ # 验证形状
109
+ assert input_ids_batch.shape == (batch_size, seq_len), \
110
+ f"输入批次形状错误: {input_ids_batch.shape} != ({batch_size}, {seq_len})"
111
+ assert target_ids_batch.shape == (batch_size, seq_len), \
112
+ f"目标批次形状错误: {target_ids_batch.shape} != ({batch_size}, {seq_len})"
113
+
114
+ print(" 形状验证通过")
115
+
116
+ # 验证数据是否正确堆叠
117
+ print("\n3. 验证数据堆叠")
118
+ for i in range(batch_size):
119
+ input_match = torch.equal(input_ids_batch[i], batch[i][0])
120
+ target_match = torch.equal(target_ids_batch[i], batch[i][1])
121
+ print(f" 样本 {i}: input_ids 匹配={input_match}, target_ids 匹配={target_match}")
122
+ assert input_match and target_match, f"样本 {i} 数据不匹配"
123
+
124
+ print(" 数据验证通过")
125
+
126
+ # 测试不同序列长度(需要 padding)
127
+ print("\n4. 测试不同序列长度(需要 padding)")
128
+ batch_variable = [
129
+ (torch.randint(0, 100, (5,)), torch.randint(0, 100, (5,))),
130
+ (torch.randint(0, 100, (8,)), torch.randint(0, 100, (8,))),
131
+ (torch.randint(0, 100, (10,)), torch.randint(0, 100, (10,))),
132
+ ]
133
+
134
+ print(" 样本长度: [5, 8, 10]")
135
+ input_batch_var, target_batch_var = collate_fn(batch_variable, pad_token_id=0)
136
+
137
+ print(f" 输入批次形状: {input_batch_var.shape}")
138
+ print(f" 目标批次形状: {target_batch_var.shape}")
139
+ print(f" 期望形状: (3, 10)")
140
+
141
+ assert input_batch_var.shape == (3, 10), \
142
+ f"输入批次形状错误: {input_batch_var.shape} != (3, 10)"
143
+ assert target_batch_var.shape == (3, 10), \
144
+ f"目标批次形状错误: {target_batch_var.shape} != (3, 10)"
145
+
146
+ # 验证 padding 是否正确
147
+ print("\n5. 验证 padding")
148
+ for i, (orig_input, orig_target) in enumerate(batch_variable):
149
+ orig_len = len(orig_input)
150
+ # 检查原始数据是否正确
151
+ assert torch.equal(input_batch_var[i, :orig_len], orig_input), \
152
+ f"样本 {i} 的 input_ids 数据不匹配"
153
+ assert torch.equal(target_batch_var[i, :orig_len], orig_target), \
154
+ f"样本 {i} 的 target_ids 数据不匹配"
155
+ # 检查 padding 是否正确(应该都是 pad_token_id)
156
+ if orig_len < 10:
157
+ assert torch.all(input_batch_var[i, orig_len:] == 0), \
158
+ f"样本 {i} 的 input_ids padding 不正确"
159
+ assert torch.all(target_batch_var[i, orig_len:] == 0), \
160
+ f"样本 {i} 的 target_ids padding 不正确"
161
+ print(f" 样本 {i}: 长度={orig_len}, padding 验证通过")
162
+
163
+ print(" ✓ Padding 验证通过")
164
+
165
+ print("\n" + "=" * 60)
166
+ print("所有测试完成!")
167
+ print("=" * 60)
llm/data/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """数据集:CharDataset"""
2
+ # 2026-01-23
3
+
4
+ import numpy as np
5
+ import torch
6
+ from pathlib import Path
7
+
8
+
9
+ class CharDataset:
10
+ """字符级数据集"""
11
+
12
+ def __init__(self, data_path, tokenizer, context_window=128):
13
+ """
14
+ 初始化数据集
15
+
16
+ 参数:
17
+ data_path: 数据文件路径(.npy 格式,包含 token IDs)
18
+ tokenizer: 分词器(用于获取 vocab_size 等信息,实际数据已预处理)
19
+ context_window: 训练时的块大小(序列长度)
20
+ """
21
+ self.tokenizer = tokenizer
22
+ self.context_window = context_window
23
+
24
+ # 加载数据
25
+ data_path = Path(data_path)
26
+ if not data_path.exists():
27
+ raise FileNotFoundError(f"数据文件不存在: {data_path}")
28
+
29
+ # 加载 numpy 数组(一维数组,包含所有 token IDs)
30
+ self.data = np.load(str(data_path))
31
+
32
+ # 确保数据是一维数组
33
+ if self.data.ndim > 1:
34
+ self.data = self.data.flatten()
35
+
36
+ print(f"数据集加载完成:")
37
+ print(f" 数据文件: {data_path}")
38
+ print(f" 数据长度: {len(self.data):,} tokens")
39
+ print(f" 块大小: {context_window}")
40
+ print(f" 可用样本数: {len(self)}")
41
+
42
+ def __len__(self):
43
+ """
44
+ 返回数据集大小
45
+
46
+ 返回:
47
+ 可用样本数量(滑动窗口的数量)
48
+ """
49
+ # 每个样本需要 block_size + 1 个 token
50
+ # 最后一个样本从 len(data) - block_size - 1 开始
51
+ # 所以总样本数是 len(data) - block_size
52
+ return len(self.data) - self.context_window
53
+
54
+ def __getitem__(self, idx):
55
+ """
56
+ 获取单个样本
57
+
58
+ 参数:
59
+ idx: 样本索引
60
+
61
+ 返回:
62
+ input_ids: 输入序列,形状为 (block_size,)
63
+ target_ids: 目标序列(输入序列右移一位),形状为 (block_size,)
64
+ """
65
+ # 边界检查
66
+ if idx < 0 or idx >= len(self):
67
+ raise IndexError(f"索引 {idx} 超出范围 [0, {len(self)})")
68
+
69
+ # 获取一个长度为 context_window + 1 的序列
70
+ # 例如:context_window=512,则取 513 个 token
71
+ chunk = self.data[idx:idx + self.context_window + 1]
72
+
73
+ # 前 context_window 个作为输入,后 context_window 个作为目标(右移一位)
74
+ # 例如:[0, 1, 2, ..., 255] -> [1, 2, 3, ..., 256]
75
+ input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
76
+ target_ids = torch.tensor(chunk[1:], dtype=torch.long)
77
+
78
+ return input_ids, target_ids
79
+
80
+
81
+ if __name__ == "__main__":
82
+ # 添加项目根目录到 Python 路径
83
+ from pathlib import Path
84
+ import sys
85
+
86
+ project_root = Path(__file__).parent.parent.parent
87
+ sys.path.insert(0, str(project_root))
88
+
89
+ from llm.data.tokenizer import CharTokenizer
90
+
91
+ print("=" * 60)
92
+ print("CharDataset 测试")
93
+ print("=" * 60)
94
+
95
+ # 测试参数
96
+ data_path = Path("data/processed/train.npy")
97
+ vocab_path = Path("data/vocab/char_vocab.json")
98
+ block_size = 256
99
+
100
+ # 检查文件是否存在
101
+ if not data_path.exists():
102
+ print(f"错误:数据文件不存在: {data_path}")
103
+ print("请先运行: python scripts/preprocess.py")
104
+ sys.exit(1)
105
+
106
+ if not vocab_path.exists():
107
+ print(f"错误:词汇表文件不存在: {vocab_path}")
108
+ print("请先运行: python scripts/preprocess.py")
109
+ sys.exit(1)
110
+
111
+ # 加载分词器
112
+ print("\n1. 加载分词器")
113
+ tokenizer = CharTokenizer(vocab_path=str(vocab_path))
114
+ print(f" 词汇表大小: {tokenizer.vocab_size}")
115
+
116
+ # 创建数据集
117
+ print("\n2. 创建数据集")
118
+ dataset = CharDataset(
119
+ data_path=str(data_path),
120
+ tokenizer=tokenizer,
121
+ block_size=block_size
122
+ )
123
+
124
+ # 测试数据集长度
125
+ print(f"\n3. 数据集信息")
126
+ print(f" 数据集大小: {len(dataset):,} 个样本")
127
+ print(f" 每个样本输入长度: {block_size}")
128
+ print(f" 每个样本目标长度: {block_size}")
129
+
130
+ # 测试获取样本
131
+ print("\n4. 测试获取样本")
132
+ input_ids, target_ids = dataset[0]
133
+ print(f" 样本 0:")
134
+ print(f" 输入形状: {input_ids.shape}")
135
+ print(f" 目标形状: {target_ids.shape}")
136
+ print(f" 输入前10个 token IDs: {input_ids[:10].tolist()}")
137
+ print(f" 目标前10个 token IDs: {target_ids[:10].tolist()}")
138
+
139
+ # 验证目标是否正确(应该是输入右移一位)
140
+ print(f"\n5. 验证目标序列")
141
+ expected_target = input_ids[1:].tolist()
142
+ actual_target = target_ids[:-1].tolist()
143
+ is_correct = expected_target == actual_target
144
+ print(f" 目标序列是否正确(右移一位): {is_correct}")
145
+ if not is_correct:
146
+ print(f" 期望: {expected_target[:10]}")
147
+ print(f" 实际: {actual_target[:10]}")
148
+
149
+ # 测试解码
150
+ print(f"\n6. 测试解码")
151
+ input_text = tokenizer.decode(input_ids[:50])
152
+ target_text = tokenizer.decode(target_ids[:50])
153
+ print(f" 输入文本(前50个字符): {repr(input_text)}")
154
+ print(f" 目标文本(前50个字符): {repr(target_text)}")
155
+
156
+ # 测试多个样本
157
+ print(f"\n7. 测试多个样本")
158
+ for i in [0, 100, len(dataset) - 1]:
159
+ input_ids, target_ids = dataset[i]
160
+ print(f" 样本 {i}: 输入形状 {input_ids.shape}, 目标形状 {target_ids.shape}")
161
+
162
+ print("\n" + "=" * 60)
163
+ print("所有测试完成!")
164
+ print("=" * 60)
llm/data/tokenizer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """分词器:CharTokenizer,从切割文本得到 token IDs"""
2
+ # 2026-01-22
3
+
4
+ import json
5
+ from pathlib import Path
6
+ from collections import Counter
7
+
8
+
9
+ class CharTokenizer:
10
+ """字符级分词器,英文的"""
11
+
12
+ def __init__(self, vocab_path=None):
13
+ """
14
+ 初始化分词器
15
+
16
+ 参数:
17
+ vocab_path: 词汇表文件路径(JSON 格式)
18
+ """
19
+ self.vocab_path = vocab_path
20
+ self.char_to_id = {} # 字符 -> ID 的映射
21
+ self.id_to_char = {} # ID -> 字符的映射
22
+ self.vocab_size = 0
23
+
24
+ if vocab_path and Path(vocab_path).exists():
25
+ self.load_vocab(vocab_path)
26
+
27
+ def build_vocab(self, texts):
28
+ """
29
+ 从文本构建词汇表
30
+
31
+ 参数:
32
+ texts: 文本列表或单个文本字符串
33
+ """
34
+ # 统一处理:如果是单个 str 类型的字符串,转为列表
35
+ if isinstance(texts, str):
36
+ texts = [texts]
37
+
38
+ # 合并所有字符串文本
39
+ all_chars = ''.join(texts)
40
+
41
+ # 统计字符(Python 自动按字符遍历,统计每个字符出现的次数)
42
+ char_counts = Counter(all_chars)
43
+
44
+ # 创建字符到 ID 的映射
45
+ self.char_to_id = {
46
+ '<unk>': 0, # 未知字符
47
+ # '<pad>': 1, # 填充,滑动窗口,不需要 '<pad>'
48
+ # '<bos>': 1, # 开始字符
49
+ # '<eos>': 2, # 结束字符
50
+
51
+ }
52
+
53
+ # 按频率排序添加字符
54
+ sorted_chars = sorted(char_counts.items(), key=lambda x: x[1], reverse=True)
55
+ for char, count in sorted_chars:
56
+ if char not in self.char_to_id:
57
+ self.char_to_id[char] = len(self.char_to_id)
58
+
59
+ # 创建反向映射
60
+ self.id_to_char = {id: char for char, id in self.char_to_id.items()}
61
+ self.vocab_size = len(self.char_to_id)
62
+
63
+ def encode(self, text):
64
+ """
65
+ 编码:将文本切割成字符,然后转换为 ID
66
+
67
+ 参数:
68
+ text: 输入文本字符串
69
+
70
+ 返回:
71
+ token_ids: token ID 列表
72
+ """
73
+ token_ids = []
74
+ # 遍历字符串,每个字符自动分离
75
+ for char in text:
76
+ # 查找字符对应的 ID,找不到则使用 <unk>
77
+ char_id = self.char_to_id.get(char, self.char_to_id.get('<unk>', 0))
78
+ token_ids.append(char_id)
79
+ return token_ids
80
+
81
+ def decode(self, token_ids):
82
+ """
83
+ 解码:将 ID 列表转换回文本
84
+
85
+ 参数:
86
+ token_ids: token ID 列表或张量
87
+
88
+ 返回:
89
+ text: 解码后的文本字符串
90
+ """
91
+ # 如果是 PyTorch 张量,转换为列表
92
+ if hasattr(token_ids, 'tolist'):
93
+ token_ids = token_ids.tolist()
94
+
95
+ # 将每个 ID 转换为字符
96
+ chars = []
97
+ for id in token_ids:
98
+ char = self.id_to_char.get(id, '<unk>')
99
+ # 过滤特殊 token
100
+ if char not in ['<unk>']:
101
+ chars.append(char)
102
+
103
+ # 拼接成文本
104
+ return ''.join(chars)
105
+
106
+ def save_vocab(self, vocab_path):
107
+ """保存词汇表到文件"""
108
+ vocab_data = {
109
+ 'char_to_id': self.char_to_id,
110
+ 'id_to_char': {str(k): v for k, v in self.id_to_char.items()},
111
+ 'vocab_size': self.vocab_size
112
+ }
113
+
114
+ Path(vocab_path).parent.mkdir(parents=True, exist_ok=True)
115
+
116
+ with open(vocab_path, 'w', encoding='utf-8') as f:
117
+ json.dump(vocab_data, f, ensure_ascii=False, indent=2)
118
+
119
+ def load_vocab(self, vocab_path):
120
+ """从文件加载词汇表"""
121
+ with open(vocab_path, 'r', encoding='utf-8') as f:
122
+ vocab_data = json.load(f)
123
+
124
+ self.char_to_id = vocab_data['char_to_id']
125
+ self.id_to_char = {int(k): v for k, v in vocab_data['id_to_char'].items()}
126
+ self.vocab_size = vocab_data['vocab_size']
llm/inference/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """推理模块:文本生成"""
2
+
3
+ from llm.inference.generate import greedy_decode, sample_decode
4
+
5
+ __all__ = ['greedy_decode', 'sample_decode']
llm/inference/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (359 Bytes). View file
 
llm/inference/__pycache__/generate.cpython-312.pyc ADDED
Binary file (5.8 kB). View file
 
llm/inference/generate.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """文本生成:greedy / sampling"""
2
+ # 2026-01-23
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from llm.model.attention import create_causal_mask
7
+
8
+
9
+ def greedy_decode(model, input_ids, max_length=100, device="cpu", stop_token_ids=None):
10
+ """
11
+ 贪心解码:每次选择概率最高的 token
12
+
13
+ 参数:
14
+ model: Transformer 模型
15
+ input_ids: 输入 token IDs,形状为 (batch_size, seq_len)
16
+ max_length: 最大生成长度(不包括输入长度)
17
+ device: 设备
18
+ stop_token_ids: 停止 token ID 列表,遇到这些 token 时提前停止(可选)
19
+
20
+ 返回:
21
+ generated_ids: 生成的 token IDs,形状为 (batch_size, total_length)
22
+ """
23
+ model.eval()
24
+ input_ids = input_ids.to(device)
25
+ generated_ids = input_ids.clone()
26
+
27
+ if stop_token_ids is None:
28
+ stop_token_ids = []
29
+ elif isinstance(stop_token_ids, int):
30
+ stop_token_ids = [stop_token_ids]
31
+
32
+ batch_size = generated_ids.size(0)
33
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
34
+
35
+ with torch.no_grad():
36
+ for step in range(max_length):
37
+ # 如果所有序列都已完成,提前退出
38
+ if finished.all():
39
+ break
40
+
41
+ # 获取当前序列长度
42
+ seq_len = generated_ids.size(1)
43
+
44
+ # 创建因果掩码
45
+ causal_mask = create_causal_mask(seq_len, device=device)
46
+
47
+ # 前向传播
48
+ logits = model(generated_ids, mask=causal_mask) # (batch_size, seq_len, vocab_size)
49
+
50
+ # 获取最后一个位置的 logits
51
+ next_token_logits = logits[:, -1, :] # (batch_size, vocab_size)
52
+
53
+ # 选择概率最高的 token(贪心)
54
+ next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) # (batch_size, 1)
55
+
56
+ # 检查是否遇到停止 token
57
+ if stop_token_ids:
58
+ for stop_id in stop_token_ids:
59
+ finished = finished | (next_token_id.squeeze(-1) == stop_id)
60
+
61
+ # 拼接生成的 token
62
+ generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
63
+
64
+ return generated_ids
65
+
66
+
67
+ def sample_decode(
68
+ model,
69
+ input_ids,
70
+ max_length=100,
71
+ temperature=1.0,
72
+ top_k=0,
73
+ top_p=0.0,
74
+ device="cpu",
75
+ stop_token_ids=None
76
+ ):
77
+ """
78
+ 采样解码:使用温度采样、top-k 采样和 top-p (nucleus) 采样
79
+
80
+ 参数:
81
+ model: Transformer 模型
82
+ input_ids: 输入 token IDs,形状为 (batch_size, seq_len)
83
+ max_length: 最大生成长度(不包括输入长度)
84
+ temperature: 采样温度(越高越随机,越低越确定,0=贪心)
85
+ top_k: Top-K 采样(只从概率最高的 k 个 token 中采样,0=禁用)
86
+ top_p: Top-P (Nucleus) 采样(保留累积概率达到 p 的 token,0.0=禁用)
87
+ device: 设备
88
+ stop_token_ids: 停止 token ID 列表,遇到这些 token 时提前停止(可选)
89
+
90
+ 返回:
91
+ generated_ids: 生成的 token IDs,形状为 (batch_size, total_length)
92
+ """
93
+ model.eval()
94
+ input_ids = input_ids.to(device)
95
+ generated_ids = input_ids.clone()
96
+
97
+ if stop_token_ids is None:
98
+ stop_token_ids = []
99
+ elif isinstance(stop_token_ids, int):
100
+ stop_token_ids = [stop_token_ids]
101
+
102
+ batch_size = generated_ids.size(0)
103
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
104
+
105
+ # 如果温度为 0,使用贪心解码
106
+ if temperature == 0:
107
+ return greedy_decode(model, input_ids, max_length, device, stop_token_ids)
108
+
109
+ with torch.no_grad():
110
+ for step in range(max_length):
111
+ # 如果所有序列都已完成,提前退出
112
+ if finished.all():
113
+ break
114
+
115
+ # 获取当前序列长度
116
+ seq_len = generated_ids.size(1)
117
+
118
+ # 创建因果掩码
119
+ causal_mask = create_causal_mask(seq_len, device=device)
120
+
121
+ # 前向传播
122
+ logits = model(generated_ids, mask=causal_mask) # (batch_size, seq_len, vocab_size)
123
+
124
+ # 获取最后一个位置的 logits
125
+ next_token_logits = logits[:, -1, :] # (batch_size, vocab_size)
126
+
127
+ # 应用温度
128
+ if temperature != 1.0:
129
+ next_token_logits = next_token_logits / temperature
130
+
131
+ # Top-K 采样
132
+ if top_k is not None and top_k > 0:
133
+ # 获取 top-k 的值和索引
134
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)), dim=-1)
135
+
136
+ # 创建掩码,将非 top-k 的位置设为负无穷
137
+ mask = torch.full_like(next_token_logits, float('-inf'))
138
+ mask.scatter_(-1, top_k_indices, top_k_logits)
139
+ next_token_logits = mask
140
+
141
+ # Top-P (Nucleus) 采样
142
+ if top_p is not None and top_p > 0.0:
143
+ # 先计算概率分布
144
+ probs = F.softmax(next_token_logits, dim=-1)
145
+
146
+ # 按概率降序排序
147
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
148
+
149
+ # 计算累积概率
150
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
151
+
152
+ # 找到第一个累积概率超过 top_p 的位置
153
+ # 保留该位置及之前的所有 token(至少保留第一个 token)
154
+ sorted_indices_to_remove = cumulative_probs > top_p
155
+ # 通过移位确保至少保留第一个 token
156
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
157
+ sorted_indices_to_remove[..., 0] = False
158
+
159
+ # 将排序后的掩码映射回原始索引
160
+ indices_to_remove = sorted_indices_to_remove.scatter(
161
+ -1, sorted_indices, sorted_indices_to_remove
162
+ )
163
+ next_token_logits[indices_to_remove] = float('-inf')
164
+
165
+ # 应用 softmax 得到概率分布
166
+ probs = F.softmax(next_token_logits, dim=-1)
167
+
168
+ # 从概率分布中采样
169
+ next_token_id = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
170
+
171
+ # 检查是否遇到停止 token
172
+ if stop_token_ids:
173
+ for stop_id in stop_token_ids:
174
+ finished = finished | (next_token_id.squeeze(-1) == stop_id)
175
+
176
+ # 拼接生成的 token
177
+ generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
178
+
179
+ return generated_ids
llm/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """模型组件模块"""
llm/model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (219 Bytes). View file
 
llm/model/__pycache__/attention.cpython-312.pyc ADDED
Binary file (16.3 kB). View file
 
llm/model/__pycache__/block.cpython-312.pyc ADDED
Binary file (6.39 kB). View file
 
llm/model/__pycache__/embedding.cpython-312.pyc ADDED
Binary file (1.45 kB). View file
 
llm/model/__pycache__/ffn.cpython-312.pyc ADDED
Binary file (5.47 kB). View file
 
llm/model/__pycache__/norm.cpython-312.pyc ADDED
Binary file (7.19 kB). View file
 
llm/model/__pycache__/rope.cpython-312.pyc ADDED
Binary file (5.92 kB). View file
 
llm/model/__pycache__/transformer.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
llm/model/attention.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """注意力机制:MHA / GQA / 滑动窗口注意力"""
2
+ # 2026-01-23
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+
8
+
9
+ def scaled_dot_product_attention(
10
+ q, k, v, mask=None, sliding_window=None, sliding_window_overlap=True
11
+ ):
12
+ """
13
+ 标量点积注意力(基础函数)
14
+
15
+ 参数:
16
+ q: Query,形状为 (batch_size, num_heads, seq_len, head_dim)
17
+ k: Key,形状为 (batch_size, num_heads, seq_len, head_dim)
18
+ v: Value,形状为 (batch_size, num_heads, seq_len, head_dim)
19
+ mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
20
+ True 表示可以关注,False 表示不能关注
21
+ sliding_window: 滑动窗口大小(可选,如果提供则应用滑动窗口注意力)
22
+ sliding_window_overlap: 是否允许滑动窗口重叠
23
+
24
+ 返回:
25
+ output: 注意力输出,形状为 (batch_size, num_heads, seq_len, head_dim)
26
+ attn_weights: 注意力权重,形状为 (batch_size, num_heads, seq_len, seq_len)
27
+ """
28
+ batch_size, num_heads, seq_len, head_dim = q.shape
29
+
30
+ # 计算注意力分数:Q @ K^T
31
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
32
+
33
+ # 先应用滑动窗口掩码(如果有)
34
+ # 注意:滑动窗口掩码已经包含了因果掩码,所以如果使用滑动窗口,通常不需要额外的 mask
35
+ if sliding_window is not None:
36
+ scores = apply_sliding_window_mask(
37
+ scores, sliding_window, seq_len, sliding_window_overlap
38
+ )
39
+
40
+ # 应用普通掩码:mask 中为 False 的位置设为 -inf
41
+ # mask 常见只有:(L, L)(所有 batch / head 共用,常用于 causal mask);(B, L, L)(每个 batch 和 head 独立,常用于 padding mask)
42
+ if mask is not None:
43
+ # 如果 mask 是 2D,扩展为 4D 以匹配 scores 的形状
44
+ if mask.dim() == 2:
45
+ mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
46
+ elif mask.dim() == 3:
47
+ mask = mask.unsqueeze(1) # (batch_size, 1, seq_len, seq_len)
48
+ # mask 为 False 的位置设为 -inf,softmax 后变为 0
49
+ scores = scores.masked_fill(~mask, float("-inf"))
50
+
51
+ # Softmax:转换为概率分布
52
+ attn_weights = torch.softmax(scores, dim=-1)
53
+
54
+ # 加权求和:attn_weights @ V
55
+ output = torch.matmul(attn_weights, v)
56
+
57
+ return output, attn_weights
58
+
59
+
60
+ def create_causal_mask(seq_len, device="cpu"):
61
+ """
62
+ 创建因果掩码(下三角矩阵)
63
+
64
+ 参数:
65
+ seq_len: 序列长度
66
+ device: 设备
67
+
68
+ 返回:
69
+ 因果掩码,形状为 (seq_len, seq_len),下三角为 True
70
+ """
71
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
72
+ return mask.bool()
73
+
74
+
75
+ def apply_sliding_window_mask(scores, window_size, seq_len, overlap=True):
76
+ """
77
+ 应用滑动窗口掩码
78
+
79
+ 参数:
80
+ scores: 注意力分数,形状为 (batch_size, num_heads, seq_len, seq_len)
81
+ window_size: 窗口大小
82
+ seq_len: 序列长度
83
+ overlap: 是否允许窗口重叠(如果为 True,使用对称窗口;如果为 False,使用非对称窗口)
84
+
85
+ 返回:
86
+ 应用掩码后的分数
87
+ """
88
+ batch_size, num_heads, seq_len, _ = scores.shape
89
+
90
+ # 创建滑动窗口掩码
91
+ # 位置 i 可以关注的位置范围
92
+ window_mask = torch.zeros(seq_len, seq_len, device=scores.device, dtype=torch.bool)
93
+
94
+ if overlap:
95
+ # 对称窗口:位置 i 可以关注 [max(0, i-window_size//2), min(seq_len, i+window_size//2+1)]
96
+ for i in range(seq_len):
97
+ start = max(0, i - window_size // 2)
98
+ end = min(seq_len, i + window_size // 2 + 1)
99
+ window_mask[i, start:end] = True
100
+ else:
101
+ # 非对称窗口:位置 i 可以关注 [max(0, i-window_size+1), i+1]
102
+ for i in range(seq_len):
103
+ start = max(0, i - window_size + 1)
104
+ end = i + 1
105
+ window_mask[i, start:end] = True
106
+
107
+ # 结合因果掩码(下三角):既要满足因果性,又要满足窗口限制
108
+ causal_mask = torch.tril(
109
+ torch.ones(seq_len, seq_len, device=scores.device, dtype=torch.bool)
110
+ )
111
+ combined_mask = causal_mask & window_mask
112
+
113
+ # 扩展维度以匹配 scores 的形状
114
+ combined_mask = combined_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
115
+
116
+ # 应用掩码:不能关注的位置设为 -inf
117
+ scores = scores.masked_fill(~combined_mask, float("-inf"))
118
+
119
+ return scores
120
+
121
+
122
+ class MultiHeadAttention(nn.Module):
123
+ """
124
+ 多头注意力(MHA)
125
+
126
+ 所有 Query、Key、Value 头都是独立的
127
+ """
128
+
129
+ def __init__(self, hidden_size, num_heads):
130
+ """
131
+ 初始化多头注意力
132
+
133
+ 参数:
134
+ hidden_size: 隐藏层维度
135
+ num_heads: 注意力头数
136
+ """
137
+ super().__init__()
138
+ self.hidden_size = hidden_size
139
+ self.num_heads = num_heads
140
+ self.head_dim = hidden_size // num_heads
141
+
142
+ # 确保 hidden_size 能被 num_heads 整除
143
+ assert hidden_size % num_heads == 0, (
144
+ f"hidden_size ({hidden_size}) 必须能被 num_heads ({num_heads}) 整除"
145
+ )
146
+
147
+ # Query、Key、Value 投影,各自创建了一个线性层
148
+ self.q_proj = nn.Linear(hidden_size, hidden_size)
149
+ self.k_proj = nn.Linear(hidden_size, hidden_size)
150
+ self.v_proj = nn.Linear(hidden_size, hidden_size)
151
+
152
+ # 输出投影,创建了一个线性层
153
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
154
+
155
+ def forward(
156
+ self, x, mask=None, rope=None, sliding_window=None, sliding_window_overlap=True
157
+ ):
158
+ """
159
+ 前向传播
160
+
161
+ 参数:
162
+ x: 输入张量,形状为 (batch_size, seq_len, hidden_size)
163
+ mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
164
+ rope: RoPE 位置编码模块(可选)
165
+ sliding_window: 滑动窗口大小(可选,如果提供则应用滑动窗口注意力)
166
+ sliding_window_overlap: 是否允许滑动窗口重叠
167
+
168
+ 返回:
169
+ 输出张量,形状为 (batch_size, seq_len, hidden_size)
170
+ """
171
+ batch_size, seq_len, hidden_size = x.shape
172
+
173
+ # 1. 通过投影层得到 Q、K、V
174
+ q = self.q_proj(x) # (batch_size, seq_len, hidden_size)
175
+ k = self.k_proj(x)
176
+ v = self.v_proj(x)
177
+
178
+ # 2. 重塑为多头形式拆分成多个头
179
+ # (batch_size, seq_len, hidden_size) -> (batch_size, seq_len, num_heads, head_dim)
180
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
181
+ k = k.view(batch_size, seq_len, self.num_heads, self.head_dim)
182
+ v = v.view(batch_size, seq_len, self.num_heads, self.head_dim)
183
+
184
+ # 3. 转置以便进行注意力计算
185
+ # (batch_size, num_heads, seq_len, head_dim)
186
+ # (B, L, H, D_h) -> (B, H, L, D_h)
187
+ q = q.transpose(1, 2)
188
+ k = k.transpose(1, 2)
189
+ v = v.transpose(1, 2)
190
+
191
+ # 4. 应用 RoPE 位置编码(如果有)
192
+ if rope is not None:
193
+ q, k = rope(q, k)
194
+
195
+ # 5. 计算注意力
196
+ # attn_output: (B, H, L, D_h), 第 h 个 head 中,第 i 个 token,从所有 token 的 Value 中,按注意力权重“加权融合”出来的向量
197
+ # attn_weights: (B, H, L, L),在 batch=b、head=h 中,第 i 个 token 对所有 token 的注意力权重
198
+ attn_output, attn_weights = scaled_dot_product_attention(
199
+ q, k, v, mask, sliding_window, sliding_window_overlap
200
+ )
201
+
202
+ # 6. 转置回来并拼接
203
+ # (batch_size, num_heads, seq_len, head_dim) -> (batch_size, seq_len, num_heads, head_dim)
204
+ attn_output = attn_output.transpose(1, 2)
205
+
206
+ # 7. 重塑为原始形状
207
+ # (batch_size, seq_len, num_heads, head_dim) -> (batch_size, seq_len, hidden_size)
208
+ attn_output = attn_output.contiguous().view(batch_size, seq_len, hidden_size)
209
+
210
+ # 8. 通过输出投影层
211
+ output = self.o_proj(attn_output)
212
+
213
+ return output
214
+
215
+
216
+ class GroupedQueryAttention(nn.Module):
217
+ """
218
+ 分组查询注意力(GQA)
219
+
220
+ 多个 Query 头共享一组 Key-Value 头,减少 KV Cache 的内存占用
221
+ """
222
+
223
+ def __init__(self, hidden_size, num_heads, num_kv_heads):
224
+ """
225
+ 初始化 GQA
226
+
227
+ 参数:
228
+ hidden_size: 隐藏层维度
229
+ num_heads: Query 头数
230
+ num_kv_heads: Key-Value 头数
231
+ """
232
+ super().__init__()
233
+ self.hidden_size = hidden_size
234
+ self.num_heads = num_heads
235
+ self.num_kv_heads = num_kv_heads
236
+ self.head_dim = hidden_size // num_heads
237
+
238
+ # 确保 hidden_size 能被 num_heads 整除
239
+ assert hidden_size % num_heads == 0, (
240
+ f"hidden_size ({hidden_size}) 必须能被 num_heads ({num_heads}) 整除"
241
+ )
242
+ # 确保 num_heads 能被 num_kv_heads 整除
243
+ assert num_heads % num_kv_heads == 0, (
244
+ f"num_heads ({num_heads}) 必须能被 num_kv_heads ({num_kv_heads}) 整除"
245
+ )
246
+
247
+ # Query 投影:每个头都有独立的 Q
248
+ self.q_proj = nn.Linear(hidden_size, hidden_size)
249
+
250
+ # Key 和 Value 投影:多个 Query 头共享 KV
251
+ # D -> (num_kv_heads × head_dim)
252
+ # 之后(B, L, 256) -> (B, L, 4, 64)
253
+ kv_hidden_size = num_kv_heads * self.head_dim
254
+ self.k_proj = nn.Linear(hidden_size, kv_hidden_size)
255
+ self.v_proj = nn.Linear(hidden_size, kv_hidden_size)
256
+
257
+ # 输出投影
258
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
259
+
260
+ def forward(
261
+ self, x, mask=None, rope=None, sliding_window=None, sliding_window_overlap=True
262
+ ):
263
+ """
264
+ 前向传播
265
+
266
+ 参数:
267
+ x: 输入张量,形状为 (batch_size, seq_len, hidden_size)
268
+ mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
269
+ rope: RoPE 位置编码模块(可选)
270
+ sliding_window: 滑动窗口大小(可选,如果提供则应用滑动窗口注意力)
271
+ sliding_window_overlap: 是否允许滑动窗口重叠
272
+
273
+ 返回:
274
+ 输出张量,形状为 (batch_size, seq_len, hidden_size)
275
+ """
276
+ batch_size, seq_len, hidden_size = x.shape
277
+
278
+ # 1. 通过投影层
279
+ q = self.q_proj(x) # (batch_size, seq_len, hidden_size)
280
+ k = self.k_proj(x) # (batch_size, seq_len, num_kv_heads * head_dim)
281
+ v = self.v_proj(x) # (batch_size, seq_len, num_kv_heads * head_dim)
282
+
283
+ # 2. 重塑 Q
284
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
285
+ q = q.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
286
+
287
+ # 3. 重塑 K 和 V
288
+ k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
289
+ k = k.transpose(1, 2) # (batch_size, num_kv_heads, seq_len, head_dim)
290
+
291
+ v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
292
+ v = v.transpose(1, 2) # (batch_size, num_kv_heads, seq_len, head_dim)
293
+
294
+ # 4. 应用 RoPE 位置编码(如果有)
295
+ if rope is not None:
296
+ q, k = rope(q, k)
297
+
298
+ # 5. 重复 K 和 V 以匹配 Q 的头数
299
+ # 例如:10 个 Q 头,2 个 KV 头,每个 KV 头需要重复 5 次
300
+ repeat_kv = self.num_heads // self.num_kv_heads
301
+ k = k.repeat_interleave(
302
+ repeat_kv, dim=1
303
+ ) # (batch_size, num_heads, seq_len, head_dim)
304
+ v = v.repeat_interleave(
305
+ repeat_kv, dim=1
306
+ ) # (batch_size, num_heads, seq_len, head_dim)
307
+
308
+ # 6. 计算注意力
309
+ # attn_output: (B, H, L, D_h), 第 h 个 head 中,第 i 个 token,从所有 token 的 Value 中,按注意力权重"加权融合"出来的向量
310
+ # attn_weights: (B, H, L, L),在 batch=b、head=h 中,第 i 个 token 对所有 token 的注意力权重
311
+ attn_output, attn_weights = scaled_dot_product_attention(
312
+ q, k, v, mask, sliding_window, sliding_window_overlap
313
+ )
314
+
315
+ # 7. 转置并重塑
316
+ attn_output = attn_output.transpose(
317
+ 1, 2
318
+ ) # (batch_size, seq_len, num_heads, head_dim)
319
+
320
+ # (batch_size, seq_len, num_heads, head_dim) -> (batch_size, seq_len, hidden_size)
321
+ attn_output = attn_output.contiguous().view(batch_size, seq_len, hidden_size)
322
+
323
+ # 8. 输出投影
324
+ output = self.o_proj(attn_output)
325
+
326
+ return output
327
+
328
+
329
+ if __name__ == "__main__":
330
+ print("=" * 60)
331
+ print("注意力机制测试")
332
+ print("=" * 60)
333
+
334
+ # 测试参数(与 configs/model.yaml 一致)
335
+ hidden_size = 320
336
+ num_heads = 10
337
+ num_kv_heads = 2
338
+ head_dim = hidden_size // num_heads
339
+ batch_size = 2
340
+ seq_len = 10
341
+
342
+ print("\n1. 测试参数")
343
+ print(f" hidden_size: {hidden_size}")
344
+ print(f" num_heads: {num_heads}")
345
+ print(f" num_kv_heads: {num_kv_heads}")
346
+ print(f" head_dim: {head_dim}")
347
+ print(f" batch_size: {batch_size}")
348
+ print(f" seq_len: {seq_len}")
349
+
350
+ # 测试基础函数
351
+ print("\n2. 测试 scaled_dot_product_attention")
352
+ q = torch.randn(batch_size, num_heads, seq_len, head_dim)
353
+ k = torch.randn(batch_size, num_heads, seq_len, head_dim)
354
+ v = torch.randn(batch_size, num_heads, seq_len, head_dim)
355
+ causal_mask = create_causal_mask(seq_len)
356
+
357
+ output, attn_weights = scaled_dot_product_attention(q, k, v, causal_mask)
358
+ print(f" 输入 Q 形状: {q.shape}")
359
+ print(f" 输出形状: {output.shape}")
360
+ print(f" 注意力权重形状: {attn_weights.shape}")
361
+ print(f" 注意力权重和(每行应为1): {attn_weights.sum(dim=-1)[0, 0, :5]}")
362
+
363
+ # 测试 MHA
364
+ print("\n3. 测试 MultiHeadAttention")
365
+ mha = MultiHeadAttention(hidden_size, num_heads)
366
+ x = torch.randn(batch_size, seq_len, hidden_size)
367
+ output_mha = mha(x, mask=causal_mask)
368
+ print(f" 输入形状: {x.shape}")
369
+ print(f" 输出形状: {output_mha.shape}")
370
+ print(f" 参数数量: {sum(p.numel() for p in mha.parameters())}")
371
+
372
+ # 测试 GQA
373
+ print("\n4. 测试 GroupedQueryAttention")
374
+ gqa = GroupedQueryAttention(hidden_size, num_heads, num_kv_heads)
375
+ output_gqa = gqa(x, mask=causal_mask)
376
+ print(f" 输入形状: {x.shape}")
377
+ print(f" 输出形状: {output_gqa.shape}")
378
+ print(f" 参数数量: {sum(p.numel() for p in gqa.parameters())}")
379
+ print(f" Q 投影参数: {gqa.q_proj.weight.shape}")
380
+ print(f" K 投影参数: {gqa.k_proj.weight.shape}")
381
+ print(f" V 投影参数: {gqa.v_proj.weight.shape}")
382
+
383
+ # 测试 GQA + RoPE
384
+ print("\n5. 测试 GroupedQueryAttention + RoPE")
385
+ from pathlib import Path
386
+ import sys
387
+
388
+ # 添加项目根目录到 Python 路径
389
+ project_root = Path(__file__).parent.parent.parent
390
+ sys.path.insert(0, str(project_root))
391
+ from llm.model.rope import RoPE
392
+
393
+ rope = RoPE(dim=head_dim, max_seq_len=1024, theta=10000.0)
394
+ output_gqa_rope = gqa(x, mask=causal_mask, rope=rope)
395
+ print(f" 输出形状: {output_gqa_rope.shape}")
396
+ print(f" 与无 RoPE 的输出不同: {not torch.allclose(output_gqa, output_gqa_rope)}")
397
+
398
+ # 测试滑动窗口掩码
399
+ print("\n6. 测试滑动窗口掩码")
400
+ window_size = 4
401
+ scores = torch.randn(batch_size, num_heads, seq_len, seq_len)
402
+ scores_windowed = apply_sliding_window_mask(
403
+ scores, window_size, seq_len, overlap=True
404
+ )
405
+ print(f" 窗口大小: {window_size}")
406
+ print(f" 原始分数形状: {scores.shape}")
407
+ print(f" 掩码后分数形状: {scores_windowed.shape}")
408
+ # 检查第一个样本第一个头的掩码
409
+ mask_check = scores_windowed[0, 0] != float("-inf")
410
+ print(f" 位置 0 可关注的位置数: {mask_check[0].sum().item()}")
411
+ print(f" 位置 5 可关注的位置数: {mask_check[5].sum().item()}")
412
+
413
+ # 测试 GQA + 滑动窗口
414
+ print("\n7. 测试 GroupedQueryAttention + 滑动窗口")
415
+ output_gqa_window = gqa(
416
+ x, mask=None, rope=rope, sliding_window=window_size, sliding_window_overlap=True
417
+ )
418
+ print(f" 输出形状: {output_gqa_window.shape}")
419
+ print(
420
+ f" 与无滑动窗口的输出不同: {not torch.allclose(output_gqa_rope, output_gqa_window)}"
421
+ )
422
+
423
+ # 测试 MHA + 滑动窗口(统一后 MHA 也支持滑动窗口)
424
+ print("\n8. 测试 MultiHeadAttention + 滑动窗口")
425
+ output_mha_window = mha(
426
+ x, mask=None, rope=rope, sliding_window=window_size, sliding_window_overlap=True
427
+ )
428
+ print(f" 输出形状: {output_mha_window.shape}")
429
+ print(
430
+ f" 与无滑动窗口的输出不同: {not torch.allclose(output_mha, output_mha_window)}"
431
+ )
432
+
433
+ print("\n" + "=" * 60)
434
+ print("所有测试完成!")
435
+ print("=" * 60)
llm/model/block.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer Block"""
2
+ # 2026-01-23
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from pathlib import Path
8
+ import sys
9
+
10
+ project_root = Path(__file__).parent.parent.parent
11
+ sys.path.insert(0, str(project_root))
12
+
13
+ from llm.model.attention import GroupedQueryAttention
14
+ from llm.model.ffn import FFN
15
+ from llm.model.norm import RMSNorm
16
+ from llm.model.rope import RoPE
17
+
18
+
19
+ class TransformerBlock(nn.Module):
20
+ """
21
+ Transformer 层(Block)
22
+
23
+ 结构:
24
+ 1. 注意力层(带残差连接)
25
+ 2. FFN 层(带残差连接)
26
+
27
+ 每个子层都使用 Pre-Norm 结构:
28
+ - Pre-Norm: x_norm = norm(x), output = x + sublayer(x_norm)
29
+ """
30
+
31
+ def __init__(self, config):
32
+ """
33
+ 初始化 Transformer Block
34
+
35
+ 参数:
36
+ config: 配置字典,包含模型参数
37
+ """
38
+ super().__init__()
39
+ self.config = config
40
+
41
+ hidden_size = config["hidden_size"]
42
+ num_heads = config["num_attention_heads"]
43
+ num_kv_heads = config.get("num_key_value_heads", num_heads)
44
+ intermediate_size = config["intermediate_size"]
45
+ rms_norm_eps = float(config.get("rms_norm_eps", 1e-5))
46
+
47
+ # 滑动窗口配置
48
+ self.sliding_window = config.get("sliding_window")
49
+ self.sliding_window_overlap = config.get("sliding_window_overlap", True)
50
+
51
+ # 注意力层(使用 GQA)
52
+ self.attn = GroupedQueryAttention(hidden_size, num_heads, num_kv_heads)
53
+
54
+ # 归一化层(Pre-Norm 结构)
55
+ self.attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
56
+ self.ffn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
57
+
58
+ # FFN 层
59
+ self.ffn = FFN(hidden_size, intermediate_size)
60
+
61
+ # RoPE 位置编码(如果配置中有)
62
+ max_position_embeddings = config.get("max_position_embeddings", 1024)
63
+ rope_theta = config.get("rope_theta", 10000.0)
64
+ head_dim = hidden_size // num_heads
65
+ self.rope = RoPE(head_dim, max_position_embeddings, rope_theta)
66
+
67
+ def forward(self, x, mask=None):
68
+ """
69
+ 前向传播
70
+
71
+ 参数:
72
+ x: 输入张量,形状为 (batch_size, seq_len, hidden_size)
73
+ mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
74
+
75
+ 返回:
76
+ 输出张量,形状为 (batch_size, seq_len, hidden_size)
77
+ """
78
+ # 1. 注意力层(Pre-Norm + 残差连接)
79
+ # Pre-Norm: 先归一化,再计算注意力
80
+ x_norm = self.attn_norm(x)
81
+ attn_output = self.attn(
82
+ x_norm,
83
+ mask=mask,
84
+ rope=self.rope,
85
+ sliding_window=self.sliding_window,
86
+ sliding_window_overlap=self.sliding_window_overlap,
87
+ )
88
+ # 残差连接
89
+ x = x + attn_output
90
+
91
+ # 2. FFN 层(Pre-Norm + 残差连接)
92
+ # Pre-Norm: 先归一化,再计算 FFN
93
+ x_norm = self.ffn_norm(x)
94
+ ffn_output = self.ffn(x_norm)
95
+ # 残差连接
96
+ x = x + ffn_output
97
+
98
+ return x
99
+
100
+
101
+ if __name__ == "__main__":
102
+ print("=" * 60)
103
+ print("TransformerBlock 测试")
104
+ print("=" * 60)
105
+
106
+ # 测试参数(与 configs/model.yaml 一致)
107
+ config = {
108
+ "hidden_size": 320,
109
+ "num_attention_heads": 10,
110
+ "num_key_value_heads": 2,
111
+ "intermediate_size": 960,
112
+ "rms_norm_eps": 1e-5,
113
+ "max_position_embeddings": 1024,
114
+ "rope_theta": 10000.0,
115
+ "sliding_window": 256,
116
+ "sliding_window_overlap": True,
117
+ }
118
+
119
+ batch_size = 2
120
+ seq_len = 64
121
+ hidden_size = config["hidden_size"]
122
+
123
+ print("\n1. 测试参数")
124
+ print(f" hidden_size: {hidden_size}")
125
+ print(f" num_attention_heads: {config['num_attention_heads']}")
126
+ print(f" num_key_value_heads: {config['num_key_value_heads']}")
127
+ print(f" intermediate_size: {config['intermediate_size']}")
128
+ print(f" batch_size: {batch_size}")
129
+ print(f" seq_len: {seq_len}")
130
+
131
+ # 创建 TransformerBlock
132
+ print("\n2. 创建 TransformerBlock")
133
+ block = TransformerBlock(config)
134
+ x = torch.randn(batch_size, seq_len, hidden_size)
135
+
136
+ print(f" 输入形状: {x.shape}")
137
+ output = block(x)
138
+ print(f" 输出形状: {output.shape}")
139
+ print(f" 形状匹配: {output.shape == x.shape}")
140
+
141
+ # 检查参数数量
142
+ total_params = sum(p.numel() for p in block.parameters())
143
+ print(f" 参数数量: {total_params:,}")
144
+
145
+ # 验证残差连接
146
+ print("\n3. 验证残差连接")
147
+ # 如果输入很小,输出应该接近输入(因为残差连接)
148
+ x_small = torch.randn(batch_size, seq_len, hidden_size) * 0.01
149
+ output_small = block(x_small)
150
+ diff = torch.abs(output_small - x_small).mean()
151
+ print(f" 小输入测试: 输入输出差异均值 = {diff.item():.6f}")
152
+
153
+ # 测试梯度
154
+ print("\n4. 测试梯度计算")
155
+ loss = output.sum()
156
+ loss.backward()
157
+ print(
158
+ f" 所有参数是否有梯度: {all(p.grad is not None for p in block.parameters())}"
159
+ )
160
+
161
+ print("\n" + "=" * 60)
162
+ print("所有测试完成!")
163
+ print("=" * 60)
llm/model/embedding.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """嵌入层:Token Embedding 词嵌入"""
2
+ # 2026-01-23
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class TokenEmbedding(nn.Module):
8
+ """词嵌入"""
9
+
10
+ def __init__(self, vocab_size, hidden_size):
11
+ """
12
+ 初始化词嵌入层
13
+
14
+ 参数:
15
+ vocab_size: 词汇表大小
16
+ hidden_size: 隐藏层维度
17
+ """
18
+ super().__init__()
19
+ self.vocab_size = vocab_size
20
+ self.hidden_size = hidden_size
21
+
22
+ # 词嵌入层:将 token ID 映射为向量
23
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
24
+
25
+ def forward(self, input_ids):
26
+ """
27
+ 前向传播
28
+
29
+ 参数:
30
+ input_ids: Token ID 张量,形状为 (batch_size, seq_len)
31
+
32
+ 返回:
33
+ 嵌入向量,形状为 (batch_size, seq_len, hidden_size)
34
+ """
35
+ return self.embedding(input_ids)
llm/model/ffn.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """前馈神经网络(FFN)"""
2
+ # 2026-01-23
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ def silu(x):
9
+ """
10
+ SiLU 激活函数(Sigmoid Linear Unit)
11
+
12
+ 公式: SiLU(x) = x * sigmoid(x)
13
+
14
+ 参数:
15
+ x: 输入张量
16
+
17
+ 返回:
18
+ 激活后的张量
19
+ """
20
+ return x * torch.sigmoid(x)
21
+
22
+
23
+ class FFN(nn.Module):
24
+ """
25
+ 前馈神经网络(使用 SwiGLU 结构)
26
+ """
27
+
28
+ def __init__(self, hidden_size, intermediate_size):
29
+ """
30
+ 初始化 FFN
31
+
32
+ 参数:
33
+ hidden_size: 隐藏层维度
34
+ intermediate_size: 中间层维度
35
+ """
36
+ super().__init__()
37
+ self.hidden_size = hidden_size
38
+ self.intermediate_size = intermediate_size
39
+
40
+ # FFN(x) = W2 · activation(W1 · x)
41
+ # FFN(x) = W_down ( activation(W_gate · x) ⊙ (W_up · x) )
42
+ # SwiGLU 结构需要三个投影层:
43
+ # gate_proj: 用于门控(gate),经过激活函数,生成“门控信号”,决定哪些特征应该被放大或抑制
44
+ # up_proj: 用于上投影(up),不经过激活函数,提供被门控调制的“原始特征”
45
+ # down_proj: 用于下投影(down),将中间层映射回 hidden_size
46
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size)
47
+ self.up_proj = nn.Linear(hidden_size, intermediate_size)
48
+ self.down_proj = nn.Linear(intermediate_size, hidden_size)
49
+
50
+ # 激活函数
51
+ self.activation = silu
52
+
53
+ def forward(self, x):
54
+ """
55
+ 前向传播
56
+
57
+ 使用 SwiGLU 结构:
58
+ - gate = activation(gate_proj(x))
59
+ - up = up_proj(x)
60
+ - output = down_proj(gate * up)
61
+
62
+ 参数:
63
+ x: 输入张量,形状为 (batch_size, seq_len, hidden_size)
64
+
65
+ 返回:
66
+ 输出张量,形状为 (batch_size, seq_len, hidden_size)
67
+ """
68
+ # SwiGLU 结构
69
+ # 1. 计算门控值(经过激活函数)
70
+ gate = self.activation(
71
+ self.gate_proj(x)
72
+ ) # (batch_size, seq_len, intermediate_size)
73
+
74
+ # 2. 计算上投影值(不经过激活函数)
75
+ up = self.up_proj(x) # (batch_size, seq_len, intermediate_size)
76
+
77
+ # 3. 门控乘法:gate * up(逐元素相乘)
78
+ gate_up = gate * up # (batch_size, seq_len, intermediate_size)
79
+
80
+ # 4. 下投影回 hidden_size
81
+ output = self.down_proj(gate_up) # (batch_size, seq_len, hidden_size)
82
+
83
+ return output
84
+
85
+
86
+ if __name__ == "__main__":
87
+ print("=" * 60)
88
+ print("FFN 测试")
89
+ print("=" * 60)
90
+
91
+ # 测试参数(与 configs/model.yaml 一致)
92
+ hidden_size = 320
93
+ intermediate_size = 960
94
+ batch_size = 2
95
+ seq_len = 10
96
+
97
+ print("\n1. 测试参数")
98
+ print(f" hidden_size: {hidden_size}")
99
+ print(f" intermediate_size: {intermediate_size}")
100
+ print(f" batch_size: {batch_size}")
101
+ print(f" seq_len: {seq_len}")
102
+
103
+ # 测试 SiLU 激活函数
104
+ print("\n2. 测试 SiLU 激活函数")
105
+ x_test = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
106
+ silu_output = silu(x_test)
107
+ print(f" 输入: {x_test.tolist()}")
108
+ print(f" 输出: {silu_output.tolist()}")
109
+ print(f" SiLU(0) 应该接近 0: {abs(silu_output[2].item()) < 0.01}")
110
+
111
+ # 测试 FFN
112
+ print("\n3. 测试 FFN(SwiGLU 结构)")
113
+ ffn = FFN(hidden_size, intermediate_size)
114
+ x = torch.randn(batch_size, seq_len, hidden_size)
115
+
116
+ print(f" 输入形状: {x.shape}")
117
+ output = ffn(x)
118
+ print(f" 输出形状: {output.shape}")
119
+ print(f" 形状匹配: {output.shape == x.shape}")
120
+
121
+ # 检查参数数量
122
+ total_params = sum(p.numel() for p in ffn.parameters())
123
+ print(f" 参数数量: {total_params}")
124
+ print(f" gate_proj 参数: {ffn.gate_proj.weight.shape}")
125
+ print(f" up_proj 参数: {ffn.up_proj.weight.shape}")
126
+ print(f" down_proj 参数: {ffn.down_proj.weight.shape}")
127
+
128
+ # 验证 SwiGLU 结构
129
+ print("\n4. 验证 SwiGLU 结构")
130
+ gate = silu(ffn.gate_proj(x))
131
+ up = ffn.up_proj(x)
132
+ gate_up = gate * up
133
+ manual_output = ffn.down_proj(gate_up)
134
+ print(f" 手动计算输出形状: {manual_output.shape}")
135
+ print(f" 与 forward 输出一致: {torch.allclose(output, manual_output)}")
136
+
137
+ print("\n" + "=" * 60)
138
+ print("所有测试完成!")
139
+ print("=" * 60)
llm/model/norm.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """归一化层:RMSNorm"""
2
+ # 2026-01-22
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class RMSNorm(nn.Module):
9
+ """RMSNorm 归一化层"""
10
+
11
+ def __init__(self, dim, eps=1e-5):
12
+ """
13
+ 初始化 RMSNorm
14
+
15
+ 参数:
16
+ dim: 特征维度
17
+ eps: 数值稳定项,防止除以零
18
+ """
19
+ super().__init__()
20
+ # 确保 eps 是浮点数(防止 YAML 解析为字符串)
21
+ self.eps = float(eps)
22
+
23
+ # 创建可学习的缩放参数,初始化为全1
24
+ # nn.Parameter 表示这是模型参数,会被优化器更新
25
+ # weight 就是 γ 参数
26
+ self.weight = nn.Parameter(torch.ones(dim))
27
+
28
+ def forward(self, x):
29
+ """
30
+ 前向传播
31
+
32
+ 参数:
33
+ x: 输入张量,形状为 (batch_size, seq_len, dim) 或其他形状
34
+
35
+ 返回:
36
+ 归一化后的张量
37
+ """
38
+ # 计算均方根(Root Mean Square)
39
+ # x.pow(2) 计算每个元素的平方
40
+ # mean(-1, keepdim=True) 在最后一个维度上求均值,保持维度
41
+ rms = torch.sqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps)
42
+
43
+ # 归一化:除以 RMS,然后乘以可学习的权重
44
+ return x / rms * self.weight
45
+
46
+
47
+ if __name__ == "__main__":
48
+ print("=" * 60)
49
+ print("RMSNorm 测试")
50
+ print("=" * 60)
51
+
52
+ # 1. 创建 RMSNorm 层
53
+ dim = 32
54
+ norm = RMSNorm(dim=dim, eps=1e-5)
55
+ print("\n1. 创建 RMSNorm 层")
56
+ print(f" 维度: {dim}")
57
+ print(f" eps: {norm.eps}")
58
+ print(f" 权重形状: {norm.weight.shape}")
59
+ print(f" 权重初始值(前5个): {norm.weight[:5]}")
60
+
61
+ # 2. 创建测试输入
62
+ batch_size = 2
63
+ seq_len = 10
64
+ x = torch.randn(batch_size, seq_len, dim)
65
+ print("\n2. 创建测试输入")
66
+ print(f" 输入形状: {x.shape}")
67
+ print(" 输入统计:")
68
+ print(f" - 均值: {x.mean().item():.4f}")
69
+ print(f" - 标准差: {x.std().item():.4f}")
70
+ print(f" - 最小值: {x.min().item():.4f}")
71
+ print(f" - 最大值: {x.max().item():.4f}")
72
+
73
+ # 3. 前向传播
74
+ output = norm(x)
75
+ print("\n3. 前向传播结果")
76
+ print(f" 输出形状: {output.shape}")
77
+ print(" 输出统计:")
78
+ print(f" - 均值: {output.mean().item():.4f}")
79
+ print(f" - 标准差: {output.std().item():.4f}")
80
+
81
+ # 4. 验证归一化效果
82
+ print("\n4. 验证归一化效果")
83
+ # 计算每个样本的 RMS(应该接近1,因为权重初始化为1)
84
+ rms_per_sample = torch.sqrt(torch.mean(output.pow(2), dim=-1))
85
+ print(" 每个样本的 RMS(归一化后):")
86
+ print(f" - 样本1: {rms_per_sample[0].mean().item():.4f}")
87
+ print(f" - 样本2: {rms_per_sample[1].mean().item():.4f}")
88
+ print(f" - 平均 RMS: {rms_per_sample.mean().item():.4f}")
89
+
90
+ # 5. 验证参数是否可学习
91
+ print("\n5. 验证参数可学习性")
92
+ print(f" 权重是否为 Parameter: {isinstance(norm.weight, nn.Parameter)}")
93
+ print(f" 权重是否需要梯度: {norm.weight.requires_grad}")
94
+
95
+ # 6. 测试梯度计算
96
+ print("\n6. 测试梯度计算")
97
+ loss = output.sum() # 简单的损失函数
98
+ loss.backward()
99
+ print(f" 权重梯度是否存在: {norm.weight.grad is not None}")
100
+ if norm.weight.grad is not None:
101
+ print(f" 权重梯度形状: {norm.weight.grad.shape}")
102
+ print(f" 权重梯度统计:")
103
+ print(f" - 均值: {norm.weight.grad.mean().item():.4f}")
104
+ print(f" - 标准差: {norm.weight.grad.std().item():.4f}")
105
+
106
+ # 7. 测试不同输入形状
107
+ print("\n7. 测试不同输入形状")
108
+ test_cases = [
109
+ (1, 5, dim), # 单个样本
110
+ (4, 20, dim), # 多个样本
111
+ (1, 1, dim), # 单个 token
112
+ ]
113
+
114
+ for i, shape in enumerate(test_cases, 1):
115
+ x_test = torch.randn(*shape)
116
+ output_test = norm(x_test)
117
+ print(f" 测试 {i}: 输入形状 {shape} -> 输出形状 {output_test.shape} ✓")
118
+
119
+ # 8. 验证数值稳定性
120
+ print("\n8. 验证数值稳定性")
121
+ # 测试非常小的输入
122
+ x_small = torch.randn(1, 1, dim) * 1e-6
123
+ output_small = norm(x_small)
124
+ print(
125
+ f" 极小输入测试: 输入范围 [{x_small.min().item():.2e}, {x_small.max().item():.2e}]"
126
+ )
127
+ print(f" 输出是否包含 NaN: {torch.isnan(output_small).any().item()}")
128
+ print(f" 输出是否包含 Inf: {torch.isinf(output_small).any().item()}")
129
+
130
+ print("\n" + "=" * 60)
131
+ print("所有测试完成!")
132
+ print("=" * 60)
llm/model/rope.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RoPE 旋转位置编码"""
2
+ # 2026-01-22
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class RoPE(nn.Module):
9
+ """旋转位置编码(Rotary Position Embedding)"""
10
+
11
+ def __init__(self, dim, max_seq_len=1024, theta=10000.0):
12
+ """
13
+ 初始化 RoPE
14
+
15
+ 参数:
16
+ dim: 每个注意力头的维度,必须是偶数
17
+ max_seq_len: 模型能处理的最大序列长度
18
+ theta: 旋转频率参数
19
+ """
20
+ super().__init__()
21
+ self.dim = dim
22
+ self.max_seq_len = max_seq_len
23
+ self.theta = theta
24
+
25
+ # 确保 dim 是偶数
26
+ assert dim % 2 == 0, "dim 必须是偶数"
27
+
28
+ # 只依赖于 dim,是维度分量
29
+ # 预计算频率,每个维度对对应一个频率,2i/dim
30
+ # inv_freq[i] = 1 / (theta^(2i/dim))
31
+ # torch.arange(0, dim, 2) = [0, 2, 4, 6, ..., 30] # 16 个数
32
+ # 除以 dim: [0/32, 2/32, 4/32, ..., 30/32]
33
+ # = [0, 0.0625, 0.125, 0.1875, ..., 0.9375]
34
+
35
+ # theta 的幂次: [10000^0, 10000^0.0625, 10000^0.125, ...]
36
+ # = [1, 1.47, 2.15, 3.16, ...] # 逐渐增大
37
+
38
+ # 取倒数: [1/1, 1/1.47, 1/2.15, ...]
39
+ # = [1.0, 0.68, 0.46, 0.32, ...] # 逐渐减小
40
+
41
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
42
+ # register_buffer 注册为缓冲区,不参与梯度计算,但会随模型保存/加载,节省内存
43
+ self.register_buffer("inv_freq", inv_freq)
44
+
45
+ def forward(self, q, k, positions=None):
46
+ """
47
+ 应用旋转位置编码到 Query 和 Key
48
+
49
+ 参数:
50
+ q: Query 张量,形状为 (batch_size, num_heads, seq_len, head_dim)
51
+ k: Key 张量,形状为 (batch_size, num_heads, seq_len, head_dim)
52
+ positions: 位置索引,如果为 None 则使用序列位置
53
+
54
+ 返回:
55
+ 旋转后的 q 和 k
56
+ """
57
+ # q[b, h, s, d]
58
+ # 第 b 个样本,第 h 个注意力头,序列里第 s 个 token,在该 head 下的第 d 个特征分量
59
+ batch_size, num_heads, seq_len, head_dim = q.shape
60
+
61
+ # 如果没有提供位置,使用序列位置 [0, 1, 2, ..., seq_len-1]
62
+ # 生成位置索引
63
+ if positions is None:
64
+ positions = torch.arange(seq_len, device=q.device)
65
+
66
+ # 计算角度矩阵 freqs,第 s 个位置,第 d 个维度对应的频率,s * inv_freq
67
+ freqs = torch.outer(positions.float(), self.inv_freq)
68
+
69
+ # 计算 cos 和 sin,旋转矩阵的参数
70
+ cos = torch.cos(freqs)
71
+ sin = torch.sin(freqs)
72
+
73
+ # 扩展维度以匹配 q 和 k 的形状
74
+ # (seq_len, head_dim // 2) -> (1, 1, seq_len, head_dim // 2)
75
+ cos = cos.unsqueeze(0).unsqueeze(0)
76
+ sin = sin.unsqueeze(0).unsqueeze(0)
77
+
78
+ # 将 q 和 k 分成两部分(实部和虚部)
79
+ # q: (batch_size, num_heads, seq_len, head_dim)
80
+ # 分成两部分: q1 和 q2,各 (batch_size, num_heads, seq_len, head_dim // 2)
81
+ q1, q2 = q.chunk(2, dim=-1)
82
+ k1, k2 = k.chunk(2, dim=-1)
83
+
84
+ # 应用旋转矩阵
85
+ # 旋转矩阵: [cos -sin] 作用在 [x1]
86
+ # [sin cos] [x2]
87
+ # 结果: [x1*cos - x2*sin]
88
+ # [x1*sin + x2*cos]
89
+ q_rot = torch.cat(
90
+ [
91
+ q1 * cos - q2 * sin, # 实部
92
+ q1 * sin + q2 * cos, # 虚部
93
+ ],
94
+ dim=-1,
95
+ )
96
+
97
+ k_rot = torch.cat(
98
+ [
99
+ k1 * cos - k2 * sin, # 实部
100
+ k1 * sin + k2 * cos, # 虚部
101
+ ],
102
+ dim=-1,
103
+ )
104
+
105
+ return q_rot, k_rot
106
+
107
+
108
+ if __name__ == "__main__":
109
+ print("=" * 60)
110
+ print("RoPE 测试")
111
+ print("=" * 60)
112
+
113
+ # 创建 RoPE 层
114
+ head_dim = 32 # 必须是偶数
115
+ rope = RoPE(dim=head_dim, max_seq_len=1024, theta=10000.0)
116
+
117
+ print(f"\n1. 创建 RoPE 层")
118
+ print(f" 头维度: {head_dim}")
119
+ print(f" 最大序列长度: {rope.max_seq_len}")
120
+ print(f" theta: {rope.theta}")
121
+ print(f" 频率数量: {len(rope.inv_freq)}")
122
+ print(f" 前5个频率: {rope.inv_freq[:5]}")
123
+
124
+ # 创建测试输入
125
+ batch_size = 2
126
+ num_heads = 10
127
+ seq_len = 10
128
+
129
+ q = torch.randn(batch_size, num_heads, seq_len, head_dim)
130
+ k = torch.randn(batch_size, num_heads, seq_len, head_dim)
131
+
132
+ print(f"\n2. 创建测试输入")
133
+ print(f" Q 形状: {q.shape}")
134
+ print(f" K 形状: {k.shape}")
135
+
136
+ # 前向传播
137
+ q_rot, k_rot = rope(q, k)
138
+
139
+ print(f"\n3. 前向传播结果")
140
+ print(f" 旋转后 Q 形状: {q_rot.shape}")
141
+ print(f" 旋转后 K 形状: {k_rot.shape}")
142
+ print(f" 形状是否匹配: {q_rot.shape == q.shape and k_rot.shape == k.shape}")
143
+
144
+ # 验证旋转效果
145
+ print(f"\n4. 验证���转效果")
146
+ # 位置 0 和位置 1 的旋转角度应该不同
147
+ q_pos0 = q_rot[0, 0, 0, :] # 第一个样本,第一个头,位置0
148
+ q_pos1 = q_rot[0, 0, 1, :] # 第一个样本,第一个头,位置1
149
+ print(f" 位置 0 的 Q(前5个值): {q_pos0[:5]}")
150
+ print(f" 位置 1 的 Q(前5个值): {q_pos1[:5]}")
151
+ print(f" 位置不同,编码不同: {not torch.allclose(q_pos0, q_pos1)}")
152
+
153
+ # 测试不同位置
154
+ print(f"\n5. 测试不同位置")
155
+ positions = torch.tensor([0, 5, 10])
156
+ q_custom, k_custom = rope(q[:, :, :3, :], k[:, :, :3, :], positions=positions)
157
+ print(f" 自定义位置: {positions.tolist()}")
158
+ print(f" 输出形状: {q_custom.shape}")
159
+
160
+ print(f"\n" + "=" * 60)
161
+ print("所有测试完成!")
162
+ print("=" * 60)
llm/model/transformer.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Decoder-only Transformer 主模型"""
2
+ # 2026-01-23
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from pathlib import Path
8
+ import sys
9
+
10
+ # 添加项目根目录到 Python 路径
11
+ project_root = Path(__file__).parent.parent.parent
12
+ sys.path.insert(0, str(project_root))
13
+
14
+ from llm.model.embedding import TokenEmbedding
15
+ from llm.model.block import TransformerBlock
16
+ from llm.model.norm import RMSNorm
17
+
18
+
19
+ class Transformer(nn.Module):
20
+ """完整的 Transformer 模型(Decoder-only)"""
21
+
22
+ def __init__(self, config):
23
+ """
24
+ 初始化 Transformer 模型
25
+
26
+ 参数:
27
+ config: 配置字典,包含模型参数
28
+ """
29
+ super().__init__()
30
+ self.config = config
31
+
32
+ vocab_size = config.get("vocab_size", 1000) # 需要从数据配置中获取
33
+ hidden_size = config["hidden_size"]
34
+ num_layers = config["num_hidden_layers"]
35
+ rms_norm_eps = float(config.get("rms_norm_eps", 1e-5))
36
+ tie_word_embeddings = config.get("tie_word_embeddings", True)
37
+
38
+ # 词嵌入层
39
+ self.embedding = TokenEmbedding(vocab_size, hidden_size)
40
+
41
+ # Transformer 层(多个 Block)
42
+ self.layers = nn.ModuleList(
43
+ [TransformerBlock(config) for _ in range(num_layers)]
44
+ )
45
+
46
+ # 最终归一化层
47
+ self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
48
+
49
+ # 输出层(用于生成下一个 token 的概率)
50
+ # 如果 tie_word_embeddings=True,则共享输入和输出词嵌入权重
51
+ if tie_word_embeddings:
52
+ # 绑定输入和输出词嵌入(共享权重)
53
+ self.lm_head = None
54
+ # 注意:实际使用时,输出层使用 embedding.weight 的转置
55
+ else:
56
+ # 独立的输出层
57
+ self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
58
+
59
+ # 权重初始化(默认启用)
60
+ init_std = config.get("init_std", 0.02)
61
+ init_weights_enabled = config.get("init_weights", True)
62
+ if init_weights_enabled:
63
+ from llm.utils.init import apply_llm_init
64
+ apply_llm_init(self, std=init_std, init_output_layer=True)
65
+
66
+ def forward(self, input_ids, mask=None):
67
+ """
68
+ 前向传播
69
+
70
+ 参数:
71
+ input_ids: Token ID 张量,形状为 (batch_size, seq_len)
72
+ mask: 注意力掩码,形状为 (batch_size, seq_len, seq_len) 或 (seq_len, seq_len)
73
+
74
+ 返回:
75
+ logits: 下一个 token 的 logits,形状为 (batch_size, seq_len, vocab_size)
76
+ """
77
+ # 1. 词嵌入
78
+ x = self.embedding(input_ids) # (batch_size, seq_len, hidden_size)
79
+
80
+ # 2. 通过所有 Transformer 层
81
+ for layer in self.layers:
82
+ x = layer(x, mask=mask) # (batch_size, seq_len, hidden_size)
83
+
84
+ # 3. 最终归一化
85
+ x = self.norm(x) # (batch_size, seq_len, hidden_size)
86
+
87
+ # 4. 输出层(生成下一个 token 的概率分布)
88
+ if self.lm_head is None:
89
+ # 使用共享的词嵌入权重(转置)
90
+ # embedding.weight: (vocab_size, hidden_size)
91
+ # 我们需要: (hidden_size, vocab_size) -> 使用转置
92
+ logits = F.linear(x, self.embedding.embedding.weight)
93
+ else:
94
+ # 使用独立的输出层
95
+ logits = self.lm_head(x) # (batch_size, seq_len, vocab_size)
96
+
97
+ return logits
98
+
99
+ def generate(
100
+ self, input_ids, max_length=100, temperature=1.0, top_k=None, top_p=None
101
+ ):
102
+ """
103
+ 文本生成(自回归生成)
104
+
105
+ 参数:
106
+ input_ids: 起始 token ID 张量,形状为 (batch_size, start_len)
107
+ max_length: 最大生成长度(包括输入)
108
+ temperature: 温度参数,控制生成的随机性
109
+ top_k: Top-K 采样(保留概率最大的 k 个 token)
110
+ top_p: Nucleus 采样(保留概率累积和达到 p 的 token)
111
+
112
+ 返回:
113
+ generated_ids: 生成的 token ID 序列,形状为 (batch_size, generated_len)
114
+ """
115
+ self.eval() # 设置为评估模式
116
+ generated_ids = input_ids.clone()
117
+
118
+ with torch.no_grad():
119
+ for _ in range(max_length - input_ids.shape[1]):
120
+ # 获取当前位置的输出 logits
121
+ logits = self.forward(
122
+ generated_ids
123
+ ) # (batch_size, seq_len, vocab_size)
124
+
125
+ # 获取最后一个位置的 logits(下一步要预测的 token)
126
+ next_token_logits = (
127
+ logits[:, -1, :] / temperature
128
+ ) # (batch_size, vocab_size)
129
+
130
+ # Top-K 采样
131
+ if top_k is not None:
132
+ # 只保留 top-k 个最大的 logits,其余设为 -inf
133
+ top_k_values, top_k_indices = torch.topk(next_token_logits, top_k)
134
+ next_token_logits_filtered = torch.full_like(
135
+ next_token_logits, float("-inf")
136
+ )
137
+ next_token_logits_filtered.scatter_(1, top_k_indices, top_k_values)
138
+ next_token_logits = next_token_logits_filtered
139
+
140
+ # Top-P (Nucleus) 采样
141
+ if top_p is not None:
142
+ # 按概率排序
143
+ sorted_logits, sorted_indices = torch.sort(
144
+ next_token_logits, descending=True
145
+ )
146
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
147
+
148
+ # 计算累积概率
149
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
150
+
151
+ # 找到第一个累积概率超过 top_p 的位置
152
+ # 移除该位置及之后的所有位置(至少保留一个 token)
153
+ sorted_indices_to_remove = cumulative_probs > top_p
154
+ # 通过移位确保至少保留第一个 token
155
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
156
+ ..., :-1
157
+ ].clone()
158
+ sorted_indices_to_remove[..., 0] = False
159
+
160
+ # 将排序后的掩码映射回原始索引
161
+ indices_to_remove = sorted_indices_to_remove.scatter(
162
+ 1, sorted_indices, sorted_indices_to_remove
163
+ )
164
+ next_token_logits[indices_to_remove] = float("-inf")
165
+
166
+ # 计算概率并采样
167
+ probs = F.softmax(next_token_logits, dim=-1) # (batch_size, vocab_size)
168
+ next_token_id = torch.multinomial(
169
+ probs, num_samples=1
170
+ ) # (batch_size, 1)
171
+
172
+ # 将新生成的 token 添加到序列中
173
+ generated_ids = torch.cat(
174
+ [generated_ids, next_token_id], dim=1
175
+ ) # (batch_size, seq_len+1)
176
+
177
+ # 如果所有序列都生成了结束符,可以提前停止(这里简化处理,生成固定长度)
178
+
179
+ return generated_ids
180
+
181
+
182
+ if __name__ == "__main__":
183
+ print("=" * 60)
184
+ print("Transformer 模型测试")
185
+ print("=" * 60)
186
+
187
+ # 测试参数(与 configs/model.yaml 一致)
188
+ config = {
189
+ "vocab_size": 100, # 示例词汇表大小
190
+ "hidden_size": 320,
191
+ "num_hidden_layers": 10,
192
+ "num_attention_heads": 10,
193
+ "num_key_value_heads": 2,
194
+ "intermediate_size": 960,
195
+ "rms_norm_eps": 1e-5,
196
+ "max_position_embeddings": 1024,
197
+ "rope_theta": 10000.0,
198
+ "sliding_window": 256,
199
+ "sliding_window_overlap": True,
200
+ "tie_word_embeddings": True,
201
+ }
202
+
203
+ batch_size = 2
204
+ seq_len = 10
205
+ vocab_size = config["vocab_size"]
206
+
207
+ print("\n1. 测试参数")
208
+ print(f" vocab_size: {vocab_size}")
209
+ print(f" hidden_size: {config['hidden_size']}")
210
+ print(f" num_hidden_layers: {config['num_hidden_layers']}")
211
+ print(f" batch_size: {batch_size}")
212
+ print(f" seq_len: {seq_len}")
213
+
214
+ # 创建 Transformer 模型
215
+ print("\n2. 创建 Transformer 模型")
216
+ model = Transformer(config)
217
+
218
+ # 创建测试输入
219
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
220
+
221
+ print(f" 输入形状: {input_ids.shape}")
222
+
223
+ # 前向传播
224
+ output = model(input_ids)
225
+ print(f" 输出形状: {output.shape}")
226
+ print(f" 输出形状正确: {output.shape == (batch_size, seq_len, vocab_size)}")
227
+
228
+ # 检查参数数量
229
+ total_params = sum(p.numel() for p in model.parameters())
230
+ print(f" 参数数量: {total_params:,}")
231
+
232
+ # 验证词嵌入共享
233
+ print("\n3. 验证词嵌入共享")
234
+ if model.lm_head is None:
235
+ print(" 使用共享的词嵌入权重(tie_word_embeddings=True)")
236
+ print(f" embedding 参数形状: {model.embedding.embedding.weight.shape}")
237
+ else:
238
+ print(" 使用独立的输出层(tie_word_embeddings=False)")
239
+ print(f" lm_head 参数形状: {model.lm_head.weight.shape}")
240
+
241
+ # 测试梯度
242
+ print("\n4. 测试梯度计算")
243
+ # 创建简单的损失函数(交叉熵)
244
+ targets = torch.randint(0, vocab_size, (batch_size, seq_len))
245
+ loss = F.cross_entropy(output.view(-1, vocab_size), targets.view(-1))
246
+ loss.backward()
247
+ print(f" 损失值: {loss.item():.4f}")
248
+ print(
249
+ f" 所有参数是否有梯度: {all(p.grad is not None for p in model.parameters())}"
250
+ )
251
+
252
+ # 测试生成(简化版)
253
+ print("\n5. 测试文本生成(简化版)")
254
+ start_ids = torch.randint(0, vocab_size, (1, 5))
255
+ print(f" 起始序列长度: {start_ids.shape[1]}")
256
+ print(f" 起始 token IDs: {start_ids[0].tolist()}")
257
+
258
+ # ���加调试:查看第一次生成的 logits 分布
259
+ with torch.no_grad():
260
+ logits = model(start_ids)
261
+ next_logits = logits[0, -1, :]
262
+ probs = F.softmax(next_logits, dim=-1)
263
+ top_probs, top_indices = torch.topk(probs, k=5)
264
+ print(" 第一次生成的前5个最可能 token:")
265
+ for i, (idx, prob) in enumerate(zip(top_indices, top_probs)):
266
+ print(f" Token {idx.item()}: {prob.item():.4f}")
267
+
268
+ # 使用 top_k 采样增加多样性(未训练模型建议使用)
269
+ generated = model.generate(
270
+ start_ids,
271
+ max_length=15,
272
+ temperature=1.5, # 提高温度增加随机性
273
+ top_k=10, # 只从前10个最可能的 token 中采样
274
+ )
275
+ print(f" 生成序列长度: {generated.shape[1]}")
276
+ print(f" 生成的 token IDs: {generated[0].tolist()}")
277
+
278
+ print("\n" + "=" * 60)
279
+ print("所有测试完成!")
280
+ print("=" * 60)
llm/training/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """训练模块"""
2
+
3
+ from llm.training.metrics import (
4
+ calculate_perplexity,
5
+ calculate_accuracy,
6
+ calculate_top_k_accuracy,
7
+ calculate_metrics,
8
+ )
9
+
10
+ __all__ = [
11
+ "calculate_perplexity",
12
+ "calculate_accuracy",
13
+ "calculate_top_k_accuracy",
14
+ "calculate_metrics",
15
+ ]
llm/training/loss.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """损失函数:交叉熵损失"""
2
+ # 2026-01-23
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class CrossEntropyLoss(nn.Module):
10
+ """交叉熵损失"""
11
+
12
+ def __init__(self, ignore_index=-1):
13
+ """
14
+ 初始化交叉熵损失
15
+
16
+ 参数:
17
+ ignore_index: 要忽略的目标索引(通常用于 padding)
18
+ """
19
+ super().__init__()
20
+ self.ignore_index = ignore_index
21
+
22
+ def forward(self, logits, targets):
23
+ """
24
+ 计算交叉熵损失
25
+
26
+ 参数:
27
+ logits: 模型输出,形状为 (batch_size, seq_len, vocab_size)
28
+ targets: 目标 token IDs,形状为 (batch_size, seq_len)
29
+
30
+ 返回:
31
+ 损失值(标量)
32
+ """
33
+ # 重塑为 (batch_size * seq_len, vocab_size) 和 (batch_size * seq_len,)
34
+ # 这样可以将所有位置的预测和目标展平,统一计算损失
35
+ logits_flat = logits.view(
36
+ -1, logits.size(-1)
37
+ ) # (batch_size * seq_len, vocab_size)
38
+ targets_flat = targets.view(-1) # (batch_size * seq_len,)
39
+
40
+ # 计算交叉熵损失
41
+ # F.cross_entropy 内部会先对 logits 应用 log_softmax,然后计算负对数似然
42
+ loss = F.cross_entropy(
43
+ logits_flat, targets_flat, ignore_index=self.ignore_index
44
+ )
45
+
46
+ return loss
47
+
48
+
49
+ if __name__ == "__main__":
50
+ print("=" * 60)
51
+ print("损失函数测试")
52
+ print("=" * 60)
53
+
54
+ # 测试交叉熵损失
55
+ print("\n1. 测试 CrossEntropyLoss")
56
+ criterion = CrossEntropyLoss()
57
+
58
+ batch_size = 2
59
+ seq_len = 10
60
+ vocab_size = 100
61
+
62
+ # 创建随机 logits 和 targets
63
+ # 注意:logits 需要梯度才能进行反向传播测试
64
+ logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
65
+ targets = torch.randint(0, vocab_size, (batch_size, seq_len))
66
+
67
+ print(f" Logits 形状: {logits.shape}")
68
+ print(f" Targets 形状: {targets.shape}")
69
+
70
+ loss = criterion(logits, targets)
71
+ print(f" 损失值: {loss.item():.4f}")
72
+
73
+ # 验证损失是否为标量
74
+ print(f" 损失是否为标量: {loss.dim() == 0}")
75
+
76
+ # 测试梯度
77
+ loss.backward()
78
+ print(f" 损失可以反向传播: True")
79
+ print(f" Logits 梯度形状: {logits.grad.shape}")
80
+
81
+ # 测试 ignore_index
82
+ print("\n2. 测试 ignore_index")
83
+ criterion_ignore = CrossEntropyLoss(ignore_index=-1)
84
+ targets_with_ignore = targets.clone()
85
+ targets_with_ignore[0, 0] = -1 # 设置一个忽略的 token
86
+ loss_ignore = criterion_ignore(logits, targets_with_ignore)
87
+ print(f" 使用 ignore_index 的损失值: {loss_ignore.item():.4f}")
88
+
89
+ print("\n" + "=" * 60)
90
+ print("所有测试完成!")
91
+ print("=" * 60)
llm/training/metrics.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """评估指标:困惑度、准确率等"""
2
+ # 2026-01-23
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def calculate_perplexity(loss):
9
+ """
10
+ 计算困惑度(Perplexity)
11
+
12
+ 困惑度是语言模型常用的评估指标,表示模型对下一个 token 的不确定性。
13
+ 困惑度越低,模型越好。
14
+
15
+ 公式: PPL = exp(loss)
16
+
17
+ 参数:
18
+ loss: 交叉熵损失值(标量或张量)
19
+
20
+ 返回:
21
+ 困惑度值
22
+ """
23
+ return torch.exp(loss)
24
+
25
+
26
+ def calculate_accuracy(logits, targets, ignore_index=-1):
27
+ """
28
+ 计算准确率(Accuracy)
29
+
30
+ 参数:
31
+ logits: 模型输出,形状为 (batch_size, seq_len, vocab_size)
32
+ targets: 目标 token IDs,形状为 (batch_size, seq_len)
33
+ ignore_index: 要忽略的目标索引(通常用于 padding)
34
+
35
+ 返回:
36
+ 准确率(标量,0-1 之间)
37
+ """
38
+ # 获取预测的 token IDs
39
+ predictions = torch.argmax(logits, dim=-1) # (batch_size, seq_len)
40
+
41
+ # 计算匹配的数量(忽略 ignore_index)
42
+ if ignore_index >= 0:
43
+ # 创建掩码:忽略 ignore_index 的位置
44
+ mask = (targets != ignore_index)
45
+ # 计算匹配的数量(只考虑非忽略的位置)
46
+ matches = (predictions == targets) & mask
47
+ total = mask.sum()
48
+ else:
49
+ # 不忽略任何位置
50
+ matches = (predictions == targets)
51
+ total = targets.numel()
52
+
53
+ # 计算准确率
54
+ accuracy = matches.sum().float() / total.float() if total > 0 else 0.0
55
+
56
+ return accuracy.item() if isinstance(accuracy, torch.Tensor) else accuracy
57
+
58
+
59
+ def calculate_top_k_accuracy(logits, targets, k=5, ignore_index=-1):
60
+ """
61
+ 计算 Top-K 准确率
62
+
63
+ 参数:
64
+ logits: 模型输出,形状为 (batch_size, seq_len, vocab_size)
65
+ targets: 目标 token IDs,形状为 (batch_size, seq_len)
66
+ k: Top-K 值(默认: 5)
67
+ ignore_index: 要忽略的目标索引
68
+
69
+ 返回:
70
+ Top-K 准确率(标量,0-1 之间)
71
+ """
72
+ # 获取 top-k 个最可能的 token IDs
73
+ _, top_k_indices = torch.topk(logits, k, dim=-1) # (batch_size, seq_len, k)
74
+
75
+ # 扩展 targets 维度以匹配 top_k_indices
76
+ targets_expanded = targets.unsqueeze(-1).expand_as(top_k_indices) # (batch_size, seq_len, k)
77
+
78
+ # 检查目标是否在 top-k 中
79
+ matches = (top_k_indices == targets_expanded).any(dim=-1) # (batch_size, seq_len)
80
+
81
+ # 计算准确率(忽略 ignore_index)
82
+ if ignore_index >= 0:
83
+ mask = (targets != ignore_index)
84
+ matches = matches & mask
85
+ total = mask.sum()
86
+ else:
87
+ total = targets.numel()
88
+
89
+ top_k_accuracy = matches.sum().float() / total.float() if total > 0 else 0.0
90
+
91
+ return top_k_accuracy.item() if isinstance(top_k_accuracy, torch.Tensor) else top_k_accuracy
92
+
93
+
94
+ def calculate_metrics(logits, targets, loss, ignore_index=-1):
95
+ """
96
+ 计算所有评估指标
97
+
98
+ 参数:
99
+ logits: 模型输出,形状为 (batch_size, seq_len, vocab_size)
100
+ targets: 目标 token IDs,形状为 (batch_size, seq_len)
101
+ loss: 损失值(标量)
102
+ ignore_index: 要忽略的目标索引
103
+
104
+ 返回:
105
+ 包含所有指标的字典
106
+ """
107
+ metrics = {}
108
+
109
+ # 困惑度
110
+ metrics["perplexity"] = calculate_perplexity(loss).item()
111
+
112
+ # 准确率
113
+ metrics["accuracy"] = calculate_accuracy(logits, targets, ignore_index=ignore_index)
114
+
115
+ # Top-5 准确率
116
+ metrics["top5_accuracy"] = calculate_top_k_accuracy(logits, targets, k=5, ignore_index=ignore_index)
117
+
118
+ # 损失
119
+ metrics["loss"] = loss.item() if isinstance(loss, torch.Tensor) else loss
120
+
121
+ return metrics
122
+
123
+
124
+ if __name__ == "__main__":
125
+ print("=" * 60)
126
+ print("评估指标测试")
127
+ print("=" * 60)
128
+
129
+ # 测试参数
130
+ batch_size = 4
131
+ seq_len = 10
132
+ vocab_size = 100
133
+
134
+ print(f"\n测试参数:")
135
+ print(f" batch_size: {batch_size}")
136
+ print(f" seq_len: {seq_len}")
137
+ print(f" vocab_size: {vocab_size}")
138
+
139
+ # 创建模拟数据
140
+ logits = torch.randn(batch_size, seq_len, vocab_size)
141
+ targets = torch.randint(0, vocab_size, (batch_size, seq_len))
142
+ loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
143
+
144
+ print(f"\n1. 测试困惑度")
145
+ ppl = calculate_perplexity(loss)
146
+ print(f" 损失: {loss.item():.4f}")
147
+ print(f" 困惑度: {ppl.item():.4f}")
148
+
149
+ print(f"\n2. 测试准确率")
150
+ accuracy = calculate_accuracy(logits, targets)
151
+ print(f" 准确率: {accuracy:.4f} ({accuracy*100:.2f}%)")
152
+
153
+ print(f"\n3. 测试 Top-5 准确率")
154
+ top5_acc = calculate_top_k_accuracy(logits, targets, k=5)
155
+ print(f" Top-5 准确率: {top5_acc:.4f} ({top5_acc*100:.2f}%)")
156
+
157
+ print(f"\n4. 测试 ignore_index")
158
+ # 设置一些 ignore_index
159
+ targets_with_ignore = targets.clone()
160
+ targets_with_ignore[0, :3] = -1 # 前3个设为忽略
161
+ accuracy_ignore = calculate_accuracy(logits, targets_with_ignore, ignore_index=-1)
162
+ print(f" 使用 ignore_index 的准确率: {accuracy_ignore:.4f} ({accuracy_ignore*100:.2f}%)")
163
+
164
+ print(f"\n5. 测试综合指标")
165
+ all_metrics = calculate_metrics(logits, targets, loss)
166
+ print(" 所有指标:")
167
+ for key, value in all_metrics.items():
168
+ if isinstance(value, float):
169
+ print(f" {key}: {value:.4f}")
170
+ else:
171
+ print(f" {key}: {value}")
172
+
173
+ print("\n" + "=" * 60)
174
+ print("所有测试完成!")
175
+ print("=" * 60)
llm/training/optim.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """优化器:AdamW / 学习率调度"""
2
+ # 2026-01-23
3
+
4
+ import math
5
+ import torch.optim as optim
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+
8
+
9
+ def get_optimizer(model, config):
10
+ """
11
+ 获取优化器(AdamW)
12
+
13
+ 参数:
14
+ model: 模型
15
+ config: 训练配置字典,应包含以下键:
16
+ - learning_rate: 学习率(默认: 1e-4)
17
+ - weight_decay: 权重衰减(默认: 0.01)
18
+ - beta1: Adam beta1(默认: 0.9)
19
+ - beta2: Adam beta2(默认: 0.999)
20
+ - eps: Adam epsilon(默认: 1e-8)
21
+
22
+ 返回:
23
+ AdamW 优化器
24
+ """
25
+
26
+ lr = float(config.get("learning_rate", 1e-4))
27
+ weight_decay = float(config.get("weight_decay", 0.01))
28
+ beta1 = float(config.get("beta1", 0.9))
29
+ beta2 = float(config.get("beta2", 0.999))
30
+ eps = float(config.get("eps", 1e-8))
31
+
32
+ optimizer = optim.AdamW(
33
+ model.parameters(),
34
+ lr=lr,
35
+ weight_decay=weight_decay,
36
+ betas=(beta1, beta2),
37
+ eps=eps,
38
+ )
39
+
40
+ return optimizer
41
+
42
+
43
+ def get_lr_scheduler(optimizer, config):
44
+ """
45
+ 获取学习率调度器(支持预热)
46
+
47
+ 参数:
48
+ optimizer: 优化器
49
+ config: 训练配置字典,应包含以下键:
50
+ - lr_scheduler: 调度器类型,可选值:
51
+ - "cosine": 余弦退火(带预热)
52
+ - "linear": 线性衰减(带预热)
53
+ - "constant": 常数学习率
54
+ - warmup_steps: 预热步数(默认: 100)
55
+ - max_steps: 最大训练步数(默认: 10000)
56
+
57
+ 返回:
58
+ 学习率调度器(如果为 constant,返回 None)
59
+ """
60
+ scheduler_type = config.get("lr_scheduler", "cosine")
61
+ max_steps = config.get("max_steps", 10000)
62
+ warmup_steps = config.get("warmup_steps", 100)
63
+
64
+ if scheduler_type == "cosine":
65
+ # 余弦退火调度(带预热)
66
+ def lr_lambda(step):
67
+ if step < warmup_steps:
68
+ # 预热阶段:线性增加从 0 到 1
69
+ return step / warmup_steps if warmup_steps > 0 else 1.0
70
+ else:
71
+ # 余弦退火阶段:从 1 衰减到 0
72
+ # 处理边界情况:如果 max_steps <= warmup_steps,直接返回最小值
73
+ if max_steps <= warmup_steps:
74
+ return 0.0
75
+ progress = (step - warmup_steps) / (max_steps - warmup_steps)
76
+ # 限制 progress 在 [0, 1] 范围内
77
+ progress = min(progress, 1.0)
78
+ # 余弦退火:0.5 * (1 + cos(π * progress))
79
+ return 0.5 * (1.0 + math.cos(progress * math.pi))
80
+
81
+ scheduler = LambdaLR(optimizer, lr_lambda)
82
+
83
+ elif scheduler_type == "linear":
84
+ # 线性衰减调度(带预热)
85
+ def lr_lambda(step):
86
+ if step < warmup_steps:
87
+ # 预热阶段:线性增加从 0 到 1
88
+ return step / warmup_steps if warmup_steps > 0 else 1.0
89
+ else:
90
+ # 线性衰减阶段:从 1 线性衰减到 0.1
91
+ # 处理边界情况:如果 max_steps <= warmup_steps,直接返回最小值
92
+ if max_steps <= warmup_steps:
93
+ return 0.1
94
+ progress = (step - warmup_steps) / (max_steps - warmup_steps)
95
+ # 限制 progress 在 [0, 1] 范围内
96
+ progress = min(progress, 1.0)
97
+ # 线性衰减:从 1.0 到 0.1
98
+ return 1.0 - 0.9 * progress
99
+
100
+ scheduler = LambdaLR(optimizer, lr_lambda)
101
+
102
+ elif scheduler_type == "constant":
103
+ # 常数学习率(无调度器)
104
+ scheduler = None
105
+
106
+ else:
107
+ raise ValueError(
108
+ f"未知的学习率调度器类型: {scheduler_type}。"
109
+ f"支持的类型: cosine, linear, constant"
110
+ )
111
+
112
+ return scheduler
113
+
114
+
115
+ if __name__ == "__main__":
116
+ import sys
117
+ import io
118
+
119
+ # 设置输出编码为 UTF-8(Windows 兼容)
120
+ if sys.platform == "win32":
121
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
122
+
123
+ print("=" * 60)
124
+ print("优化器和学习率调度器测试")
125
+ print("=" * 60)
126
+
127
+ # 创建一个简单的模型用于测试
128
+ import torch.nn as nn
129
+
130
+ class SimpleModel(nn.Module):
131
+ def __init__(self):
132
+ super().__init__()
133
+ self.linear = nn.Linear(10, 1)
134
+
135
+ def forward(self, x):
136
+ return self.linear(x)
137
+
138
+ model = SimpleModel()
139
+
140
+ # 测试配置
141
+ config = {
142
+ "learning_rate": 1e-3,
143
+ "weight_decay": 0.01,
144
+ "beta1": 0.9,
145
+ "beta2": 0.999,
146
+ "eps": 1e-8,
147
+ "lr_scheduler": "cosine",
148
+ "warmup_steps": 10,
149
+ "max_steps": 100,
150
+ }
151
+
152
+ # 测试优化器
153
+ print("\n1. 测试优化器创���")
154
+ optimizer = get_optimizer(model, config)
155
+ print(f" 优化器类型: {type(optimizer).__name__}")
156
+ print(f" 学习率: {optimizer.param_groups[0]['lr']}")
157
+ print(f" 权重衰减: {optimizer.param_groups[0]['weight_decay']}")
158
+ print(f" Beta1: {optimizer.param_groups[0]['betas'][0]}")
159
+ print(f" Beta2: {optimizer.param_groups[0]['betas'][1]}")
160
+
161
+ # 测试学习率调度器(cosine)
162
+ print("\n2. 测试余弦退火学习率调度器(带预热)")
163
+ scheduler = get_lr_scheduler(optimizer, config)
164
+ print(f" 调度器类型: {type(scheduler).__name__}")
165
+ print(f" 初始学习率: {optimizer.param_groups[0]['lr']:.6f}")
166
+
167
+ # 模拟训练步骤,观察学习率变化
168
+ print("\n3. 模拟训练步骤,观察学习率变化(cosine)")
169
+ lrs = []
170
+ for step in range(0, 101, 10):
171
+ current_lr = optimizer.param_groups[0]["lr"]
172
+ lrs.append(current_lr)
173
+ print(f" 步数 {step:3d}: 学习率 = {current_lr:.6f}")
174
+ # 模拟优化步骤(先 optimizer.step(),再 scheduler.step())
175
+ # 创建一个虚拟的损失并反向传播(仅用于测试)
176
+ dummy_loss = sum(p.sum() for p in model.parameters())
177
+ dummy_loss.backward()
178
+ optimizer.step()
179
+ optimizer.zero_grad()
180
+ if scheduler is not None:
181
+ scheduler.step()
182
+
183
+ # 测试线性调度器
184
+ print("\n4. 测试线性学习率调度器(带预热)")
185
+ config_linear = config.copy()
186
+ config_linear["lr_scheduler"] = "linear"
187
+ optimizer_linear = get_optimizer(model, config_linear)
188
+ scheduler_linear = get_lr_scheduler(optimizer_linear, config_linear)
189
+ print(f" 调度器类型: {type(scheduler_linear).__name__}")
190
+
191
+ print("\n5. 模拟训练步骤,观察学习率变化(linear)")
192
+ for step in range(0, 101, 10):
193
+ current_lr = optimizer_linear.param_groups[0]["lr"]
194
+ print(f" 步数 {step:3d}: 学习率 = {current_lr:.6f}")
195
+ # 模拟优化步骤(先 optimizer.step(),再 scheduler.step())
196
+ dummy_loss = sum(p.sum() for p in model.parameters())
197
+ dummy_loss.backward()
198
+ optimizer_linear.step()
199
+ optimizer_linear.zero_grad()
200
+ if scheduler_linear is not None:
201
+ scheduler_linear.step()
202
+
203
+ # 测试常数学习率
204
+ print("\n6. 测试常数学习率")
205
+ config_constant = config.copy()
206
+ config_constant["lr_scheduler"] = "constant"
207
+ optimizer_constant = get_optimizer(model, config_constant)
208
+ scheduler_constant = get_lr_scheduler(optimizer_constant, config_constant)
209
+ print(f" 调度器类型: {scheduler_constant}")
210
+ print(f" 学习率: {optimizer_constant.param_groups[0]['lr']:.6f}")
211
+
212
+ # 测试错误类型
213
+ print("\n7. 测试错误类型处理")
214
+ try:
215
+ config_error = config.copy()
216
+ config_error["lr_scheduler"] = "invalid"
217
+ scheduler_error = get_lr_scheduler(optimizer, config_error)
218
+ except ValueError as e:
219
+ print(f" 正确捕获错误: {e}")
220
+
221
+ print("\n" + "=" * 60)
222
+ print("所有测试完成!")
223
+ print("=" * 60)
llm/training/trainer.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """训练器:训练循环"""
2
+ # 2026-01-23
3
+
4
+ import torch
5
+ from pathlib import Path
6
+ from collections import deque
7
+
8
+ from llm.training.loss import CrossEntropyLoss
9
+ from llm.training.optim import get_optimizer, get_lr_scheduler
10
+ from llm.utils.checkpoint import save_checkpoint
11
+ from llm.utils.logging import ProgressBar
12
+ from llm.model.attention import create_causal_mask
13
+
14
+
15
+ class Trainer:
16
+ """训练器"""
17
+
18
+ def __init__(self, model, train_loader, val_loader, config):
19
+ """
20
+ 初始化训练器
21
+
22
+ 参数:
23
+ model: Transformer 模型
24
+ train_loader: 训练数据加载器
25
+ val_loader: 验证数据加载器
26
+ config: 训练配置字典
27
+ """
28
+ self.model = model
29
+ self.train_loader = train_loader
30
+ self.val_loader = val_loader
31
+ self.config = config
32
+
33
+ # 设备
34
+ device_str = config.get("device", "cpu")
35
+ self.device = torch.device(device_str)
36
+ self.model.to(self.device)
37
+
38
+ # 损失函数
39
+ self.criterion = CrossEntropyLoss()
40
+
41
+ # 优化器
42
+ self.optimizer = get_optimizer(model, config)
43
+
44
+ # 学习率调度器
45
+ self.scheduler = get_lr_scheduler(self.optimizer, config)
46
+
47
+ # 训练状态
48
+ self.global_step = 0
49
+ self.current_epoch = 0
50
+ self.best_val_loss = float("inf")
51
+
52
+ # 梯度累积
53
+ self.gradient_accumulation_steps = config.get("gradient_accumulation_steps", 1)
54
+
55
+ # 检查点配置
56
+ self.save_steps = config.get("save_steps", 500)
57
+ self.eval_steps = config.get("eval_steps", 500)
58
+ self.save_total_limit = config.get("save_total_limit", 3)
59
+ self.checkpoint_dir = Path(config.get("checkpoint_dir", "checkpoints"))
60
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
61
+
62
+ # 保存的检查点路径(用于限制数量)
63
+ # 注意:不使用 maxlen,手动管理队列大小以便删除文件
64
+ self.saved_checkpoints = deque()
65
+ self.best_checkpoint_path = None # 保存最佳模型的路径
66
+
67
+ def train_step(self, batch):
68
+ """
69
+ 单步训练
70
+
71
+ 参数:
72
+ batch: 批次数据,格式为 (input_ids, target_ids)
73
+ - input_ids: 输入 token IDs,形状为 (batch_size, seq_len)
74
+ - target_ids: 目标 token IDs,形状为 (batch_size, seq_len)
75
+
76
+ 返回:
77
+ 损失值(标量)
78
+ """
79
+ input_ids, target_ids = batch
80
+ input_ids = input_ids.to(self.device)
81
+ target_ids = target_ids.to(self.device)
82
+
83
+ # 创建因果掩码
84
+ seq_len = input_ids.size(1)
85
+ causal_mask = create_causal_mask(seq_len, device=self.device)
86
+
87
+ # 前向传播
88
+ logits = self.model(input_ids, mask=causal_mask)
89
+
90
+ # 计算损失
91
+ loss = self.criterion(logits, target_ids)
92
+
93
+ # 梯度累积:将损失除以累积步数
94
+ loss = loss / self.gradient_accumulation_steps
95
+
96
+ # 反向传播
97
+ loss.backward()
98
+
99
+ # 梯度累积:只在累积步数达到时才更新参数
100
+ if (self.global_step + 1) % self.gradient_accumulation_steps == 0:
101
+ # 梯度裁剪
102
+ max_grad_norm = self.config.get("max_grad_norm", 1.0)
103
+ torch.nn.utils.clip_grad_norm_(
104
+ self.model.parameters(), max_grad_norm
105
+ )
106
+
107
+ # 更新参数
108
+ self.optimizer.step()
109
+ self.optimizer.zero_grad()
110
+
111
+ # 更新学习率
112
+ if self.scheduler is not None:
113
+ self.scheduler.step()
114
+
115
+ self.global_step += 1
116
+
117
+ return loss.item() * self.gradient_accumulation_steps # 返回原始损失值
118
+
119
+ def evaluate(self, max_batches=10):
120
+ """
121
+ 在验证集上评估模型
122
+
123
+ 返回:
124
+ val_loss: 验证损失
125
+ """
126
+ self.model.eval()
127
+ total_loss = 0.0
128
+ num_batches = 0
129
+
130
+ with torch.no_grad():
131
+ for batch in self.val_loader:
132
+ if max_batches and num_batches >= max_batches:
133
+ break # 限制评估批次数
134
+ input_ids, target_ids = batch
135
+ input_ids = input_ids.to(self.device)
136
+ target_ids = target_ids.to(self.device)
137
+
138
+ # 创建因果掩码
139
+ seq_len = input_ids.size(1)
140
+ causal_mask = create_causal_mask(seq_len, device=self.device)
141
+
142
+ # 前向传播
143
+ logits = self.model(input_ids, mask=causal_mask)
144
+
145
+ # 计算损失
146
+ loss = self.criterion(logits, target_ids)
147
+ total_loss += loss.item()
148
+ num_batches += 1
149
+
150
+ avg_loss = total_loss / num_batches if num_batches > 0 else float("inf")
151
+ self.model.train() # 恢复训练模式
152
+
153
+ return avg_loss
154
+
155
+ def save_checkpoint(self, is_best=False):
156
+ """
157
+ 保存检查点
158
+
159
+ 参数:
160
+ is_best: 是否为最佳模型
161
+ """
162
+ # 传入简单的 name,让 checkpoint.py 统一拼接 step
163
+ checkpoint_name = "best_model" if is_best else "checkpoint"
164
+
165
+ checkpoint_path = save_checkpoint(
166
+ model=self.model,
167
+ optimizer=self.optimizer,
168
+ epoch=self.current_epoch,
169
+ step=self.global_step,
170
+ loss=self.best_val_loss if is_best else None,
171
+ checkpoint_dir=self.checkpoint_dir,
172
+ name=checkpoint_name,
173
+ )
174
+
175
+ # 记录保存的检查点
176
+ if is_best:
177
+ # 删除旧的最佳模型(只保留最新的一个)
178
+ if self.best_checkpoint_path is not None and self.best_checkpoint_path.exists():
179
+ self.best_checkpoint_path.unlink()
180
+ self.best_checkpoint_path = checkpoint_path
181
+ else:
182
+ # 在添加新检查点前,如果队列已满,先删除最旧的文件
183
+ if len(self.saved_checkpoints) >= self.save_total_limit:
184
+ old_checkpoint = self.saved_checkpoints.popleft()
185
+ if old_checkpoint.exists():
186
+ old_checkpoint.unlink()
187
+ print(f"删除旧检查点: {old_checkpoint}")
188
+
189
+ # 添加新检查点到队列
190
+ self.saved_checkpoints.append(checkpoint_path)
191
+
192
+ return checkpoint_path
193
+
194
+ def train(self):
195
+ """
196
+ 训练循环
197
+ """
198
+ num_epochs = self.config.get("num_epochs", 10)
199
+ max_steps = self.config.get("max_steps", None)
200
+
201
+ print("=" * 60)
202
+ print("开始训练")
203
+ print("=" * 60)
204
+ print(f"设备: {self.device}")
205
+ print(f"训练样本数: {len(self.train_loader.dataset)}")
206
+ print(f"验证样本数: {len(self.val_loader.dataset) if self.val_loader else 0}")
207
+ print(f"批次大小: {self.config.get('batch_size', 4)}")
208
+ print(f"梯度累积步数: {self.gradient_accumulation_steps}")
209
+ print(f"最大步数: {max_steps}")
210
+ print(f"训练轮数: {num_epochs}")
211
+ print("=" * 60)
212
+
213
+ for epoch in range(num_epochs):
214
+ self.current_epoch = epoch
215
+ self.model.train()
216
+
217
+ # 创建进度条
218
+ pbar = ProgressBar(
219
+ total=len(self.train_loader), desc=f"Epoch {epoch+1}/{num_epochs}"
220
+ )
221
+
222
+ epoch_loss = 0.0
223
+ num_batches = 0
224
+
225
+ for batch_idx, batch in enumerate(self.train_loader):
226
+ # 检查是否达到最大步数
227
+ if max_steps is not None and self.global_step >= max_steps:
228
+ print(f"\n达到最大步数 {max_steps},停止训练")
229
+ pbar.close()
230
+ return
231
+
232
+ # 训练一步
233
+ loss = self.train_step(batch)
234
+ epoch_loss += loss
235
+ num_batches += 1
236
+
237
+ # 更新进度条
238
+ pbar.update(1)
239
+
240
+ # 定期评估
241
+ saved_best_at_this_step = False
242
+ if self.global_step % self.eval_steps == 0 and self.val_loader:
243
+ val_loss = self.evaluate()
244
+ # 获取当前学习率
245
+ current_lr = self.optimizer.param_groups[0]["lr"]
246
+
247
+ print(
248
+ f"\n步数 {self.global_step}: "
249
+ f"训练损失={loss:.4f}, "
250
+ f"验证损失={val_loss:.4f}, "
251
+ f"学习率={current_lr:.2e}"
252
+ )
253
+
254
+ # 保存最佳模型
255
+ if val_loss < self.best_val_loss:
256
+ self.best_val_loss = val_loss
257
+ checkpoint_path = self.save_checkpoint(is_best=True)
258
+ print(f"保存最佳模型: {checkpoint_path}")
259
+ saved_best_at_this_step = True
260
+
261
+ # 定期保存检查点(如果在这个 step 已经保存了最佳模型,则跳过)
262
+ if self.global_step % self.save_steps == 0 and not saved_best_at_this_step:
263
+ checkpoint_path = self.save_checkpoint(is_best=False)
264
+ print(f"保存检查点: {checkpoint_path}")
265
+
266
+ pbar.close()
267
+
268
+ # 每个 epoch 结束时的评估
269
+ avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
270
+ print(f"\nEpoch {epoch+1}/{num_epochs} 完成:")
271
+ print(f" 平均训练损失: {avg_epoch_loss:.4f}")
272
+
273
+ if self.val_loader:
274
+ val_loss = self.evaluate()
275
+ print(f" 验证损失: {val_loss:.4f}")
276
+
277
+ print("\n" + "=" * 60)
278
+ print("训练完成!")
279
+ print("=" * 60)
280
+ print(f"最佳验证损失: {self.best_val_loss:.4f}")
281
+
282
+ # 显示最佳模型路径
283
+ if self.best_checkpoint_path is not None and self.best_checkpoint_path.exists():
284
+ print(f"\n最佳模型路径: {self.best_checkpoint_path}")
285
+ print(" (文件名包含 'best_model' 的检查点就是最佳模型)")
286
+ else:
287
+ print("\n警告: 未找到最佳模型检查点")
288
+
289
+ # 显示所有保存的检查点
290
+ if self.saved_checkpoints:
291
+ print(f"\n保存的检查点数量: {len(self.saved_checkpoints)}")
292
+ print("检查点列表:")
293
+ for i, cp_path in enumerate(self.saved_checkpoints, 1):
294
+ print(f" {i}. {cp_path}")
llm/utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """工具模块"""
2
+
3
+ from llm.utils.init import init_weights, init_weights_with_scaling, apply_llm_init
4
+
5
+ __all__ = [
6
+ "init_weights",
7
+ "init_weights_with_scaling",
8
+ "apply_llm_init",
9
+ ]
llm/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (358 Bytes). View file
 
llm/utils/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (1.96 kB). View file
 
llm/utils/__pycache__/init.cpython-312.pyc ADDED
Binary file (10.1 kB). View file
 
llm/utils/checkpoint.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """检查点保存和加载"""
2
+ # 2026-01-23
3
+
4
+ import torch
5
+ from pathlib import Path
6
+
7
+
8
+ def save_checkpoint(model, optimizer, epoch, step, loss, checkpoint_dir, name='checkpoint'):
9
+ """保存检查点"""
10
+ checkpoint_dir = Path(checkpoint_dir)
11
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
12
+
13
+ # 统一拼接 step,避免重复
14
+ checkpoint_path = checkpoint_dir / f"{name}_step_{step}.pt"
15
+
16
+ torch.save({
17
+ 'epoch': epoch,
18
+ 'step': step,
19
+ 'model_state_dict': model.state_dict(),
20
+ 'optimizer_state_dict': optimizer.state_dict(),
21
+ 'loss': loss,
22
+ }, checkpoint_path)
23
+
24
+ return checkpoint_path
25
+
26
+
27
+ def load_checkpoint(model, optimizer, checkpoint_path):
28
+ """加载检查点(用于恢复训练)"""
29
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
30
+ model.load_state_dict(checkpoint['model_state_dict'])
31
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
32
+ return checkpoint['epoch'], checkpoint['step'], checkpoint.get('loss', None)
33
+
34
+
35
+ def load_model_only(model, checkpoint_path):
36
+ """只加载模型权重(用于推理,不需要优化器)"""
37
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
38
+ model.load_state_dict(checkpoint['model_state_dict'])
39
+ return checkpoint.get('epoch', 0), checkpoint.get('step', 0), checkpoint.get('loss', None)
llm/utils/config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """配置加载:读取 YAML"""
2
+
3
+ import yaml
4
+ from pathlib import Path
5
+
6
+
7
+ def load_config(config_path):
8
+ """加载配置文件"""
9
+ with open(config_path, 'r', encoding='utf-8') as f:
10
+ config = yaml.safe_load(f)
11
+ return config
12
+
13
+
14
+ def load_all_configs(config_dir='configs'):
15
+ """加载所有配置文件"""
16
+ config_dir = Path(config_dir)
17
+ configs = {}
18
+
19
+ for config_file in ['model.yaml', 'train.yaml', 'data.yaml']:
20
+ config_path = config_dir / config_file
21
+ if config_path.exists():
22
+ name = config_file.replace('.yaml', '')
23
+ configs[name] = load_config(config_path)
24
+
25
+ return configs
llm/utils/init.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """权重初始化:LLM 模型权重初始化策略"""
2
+
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+
8
+
9
+ def init_weights(module, std=0.02):
10
+ """
11
+ 初始化模型权重(适用于 LLM,参考 GPT/LLaMA)
12
+
13
+ 参数:
14
+ module: PyTorch 模块
15
+ std: 正态分布的标准差(默认: 0.02)
16
+
17
+ 初始化策略:
18
+ - nn.Embedding: 正态分布 N(0, std)
19
+ - nn.Linear: 正态分布 N(0, std),偏置初始化为 0
20
+ - RMSNorm: 可学习参数(scale)初始化为 1.0(RMSNorm 会自动处理)
21
+ """
22
+ if isinstance(module, nn.Embedding):
23
+ # 词嵌入层:正态分布初始化
24
+ nn.init.normal_(module.weight, mean=0.0, std=std)
25
+ elif isinstance(module, nn.Linear):
26
+ # 线性层:权重正态分布初始化,偏置初始化为 0
27
+ nn.init.normal_(module.weight, mean=0.0, std=std)
28
+ if module.bias is not None:
29
+ nn.init.zeros_(module.bias)
30
+
31
+
32
+ def init_weights_with_scaling(module, hidden_size=None, std=0.02):
33
+ """
34
+ 初始化模型权重(带缩放,适用于输出层)
35
+
36
+ 参数:
37
+ module: PyTorch 模块
38
+ hidden_size: 隐藏层维度(用于输出层缩放)
39
+ std: 基础标准差(默认: 0.02)
40
+
41
+ 初始化策略:
42
+ - nn.Embedding: 正态分布 N(0, std)
43
+ - nn.Linear:
44
+ - 输出层(如果 hidden_size 提供): N(0, std / sqrt(hidden_size))
45
+ - 其他层: N(0, std)
46
+ - 偏置: 初始化为 0
47
+ """
48
+ if isinstance(module, nn.Embedding):
49
+ # 词嵌入层:正态分布初始化
50
+ nn.init.normal_(module.weight, mean=0.0, std=std)
51
+ elif isinstance(module, nn.Linear):
52
+ # 线性层
53
+ if hidden_size is not None:
54
+ # 输出层:使用缩放的标准差
55
+ output_std = std / math.sqrt(hidden_size)
56
+ nn.init.normal_(module.weight, mean=0.0, std=output_std)
57
+ else:
58
+ # 普通线性层:标准正态分布初始化
59
+ nn.init.normal_(module.weight, mean=0.0, std=std)
60
+
61
+ # 偏置初始化为 0
62
+ if module.bias is not None:
63
+ nn.init.zeros_(module.bias)
64
+
65
+
66
+ def apply_llm_init(model, std=0.02, init_output_layer=True):
67
+ """
68
+ 对整个模型应用 LLM 权重初始化
69
+
70
+ 参数:
71
+ model: Transformer 模型
72
+ std: 正态分布的标准差(默认: 0.02)
73
+ init_output_layer: 是否对输出层使用特殊初始化(默认: True)
74
+
75
+ 返回:
76
+ 初始化后的模型
77
+ """
78
+ # 获取 hidden_size(用于输出层初始化)
79
+ hidden_size = None
80
+ if hasattr(model, "config"):
81
+ hidden_size = model.config.get("hidden_size")
82
+ elif hasattr(model, "hidden_size"):
83
+ hidden_size = model.hidden_size
84
+
85
+ # 遍历所有模块并初始化
86
+ for module in model.modules():
87
+ if isinstance(module, (nn.Embedding, nn.Linear)):
88
+ if init_output_layer and isinstance(module, nn.Linear):
89
+ # 检查是否是输出层(lm_head)
90
+ # 如果 tie_word_embeddings=True,lm_head 可能为 None,使用 embedding 的权重
91
+ if hasattr(model, "lm_head") and module is model.lm_head:
92
+ # 输出层:使用缩放初始化
93
+ init_weights_with_scaling(module, hidden_size=hidden_size, std=std)
94
+ else:
95
+ # 普通线性层
96
+ init_weights(module, std=std)
97
+ else:
98
+ # 标准初始化
99
+ init_weights(module, std=std)
100
+
101
+ return model
102
+
103
+
104
+ if __name__ == "__main__":
105
+ import sys
106
+ import io
107
+
108
+ # 设置输出编码为 UTF-8(Windows 兼容)
109
+ if sys.platform == "win32":
110
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
111
+
112
+ print("=" * 60)
113
+ print("权重初始化测试")
114
+ print("=" * 60)
115
+
116
+ # 测试参数
117
+ vocab_size = 100
118
+ hidden_size = 320
119
+ intermediate_size = 960
120
+ std = 0.02
121
+
122
+ print(f"\n测试参数:")
123
+ print(f" vocab_size: {vocab_size}")
124
+ print(f" hidden_size: {hidden_size}")
125
+ print(f" intermediate_size: {intermediate_size}")
126
+ print(f" std: {std}")
127
+
128
+ # 测试 Embedding 初始化
129
+ print("\n1. 测试 Embedding 初始化")
130
+ embedding = nn.Embedding(vocab_size, hidden_size)
131
+ init_weights(embedding, std=std)
132
+ weight_mean = embedding.weight.mean().item()
133
+ weight_std = embedding.weight.std().item()
134
+ print(f" Embedding 权重均值: {weight_mean:.6f} (应该接近 0)")
135
+ print(f" Embedding 权重标准差: {weight_std:.6f} (应该接近 {std})")
136
+
137
+ # 测试 Linear 初始化
138
+ print("\n2. 测试 Linear 初始化")
139
+ linear = nn.Linear(hidden_size, intermediate_size)
140
+ init_weights(linear, std=std)
141
+ weight_mean = linear.weight.mean().item()
142
+ weight_std = linear.weight.std().item()
143
+ bias_mean = linear.bias.mean().item() if linear.bias is not None else 0.0
144
+ print(f" Linear 权重均值: {weight_mean:.6f} (应该接近 0)")
145
+ print(f" Linear 权重标准差: {weight_std:.6f} (应该接近 {std})")
146
+ print(f" Linear 偏置均值: {bias_mean:.6f} (应该为 0)")
147
+
148
+ # 测试输出层初始化(带缩放)
149
+ print("\n3. 测试输出层初始化(带缩放)")
150
+ output_layer = nn.Linear(hidden_size, vocab_size, bias=False)
151
+ init_weights_with_scaling(output_layer, hidden_size=hidden_size, std=std)
152
+ weight_mean = output_layer.weight.mean().item()
153
+ weight_std = output_layer.weight.std().item()
154
+ expected_std = std / math.sqrt(hidden_size)
155
+ print(f" 输出层权重均值: {weight_mean:.6f} (应该接近 0)")
156
+ print(f" 输出层权重标准差: {weight_std:.6f}")
157
+ print(f" 期望标准差: {expected_std:.6f}")
158
+
159
+ # 测试完整模型初始化
160
+ print("\n4. 测试完整模型初始化")
161
+ from llm.model.transformer import Transformer
162
+
163
+ config = {
164
+ "vocab_size": vocab_size,
165
+ "hidden_size": hidden_size,
166
+ "num_hidden_layers": 2, # 使用较小的层数用于测试
167
+ "num_attention_heads": 10,
168
+ "num_key_value_heads": 2,
169
+ "intermediate_size": intermediate_size,
170
+ "rms_norm_eps": 1e-5,
171
+ "max_position_embeddings": 1024,
172
+ "rope_theta": 10000.0,
173
+ "sliding_window": 256,
174
+ "sliding_window_overlap": True,
175
+ "tie_word_embeddings": True,
176
+ }
177
+
178
+ model = Transformer(config)
179
+
180
+ # 记录初始化前的权重统计
181
+ print(" 初始化前的权重统计:")
182
+ for name, param in model.named_parameters():
183
+ if "embedding" in name or "weight" in name:
184
+ print(f" {name}: mean={param.data.mean().item():.6f}, std={param.data.std().item():.6f}")
185
+
186
+ # 应用初始化
187
+ apply_llm_init(model, std=std, init_output_layer=True)
188
+
189
+ # 记录初始化后的权重统计
190
+ print("\n 初始化后的权重统计:")
191
+ for name, param in model.named_parameters():
192
+ if "embedding" in name or ("weight" in name and "norm" not in name):
193
+ print(f" {name}: mean={param.data.mean().item():.6f}, std={param.data.std().item():.6f}")
194
+
195
+ # 验证初始化效果
196
+ print("\n5. 验证初始化效果")
197
+ embedding_weight = model.embedding.embedding.weight
198
+ print(f" Embedding 权重均值: {embedding_weight.mean().item():.6f}")
199
+ print(f" Embedding 权重标准差: {embedding_weight.std().item():.6f}")
200
+
201
+ # 检查第一个 Transformer Block 的线性层
202
+ first_block = model.layers[0]
203
+ attn_q_proj = first_block.attn.q_proj
204
+ print(f" Attention Q 投影权重均值: {attn_q_proj.weight.mean().item():.6f}")
205
+ print(f" Attention Q 投影权重标准差: {attn_q_proj.weight.std().item():.6f}")
206
+
207
+ ffn_gate_proj = first_block.ffn.gate_proj
208
+ print(f" FFN Gate 投影权重均值: {ffn_gate_proj.weight.mean().item():.6f}")
209
+ print(f" FFN Gate 投影权重标准差: {ffn_gate_proj.weight.std().item():.6f}")
210
+
211
+ print("\n" + "=" * 60)
212
+ print("所有测试完成!")
213
+ print("=" * 60)
llm/utils/logging.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """日志:tqdm / tensorboard"""
2
+
3
+ from tqdm import tqdm
4
+
5
+
6
+ class ProgressBar:
7
+ """进度条"""
8
+
9
+ def __init__(self, total, desc=""):
10
+ self.pbar = tqdm(total=total, desc=desc)
11
+
12
+ def update(self, n=1):
13
+ """更新进度"""
14
+ self.pbar.update(n)
15
+
16
+ def close(self):
17
+ """关闭进度条"""
18
+ self.pbar.close()
llm/utils/seed.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """随机种子设置"""
2
+
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def set_seed(seed=42):
9
+ """设置随机种子"""
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+ torch.manual_seed(seed)
13
+ if torch.cuda.is_available():
14
+ torch.cuda.manual_seed_all(seed)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0.0
2
+ pyyaml>=6.0
3
+ numpy>=2.0.0
4
+ huggingface_hub>=0.30.0