PGC-AI-Chatbot / scripts /run_retrieval_benchmark.py
Jacooo's picture
Deploy from GitHub: e13a058
d0dada0 verified
from __future__ import annotations
import asyncio
from contextlib import redirect_stdout
import io
import json
from pathlib import Path
import re as _re
import sys
import time
from typing import Any, Awaitable, Callable, TextIO
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from app.local_plant_db import get_plant_parameters
from app.retrieval_eval import (
BenchmarkResult,
build_report_summary,
load_golden_retrieval_cases,
score_benchmark_groups,
score_benchmark_run,
)
from app.retrieval_pipeline import build_dual_retrieval_queries, build_retrieval_query, plan_plant_retrieval
from app.vector_store import (
VERIFIED_DENSE_THRESHOLD,
merge_knowledge_results,
search_knowledge,
search_knowledge_hybrid,
)
LlmCall = Callable[[str, bool, float], Awaitable[str]]
Timer = Callable[[], float]
BENCHMARK_MODE = "dual_score_fixture_benchmark"
BENCHMARK_NOTES = (
"Structured cases score rank=1 via planner/query-builder path (LLM stub). "
"Vector-RAG cases report two ranks: strict (exact source+filename+page) and "
"semantic (any chunk from the right document containing expected keywords). "
"TOC/figure-index chunks are excluded from rank computation."
)
_FIGURE_PATTERN = _re.compile(r"figure\s+\d", _re.IGNORECASE)
_DOT_LEADER_PATTERN = _re.compile(r"\.{4,}|\s(?:\.\s){3,}")
async def _empty_llm(
prompt: str,
json_mode: bool = False,
temperature: float = 0.0,
) -> str:
# json_mode=True: entity extraction stub (returns parseable JSON)
# json_mode=False: HyDE / free-text stub (returns empty string → forces raw query fallback)
if json_mode:
return json.dumps({"plant": None, "variety": None, "stage": None})
return ""
def _chunk_matches_case(chunk: dict[str, Any], case: dict[str, Any]) -> bool:
expected_source = case.get("expected_source")
expected_filename = case.get("expected_filename")
if expected_source is None or expected_filename is None:
return False
expected_page = case.get("expected_page")
return (
chunk.get("source") == expected_source
and chunk.get("filename") == expected_filename
and (expected_page is None or chunk.get("page_number") == expected_page)
)
def _chunk_matches_semantic(chunk: dict[str, Any], case: dict[str, Any]) -> bool:
"""Return True if chunk is from the right source/filename and contains at least one expected keyword."""
expected_source = case.get("expected_source")
expected_filename = case.get("expected_filename")
keywords: list[str] = case.get("expected_content_keywords") or []
if not expected_source or not expected_filename or not keywords:
return False
if chunk.get("source") != expected_source:
return False
if chunk.get("filename") != expected_filename:
return False
content_lower = (chunk.get("content") or "").lower()
return any(kw.lower() in content_lower for kw in keywords)
def _is_toc_chunk(chunk: dict[str, Any]) -> bool:
"""Return True if this chunk looks like a table-of-contents or figure-index page.
Heuristic: flag if content has >=3 figure-reference lines OR >=3 dot-leader sequences.
These pages match keyword queries but contain no real explanatory content.
"""
content = chunk.get("content") or ""
figure_hits = len(_FIGURE_PATTERN.findall(content))
dot_hits = len(_DOT_LEADER_PATTERN.findall(content))
return (figure_hits >= 3 and dot_hits >= 1) or dot_hits >= 3
async def _evaluate_vector_case(
case: dict[str, Any],
query: str,
retrieval_queries: list[Any],
retrieval_mode: str = "dense",
) -> tuple[dict[str, Any], BenchmarkResult]:
if retrieval_mode == "hybrid":
chunks = await search_knowledge_hybrid(
raw_query=query,
dense_queries=retrieval_queries,
)
else:
dense_lists = []
for rq in retrieval_queries:
dense_lists.append(await search_knowledge(query=rq.text, query_label=rq.label))
chunks = merge_knowledge_results(dense_lists)
filtered_chunks = [c for c in chunks if not _is_toc_chunk(c)]
matched_rank, matched_chunk = None, None
semantic_rank = None
for index, chunk in enumerate(filtered_chunks, start=1):
if matched_rank is None and _chunk_matches_case(chunk, case):
matched_rank = index
matched_chunk = chunk
if semantic_rank is None and _chunk_matches_semantic(chunk, case):
semantic_rank = index
query_label = "+".join(rq.label for rq in retrieval_queries) if retrieval_queries else "raw"
result = BenchmarkResult(
case_id=str(case["case_id"]),
rank=matched_rank,
semantic_rank=semantic_rank,
expected_found=bool(case.get("expected_found")),
verified_hit=bool(matched_chunk and matched_chunk.get("similarity", 0) >= VERIFIED_DENSE_THRESHOLD),
latency_ms=0,
)
case_report = {
"case_id": case["case_id"],
"case_group": case.get("case_group"),
"expected_mode": case.get("expected_mode"),
"query": case["query"],
"retrieval_query_label": query_label,
"used_structured_params": False,
"used_vector_rag": True,
"matched_source": matched_chunk.get("source") if matched_chunk else None,
"matched_filename": matched_chunk.get("filename") if matched_chunk else None,
"matched_page": matched_chunk.get("page_number") if matched_chunk else None,
"rank": result.rank,
"semantic_rank": result.semantic_rank,
"verified_hit": result.verified_hit,
"latency_ms": result.latency_ms,
}
return case_report, result
async def _evaluate_case(
case: dict[str, Any],
llm_call: LlmCall,
timer: Timer,
retrieval_mode: str = "dense",
) -> tuple[dict[str, Any], BenchmarkResult]:
query = str(case["query"])
query_type = str(case.get("query_type", "plant_specific"))
expected_mode = case.get("expected_mode")
started_at = timer()
if expected_mode == "vector_rag":
retrieval_queries = await build_dual_retrieval_queries(
query=query,
query_type=query_type,
needs_vector_rag=True,
llm_call=llm_call,
)
case_report, benchmark_result = await _evaluate_vector_case(
case=case,
query=query,
retrieval_queries=retrieval_queries,
retrieval_mode=retrieval_mode,
)
latency_ms = max(int(round((timer() - started_at) * 1000)), 0)
case_report["latency_ms"] = latency_ms
return case_report, BenchmarkResult(
case_id=benchmark_result.case_id,
rank=benchmark_result.rank,
semantic_rank=benchmark_result.semantic_rank,
expected_found=benchmark_result.expected_found,
verified_hit=benchmark_result.verified_hit,
latency_ms=latency_ms,
)
expected_plant = case.get("expected_plant")
expected_stage = case.get("expected_stage")
expected_found = bool(case.get("expected_found"))
plan = await plan_plant_retrieval(
query=query,
llm_call=llm_call,
)
retrieval_query = await build_retrieval_query(
query=query,
query_type=query_type,
needs_vector_rag=plan.use_vector_rag,
llm_call=llm_call,
)
structured_params = (
get_plant_parameters(plan.plant_name, plan.stage)
if plan.plant_name and plan.use_structured_params
else None
)
latency_ms = max(int(round((timer() - started_at) * 1000)), 0)
matched_expected_result = bool(
expected_found
and structured_params
and plan.plant_name == expected_plant
and (expected_stage is None or plan.stage == expected_stage)
)
benchmark_result = BenchmarkResult(
case_id=str(case["case_id"]),
rank=1 if matched_expected_result else None,
semantic_rank=None,
expected_found=expected_found,
verified_hit=bool(matched_expected_result and structured_params.get("from_local_db")),
latency_ms=latency_ms,
)
case_report = {
"case_id": case["case_id"],
"case_group": case.get("case_group"),
"expected_mode": expected_mode,
"query": query,
"expected_plant": expected_plant,
"expected_stage": expected_stage,
"expected_found": expected_found,
"actual_plant": plan.plant_name,
"actual_stage": plan.stage,
"retrieval_query_label": retrieval_query.label,
"used_structured_params": plan.use_structured_params,
"used_vector_rag": plan.use_vector_rag,
"rank": benchmark_result.rank,
"semantic_rank": benchmark_result.semantic_rank,
"verified_hit": benchmark_result.verified_hit,
"latency_ms": benchmark_result.latency_ms,
}
return case_report, benchmark_result
async def run_benchmark(
cases: list[dict[str, Any]],
llm_call: LlmCall = _empty_llm,
timer: Timer = time.perf_counter,
retrieval_mode: str = "dense",
) -> dict[str, Any]:
case_reports: list[dict[str, Any]] = []
benchmark_results: list[BenchmarkResult] = []
case_groups: dict[str, str] = {}
for case in cases:
case_groups[str(case["case_id"])] = str(case.get("case_group", "structured_routing"))
case_report, benchmark_result = await _evaluate_case(
case, llm_call=llm_call, timer=timer, retrieval_mode=retrieval_mode
)
case_reports.append(case_report)
benchmark_results.append(benchmark_result)
metrics = score_benchmark_run(benchmark_results)
metrics_by_group = score_benchmark_groups(benchmark_results, case_groups)
return {
"benchmark_mode": BENCHMARK_MODE,
"retrieval_mode": retrieval_mode,
"notes": BENCHMARK_NOTES,
"summary": build_report_summary(metrics),
"metrics": metrics,
"metrics_by_group": metrics_by_group,
"cases": case_reports,
}
def emit_report(
report: dict[str, Any],
stdout: TextIO = sys.stdout,
stderr: TextIO = sys.stderr,
) -> None:
print(report["summary"], file=stderr)
print(report["notes"], file=stderr)
print(json.dumps(report, indent=2), file=stdout)
def main() -> int:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["dense", "hybrid"], default="dense",
help="Retrieval mode: dense (default) or hybrid")
args = parser.parse_args()
cases = load_golden_retrieval_cases()
benchmark_logs = io.StringIO()
with redirect_stdout(benchmark_logs):
report = asyncio.run(run_benchmark(cases, retrieval_mode=args.mode))
log_output = benchmark_logs.getvalue()
if log_output:
print(log_output, file=sys.stderr, end="")
emit_report(report)
return 0
if __name__ == "__main__":
raise SystemExit(main())