import torch import torch.nn as nn from tokenizers import Tokenizer import re import argparse import sys # ================================== # 模型定义 # ================================== class StabilizedDenoisingModel(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers): super(StabilizedDenoisingModel, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.row_transform = nn.Linear(embed_dim, hidden_dim) self.dim_transform = nn.Linear(hidden_dim, hidden_dim) self.norm = nn.LayerNorm(hidden_dim) self.denoise_layers = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim) ) for _ in range(num_layers) ]) self.output_layer = nn.Linear(hidden_dim, vocab_size) def forward(self, input_seq): embedded_seq = self.embedding(input_seq) hidden_space = self.row_transform(embedded_seq) hidden_space = self.dim_transform(hidden_space) hidden_space = self.norm(hidden_space) for denoise_layer in self.denoise_layers: signal = denoise_layer(hidden_space) gate = torch.sigmoid(signal) denoised = hidden_space - gate * signal + (1 - gate) * torch.relu(signal) hidden_space = self.norm(hidden_space + denoised) logits = self.output_layer(hidden_space) return logits # ================================== # 文本处理函数 # ================================== def clean_text(text): """清洗输入文本""" text = text.lower() text = re.sub(r'[^a-z0-9\s.,!?;:\'"-]', '', text) text = re.sub(r'\s+', ' ', text).strip() return text # ================================== # 流式文本生成函数(修复输出问题) # ================================== def stream_generate_text(model, tokenizer, device, start_text, max_len=100, temperature=0.8): """流式生成文本,逐个token输出(修复输出问题)""" model.eval() # 清洗输入文本 start_text = clean_text(start_text) # 编码输入文本 input_ids = tokenizer.encode(start_text).ids input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device) generated_ids = input_ids.copy() # 记录上一次输出的文本长度 last_output_length = len(start_text) # 输出初始文本(不换行) print(start_text, end="", flush=True) for i in range(max_len): with torch.no_grad(): # 限制输入长度 if input_tensor.size(1) > 100: input_tensor = input_tensor[:, -100:] # 预测下一个token logits = model(input_tensor) next_token_logits = logits[:, -1, :] / temperature probs = torch.softmax(next_token_logits, dim=-1) # 过滤低概率token probs[probs < 0.01] = 0 probs = probs / probs.sum() # 采样下一个token next_token = torch.multinomial(probs, num_samples=1).item() # 如果生成了终止标记,停止生成 if next_token == tokenizer.token_to_id(""): break # 添加新token并更新输入 generated_ids.append(next_token) next_token_tensor = torch.tensor([[next_token]], device=device, dtype=torch.long) input_tensor = torch.cat([input_tensor, next_token_tensor], dim=1) # 解码整个序列(确保空格正确) current_text = tokenizer.decode(generated_ids) # 只输出新增的部分 new_text = current_text[last_output_length:] last_output_length = len(current_text) # 输出新文本 print(new_text, end="", flush=True) # 返回完整生成的文本 return tokenizer.decode(generated_ids) # ================================== # 主推理函数(修复输出问题) # ================================== def main(model_path, tokenizer_path): # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") print(f"使用设备: {device}") # 加载分词器 tokenizer = Tokenizer.from_file(tokenizer_path) vocab_size = tokenizer.get_vocab_size() print(f"加载分词器成功,词汇表大小: {vocab_size}") # 解析模型参数 model_params = { "vocab_size": vocab_size, "embed_dim": 256, # 与训练参数一致 "hidden_dim": 512, # 与训练参数一致 "num_layers": 16 # 与训练参数一致 } # 初始化模型 model = StabilizedDenoisingModel(**model_params).to(device) # 加载模型权重 try: checkpoint = torch.load(model_path, map_location=device) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) print(f"加载模型成功: {model_path}") except Exception as e: print(f"模型加载失败: {str(e)}") return # 交互式生成 print("\n===== GTC-2 Large Base Model Text Generator (Early Research Preview) =====") print("输入文本后按回车生成,输入'quit'退出") while True: user_input = input("\n输入: ") if "activate" in user_input and "venv" in user_input: print("检测到虚拟环境激活命令,已忽略") continue # 跳过这次输入 if user_input.lower() == 'quit': break # 清空缓冲区 sys.stdout.flush() # 流式生成文本 print("生成: ", end="", flush=True) generated_text = stream_generate_text( model, tokenizer, device, user_input, max_len=100, temperature=0.8 ) print("\n") # 生成结束后换行 if __name__ == "__main__": # 设置命令行参数 parser = argparse.ArgumentParser(description='GTC-2 Large Base Model 文本生成器') parser.add_argument('--model', type=str, default='best_model.pth', help='模型文件路径 (默认: best_model.pth)') parser.add_argument('--tokenizer', type=str, default='bpe_tokenizer.json', help='分词器文件路径 (默认: bpe_tokenizer.json)') args = parser.parse_args() main(args.model, args.tokenizer)