clarenceleo commited on
Commit
aead756
·
verified ·
1 Parent(s): 7568bc5

Upload 3 files

Browse files
Files changed (3) hide show
  1. best_model.pth +3 -0
  2. bpe_tokenizer.json +0 -0
  3. inference.py +198 -0
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1bdc1af98e559d9a99c1780e0e31ce136a2c4f9fa142d2726ae9b61a6c8b91
3
+ size 189026677
bpe_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
inference.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from tokenizers import Tokenizer
4
+ import re
5
+ import argparse
6
+ import sys
7
+
8
+ # ==================================
9
+ # 模型定义
10
+ # ==================================
11
+
12
+ class StabilizedDenoisingModel(nn.Module):
13
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
14
+ super(StabilizedDenoisingModel, self).__init__()
15
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
16
+ self.row_transform = nn.Linear(embed_dim, hidden_dim)
17
+ self.dim_transform = nn.Linear(hidden_dim, hidden_dim)
18
+ self.norm = nn.LayerNorm(hidden_dim)
19
+
20
+ self.denoise_layers = nn.ModuleList([
21
+ nn.Sequential(
22
+ nn.Linear(hidden_dim, hidden_dim),
23
+ nn.ReLU(),
24
+ nn.Linear(hidden_dim, hidden_dim)
25
+ )
26
+ for _ in range(num_layers)
27
+ ])
28
+
29
+ self.output_layer = nn.Linear(hidden_dim, vocab_size)
30
+
31
+ def forward(self, input_seq):
32
+ embedded_seq = self.embedding(input_seq)
33
+ hidden_space = self.row_transform(embedded_seq)
34
+ hidden_space = self.dim_transform(hidden_space)
35
+ hidden_space = self.norm(hidden_space)
36
+
37
+ for denoise_layer in self.denoise_layers:
38
+ signal = denoise_layer(hidden_space)
39
+ gate = torch.sigmoid(signal)
40
+ denoised = hidden_space - gate * signal + (1 - gate) * torch.relu(signal)
41
+ hidden_space = self.norm(hidden_space + denoised)
42
+
43
+ logits = self.output_layer(hidden_space)
44
+ return logits
45
+
46
+ # ==================================
47
+ # 文本处理函数
48
+ # ==================================
49
+
50
+ def clean_text(text):
51
+ """清洗输入文本"""
52
+ text = text.lower()
53
+ text = re.sub(r'[^a-z0-9\s.,!?;:\'"-]', '', text)
54
+ text = re.sub(r'\s+', ' ', text).strip()
55
+ return text
56
+
57
+ # ==================================
58
+ # 流式文本生成函数(修复输出问题)
59
+ # ==================================
60
+
61
+ def stream_generate_text(model, tokenizer, device, start_text, max_len=100, temperature=0.8):
62
+ """流式生成文本,逐个token输出(修复输出问题)"""
63
+ model.eval()
64
+
65
+ # 清洗输入文本
66
+ start_text = clean_text(start_text)
67
+
68
+ # 编码输入文本
69
+ input_ids = tokenizer.encode(start_text).ids
70
+ input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
71
+
72
+ generated_ids = input_ids.copy()
73
+
74
+ # 记录上一次输出的文本长度
75
+ last_output_length = len(start_text)
76
+
77
+ # 输出初始文本(不换行)
78
+ print(start_text, end="", flush=True)
79
+
80
+ for i in range(max_len):
81
+ with torch.no_grad():
82
+ # 限制输入长度
83
+ if input_tensor.size(1) > 100:
84
+ input_tensor = input_tensor[:, -100:]
85
+
86
+ # 预测下一个token
87
+ logits = model(input_tensor)
88
+ next_token_logits = logits[:, -1, :] / temperature
89
+ probs = torch.softmax(next_token_logits, dim=-1)
90
+
91
+ # 过滤低概率token
92
+ probs[probs < 0.01] = 0
93
+ probs = probs / probs.sum()
94
+
95
+ # 采样下一个token
96
+ next_token = torch.multinomial(probs, num_samples=1).item()
97
+
98
+ # 如果生成了终止标记,停止生成
99
+ if next_token == tokenizer.token_to_id("<SEP>"):
100
+ break
101
+
102
+ # 添加新token并更新输入
103
+ generated_ids.append(next_token)
104
+ next_token_tensor = torch.tensor([[next_token]], device=device, dtype=torch.long)
105
+ input_tensor = torch.cat([input_tensor, next_token_tensor], dim=1)
106
+
107
+ # 解码整个序列(确保空格正确)
108
+ current_text = tokenizer.decode(generated_ids)
109
+
110
+ # 只输出新增的部分
111
+ new_text = current_text[last_output_length:]
112
+ last_output_length = len(current_text)
113
+
114
+ # 输出新文本
115
+ print(new_text, end="", flush=True)
116
+
117
+ # 返回完整生成的文本
118
+ return tokenizer.decode(generated_ids)
119
+
120
+ # ==================================
121
+ # 主推理函数(修复输出问题)
122
+ # ==================================
123
+
124
+ def main(model_path, tokenizer_path):
125
+ # 设置设备
126
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
127
+ print(f"使用设备: {device}")
128
+
129
+ # 加载分词器
130
+ tokenizer = Tokenizer.from_file(tokenizer_path)
131
+ vocab_size = tokenizer.get_vocab_size()
132
+ print(f"加载分词器成功,词汇表大小: {vocab_size}")
133
+
134
+ # 解析模型参数
135
+ model_params = {
136
+ "vocab_size": vocab_size,
137
+ "embed_dim": 256, # 与训练参数一致
138
+ "hidden_dim": 512, # 与训练参数一致
139
+ "num_layers": 16 # 与训练参数一致
140
+ }
141
+
142
+ # 初始化模型
143
+ model = StabilizedDenoisingModel(**model_params).to(device)
144
+
145
+ # 加载模型权重
146
+ try:
147
+ checkpoint = torch.load(model_path, map_location=device)
148
+
149
+ if 'model_state_dict' in checkpoint:
150
+ model.load_state_dict(checkpoint['model_state_dict'])
151
+ else:
152
+ model.load_state_dict(checkpoint)
153
+
154
+ print(f"加载模型成功: {model_path}")
155
+ except Exception as e:
156
+ print(f"模型加载失败: {str(e)}")
157
+ return
158
+
159
+ # 交互式生成
160
+ print("\n===== GTC-2 Large Base Model Text Generator (Early Research Preview) =====")
161
+ print("输入文本后按回车生成,输入'quit'退出")
162
+
163
+ while True:
164
+ user_input = input("\n输入: ")
165
+ if "activate" in user_input and "venv" in user_input:
166
+ print("检测到虚拟环境激活命令,已忽略")
167
+ continue # 跳过这次输入
168
+
169
+ if user_input.lower() == 'quit':
170
+ break
171
+
172
+ # 清空缓冲区
173
+ sys.stdout.flush()
174
+
175
+ # 流式生成文本
176
+ print("生成: ", end="", flush=True)
177
+ generated_text = stream_generate_text(
178
+ model,
179
+ tokenizer,
180
+ device,
181
+ user_input,
182
+ max_len=100,
183
+ temperature=0.8
184
+ )
185
+
186
+ print("\n") # 生成结束后换行
187
+
188
+ if __name__ == "__main__":
189
+ # 设置命令行参数
190
+ parser = argparse.ArgumentParser(description='GTC-2 Large Base Model 文本生成器')
191
+ parser.add_argument('--model', type=str, default='best_model.pth',
192
+ help='模型文件路径 (默认: best_model.pth)')
193
+ parser.add_argument('--tokenizer', type=str, default='bpe_tokenizer.json',
194
+ help='分词器文件路径 (默认: bpe_tokenizer.json)')
195
+
196
+ args = parser.parse_args()
197
+
198
+ main(args.model, args.tokenizer)