File size: 4,017 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import torch
import pandas as pd
import wandb

# === 初始化 wandb ===
wandb.init(
    project="reward_model_scoring",      # 项目名可自定义
    name="fomatted_5e-6_1500",           # 当前 run 的名称
)

# === 模型路径(你保存训练结果的目录)===
rm_path = "/home/ckpt/5e-6/global_step180_hf"  # 你的 reward model 存放目录

# === 加载 tokenizer(包含 special token)===
tokenizer = AutoTokenizer.from_pretrained(rm_path)

# === 加载 config 并确保 num_labels=1 ===
config = AutoConfig.from_pretrained(rm_path)
config.num_labels = 1

# === 加载奖励模型 ===
model = AutoModelForSequenceClassification.from_pretrained(
    rm_path,
    config=config,
    device_map="auto"
)
model.eval()

# === 套壳函数:输入一批文本 → 输出一批 reward 分数 ===
def get_reward_score(texts):
    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=8192,
    ).to(model.device)

    with torch.no_grad():
        outputs = model(**inputs)
        scores = outputs.logits.squeeze(-1).float().cpu().tolist()  # shape: [batch_size]
    return scores

# === 读取你要打分的数据集 ===
df = pd.read_parquet("/home/data/formatted_test.parquet").sample(n=1500, random_state=42).reset_index(drop=True)
  # 字段必须包含 chosen_prompt, chosen, reject

def format_input(prompt, answer):
    return prompt + answer  # 可替换为 prompt + "\n\n" + answer

chosen_texts = [format_input(p, a) for p, a in zip(df["chosen_prompt"], df["chosen"])]
rejected_texts = [format_input(p, a) for p, a in zip(df["chosen_prompt"], df["reject"])]

# === 初始化打分结果列表 ===
chosen_scores, rejected_scores, accs = [], [], []

# === 创建 wandb 表格用于可视化 ===
sample_table = wandb.Table(columns=[
    "index", "prompt", "chosen", "rejected",
    "chosen_score", "rejected_score", "delta_score", "acc"
])

# === 分批次打分 + 实时打印 + 写入 wandb 表格 ===
batch_size = 16
for i in range(0, len(chosen_texts), batch_size):
    chosen_batch = chosen_texts[i:i+batch_size]
    rejected_batch = rejected_texts[i:i+batch_size]

    chosen_batch_scores = get_reward_score(chosen_batch)
    rejected_batch_scores = get_reward_score(rejected_batch)

    for j in range(len(chosen_batch_scores)):
        idx = i + j
        c_score = chosen_batch_scores[j]
        r_score = rejected_batch_scores[j]
        delta = c_score - r_score
        acc = int(delta > 0)


        # ✅ 写入全局结果
        chosen_scores.append(c_score)
        rejected_scores.append(r_score)
        accs.append(acc)
        current_accuracy = sum(accs) / len(accs)
        print(f"[{idx}] acc={acc}, chosen_reward={c_score:.3f}, reject_reward={r_score:.3f} | 当前平均准确率: {current_accuracy:.3f}")

        # ✅ 添加到 wandb 表格
        sample_table.add_data(
            idx,
            df.loc[idx, "chosen_prompt"],
            df.loc[idx, "chosen"],
            df.loc[idx, "reject"],
            c_score,
            r_score,
            delta,
            acc
        )

# === 写入打分结果到 DataFrame ===
df["chosen_score"] = chosen_scores
df["rejected_score"] = rejected_scores
df["delta_score"] = df["chosen_score"] - df["rejected_score"]
df["acc"] = accs

# === 显示平均指标 ===
accuracy = df["acc"].mean()
mean_chosen = df["chosen_score"].mean()
mean_rejected = df["rejected_score"].mean()
mean_delta = df["delta_score"].mean()

print(f"\n✅ Reward Model Accuracy = {accuracy:.3f}")
print(f"📊 mean_chosen = {mean_chosen:.3f}, mean_rejected = {mean_rejected:.3f}, mean_delta = {mean_delta:.3f}")

# === log 到 wandb ===
wandb.log({
    "samples_table": sample_table,
    "final_accuracy": accuracy,
    "mean_chosen_score": mean_chosen,
    "mean_rejected_score": mean_rejected,
    "mean_delta_score": mean_delta,
})


# === 关闭 wandb run ===
wandb.finish()