File size: 4,663 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch, wandb, pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig

# === 参数 ===
rm_path   = "/home/rm5.0_9e-6"
data_path = "/home/data/raw/test/1159-L6_format_full_label_v5.0safe.parquet"
batch_size = 16
max_length = 8192

# === wandb ===
wandb.init(project="reward_model_scoring", name="5.0_9e-6")

# === 模型 & tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(rm_path, trust_remote_code=True)
tokenizer.padding_side = "left"                          # ← 修改
config = AutoConfig.from_pretrained(rm_path)
config.num_labels = 1                                    # reward head
model = AutoModelForSequenceClassification.from_pretrained(
        rm_path, config=config, device_map="auto")
model.eval()

device = next(model.parameters()).device

# === 数据 ===
df = pd.read_parquet(data_path).sample(n=1500).reset_index(drop=True)

def format_input(prompt, reply):                         # ← 修改
    txt = (prompt + reply).rstrip("\n")
    if not txt.endswith(tokenizer.eos_token):
        txt += " " + tokenizer.eos_token
    return txt

def encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device):
    # 1 tokenize
    ch = tokenizer(chosen_texts,  add_special_tokens=False,
                   truncation=True, max_length=max_length, padding=False)
    rj = tokenizer(rejected_texts, add_special_tokens=False,
                   truncation=True, max_length=max_length, padding=False)
    ids1, mask1 = ch["input_ids"], ch["attention_mask"]
    ids2, mask2 = rj["input_ids"], rj["attention_mask"]

    # 2 ensure eos 存在
    for arr_ids, arr_mask in ((ids1, mask1), (ids2, mask2)):
        for i in range(len(arr_ids)):
            arr_ids[i][-1]  = tokenizer.eos_token_id
            arr_mask[i][-1] = 1

    # 3 left-pad 到 joint_max
    joint_max = max(max(len(x) for x in ids1), max(len(x) for x in ids2))
    lpad = lambda seq, pad: [pad]*(joint_max-len(seq)) + seq
    ids1  = [lpad(x, tokenizer.pad_token_id) for x in ids1]
    ids2  = [lpad(x, tokenizer.pad_token_id) for x in ids2]
    mask1 = [lpad(x, 0)                      for x in mask1]
    mask2 = [lpad(x, 0)                      for x in mask2]

    input_ids  = torch.tensor(ids1 + ids2,  dtype=torch.long).to(device)
    attn_masks = torch.tensor(mask1 + mask2, dtype=torch.long).to(device)
    return input_ids, attn_masks, len(chosen_texts)

# === 推理 ===
chosen_scores, rejected_scores, accs = [], [], []
sample_table = wandb.Table(columns=["index","prompt","chosen","rejected",
                                    "chosen_score","rejected_score","delta","acc"])

for i in tqdm(range(0, len(df), batch_size)):
    batch = df.iloc[i:i+batch_size]
    chosen_texts   = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["chosen"])]
    rejected_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["reject"])]

    input_ids, attn_masks, split = encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device)

    with torch.no_grad():
        rewards = model(input_ids=input_ids, attention_mask=attn_masks).logits.squeeze(-1)
        if config.std is not None and config.mean is not None:         # ← 修改
            rewards = rewards * config.std + config.mean

    chosen_r, rejected_r = rewards[:split], rewards[split:]

    for j in range(len(chosen_r)):
        idx = i + j
        c, r = chosen_r[j].item(), rejected_r[j].item()
        delta = c - r
        acc   = int(delta > 0)
        chosen_scores.append(c); rejected_scores.append(r); accs.append(acc)
        avg_acc = sum(accs) / len(accs)
        print(f"[{idx}] acc={acc}, chosen={c:.3f}, rejected={r:.3f}, Δ={delta:.3f} | avg acc={avg_acc:.3f}")

      
        sample_table.add_data(idx, batch["chosen_prompt"].iloc[j],
                                  batch["chosen"].iloc[j], batch["reject"].iloc[j],
                                  c, r, delta, acc)

# === 结果 ===
df["chosen_score"] = chosen_scores
df["rejected_score"] = rejected_scores
df["delta"]  = df["chosen_score"] - df["rejected_score"]
df["acc"]    = accs

accuracy     = df["acc"].mean()
mean_chosen  = df["chosen_score"].mean()
mean_reject  = df["rejected_score"].mean()
mean_delta   = df["delta"].mean()

print(f"\n✅ Accuracy  = {accuracy:.3f}")
print(f"📊 mean_chosen = {mean_chosen:.3f}, mean_rejected = {mean_reject:.3f}, mean_delta = {mean_delta:.3f}")

wandb.log({
    "samples_table": sample_table,
    "final_accuracy": accuracy,
    "mean_chosen_score": mean_chosen,
    "mean_rejected_score": mean_reject,
    "mean_delta_score": mean_delta,
})
wandb.finish()