| | import torch, wandb, pandas as pd |
| | from tqdm import tqdm |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig |
| |
|
| | |
| | rm_path = "/home/rm5.0_9e-6" |
| | data_path = "/home/data/raw/test/1159-L6_format_full_label_v5.0safe.parquet" |
| | batch_size = 16 |
| | max_length = 8192 |
| |
|
| | |
| | wandb.init(project="reward_model_scoring", name="5.0_9e-6") |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(rm_path, trust_remote_code=True) |
| | tokenizer.padding_side = "left" |
| | config = AutoConfig.from_pretrained(rm_path) |
| | config.num_labels = 1 |
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | rm_path, config=config, device_map="auto") |
| | model.eval() |
| |
|
| | device = next(model.parameters()).device |
| |
|
| | |
| | df = pd.read_parquet(data_path).sample(n=1500).reset_index(drop=True) |
| |
|
| | def format_input(prompt, reply): |
| | txt = (prompt + reply).rstrip("\n") |
| | if not txt.endswith(tokenizer.eos_token): |
| | txt += " " + tokenizer.eos_token |
| | return txt |
| |
|
| | def encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device): |
| | |
| | ch = tokenizer(chosen_texts, add_special_tokens=False, |
| | truncation=True, max_length=max_length, padding=False) |
| | rj = tokenizer(rejected_texts, add_special_tokens=False, |
| | truncation=True, max_length=max_length, padding=False) |
| | ids1, mask1 = ch["input_ids"], ch["attention_mask"] |
| | ids2, mask2 = rj["input_ids"], rj["attention_mask"] |
| |
|
| | |
| | for arr_ids, arr_mask in ((ids1, mask1), (ids2, mask2)): |
| | for i in range(len(arr_ids)): |
| | arr_ids[i][-1] = tokenizer.eos_token_id |
| | arr_mask[i][-1] = 1 |
| |
|
| | |
| | joint_max = max(max(len(x) for x in ids1), max(len(x) for x in ids2)) |
| | lpad = lambda seq, pad: [pad]*(joint_max-len(seq)) + seq |
| | ids1 = [lpad(x, tokenizer.pad_token_id) for x in ids1] |
| | ids2 = [lpad(x, tokenizer.pad_token_id) for x in ids2] |
| | mask1 = [lpad(x, 0) for x in mask1] |
| | mask2 = [lpad(x, 0) for x in mask2] |
| |
|
| | input_ids = torch.tensor(ids1 + ids2, dtype=torch.long).to(device) |
| | attn_masks = torch.tensor(mask1 + mask2, dtype=torch.long).to(device) |
| | return input_ids, attn_masks, len(chosen_texts) |
| |
|
| | |
| | chosen_scores, rejected_scores, accs = [], [], [] |
| | sample_table = wandb.Table(columns=["index","prompt","chosen","rejected", |
| | "chosen_score","rejected_score","delta","acc"]) |
| |
|
| | for i in tqdm(range(0, len(df), batch_size)): |
| | batch = df.iloc[i:i+batch_size] |
| | chosen_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["chosen"])] |
| | rejected_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["reject"])] |
| |
|
| | input_ids, attn_masks, split = encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device) |
| |
|
| | with torch.no_grad(): |
| | rewards = model(input_ids=input_ids, attention_mask=attn_masks).logits.squeeze(-1) |
| | if config.std is not None and config.mean is not None: |
| | rewards = rewards * config.std + config.mean |
| |
|
| | chosen_r, rejected_r = rewards[:split], rewards[split:] |
| |
|
| | for j in range(len(chosen_r)): |
| | idx = i + j |
| | c, r = chosen_r[j].item(), rejected_r[j].item() |
| | delta = c - r |
| | acc = int(delta > 0) |
| | chosen_scores.append(c); rejected_scores.append(r); accs.append(acc) |
| | avg_acc = sum(accs) / len(accs) |
| | print(f"[{idx}] acc={acc}, chosen={c:.3f}, rejected={r:.3f}, Δ={delta:.3f} | avg acc={avg_acc:.3f}") |
| |
|
| | |
| | sample_table.add_data(idx, batch["chosen_prompt"].iloc[j], |
| | batch["chosen"].iloc[j], batch["reject"].iloc[j], |
| | c, r, delta, acc) |
| |
|
| | |
| | df["chosen_score"] = chosen_scores |
| | df["rejected_score"] = rejected_scores |
| | df["delta"] = df["chosen_score"] - df["rejected_score"] |
| | df["acc"] = accs |
| |
|
| | accuracy = df["acc"].mean() |
| | mean_chosen = df["chosen_score"].mean() |
| | mean_reject = df["rejected_score"].mean() |
| | mean_delta = df["delta"].mean() |
| |
|
| | print(f"\n✅ Accuracy = {accuracy:.3f}") |
| | print(f"📊 mean_chosen = {mean_chosen:.3f}, mean_rejected = {mean_reject:.3f}, mean_delta = {mean_delta:.3f}") |
| |
|
| | wandb.log({ |
| | "samples_table": sample_table, |
| | "final_accuracy": accuracy, |
| | "mean_chosen_score": mean_chosen, |
| | "mean_rejected_score": mean_reject, |
| | "mean_delta_score": mean_delta, |
| | }) |
| | wandb.finish() |
| |
|