File size: 5,669 Bytes
8b57151 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | import torch
import sentencepiece as spm
from model_optimized import MemoryOptimizedBigramLM
# 设备设置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")
# 加载tokenizer
sp = spm.SentencePieceProcessor()
sp.load("tokenizer.model")
vocab_size = sp.get_piece_size()
print(f"词汇表大小: {vocab_size}")
# 模型参数(与训练时保持一致)
d_model = 512
max_seq_len = 2048
h = 8
Nx = 6
dropout_rate = 0.2
# 创建模型
model = MemoryOptimizedBigramLM(
vocab_size=vocab_size,
d_model=d_model,
max_seq_len=max_seq_len,
h=h,
Nx=Nx,
dropout_rate=dropout_rate
)
# 加载最新的训练模型权重
try:
checkpoint = torch.load("saved_models/gpt_model_enhanced_stop_20251004_181034.pth", map_location=device, weights_only=False)
# 过滤掉mask相关的键,因为它们不是模型参数而是缓冲区
state_dict = checkpoint['model_state_dict']
filtered_state_dict = {k: v for k, v in state_dict.items() if 'mask' not in k}
model.load_state_dict(filtered_state_dict, strict=False)
print("✅ 成功加载最新训练模型权重")
print(f"训练迭代次数: {checkpoint['iteration']}")
print(f"最终训练损失: {checkpoint['train_losses'][-1]:.4f}")
print(f"最终验证损失: {checkpoint['valid_losses'][-1]:.4f}")
print(f"最终训练PPL: {checkpoint['train_ppls'][-1]:.2f}")
print(f"最终验证PPL: {checkpoint['valid_ppls'][-1]:.2f}")
except Exception as e:
print(f" 加载模型失败: {e}")
exit(1)
model = model.to(device)
model.eval()
def calculate_repetition_rate(text):
"""计算文本的重复率"""
words = text.split()
if len(words) < 2:
return 0.0
# 计算连续重复的比率
repeated_count = 0
total_pairs = len(words) - 1
for i in range(total_pairs):
if words[i] == words[i+1]:
repeated_count += 1
return repeated_count / total_pairs if total_pairs > 0 else 0.0
def test_output_optimized(prompt, max_new_tokens=300):
"""使用优化参数测试模型输出功能"""
# 最佳参数组合(根据测试结果)
temperature = 0.8
top_k = 50
repetition_penalty = 1.3
print(f"\n{'='*80}")
print(f"优化参数: temperature={temperature}, top_k={top_k}, repetition_penalty={repetition_penalty}")
print(f"输入提示: {prompt}")
print(f"{'='*80}")
# 编码提示文本
prompt_tokens = sp.encode(prompt, out_type=int)
# 转换为tensor
context = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
# 生成响应
with torch.no_grad():
generated_tokens = model.generate(
context,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
repetition_penalty=repetition_penalty
)[0].tolist()
generated_text = sp.decode(generated_tokens)
# 提取生成的响应部分(去掉prompt)
response_start = generated_text.find(prompt) + len(prompt)
response = generated_text[response_start:].strip()
# 计算重复率
repetition_rate = calculate_repetition_rate(response)
print(f"完整输出:")
print(f"{generated_text}")
print(f"\n提取的响应:")
print(f"{response}")
print(f"\n评估指标:")
print(f" 输出长度: {len(response)} 字符")
print(f" 重复率: {repetition_rate:.4f}")
return response, repetition_rate
# 测试多个不同类型的提示
print("开始使用优化参数测试模型输出...")
test_prompts = [
"关键词: 信 天涯 晚风",
"关键词: 风 雾 寂寞",
"关键词: 贴心 改变 自信",
"关键词: 午夜 寒冬 心动",
"关键词: 思考 推理 分析",
"关键词: 月光 思念 远方",
"关键词: 梦想 坚持 成功",
"关键词: 春天 希望 新生"
]
total_repetition_rate = 0
total_responses = len(test_prompts)
for i, prompt in enumerate(test_prompts, 1):
print(f"\n🔬 测试 {i}/{total_responses}")
response, repetition_rate = test_output_optimized(prompt)
total_repetition_rate += repetition_rate
# 评估输出质量
if repetition_rate == 0.0:
print(f"✅ 输出质量优秀 - 无重复")
elif repetition_rate < 0.05:
print(f"✅ 输出质量良好 - 轻微重复")
elif repetition_rate < 0.1:
print(f"⚠️ 输出质量一般 - 中等重复")
else:
print(f"❌ 输出质量较差 - 严重重复")
# 计算平均重复率
avg_repetition_rate = total_repetition_rate / total_responses
print(f"\n{'='*80}")
print("🎯 最终测试结果总结")
print(f"{'='*80}")
print(f"测试提示数量: {total_responses}")
print(f"平均重复率: {avg_repetition_rate:.4f}")
print(f"最佳参数组合: temperature=0.8, top_k=50, repetition_penalty=1.3")
print(f"生成长度: 300 tokens")
if avg_repetition_rate == 0.0:
print(f"🎉 优化成功!所有输出均无重复")
elif avg_repetition_rate < 0.05:
print(f"✅ 优化效果良好!平均重复率很低")
elif avg_repetition_rate < 0.1:
print(f"⚠️ 优化效果一般,仍有改进空间")
else:
print(f"❌ 需要进一步优化")
print(f"\n优化前问题: 大量重复词汇(如'兄弟'、'兄弟姐妹'等)")
print(f"优化后效果: 重复率显著降低,输出多样性提高")
|