| """Judge stack: batch LLM evaluation. |
| |
| Session: 2 calls (recall + correctness) + per-item calls for update/interference. |
| QA: 2 calls (answer + evidence). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from eval_framework.judges.llm_client import llm_request_for_json |
| from eval_framework.judges.prompts import ( |
| CORRECTNESS_BATCH_PROMPT, |
| EVIDENCE_BATCH_PROMPT, |
| INTERFERENCE_EVAL_PROMPT, |
| QA_EVALUATION_PROMPT, |
| RECALL_BATCH_PROMPT, |
| UPDATE_EVAL_PROMPT, |
| ) |
|
|
| __all__ = [ |
| "evaluate_recall_batch", |
| "evaluate_correctness_batch", |
| "evaluate_update_single", |
| "evaluate_interference_single", |
| "evaluate_evidence_batch", |
| "evaluate_qa_llm", |
| "llm_request_for_json", |
| ] |
|
|
|
|
| def evaluate_recall_batch( |
| extracted_memories_str: str, |
| gold_points_tagged: list[str], |
| ) -> dict[str, object]: |
| """One LLM call: how many gold points are covered? Distinguishes update sub-score. |
| |
| gold_points_tagged: list of "[normal] content" or "[update] content" strings. |
| Returns {covered_count, update_covered_count, total, update_total, reasoning}. |
| """ |
| if not extracted_memories_str.strip(): |
| update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]")) |
| return { |
| "covered_count": 0, "update_covered_count": 0, |
| "total": len(gold_points_tagged), "update_total": update_total, |
| "reasoning": "No extracted memories.", |
| } |
| if not gold_points_tagged: |
| return { |
| "covered_count": 0, "update_covered_count": 0, |
| "total": 0, "update_total": 0, "reasoning": "No gold points.", |
| } |
|
|
| numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(gold_points_tagged)) |
| update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]")) |
|
|
| prompt = RECALL_BATCH_PROMPT.format(memories=extracted_memories_str, gold_points=numbered) |
| try: |
| result = llm_request_for_json(prompt) |
| covered = int(result.get("covered_count", 0)) |
| upd_covered = int(result.get("update_covered_count", 0)) |
| return { |
| "covered_count": min(covered, len(gold_points_tagged)), |
| "update_covered_count": min(upd_covered, update_total), |
| "total": len(gold_points_tagged), |
| "update_total": update_total, |
| "reasoning": result.get("reasoning", ""), |
| } |
| except Exception as e: |
| return { |
| "covered_count": None, "update_covered_count": None, |
| "total": len(gold_points_tagged), "update_total": update_total, |
| "reasoning": f"LLM error: {e}", |
| } |
|
|
|
|
| def evaluate_correctness_batch( |
| snapshot_memories: list[str], |
| gold_points_tagged: list[str], |
| interference_total: int, |
| ) -> dict[str, object]: |
| """One LLM call: is each snapshot memory correct? Includes interference detection. |
| |
| gold_points_tagged: list of "[normal] content", "[update] content", "[interference] content". |
| Returns {results: [{id, label}], interference_memorized_count, interference_total, reasoning}. |
| """ |
| if not snapshot_memories: |
| return { |
| "results": [], |
| "interference_memorized_count": 0, |
| "interference_total": interference_total, |
| "reasoning": "No snapshot memories.", |
| } |
|
|
| numbered_memories = "\n".join(f"[{i+1}] {m}" for i, m in enumerate(snapshot_memories)) |
| numbered_golds = "\n".join(f"- {p}" for p in gold_points_tagged) if gold_points_tagged else "(no ground-truth)" |
|
|
| prompt = CORRECTNESS_BATCH_PROMPT.format(memories=numbered_memories, gold_points=numbered_golds) |
| try: |
| result = llm_request_for_json(prompt) |
| raw_results = result.get("results", []) |
| valid_labels = {"correct", "hallucination", "irrelevant"} |
| cleaned = [] |
| for r in raw_results: |
| label = str(r.get("label", "irrelevant")).lower().strip() |
| if label not in valid_labels: |
| label = "irrelevant" |
| cleaned.append({"id": r.get("id"), "label": label}) |
| interf_mem = int(result.get("interference_memorized_count", 0)) |
| return { |
| "results": cleaned, |
| "interference_memorized_count": min(interf_mem, interference_total), |
| "interference_total": interference_total, |
| "reasoning": result.get("reasoning", ""), |
| } |
| except Exception as e: |
| return { |
| "results": [], |
| "interference_memorized_count": None, |
| "interference_total": interference_total, |
| "reasoning": f"LLM error: {e}", |
| } |
|
|
|
|
| def evaluate_update_single( |
| delta_memories_str: str, |
| new_content: str, |
| old_contents: list[str], |
| ) -> dict[str, object]: |
| """One LLM call: how did the system handle a single memory update? |
| |
| Returns {label: "updated"|"both"|"outdated", reasoning}. |
| """ |
| old_str = "\n".join(f"- {o}" for o in old_contents) if old_contents else "(none)" |
| prompt = UPDATE_EVAL_PROMPT.format( |
| memories=delta_memories_str, |
| new_content=new_content, |
| old_contents=old_str, |
| ) |
| try: |
| result = llm_request_for_json(prompt) |
| label = str(result.get("label", "outdated")).lower().strip() |
| if label not in ("updated", "both", "outdated"): |
| label = "outdated" |
| return {"label": label, "reasoning": result.get("reasoning", "")} |
| except Exception as e: |
| return {"label": None, "reasoning": f"LLM error: {e}"} |
|
|
|
|
| def evaluate_interference_single( |
| delta_memories_str: str, |
| interference_content: str, |
| ) -> dict[str, object]: |
| """One LLM call: did the system incorrectly memorize an interference point? |
| |
| Returns {label: "rejected"|"memorized", reasoning}. |
| """ |
| prompt = INTERFERENCE_EVAL_PROMPT.format( |
| memories=delta_memories_str, |
| interference_content=interference_content, |
| ) |
| try: |
| result = llm_request_for_json(prompt) |
| label = str(result.get("label", "memorized")).lower().strip() |
| if label not in ("rejected", "memorized"): |
| label = "memorized" |
| return {"label": label, "reasoning": result.get("reasoning", "")} |
| except Exception as e: |
| return {"label": None, "reasoning": f"LLM error: {e}"} |
|
|
|
|
| def evaluate_evidence_batch( |
| retrieved_memories_str: str, |
| evidence_points: list[str], |
| ) -> dict[str, object]: |
| """One LLM call: how many gold evidence points are covered by retrieval?""" |
| if not retrieved_memories_str.strip(): |
| return {"covered_count": 0, "total": len(evidence_points), "reasoning": "No retrieved memories."} |
| if not evidence_points: |
| return {"covered_count": 0, "total": 0, "reasoning": "No evidence points."} |
|
|
| numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(evidence_points)) |
| prompt = EVIDENCE_BATCH_PROMPT.format(retrieved_memories=retrieved_memories_str, gold_evidence_points=numbered) |
| try: |
| result = llm_request_for_json(prompt) |
| covered = int(result.get("covered_count", 0)) |
| return { |
| "covered_count": min(covered, len(evidence_points)), |
| "total": len(evidence_points), |
| "reasoning": result.get("reasoning", ""), |
| } |
| except Exception as e: |
| return {"covered_count": None, "total": len(evidence_points), "reasoning": f"LLM error: {e}"} |
|
|
|
|
| def evaluate_qa_llm( |
| question: str, |
| reference_answer: str, |
| key_memory_points: str, |
| system_response: str, |
| ) -> dict[str, object]: |
| """LLM judge: classify the QA response as Correct/Hallucination/Omission.""" |
| if not system_response.strip(): |
| return {"evaluation_result": "Omission", "reasoning": "Empty system response."} |
|
|
| prompt = QA_EVALUATION_PROMPT.format( |
| question=question, reference_answer=reference_answer, |
| key_memory_points=key_memory_points, response=system_response, |
| ) |
| try: |
| result = llm_request_for_json(prompt) |
| label = result.get("evaluation_result", "Omission") |
| if label not in ("Correct", "Hallucination", "Omission"): |
| label = "Omission" |
| return {"evaluation_result": label, "reasoning": result.get("reasoning", "")} |
| except Exception as e: |
| return {"evaluation_result": None, "reasoning": f"LLM judge error: {e}"} |
|
|