ESFT-GTC-2 / inference.py
clarenceleo's picture
Upload 4 files
6d5fd7f verified
import torch
import torch.nn as nn
from tokenizers import Tokenizer
import re
import argparse
import sys
# ==================================
# 模型定义
# ==================================
class StabilizedDenoisingModel(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
super(StabilizedDenoisingModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.row_transform = nn.Linear(embed_dim, hidden_dim)
self.dim_transform = nn.Linear(hidden_dim, hidden_dim)
self.norm = nn.LayerNorm(hidden_dim)
self.denoise_layers = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
for _ in range(num_layers)
])
self.output_layer = nn.Linear(hidden_dim, vocab_size)
def forward(self, input_seq):
embedded_seq = self.embedding(input_seq)
hidden_space = self.row_transform(embedded_seq)
hidden_space = self.dim_transform(hidden_space)
hidden_space = self.norm(hidden_space)
for denoise_layer in self.denoise_layers:
signal = denoise_layer(hidden_space)
gate = torch.sigmoid(signal)
denoised = hidden_space - gate * signal + (1 - gate) * torch.relu(signal)
hidden_space = self.norm(hidden_space + denoised)
logits = self.output_layer(hidden_space)
return logits
# ==================================
# 文本处理函数
# ==================================
def clean_text(text):
"""清洗输入文本"""
text = text.lower()
text = re.sub(r'[^a-z0-9\s.,!?;:\'"-]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# ==================================
# 流式文本生成函数(修复输出问题)
# ==================================
def stream_generate_text(model, tokenizer, device, start_text, max_len=100, temperature=0.8):
"""流式生成文本,逐个token输出(修复输出问题)"""
model.eval()
# 清洗输入文本
start_text = clean_text(start_text)
# 编码输入文本
input_ids = tokenizer.encode(start_text).ids
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
generated_ids = input_ids.copy()
# 记录上一次输出的文本长度
last_output_length = len(start_text)
# 输出初始文本(不换行)
#print(start_text, end="", flush=True)
for i in range(max_len):
with torch.no_grad():
# 限制输入长度
if input_tensor.size(1) > 100:
input_tensor = input_tensor[:, -100:]
# 预测下一个token
logits = model(input_tensor)
next_token_logits = logits[:, -1, :] / temperature
probs = torch.softmax(next_token_logits, dim=-1)
# 过滤低概率token
probs[probs < 0.01] = 0
probs = probs / probs.sum()
# 采样下一个token
next_token = torch.multinomial(probs, num_samples=1).item()
# 如果生成了终止标记,停止生成
if next_token == tokenizer.token_to_id("<SEP>"):
break
# 添加新token并更新输入
generated_ids.append(next_token)
next_token_tensor = torch.tensor([[next_token]], device=device, dtype=torch.long)
input_tensor = torch.cat([input_tensor, next_token_tensor], dim=1)
# 解码整个序列(确保空格正确)
current_text = tokenizer.decode(generated_ids)
# 只输出新增的部分
new_text = current_text[last_output_length:]
last_output_length = len(current_text)
# 输出新文本
print(new_text, end="", flush=True)
# 返回完整生成的文本
return tokenizer.decode(generated_ids)
# ==================================
# 主推理函数(修复输出问题)
# ==================================
def main(model_path, tokenizer_path):
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"使用设备: {device}")
# 加载分词器
tokenizer = Tokenizer.from_file(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()
print(f"加载分词器成功,词汇表大小: {vocab_size}")
# 解析模型参数
model_params = {
"vocab_size": vocab_size,
"embed_dim": 256, # 与训练参数一致
"hidden_dim": 512, # 与训练参数一致
"num_layers": 16 # 与训练参数一致
}
# 初始化模型
model = StabilizedDenoisingModel(**model_params).to(device)
# 加载模型权重
try:
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
print(f"加载模型成功: {model_path}")
except Exception as e:
print(f"模型加载失败: {str(e)}")
return
# 交互式生成
print("\n===== ESFT GTC-2 Chatbot =====")
print("输入文本后按回车生成,输入'quit'退出")
while True:
user_input = input("\n输入: ")
if "activate" in user_input and "venv" in user_input:
print("检测到虚拟环境激活命令,已忽略")
continue # 跳过这次输入
if user_input.lower() == 'quit':
break
# 清空缓冲区
sys.stdout.flush()
# 流式生成文本
print("生成: ", end="", flush=True)
generated_text = stream_generate_text(
model,
tokenizer,
device,
user_input,
max_len=100,
temperature=0.8
)
print("\n") # 生成结束后换行
if __name__ == "__main__":
# 设置命令行参数
parser = argparse.ArgumentParser(description='ESFT GTC-2 Chatbot')
parser.add_argument('--model', type=str, default='sft_best_model.pth',
help='模型文件路径 (默认: best_model.pth)')
parser.add_argument('--tokenizer', type=str, default='bpe_tokenizer.json',
help='分词器文件路径 (默认: bpe_tokenizer.json)')
args = parser.parse_args()
main(args.model, args.tokenizer)