clarenceleo commited on
Commit
e572d32
·
verified ·
1 Parent(s): 7098ca4

Upload 3 files

Browse files
Files changed (3) hide show
  1. bpe_tokenizer.json +0 -0
  2. gtc-2-large-mini.pth +3 -0
  3. inference.py +240 -0
bpe_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
gtc-2-large-mini.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:968e146c337e17d096435d81f33f2c6e60b50cae4d5854c0aa06a95525702ee3
3
+ size 189023590
inference.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from tokenizers import Tokenizer
4
+ import re
5
+ import argparse
6
+ import sys
7
+ import os
8
+
9
+ # ==================================
10
+ # 模型定义
11
+ # ==================================
12
+
13
+ class StabilizedDenoisingModel(nn.Module):
14
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
15
+ super(StabilizedDenoisingModel, self).__init__()
16
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
17
+ self.row_transform = nn.Linear(embed_dim, hidden_dim)
18
+ self.dim_transform = nn.Linear(hidden_dim, hidden_dim)
19
+ self.norm = nn.LayerNorm(hidden_dim)
20
+
21
+ self.denoise_layers = nn.ModuleList([
22
+ nn.Sequential(
23
+ nn.Linear(hidden_dim, hidden_dim),
24
+ nn.ReLU(),
25
+ nn.Linear(hidden_dim, hidden_dim)
26
+ )
27
+ for _ in range(num_layers)
28
+ ])
29
+
30
+ self.output_layer = nn.Linear(hidden_dim, vocab_size)
31
+ self.num_layers = num_layers
32
+
33
+ def forward(self, input_seq):
34
+ embedded_seq = self.embedding(input_seq)
35
+ hidden_space = self.row_transform(embedded_seq)
36
+ hidden_space = self.dim_transform(hidden_space)
37
+ hidden_space = self.norm(hidden_space)
38
+
39
+ for denoise_layer in self.denoise_layers:
40
+ signal = denoise_layer(hidden_space)
41
+ gate = torch.sigmoid(signal)
42
+ denoised = hidden_space - gate * signal + (1 - gate) * torch.relu(signal)
43
+ hidden_space = self.norm(hidden_space + denoised)
44
+
45
+ logits = self.output_layer(hidden_space)
46
+ return logits
47
+
48
+ # ==================================
49
+ # 文本处理函数
50
+ # ==================================
51
+
52
+ def clean_text(text):
53
+ """清洗输入文本"""
54
+ text = text.lower()
55
+ text = re.sub(r'[^a-z0-9\s.,!?;:\'"-]', '', text)
56
+ text = re.sub(r'\s+', ' ', text).strip()
57
+ return text
58
+
59
+ # ==================================
60
+ # 流式文本生成函数(修复输出问题)
61
+ # ==================================
62
+
63
+ def stream_generate_text(model, tokenizer, device, start_text, max_len=100, temperature=0.8):
64
+ """流式生成文本,逐个token输出(修复输出问题)"""
65
+ model.eval()
66
+
67
+ # 清洗输入文本
68
+ start_text = clean_text(start_text)
69
+
70
+ # 编码输入文本
71
+ input_ids = tokenizer.encode(start_text).ids
72
+ input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
73
+
74
+ generated_ids = input_ids.copy()
75
+
76
+ # 记录上一次输出的文本长度
77
+ last_output_length = len(start_text)
78
+
79
+ # 输出初始文本(不换行)
80
+ print(start_text, end="", flush=True)
81
+
82
+ for i in range(max_len):
83
+ with torch.no_grad():
84
+ # 限制输入长度
85
+ if input_tensor.size(1) > 100:
86
+ input_tensor = input_tensor[:, -100:]
87
+
88
+ # 预测下一个token
89
+ logits = model(input_tensor)
90
+ next_token_logits = logits[:, -1, :] / temperature
91
+ probs = torch.softmax(next_token_logits, dim=-1)
92
+
93
+ # 过滤低概率token
94
+ probs[probs < 0.01] = 0
95
+ probs = probs / probs.sum()
96
+
97
+ # 采样下一个token
98
+ next_token = torch.multinomial(probs, num_samples=1).item()
99
+
100
+ # 如果生成了终止标记,停止生成
101
+ if next_token == tokenizer.token_to_id("<SEP>"):
102
+ break
103
+
104
+ # 添加新token并更新输入
105
+ generated_ids.append(next_token)
106
+ next_token_tensor = torch.tensor([[next_token]], device=device, dtype=torch.long)
107
+ input_tensor = torch.cat([input_tensor, next_token_tensor], dim=1)
108
+
109
+ # 解码整个序列(确保空格正确)
110
+ current_text = tokenizer.decode(generated_ids)
111
+
112
+ # 只输出新增的部分
113
+ new_text = current_text[last_output_length:]
114
+ last_output_length = len(current_text)
115
+
116
+ # 输出新文本
117
+ print(new_text, end="", flush=True)
118
+
119
+ # 返回完整生成的文本
120
+ return tokenizer.decode(generated_ids)
121
+
122
+ # ==================================
123
+ # 模型加载和过滤函数
124
+ # ==================================
125
+
126
+ def load_model_with_filtering(model, model_path, device, target_layers):
127
+ """加载模型权重并过滤掉不需要的层"""
128
+ try:
129
+ checkpoint = torch.load(model_path, map_location=device)
130
+
131
+ if 'model_state_dict' in checkpoint:
132
+ state_dict = checkpoint['model_state_dict']
133
+ else:
134
+ state_dict = checkpoint
135
+
136
+ # 过滤状态字典,只保留目标层数
137
+ filtered_state_dict = {}
138
+ for key, value in state_dict.items():
139
+ # 检查是否是denoise层的参数
140
+ if key.startswith('denoise_layers'):
141
+ # 提取层号
142
+ layer_num = int(key.split('.')[1])
143
+ # 只保留目标层数范围内的参数
144
+ if layer_num < target_layers:
145
+ filtered_state_dict[key] = value
146
+ else:
147
+ # 保留所有其他参数
148
+ filtered_state_dict[key] = value
149
+
150
+ # 加载过滤后的状态字典
151
+ model.load_state_dict(filtered_state_dict, strict=False)
152
+ print(f"加载模型成功: {model_path}")
153
+ print(f"模型层数: {target_layers}")
154
+ return True
155
+ except Exception as e:
156
+ print(f"模型加载失败: {str(e)}")
157
+ return False
158
+
159
+ # ==================================
160
+ # 主推理函数(修复输出问题)
161
+ # ==================================
162
+
163
+ def main(model_path, tokenizer_path, model_size="mini"):
164
+ # 设置设备
165
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
166
+ print(f"使用设备: {device}")
167
+
168
+ # 加载分词器
169
+ tokenizer = Tokenizer.from_file(tokenizer_path)
170
+ vocab_size = tokenizer.get_vocab_size()
171
+ print(f"加载分词器成功,词汇表大小: {vocab_size}")
172
+
173
+ # 根据模型大小设置层数
174
+ if model_size == "large":
175
+ num_layers = 16
176
+ elif model_size == "mini":
177
+ num_layers = 12
178
+ elif model_size == "nano":
179
+ num_layers = 8
180
+ else:
181
+ print(f"未知模型大小: {model_size}, 使用默认mini(12层)")
182
+ num_layers = 12
183
+
184
+ # 解析模型参数
185
+ model_params = {
186
+ "vocab_size": vocab_size,
187
+ "embed_dim": 256, # 与训练参数一致
188
+ "hidden_dim": 512, # 与训练参数一致
189
+ "num_layers": num_layers # 动态设置层数
190
+ }
191
+
192
+ # 初始化模型
193
+ model = StabilizedDenoisingModel(**model_params).to(device)
194
+
195
+ # 加载模型权重
196
+ if not load_model_with_filtering(model, model_path, device, num_layers):
197
+ return
198
+
199
+ # 交互式生成
200
+ print(f"\n===== GTC-2 Large mini Base Model Text Generator =====")
201
+ print("输入文本后按回车生成,输入'quit'退出")
202
+
203
+ while True:
204
+ user_input = input("\n输入: ")
205
+ if "activate" in user_input and "venv" in user_input:
206
+ print("检测到虚拟环境激活命令,已忽略")
207
+ continue # 跳过这次输入
208
+
209
+ if user_input.lower() == 'quit':
210
+ break
211
+
212
+ # 清空缓冲区
213
+ sys.stdout.flush()
214
+
215
+ # 流式生成文本
216
+ print("生成: ", end="", flush=True)
217
+ generated_text = stream_generate_text(
218
+ model,
219
+ tokenizer,
220
+ device,
221
+ user_input,
222
+ max_len=100,
223
+ temperature=0.8
224
+ )
225
+
226
+ print("\n") # 生成结束后换行
227
+
228
+ if __name__ == "__main__":
229
+ # 设置命令行参数
230
+ parser = argparse.ArgumentParser(description='GTC-2 Base Model 文本生成器')
231
+ parser.add_argument('--model', type=str, default='gtc-2-large-mini.pth',
232
+ help='模型文件路径 (默认: gtc-2-large-mini.pth)')
233
+ parser.add_argument('--tokenizer', type=str, default='bpe_tokenizer.json',
234
+ help='分词器文件路径 (默认: bpe_tokenizer.json)')
235
+ parser.add_argument('--size', type=str, default='mini', choices=['large', 'mini', 'nano'],
236
+ help='模型大小: large(16层), mini(12层), nano(8层) (默认: mini)')
237
+
238
+ args = parser.parse_args()
239
+
240
+ main(args.model, args.tokenizer, args.size)