clarenceleo's picture
Upload 3 files
e572d32 verified
import torch
import torch.nn as nn
from tokenizers import Tokenizer
import re
import argparse
import sys
import os
# ==================================
# 模型定义
# ==================================
class StabilizedDenoisingModel(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
super(StabilizedDenoisingModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.row_transform = nn.Linear(embed_dim, hidden_dim)
self.dim_transform = nn.Linear(hidden_dim, hidden_dim)
self.norm = nn.LayerNorm(hidden_dim)
self.denoise_layers = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
for _ in range(num_layers)
])
self.output_layer = nn.Linear(hidden_dim, vocab_size)
self.num_layers = num_layers
def forward(self, input_seq):
embedded_seq = self.embedding(input_seq)
hidden_space = self.row_transform(embedded_seq)
hidden_space = self.dim_transform(hidden_space)
hidden_space = self.norm(hidden_space)
for denoise_layer in self.denoise_layers:
signal = denoise_layer(hidden_space)
gate = torch.sigmoid(signal)
denoised = hidden_space - gate * signal + (1 - gate) * torch.relu(signal)
hidden_space = self.norm(hidden_space + denoised)
logits = self.output_layer(hidden_space)
return logits
# ==================================
# 文本处理函数
# ==================================
def clean_text(text):
"""清洗输入文本"""
text = text.lower()
text = re.sub(r'[^a-z0-9\s.,!?;:\'"-]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# ==================================
# 流式文本生成函数(修复输出问题)
# ==================================
def stream_generate_text(model, tokenizer, device, start_text, max_len=100, temperature=0.8):
"""流式生成文本,逐个token输出(修复输出问题)"""
model.eval()
# 清洗输入文本
start_text = clean_text(start_text)
# 编码输入文本
input_ids = tokenizer.encode(start_text).ids
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
generated_ids = input_ids.copy()
# 记录上一次输出的文本长度
last_output_length = len(start_text)
# 输出初始文本(不换行)
print(start_text, end="", flush=True)
for i in range(max_len):
with torch.no_grad():
# 限制输入长度
if input_tensor.size(1) > 100:
input_tensor = input_tensor[:, -100:]
# 预测下一个token
logits = model(input_tensor)
next_token_logits = logits[:, -1, :] / temperature
probs = torch.softmax(next_token_logits, dim=-1)
# 过滤低概率token
probs[probs < 0.01] = 0
probs = probs / probs.sum()
# 采样下一个token
next_token = torch.multinomial(probs, num_samples=1).item()
# 如果生成了终止标记,停止生成
if next_token == tokenizer.token_to_id("<SEP>"):
break
# 添加新token并更新输入
generated_ids.append(next_token)
next_token_tensor = torch.tensor([[next_token]], device=device, dtype=torch.long)
input_tensor = torch.cat([input_tensor, next_token_tensor], dim=1)
# 解码整个序列(确保空格正确)
current_text = tokenizer.decode(generated_ids)
# 只输出新增的部分
new_text = current_text[last_output_length:]
last_output_length = len(current_text)
# 输出新文本
print(new_text, end="", flush=True)
# 返回完整生成的文本
return tokenizer.decode(generated_ids)
# ==================================
# 模型加载和过滤函数
# ==================================
def load_model_with_filtering(model, model_path, device, target_layers):
"""加载模型权重并过滤掉不需要的层"""
try:
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# 过滤状态字典,只保留目标层数
filtered_state_dict = {}
for key, value in state_dict.items():
# 检查是否是denoise层的参数
if key.startswith('denoise_layers'):
# 提取层号
layer_num = int(key.split('.')[1])
# 只保留目标层数范围内的参数
if layer_num < target_layers:
filtered_state_dict[key] = value
else:
# 保留所有其他参数
filtered_state_dict[key] = value
# 加载过滤后的状态字典
model.load_state_dict(filtered_state_dict, strict=False)
print(f"加载模型成功: {model_path}")
print(f"模型层数: {target_layers}")
return True
except Exception as e:
print(f"模型加载失败: {str(e)}")
return False
# ==================================
# 主推理函数(修复输出问题)
# ==================================
def main(model_path, tokenizer_path, model_size="mini"):
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"使用设备: {device}")
# 加载分词器
tokenizer = Tokenizer.from_file(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()
print(f"加载分词器成功,词汇表大小: {vocab_size}")
# 根据模型大小设置层数
if model_size == "large":
num_layers = 16
elif model_size == "mini":
num_layers = 12
elif model_size == "nano":
num_layers = 8
else:
print(f"未知模型大小: {model_size}, 使用默认mini(12层)")
num_layers = 12
# 解析模型参数
model_params = {
"vocab_size": vocab_size,
"embed_dim": 256, # 与训练参数一致
"hidden_dim": 512, # 与训练参数一致
"num_layers": num_layers # 动态设置层数
}
# 初始化模型
model = StabilizedDenoisingModel(**model_params).to(device)
# 加载模型权重
if not load_model_with_filtering(model, model_path, device, num_layers):
return
# 交互式生成
print(f"\n===== GTC-2 Large mini Base Model Text Generator =====")
print("输入文本后按回车生成,输入'quit'退出")
while True:
user_input = input("\n输入: ")
if "activate" in user_input and "venv" in user_input:
print("检测到虚拟环境激活命令,已忽略")
continue # 跳过这次输入
if user_input.lower() == 'quit':
break
# 清空缓冲区
sys.stdout.flush()
# 流式生成文本
print("生成: ", end="", flush=True)
generated_text = stream_generate_text(
model,
tokenizer,
device,
user_input,
max_len=100,
temperature=0.8
)
print("\n") # 生成结束后换行
if __name__ == "__main__":
# 设置命令行参数
parser = argparse.ArgumentParser(description='GTC-2 Base Model 文本生成器')
parser.add_argument('--model', type=str, default='gtc-2-large-mini.pth',
help='模型文件路径 (默认: gtc-2-large-mini.pth)')
parser.add_argument('--tokenizer', type=str, default='bpe_tokenizer.json',
help='分词器文件路径 (默认: bpe_tokenizer.json)')
parser.add_argument('--size', type=str, default='mini', choices=['large', 'mini', 'nano'],
help='模型大小: large(16层), mini(12层), nano(8层) (默认: mini)')
args = parser.parse_args()
main(args.model, args.tokenizer, args.size)