| | |
| | |
| |
|
| | from argparse import Namespace |
| | import pandas as pd |
| | from vllm import LLM, EngineArgs |
| | from vllm.utils import FlexibleArgumentParser |
| | import wandb |
| |
|
| | |
| | PREFIX = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' |
| | SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
| |
|
| | DATA_PATH = "/home/data/test_transformed_v1.parquet" |
| | WANDB_PROJECT = "reranker_eval_wrong" |
| | WANDB_RUN_NAME = "qwen3_seqcls_scoring" |
| |
|
| | def format_query(chosen_prompt: str) -> str: |
| | |
| | instruction = ( |
| | "Given a roleplay prompt and recent context, score candidate replies higher when they stay in character, continue the scene coherently, and feel vivid and engaging." |
| | ) |
| | return f"{PREFIX}<Instruct>: {instruction}\n<Query>:{chosen_prompt}\n" |
| |
|
| | def format_document(doc_text: str) -> str: |
| | |
| | return f"<Document>: {doc_text}{SUFFIX}" |
| |
|
| | def parse_args(): |
| | parser = FlexibleArgumentParser() |
| | parser = EngineArgs.add_cli_args(parser) |
| | parser.set_defaults( |
| | model="deeppin/Qwen3-Reranker-8B-SequenceClassification", |
| | task="score", |
| | enforce_eager=True, |
| | trust_remote_code=True, |
| | ) |
| | return parser.parse_args() |
| |
|
| | def main(args: Namespace): |
| | |
| | df = pd.read_parquet(DATA_PATH) |
| |
|
| | wandb.init(project=WANDB_PROJECT, name=WANDB_RUN_NAME) |
| | wandb.config.update({"model": args.model, "data_path": DATA_PATH}) |
| |
|
| | |
| | llm = LLM(**vars(args)) |
| |
|
| | |
| | correct = 0 |
| | total = 0 |
| | wrong_samples = [] |
| | for i, row in df.iterrows(): |
| | chosen_prompt = row["chosen_prompt"] |
| | chosen = row["chosen"] |
| | reject = row["reject"] |
| |
|
| | |
| | if not isinstance(chosen_prompt, str) or not isinstance(chosen, str) or not isinstance(reject, str): |
| | continue |
| | if chosen.strip() == "" or reject.strip() == "": |
| | continue |
| |
|
| | q = format_query(chosen_prompt) |
| | d1 = format_document(chosen) |
| | d2 = format_document(reject) |
| |
|
| | try: |
| | |
| | outs = llm.score([q, q], [d1, d2]) |
| | |
| | s1, s2 = (o.outputs.score for o in outs) |
| | chosen_better = (s1 > s2) |
| | total += 1 |
| | if chosen_better: |
| | correct += 1 |
| | running_acc = correct / total if total > 0 else 0.0 |
| | |
| | print({"chosen_score": s1, "reject_score": s2, "chosen_better": chosen_better},f"[RunningAcc] {correct}/{total} = {running_acc:.4f}") |
| | wandb.log({ |
| | "metric/running_acc": running_acc, |
| | "score/chosen": float(s1), |
| | "score/reject": float(s2), |
| | "score/margin": float(s1 - s2), |
| | }, step=total) |
| | if not chosen_better: |
| | wrong_samples.append({ |
| | "index": int(i), |
| | "chosen_score": float(s1), |
| | "reject_score": float(s2), |
| | "margin": float(s1 - s2), |
| | "chosen_prompt": chosen_prompt, |
| | "chosen": chosen, |
| | "reject": reject, |
| | }) |
| | except Exception as e: |
| | |
| | print(f"[Error] index={i}: {e}") |
| |
|
| | |
| | final_acc = correct / total if total > 0 else 0.0 |
| | print(f"[FinalAcc] {correct}/{total} = {final_acc:.4f}") |
| | wandb.summary["final/accuracy"] = final_acc |
| | wandb.summary["final/total"] = total |
| | wandb.summary["final/correct"] = correct |
| | wandb.summary["final/wrong"] = len(wrong_samples) |
| | |
| | if wrong_samples: |
| | table = wandb.Table(columns=[ |
| | "index", "chosen_score", "reject_score", "margin", |
| | "chosen_prompt", "chosen", "reject" |
| | ]) |
| | for r in wrong_samples: |
| | table.add_data( |
| | r["index"], r["chosen_score"], r["reject_score"], r["margin"], |
| | r["chosen_prompt"], r["chosen"], r["reject"] |
| | ) |
| | wandb.log({"errors/wrong_samples": table}) |
| | |
| | try: |
| | _df = pd.DataFrame(wrong_samples) |
| | _df.to_csv("wrong_samples.csv", index=False) |
| | art = wandb.Artifact("wrong_samples", type="dataset") |
| | art.add_file("wrong_samples.csv") |
| | wandb.log_artifact(art) |
| | except Exception: |
| | pass |
| | wandb.finish() |
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | main(args) |
| |
|