"""Judge stack: batch LLM evaluation. Session: 2 calls (recall + correctness) + per-item calls for update/interference. QA: 2 calls (answer + evidence). """ from __future__ import annotations from eval_framework.judges.llm_client import llm_request_for_json from eval_framework.judges.prompts import ( CORRECTNESS_BATCH_PROMPT, EVIDENCE_BATCH_PROMPT, INTERFERENCE_EVAL_PROMPT, QA_EVALUATION_PROMPT, RECALL_BATCH_PROMPT, UPDATE_EVAL_PROMPT, ) __all__ = [ "evaluate_recall_batch", "evaluate_correctness_batch", "evaluate_update_single", "evaluate_interference_single", "evaluate_evidence_batch", "evaluate_qa_llm", "llm_request_for_json", ] def evaluate_recall_batch( extracted_memories_str: str, gold_points_tagged: list[str], ) -> dict[str, object]: """One LLM call: how many gold points are covered? Distinguishes update sub-score. gold_points_tagged: list of "[normal] content" or "[update] content" strings. Returns {covered_count, update_covered_count, total, update_total, reasoning}. """ if not extracted_memories_str.strip(): update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]")) return { "covered_count": 0, "update_covered_count": 0, "total": len(gold_points_tagged), "update_total": update_total, "reasoning": "No extracted memories.", } if not gold_points_tagged: return { "covered_count": 0, "update_covered_count": 0, "total": 0, "update_total": 0, "reasoning": "No gold points.", } numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(gold_points_tagged)) update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]")) prompt = RECALL_BATCH_PROMPT.format(memories=extracted_memories_str, gold_points=numbered) try: result = llm_request_for_json(prompt) covered = int(result.get("covered_count", 0)) upd_covered = int(result.get("update_covered_count", 0)) return { "covered_count": min(covered, len(gold_points_tagged)), "update_covered_count": min(upd_covered, update_total), "total": len(gold_points_tagged), "update_total": update_total, "reasoning": result.get("reasoning", ""), } except Exception as e: return { "covered_count": None, "update_covered_count": None, "total": len(gold_points_tagged), "update_total": update_total, "reasoning": f"LLM error: {e}", } def evaluate_correctness_batch( snapshot_memories: list[str], gold_points_tagged: list[str], interference_total: int, ) -> dict[str, object]: """One LLM call: is each snapshot memory correct? Includes interference detection. gold_points_tagged: list of "[normal] content", "[update] content", "[interference] content". Returns {results: [{id, label}], interference_memorized_count, interference_total, reasoning}. """ if not snapshot_memories: return { "results": [], "interference_memorized_count": 0, "interference_total": interference_total, "reasoning": "No snapshot memories.", } numbered_memories = "\n".join(f"[{i+1}] {m}" for i, m in enumerate(snapshot_memories)) numbered_golds = "\n".join(f"- {p}" for p in gold_points_tagged) if gold_points_tagged else "(no ground-truth)" prompt = CORRECTNESS_BATCH_PROMPT.format(memories=numbered_memories, gold_points=numbered_golds) try: result = llm_request_for_json(prompt) raw_results = result.get("results", []) valid_labels = {"correct", "hallucination", "irrelevant"} cleaned = [] for r in raw_results: label = str(r.get("label", "irrelevant")).lower().strip() if label not in valid_labels: label = "irrelevant" cleaned.append({"id": r.get("id"), "label": label}) interf_mem = int(result.get("interference_memorized_count", 0)) return { "results": cleaned, "interference_memorized_count": min(interf_mem, interference_total), "interference_total": interference_total, "reasoning": result.get("reasoning", ""), } except Exception as e: return { "results": [], "interference_memorized_count": None, "interference_total": interference_total, "reasoning": f"LLM error: {e}", } def evaluate_update_single( delta_memories_str: str, new_content: str, old_contents: list[str], ) -> dict[str, object]: """One LLM call: how did the system handle a single memory update? Returns {label: "updated"|"both"|"outdated", reasoning}. """ old_str = "\n".join(f"- {o}" for o in old_contents) if old_contents else "(none)" prompt = UPDATE_EVAL_PROMPT.format( memories=delta_memories_str, new_content=new_content, old_contents=old_str, ) try: result = llm_request_for_json(prompt) label = str(result.get("label", "outdated")).lower().strip() if label not in ("updated", "both", "outdated"): label = "outdated" return {"label": label, "reasoning": result.get("reasoning", "")} except Exception as e: return {"label": None, "reasoning": f"LLM error: {e}"} def evaluate_interference_single( delta_memories_str: str, interference_content: str, ) -> dict[str, object]: """One LLM call: did the system incorrectly memorize an interference point? Returns {label: "rejected"|"memorized", reasoning}. """ prompt = INTERFERENCE_EVAL_PROMPT.format( memories=delta_memories_str, interference_content=interference_content, ) try: result = llm_request_for_json(prompt) label = str(result.get("label", "memorized")).lower().strip() if label not in ("rejected", "memorized"): label = "memorized" return {"label": label, "reasoning": result.get("reasoning", "")} except Exception as e: return {"label": None, "reasoning": f"LLM error: {e}"} def evaluate_evidence_batch( retrieved_memories_str: str, evidence_points: list[str], ) -> dict[str, object]: """One LLM call: how many gold evidence points are covered by retrieval?""" if not retrieved_memories_str.strip(): return {"covered_count": 0, "total": len(evidence_points), "reasoning": "No retrieved memories."} if not evidence_points: return {"covered_count": 0, "total": 0, "reasoning": "No evidence points."} numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(evidence_points)) prompt = EVIDENCE_BATCH_PROMPT.format(retrieved_memories=retrieved_memories_str, gold_evidence_points=numbered) try: result = llm_request_for_json(prompt) covered = int(result.get("covered_count", 0)) return { "covered_count": min(covered, len(evidence_points)), "total": len(evidence_points), "reasoning": result.get("reasoning", ""), } except Exception as e: return {"covered_count": None, "total": len(evidence_points), "reasoning": f"LLM error: {e}"} def evaluate_qa_llm( question: str, reference_answer: str, key_memory_points: str, system_response: str, ) -> dict[str, object]: """LLM judge: classify the QA response as Correct/Hallucination/Omission.""" if not system_response.strip(): return {"evaluation_result": "Omission", "reasoning": "Empty system response."} prompt = QA_EVALUATION_PROMPT.format( question=question, reference_answer=reference_answer, key_memory_points=key_memory_points, response=system_response, ) try: result = llm_request_for_json(prompt) label = result.get("evaluation_result", "Omission") if label not in ("Correct", "Hallucination", "Omission"): label = "Omission" return {"evaluation_result": label, "reasoning": result.get("reasoning", "")} except Exception as e: return {"evaluation_result": None, "reasoning": f"LLM judge error: {e}"}