File size: 5,669 Bytes
8b57151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
import sentencepiece as spm
from model_optimized import MemoryOptimizedBigramLM

# 设备设置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")

# 加载tokenizer
sp = spm.SentencePieceProcessor()
sp.load("tokenizer.model")
vocab_size = sp.get_piece_size()
print(f"词汇表大小: {vocab_size}")

# 模型参数(与训练时保持一致)
d_model = 512
max_seq_len = 2048
h = 8
Nx = 6
dropout_rate = 0.2

# 创建模型
model = MemoryOptimizedBigramLM(
    vocab_size=vocab_size,
    d_model=d_model,
    max_seq_len=max_seq_len,
    h=h,
    Nx=Nx,
    dropout_rate=dropout_rate
)

# 加载最新的训练模型权重
try:
    checkpoint = torch.load("saved_models/gpt_model_enhanced_stop_20251004_181034.pth", map_location=device, weights_only=False)
    
    # 过滤掉mask相关的键,因为它们不是模型参数而是缓冲区
    state_dict = checkpoint['model_state_dict']
    filtered_state_dict = {k: v for k, v in state_dict.items() if 'mask' not in k}
    
    model.load_state_dict(filtered_state_dict, strict=False)
    print("✅ 成功加载最新训练模型权重")
    print(f"训练迭代次数: {checkpoint['iteration']}")
    print(f"最终训练损失: {checkpoint['train_losses'][-1]:.4f}")
    print(f"最终验证损失: {checkpoint['valid_losses'][-1]:.4f}")
    print(f"最终训练PPL: {checkpoint['train_ppls'][-1]:.2f}")
    print(f"最终验证PPL: {checkpoint['valid_ppls'][-1]:.2f}")
except Exception as e:
    print(f" 加载模型失败: {e}")
    exit(1)

model = model.to(device)
model.eval()

def calculate_repetition_rate(text):
    """计算文本的重复率"""
    words = text.split()
    if len(words) < 2:
        return 0.0
    
    # 计算连续重复的比率
    repeated_count = 0
    total_pairs = len(words) - 1
    
    for i in range(total_pairs):
        if words[i] == words[i+1]:
            repeated_count += 1
    
    return repeated_count / total_pairs if total_pairs > 0 else 0.0

def test_output_optimized(prompt, max_new_tokens=300):
    """使用优化参数测试模型输出功能"""
    # 最佳参数组合(根据测试结果)
    temperature = 0.8
    top_k = 50
    repetition_penalty = 1.3
    
    print(f"\n{'='*80}")
    print(f"优化参数: temperature={temperature}, top_k={top_k}, repetition_penalty={repetition_penalty}")
    print(f"输入提示: {prompt}")
    print(f"{'='*80}")
    
    # 编码提示文本
    prompt_tokens = sp.encode(prompt, out_type=int)
    
    # 转换为tensor
    context = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
    
    # 生成响应
    with torch.no_grad():
        generated_tokens = model.generate(
            context, 
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            repetition_penalty=repetition_penalty
        )[0].tolist()
        
        generated_text = sp.decode(generated_tokens)
        
        # 提取生成的响应部分(去掉prompt)
        response_start = generated_text.find(prompt) + len(prompt)
        response = generated_text[response_start:].strip()
        
        # 计算重复率
        repetition_rate = calculate_repetition_rate(response)
        
        print(f"完整输出:")
        print(f"{generated_text}")
        print(f"\n提取的响应:")
        print(f"{response}")
        print(f"\n评估指标:")
        print(f"  输出长度: {len(response)} 字符")
        print(f"  重复率: {repetition_rate:.4f}")
        
        return response, repetition_rate

# 测试多个不同类型的提示
print("开始使用优化参数测试模型输出...")
test_prompts = [
    "关键词: 信 天涯 晚风",
    "关键词: 风 雾 寂寞",
    "关键词: 贴心 改变 自信",
    "关键词: 午夜 寒冬 心动",
    "关键词: 思考 推理 分析",
    "关键词: 月光 思念 远方",
    "关键词: 梦想 坚持 成功",
    "关键词: 春天 希望 新生"
]

total_repetition_rate = 0
total_responses = len(test_prompts)

for i, prompt in enumerate(test_prompts, 1):
    print(f"\n🔬 测试 {i}/{total_responses}")
    response, repetition_rate = test_output_optimized(prompt)
    total_repetition_rate += repetition_rate
    
    # 评估输出质量
    if repetition_rate == 0.0:
        print(f"✅ 输出质量优秀 - 无重复")
    elif repetition_rate < 0.05:
        print(f"✅ 输出质量良好 - 轻微重复")
    elif repetition_rate < 0.1:
        print(f"⚠️ 输出质量一般 - 中等重复")
    else:
        print(f"❌ 输出质量较差 - 严重重复")

# 计算平均重复率
avg_repetition_rate = total_repetition_rate / total_responses

print(f"\n{'='*80}")
print("🎯 最终测试结果总结")
print(f"{'='*80}")
print(f"测试提示数量: {total_responses}")
print(f"平均重复率: {avg_repetition_rate:.4f}")
print(f"最佳参数组合: temperature=0.8, top_k=50, repetition_penalty=1.3")
print(f"生成长度: 300 tokens")

if avg_repetition_rate == 0.0:
    print(f"🎉 优化成功!所有输出均无重复")
elif avg_repetition_rate < 0.05:
    print(f"✅ 优化效果良好!平均重复率很低")
elif avg_repetition_rate < 0.1:
    print(f"⚠️ 优化效果一般,仍有改进空间")
else:
    print(f"❌ 需要进一步优化")

print(f"\n优化前问题: 大量重复词汇(如'兄弟'、'兄弟姐妹'等)")
print(f"优化后效果: 重复率显著降低,输出多样性提高")