|
|
import torch
|
|
|
import sentencepiece as spm
|
|
|
from model_optimized import MemoryOptimizedBigramLM
|
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
print(f"使用设备: {device}")
|
|
|
|
|
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
|
sp.load("tokenizer.model")
|
|
|
vocab_size = sp.get_piece_size()
|
|
|
print(f"词汇表大小: {vocab_size}")
|
|
|
|
|
|
|
|
|
d_model = 512
|
|
|
max_seq_len = 2048
|
|
|
h = 8
|
|
|
Nx = 6
|
|
|
dropout_rate = 0.2
|
|
|
|
|
|
|
|
|
model = MemoryOptimizedBigramLM(
|
|
|
vocab_size=vocab_size,
|
|
|
d_model=d_model,
|
|
|
max_seq_len=max_seq_len,
|
|
|
h=h,
|
|
|
Nx=Nx,
|
|
|
dropout_rate=dropout_rate
|
|
|
)
|
|
|
|
|
|
|
|
|
try:
|
|
|
checkpoint = torch.load("saved_models/gpt_model_enhanced_stop_20251004_181034.pth", map_location=device, weights_only=False)
|
|
|
|
|
|
|
|
|
state_dict = checkpoint['model_state_dict']
|
|
|
filtered_state_dict = {k: v for k, v in state_dict.items() if 'mask' not in k}
|
|
|
|
|
|
model.load_state_dict(filtered_state_dict, strict=False)
|
|
|
print("✅ 成功加载最新训练模型权重")
|
|
|
print(f"训练迭代次数: {checkpoint['iteration']}")
|
|
|
print(f"最终训练损失: {checkpoint['train_losses'][-1]:.4f}")
|
|
|
print(f"最终验证损失: {checkpoint['valid_losses'][-1]:.4f}")
|
|
|
print(f"最终训练PPL: {checkpoint['train_ppls'][-1]:.2f}")
|
|
|
print(f"最终验证PPL: {checkpoint['valid_ppls'][-1]:.2f}")
|
|
|
except Exception as e:
|
|
|
print(f" 加载模型失败: {e}")
|
|
|
exit(1)
|
|
|
|
|
|
model = model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
def calculate_repetition_rate(text):
|
|
|
"""计算文本的重复率"""
|
|
|
words = text.split()
|
|
|
if len(words) < 2:
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
repeated_count = 0
|
|
|
total_pairs = len(words) - 1
|
|
|
|
|
|
for i in range(total_pairs):
|
|
|
if words[i] == words[i+1]:
|
|
|
repeated_count += 1
|
|
|
|
|
|
return repeated_count / total_pairs if total_pairs > 0 else 0.0
|
|
|
|
|
|
def test_output_optimized(prompt, max_new_tokens=300):
|
|
|
"""使用优化参数测试模型输出功能"""
|
|
|
|
|
|
temperature = 0.8
|
|
|
top_k = 50
|
|
|
repetition_penalty = 1.3
|
|
|
|
|
|
print(f"\n{'='*80}")
|
|
|
print(f"优化参数: temperature={temperature}, top_k={top_k}, repetition_penalty={repetition_penalty}")
|
|
|
print(f"输入提示: {prompt}")
|
|
|
print(f"{'='*80}")
|
|
|
|
|
|
|
|
|
prompt_tokens = sp.encode(prompt, out_type=int)
|
|
|
|
|
|
|
|
|
context = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
generated_tokens = model.generate(
|
|
|
context,
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
temperature=temperature,
|
|
|
top_k=top_k,
|
|
|
repetition_penalty=repetition_penalty
|
|
|
)[0].tolist()
|
|
|
|
|
|
generated_text = sp.decode(generated_tokens)
|
|
|
|
|
|
|
|
|
response_start = generated_text.find(prompt) + len(prompt)
|
|
|
response = generated_text[response_start:].strip()
|
|
|
|
|
|
|
|
|
repetition_rate = calculate_repetition_rate(response)
|
|
|
|
|
|
print(f"完整输出:")
|
|
|
print(f"{generated_text}")
|
|
|
print(f"\n提取的响应:")
|
|
|
print(f"{response}")
|
|
|
print(f"\n评估指标:")
|
|
|
print(f" 输出长度: {len(response)} 字符")
|
|
|
print(f" 重复率: {repetition_rate:.4f}")
|
|
|
|
|
|
return response, repetition_rate
|
|
|
|
|
|
|
|
|
print("开始使用优化参数测试模型输出...")
|
|
|
test_prompts = [
|
|
|
"关键词: 信 天涯 晚风",
|
|
|
"关键词: 风 雾 寂寞",
|
|
|
"关键词: 贴心 改变 自信",
|
|
|
"关键词: 午夜 寒冬 心动",
|
|
|
"关键词: 思考 推理 分析",
|
|
|
"关键词: 月光 思念 远方",
|
|
|
"关键词: 梦想 坚持 成功",
|
|
|
"关键词: 春天 希望 新生"
|
|
|
]
|
|
|
|
|
|
total_repetition_rate = 0
|
|
|
total_responses = len(test_prompts)
|
|
|
|
|
|
for i, prompt in enumerate(test_prompts, 1):
|
|
|
print(f"\n🔬 测试 {i}/{total_responses}")
|
|
|
response, repetition_rate = test_output_optimized(prompt)
|
|
|
total_repetition_rate += repetition_rate
|
|
|
|
|
|
|
|
|
if repetition_rate == 0.0:
|
|
|
print(f"✅ 输出质量优秀 - 无重复")
|
|
|
elif repetition_rate < 0.05:
|
|
|
print(f"✅ 输出质量良好 - 轻微重复")
|
|
|
elif repetition_rate < 0.1:
|
|
|
print(f"⚠️ 输出质量一般 - 中等重复")
|
|
|
else:
|
|
|
print(f"❌ 输出质量较差 - 严重重复")
|
|
|
|
|
|
|
|
|
avg_repetition_rate = total_repetition_rate / total_responses
|
|
|
|
|
|
print(f"\n{'='*80}")
|
|
|
print("🎯 最终测试结果总结")
|
|
|
print(f"{'='*80}")
|
|
|
print(f"测试提示数量: {total_responses}")
|
|
|
print(f"平均重复率: {avg_repetition_rate:.4f}")
|
|
|
print(f"最佳参数组合: temperature=0.8, top_k=50, repetition_penalty=1.3")
|
|
|
print(f"生成长度: 300 tokens")
|
|
|
|
|
|
if avg_repetition_rate == 0.0:
|
|
|
print(f"🎉 优化成功!所有输出均无重复")
|
|
|
elif avg_repetition_rate < 0.05:
|
|
|
print(f"✅ 优化效果良好!平均重复率很低")
|
|
|
elif avg_repetition_rate < 0.1:
|
|
|
print(f"⚠️ 优化效果一般,仍有改进空间")
|
|
|
else:
|
|
|
print(f"❌ 需要进一步优化")
|
|
|
|
|
|
print(f"\n优化前问题: 大量重复词汇(如'兄弟'、'兄弟姐妹'等)")
|
|
|
print(f"优化后效果: 重复率显著降低,输出多样性提高")
|
|
|
|