| | |
| | import os, math, argparse, warnings |
| | import pandas as pd |
| | import numpy as np |
| | import torch |
| | from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification |
| | |
| | def build_left_padded_inputs(tokenizer, texts, max_length, device): |
| | |
| | tokenizer.padding_side = "left" |
| | if tokenizer.pad_token_id is None: |
| | if tokenizer.eos_token_id is not None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | else: |
| | tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) |
| |
|
| | enc = tokenizer( |
| | texts, |
| | padding=True, |
| | truncation=True, |
| | max_length=max_length, |
| | return_tensors="pt", |
| | ) |
| | inputs = ( |
| | enc["input_ids"].to(device), |
| | enc["attention_mask"].to(device), |
| | ) |
| | return inputs |
| |
|
| | @torch.inference_mode() |
| | def score_texts_last_token(reward_model, tokenizer, texts, max_length, device): |
| | """ |
| | 返回 shape=[len(texts)] 的标量分数列表。 |
| | 内部严格左填充,确保 [:, -1] 是最后一个非 pad token。 |
| | """ |
| | inputs = build_left_padded_inputs(tokenizer, texts, max_length, device) |
| | |
| | |
| | hidden = reward_model.model(*inputs).last_hidden_state |
| | |
| | score_seq = reward_model.score(hidden) |
| | if score_seq.dim() == 3 and score_seq.size(-1) == 1: |
| | score_seq = score_seq.squeeze(-1) |
| | |
| | scores = score_seq[:, -1] |
| | |
| | scores = torch.nan_to_num(scores, nan=-1e30) |
| | return scores.detach().float().cpu().tolist() |
| |
|
| | |
| | def join_prompt_answer(prompt, answer, joiner="\n"): |
| | p = (prompt or "").rstrip() |
| | a = (answer or "").rstrip() |
| | return f"{p}{joiner}{a}" |
| |
|
| | |
| | def main(): |
| | ap = argparse.ArgumentParser() |
| | ap.add_argument("--data_path", type=str, required=True, |
| | help="包含列 chosen_prompt/chosen/reject 的 parquet 路径") |
| | ap.add_argument("--batch_size", type=int, default=16) |
| | ap.add_argument("--max_length", type=int, default=1024) |
| | ap.add_argument("--joiner", type=str, default="") |
| | |
| | args = ap.parse_args() |
| |
|
| | if not os.path.exists(args.data_path): |
| | raise FileNotFoundError(args.data_path) |
| |
|
| | df = pd.read_parquet(args.data_path) |
| | for col in ["chosen_prompt", "chosen", "reject"]: |
| | if col not in df.columns: |
| | raise ValueError(f"缺列 `{col}`,实际列:{list(df.columns)}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | reward_model = AutoModelForSequenceClassification.from_pretrained( |
| | "/home/rm5.0_9e-6", |
| | num_labels=1, |
| | torch_dtype=torch.bfloat16, |
| | use_cache=False, |
| | ) |
| | tokenizer =AutoTokenizer.from_pretrained("/home/rm5.0_9e-6") |
| | device = next(reward_model.parameters()).device |
| |
|
| | total = len(df) |
| | correct = 0 |
| | seen = 0 |
| |
|
| | print(f"Loaded {total} samples from {args.data_path}") |
| | print("Start evaluating (pairwise chosen vs reject)...\n" + "-" * 70) |
| |
|
| | |
| | for start in range(0, total, args.batch_size): |
| | end = min(start + args.batch_size, total) |
| | batch = df.iloc[start:end] |
| |
|
| | pair_texts = [] |
| | for _, row in batch.iterrows(): |
| | pair_texts.append(join_prompt_answer(row["chosen_prompt"], row["chosen"], args.joiner)) |
| | pair_texts.append(join_prompt_answer(row["chosen_prompt"], row["reject"], args.joiner)) |
| |
|
| | |
| | scores = score_texts_last_token( |
| | reward_model=reward_model, |
| | tokenizer=tokenizer, |
| | texts=pair_texts, |
| | max_length=args.max_length, |
| | device=device, |
| | ) |
| | |
| | for i, (_, row) in enumerate(batch.iterrows()): |
| | chosen_score = float(scores[2 * i]) |
| | reject_score = float(scores[2 * i + 1]) |
| | seen += 1 |
| | is_correct = chosen_score > reject_score |
| | correct += int(is_correct) |
| | running_acc = correct / seen |
| |
|
| | print( |
| | f"[{seen:6d}] " |
| | f"Chosen={chosen_score:.6f} | Reject={reject_score:.6f} | " |
| | f"Correct={is_correct} | RunningAcc={running_acc*100:.2f}%" |
| | ) |
| |
|
| | print("\n" + "-" * 70) |
| | print(f"Finished. Total={seen}, Correct={correct}, FinalAcc={correct/seen*100:.2f}%") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|