rm_code / reward_acc.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
import torch, wandb, pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
# === 参数 ===
rm_path = "/home/rm5.0_9e-6"
data_path = "/home/data/raw/test/1159-L6_format_full_label_v5.0safe.parquet"
batch_size = 16
max_length = 8192
# === wandb ===
wandb.init(project="reward_model_scoring", name="5.0_9e-6")
# === 模型 & tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(rm_path, trust_remote_code=True)
tokenizer.padding_side = "left" # ← 修改
config = AutoConfig.from_pretrained(rm_path)
config.num_labels = 1 # reward head
model = AutoModelForSequenceClassification.from_pretrained(
rm_path, config=config, device_map="auto")
model.eval()
device = next(model.parameters()).device
# === 数据 ===
df = pd.read_parquet(data_path).sample(n=1500).reset_index(drop=True)
def format_input(prompt, reply): # ← 修改
txt = (prompt + reply).rstrip("\n")
if not txt.endswith(tokenizer.eos_token):
txt += " " + tokenizer.eos_token
return txt
def encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device):
# 1 tokenize
ch = tokenizer(chosen_texts, add_special_tokens=False,
truncation=True, max_length=max_length, padding=False)
rj = tokenizer(rejected_texts, add_special_tokens=False,
truncation=True, max_length=max_length, padding=False)
ids1, mask1 = ch["input_ids"], ch["attention_mask"]
ids2, mask2 = rj["input_ids"], rj["attention_mask"]
# 2 ensure eos 存在
for arr_ids, arr_mask in ((ids1, mask1), (ids2, mask2)):
for i in range(len(arr_ids)):
arr_ids[i][-1] = tokenizer.eos_token_id
arr_mask[i][-1] = 1
# 3 left-pad 到 joint_max
joint_max = max(max(len(x) for x in ids1), max(len(x) for x in ids2))
lpad = lambda seq, pad: [pad]*(joint_max-len(seq)) + seq
ids1 = [lpad(x, tokenizer.pad_token_id) for x in ids1]
ids2 = [lpad(x, tokenizer.pad_token_id) for x in ids2]
mask1 = [lpad(x, 0) for x in mask1]
mask2 = [lpad(x, 0) for x in mask2]
input_ids = torch.tensor(ids1 + ids2, dtype=torch.long).to(device)
attn_masks = torch.tensor(mask1 + mask2, dtype=torch.long).to(device)
return input_ids, attn_masks, len(chosen_texts)
# === 推理 ===
chosen_scores, rejected_scores, accs = [], [], []
sample_table = wandb.Table(columns=["index","prompt","chosen","rejected",
"chosen_score","rejected_score","delta","acc"])
for i in tqdm(range(0, len(df), batch_size)):
batch = df.iloc[i:i+batch_size]
chosen_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["chosen"])]
rejected_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["reject"])]
input_ids, attn_masks, split = encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device)
with torch.no_grad():
rewards = model(input_ids=input_ids, attention_mask=attn_masks).logits.squeeze(-1)
if config.std is not None and config.mean is not None: # ← 修改
rewards = rewards * config.std + config.mean
chosen_r, rejected_r = rewards[:split], rewards[split:]
for j in range(len(chosen_r)):
idx = i + j
c, r = chosen_r[j].item(), rejected_r[j].item()
delta = c - r
acc = int(delta > 0)
chosen_scores.append(c); rejected_scores.append(r); accs.append(acc)
avg_acc = sum(accs) / len(accs)
print(f"[{idx}] acc={acc}, chosen={c:.3f}, rejected={r:.3f}, Δ={delta:.3f} | avg acc={avg_acc:.3f}")
sample_table.add_data(idx, batch["chosen_prompt"].iloc[j],
batch["chosen"].iloc[j], batch["reject"].iloc[j],
c, r, delta, acc)
# === 结果 ===
df["chosen_score"] = chosen_scores
df["rejected_score"] = rejected_scores
df["delta"] = df["chosen_score"] - df["rejected_score"]
df["acc"] = accs
accuracy = df["acc"].mean()
mean_chosen = df["chosen_score"].mean()
mean_reject = df["rejected_score"].mean()
mean_delta = df["delta"].mean()
print(f"\n✅ Accuracy = {accuracy:.3f}")
print(f"📊 mean_chosen = {mean_chosen:.3f}, mean_rejected = {mean_reject:.3f}, mean_delta = {mean_delta:.3f}")
wandb.log({
"samples_table": sample_table,
"final_accuracy": accuracy,
"mean_chosen_score": mean_chosen,
"mean_rejected_score": mean_reject,
"mean_delta_score": mean_delta,
})
wandb.finish()