| |
| """LLM-as-a-judge semantic agreement over human/v5 answer pairs.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import re |
| import time |
| from collections import Counter, defaultdict |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| from openai import OpenAI |
|
|
|
|
| LABEL_SCORES = { |
| "full_agreement": 1.0, |
| "partial_agreement": 0.5, |
| "disagreement": 0.0, |
| "absence_mismatch": 0.0, |
| } |
|
|
|
|
| SYSTEM_PROMPT = """You are evaluating semantic agreement between two annotators' free-text answers to the same benchmark question. |
| |
| You are not judging whether either answer is true against the original chat history. Judge only whether Annotator A and Annotator B give the same answer to the question. |
| |
| Use these labels: |
| - full_agreement: Both answers express the same final answer. Minor wording differences, extra explanation, or extra session IDs/dates are okay if the answer content is equivalent. |
| - partial_agreement: The answers overlap on at least one key fact, but one answer is incomplete, has an extra unsupported detail, or differs on a secondary part. |
| - disagreement: The answers materially disagree on the main answer, count, entity, location, date, ordering, or conclusion. |
| - absence_mismatch: One answer says the information is unavailable/not yet/not mentioned/unsolvable while the other gives a concrete answer. |
| |
| Scores: |
| - full_agreement = 1.0 |
| - partial_agreement = 0.5 |
| - disagreement = 0.0 |
| - absence_mismatch = 0.0 |
| |
| Be strict about numbers, named entities, time/order constraints, and yes/no conclusions. If both answers say the information is unavailable, this is full_agreement even if one answer is much longer. |
| |
| Return JSON only with this schema: |
| { |
| "label": "full_agreement|partial_agreement|disagreement|absence_mismatch", |
| "score": 1.0, |
| "rationale": "one short sentence", |
| "mismatch": "empty string if no mismatch; otherwise short mismatch description" |
| } |
| """ |
|
|
|
|
| def load_json(path: str | Path) -> Any: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def append_jsonl(path: Path, row: Dict[str, Any]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "a", encoding="utf-8") as f: |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") |
|
|
|
|
| def read_existing(path: Path) -> Dict[str, Dict[str, Any]]: |
| if not path.exists(): |
| return {} |
| out = {} |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| if not line.strip(): |
| continue |
| row = json.loads(line) |
| out[row["question_id"]] = row |
| return out |
|
|
|
|
| def parse_json(text: str) -> Dict[str, Any]: |
| text = text.strip() |
| if text.startswith("```"): |
| text = re.sub(r"^```(?:json)?", "", text).strip() |
| text = re.sub(r"```$", "", text).strip() |
| try: |
| return json.loads(text) |
| except json.JSONDecodeError: |
| start = text.find("{") |
| end = text.rfind("}") + 1 |
| if start >= 0 and end > start: |
| return json.loads(text[start:end]) |
| raise |
|
|
|
|
| def build_user_prompt(row: Dict[str, Any]) -> str: |
| return f"""Question: |
| {row["question"]} |
| |
| Question type: |
| {row["question_type"]} |
| |
| Annotator A answer: |
| {row["v5_answer"]} |
| |
| Annotator B answer: |
| {row["human_answer"]} |
| |
| Judge whether Annotator A and Annotator B agree semantically.""" |
|
|
|
|
| def judge_one(client: OpenAI, model: str, row: Dict[str, Any], max_retries: int = 5) -> Dict[str, Any]: |
| prompt = build_user_prompt(row) |
| for attempt in range(max_retries): |
| try: |
| response = client.chat.completions.create( |
| model=model, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt}, |
| ], |
| temperature=0, |
| max_tokens=220, |
| ) |
| content = response.choices[0].message.content or "" |
| parsed = parse_json(content) |
| label = parsed.get("label") |
| if label not in LABEL_SCORES: |
| raise ValueError(f"bad label: {label!r}; raw={content[:300]}") |
| score = float(parsed.get("score", LABEL_SCORES[label])) |
| expected = LABEL_SCORES[label] |
| if abs(score - expected) > 1e-6: |
| score = expected |
| return { |
| **row, |
| "judge_label": label, |
| "judge_score": score, |
| "judge_rationale": str(parsed.get("rationale", "")), |
| "judge_mismatch": str(parsed.get("mismatch", "")), |
| "judge_raw": content, |
| } |
| except Exception as exc: |
| if attempt == max_retries - 1: |
| raise |
| wait = min(30, 2 ** attempt) |
| print(f"[retry] {row['question_id']} attempt={attempt + 1}: {type(exc).__name__}: {exc}; sleep={wait}", flush=True) |
| time.sleep(wait) |
| raise RuntimeError("unreachable") |
|
|
|
|
| def summarize(rows: List[Dict[str, Any]]) -> Dict[str, Any]: |
| def group_counts(key: str) -> Dict[str, Dict[str, Any]]: |
| grouped = defaultdict(list) |
| for row in rows: |
| grouped[str(row.get(key, ""))].append(row) |
| return { |
| name: { |
| "n": len(items), |
| "mean_score": sum(x["judge_score"] for x in items) / len(items), |
| "label_counts": dict(Counter(x["judge_label"] for x in items)), |
| } |
| for name, items in sorted(grouped.items()) |
| } |
|
|
| return { |
| "n": len(rows), |
| "mean_score": sum(x["judge_score"] for x in rows) / len(rows) if rows else 0.0, |
| "label_counts": dict(Counter(x["judge_label"] for x in rows)), |
| "by_prior_category": group_counts("category"), |
| "by_annotator": group_counts("annotator"), |
| "by_question_type": group_counts("question_type"), |
| } |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--agreement_file", default="dataset/evolv_mem_v5_submitted_v3_96_agreement.json") |
| parser.add_argument("--v5_file", default="dataset/evolv_mem_v5.json") |
| parser.add_argument("--out_jsonl", default="dataset/evolv_mem_v5_submitted_v3_96_llm_judge_gpt41.jsonl") |
| parser.add_argument("--out_report", default="dataset/evolv_mem_v5_submitted_v3_96_llm_judge_gpt41_report.json") |
| parser.add_argument("--model", default="us/azure/openai/gpt-4.1") |
| parser.add_argument("--base_url", default="https://inference-api.nvidia.com/v1") |
| parser.add_argument("--limit", type=int, default=None) |
| args = parser.parse_args() |
|
|
| api_key = os.getenv("NV_API_KEY") |
| if not api_key: |
| raise SystemExit("NV_API_KEY is not set") |
|
|
| agreement = load_json(args.agreement_file) |
| v5_by_qid = {x["question_id"]: x for x in load_json(args.v5_file)} |
| rows = [] |
| for row in agreement["rows"]: |
| qid = row["question_id"] |
| v5 = v5_by_qid[qid] |
| rows.append( |
| { |
| "question_id": qid, |
| "question": v5["question"], |
| "question_type": row["question_type"], |
| "annotator": row["annotator"], |
| "category": row["category"], |
| "human_answer": row["human_answer"], |
| "v5_answer": row["v5_answer"], |
| } |
| ) |
| if args.limit is not None: |
| rows = rows[: args.limit] |
|
|
| out_jsonl = Path(args.out_jsonl) |
| existing = read_existing(out_jsonl) |
| client = OpenAI(api_key=api_key, base_url=args.base_url) |
|
|
| completed = list(existing.values()) |
| for idx, row in enumerate(rows, start=1): |
| if row["question_id"] in existing: |
| continue |
| print(f"[judge] {idx}/{len(rows)} {row['question_id']}", flush=True) |
| judged = judge_one(client, args.model, row) |
| append_jsonl(out_jsonl, judged) |
| completed.append(judged) |
|
|
| |
| completed_by_qid = read_existing(out_jsonl) |
| report_rows = [completed_by_qid[row["question_id"]] for row in rows if row["question_id"] in completed_by_qid] |
| report = { |
| "created_at_utc": datetime.now(timezone.utc).isoformat(), |
| "model": args.model, |
| "base_url": args.base_url, |
| "agreement_file": args.agreement_file, |
| "v5_file": args.v5_file, |
| "judge_scheme": { |
| "full_agreement": 1.0, |
| "partial_agreement": 0.5, |
| "disagreement": 0.0, |
| "absence_mismatch": 0.0, |
| }, |
| "summary": summarize(report_rows), |
| "rows": report_rows, |
| } |
| out_report = Path(args.out_report) |
| out_report.parent.mkdir(parents=True, exist_ok=True) |
| out_report.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") |
| print(json.dumps(report["summary"], indent=2), flush=True) |
| print(f"[wrote] {out_jsonl}", flush=True) |
| print(f"[wrote] {out_report}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|