own_gpt / test_final_optimized.py
AISkywalker's picture
Upload 21 files
8b57151 verified
import torch
import sentencepiece as spm
from model_optimized import MemoryOptimizedBigramLM
# 设备设置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")
# 加载tokenizer
sp = spm.SentencePieceProcessor()
sp.load("tokenizer.model")
vocab_size = sp.get_piece_size()
print(f"词汇表大小: {vocab_size}")
# 模型参数(与训练时保持一致)
d_model = 512
max_seq_len = 2048
h = 8
Nx = 6
dropout_rate = 0.2
# 创建模型
model = MemoryOptimizedBigramLM(
vocab_size=vocab_size,
d_model=d_model,
max_seq_len=max_seq_len,
h=h,
Nx=Nx,
dropout_rate=dropout_rate
)
# 加载最新的训练模型权重
try:
checkpoint = torch.load("saved_models/gpt_model_enhanced_stop_20251004_181034.pth", map_location=device, weights_only=False)
# 过滤掉mask相关的键,因为它们不是模型参数而是缓冲区
state_dict = checkpoint['model_state_dict']
filtered_state_dict = {k: v for k, v in state_dict.items() if 'mask' not in k}
model.load_state_dict(filtered_state_dict, strict=False)
print("✅ 成功加载最新训练模型权重")
print(f"训练迭代次数: {checkpoint['iteration']}")
print(f"最终训练损失: {checkpoint['train_losses'][-1]:.4f}")
print(f"最终验证损失: {checkpoint['valid_losses'][-1]:.4f}")
print(f"最终训练PPL: {checkpoint['train_ppls'][-1]:.2f}")
print(f"最终验证PPL: {checkpoint['valid_ppls'][-1]:.2f}")
except Exception as e:
print(f" 加载模型失败: {e}")
exit(1)
model = model.to(device)
model.eval()
def calculate_repetition_rate(text):
"""计算文本的重复率"""
words = text.split()
if len(words) < 2:
return 0.0
# 计算连续重复的比率
repeated_count = 0
total_pairs = len(words) - 1
for i in range(total_pairs):
if words[i] == words[i+1]:
repeated_count += 1
return repeated_count / total_pairs if total_pairs > 0 else 0.0
def test_output_optimized(prompt, max_new_tokens=300):
"""使用优化参数测试模型输出功能"""
# 最佳参数组合(根据测试结果)
temperature = 0.8
top_k = 50
repetition_penalty = 1.3
print(f"\n{'='*80}")
print(f"优化参数: temperature={temperature}, top_k={top_k}, repetition_penalty={repetition_penalty}")
print(f"输入提示: {prompt}")
print(f"{'='*80}")
# 编码提示文本
prompt_tokens = sp.encode(prompt, out_type=int)
# 转换为tensor
context = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
# 生成响应
with torch.no_grad():
generated_tokens = model.generate(
context,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
repetition_penalty=repetition_penalty
)[0].tolist()
generated_text = sp.decode(generated_tokens)
# 提取生成的响应部分(去掉prompt)
response_start = generated_text.find(prompt) + len(prompt)
response = generated_text[response_start:].strip()
# 计算重复率
repetition_rate = calculate_repetition_rate(response)
print(f"完整输出:")
print(f"{generated_text}")
print(f"\n提取的响应:")
print(f"{response}")
print(f"\n评估指标:")
print(f" 输出长度: {len(response)} 字符")
print(f" 重复率: {repetition_rate:.4f}")
return response, repetition_rate
# 测试多个不同类型的提示
print("开始使用优化参数测试模型输出...")
test_prompts = [
"关键词: 信 天涯 晚风",
"关键词: 风 雾 寂寞",
"关键词: 贴心 改变 自信",
"关键词: 午夜 寒冬 心动",
"关键词: 思考 推理 分析",
"关键词: 月光 思念 远方",
"关键词: 梦想 坚持 成功",
"关键词: 春天 希望 新生"
]
total_repetition_rate = 0
total_responses = len(test_prompts)
for i, prompt in enumerate(test_prompts, 1):
print(f"\n🔬 测试 {i}/{total_responses}")
response, repetition_rate = test_output_optimized(prompt)
total_repetition_rate += repetition_rate
# 评估输出质量
if repetition_rate == 0.0:
print(f"✅ 输出质量优秀 - 无重复")
elif repetition_rate < 0.05:
print(f"✅ 输出质量良好 - 轻微重复")
elif repetition_rate < 0.1:
print(f"⚠️ 输出质量一般 - 中等重复")
else:
print(f"❌ 输出质量较差 - 严重重复")
# 计算平均重复率
avg_repetition_rate = total_repetition_rate / total_responses
print(f"\n{'='*80}")
print("🎯 最终测试结果总结")
print(f"{'='*80}")
print(f"测试提示数量: {total_responses}")
print(f"平均重复率: {avg_repetition_rate:.4f}")
print(f"最佳参数组合: temperature=0.8, top_k=50, repetition_penalty=1.3")
print(f"生成长度: 300 tokens")
if avg_repetition_rate == 0.0:
print(f"🎉 优化成功!所有输出均无重复")
elif avg_repetition_rate < 0.05:
print(f"✅ 优化效果良好!平均重复率很低")
elif avg_repetition_rate < 0.1:
print(f"⚠️ 优化效果一般,仍有改进空间")
else:
print(f"❌ 需要进一步优化")
print(f"\n优化前问题: 大量重复词汇(如'兄弟'、'兄弟姐妹'等)")
print(f"优化后效果: 重复率显著降低,输出多样性提高")