Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| from contextlib import redirect_stdout | |
| import io | |
| import json | |
| from pathlib import Path | |
| import re as _re | |
| import sys | |
| import time | |
| from typing import Any, Awaitable, Callable, TextIO | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| from app.local_plant_db import get_plant_parameters | |
| from app.retrieval_eval import ( | |
| BenchmarkResult, | |
| build_report_summary, | |
| load_golden_retrieval_cases, | |
| score_benchmark_groups, | |
| score_benchmark_run, | |
| ) | |
| from app.retrieval_pipeline import build_dual_retrieval_queries, build_retrieval_query, plan_plant_retrieval | |
| from app.vector_store import ( | |
| VERIFIED_DENSE_THRESHOLD, | |
| merge_knowledge_results, | |
| search_knowledge, | |
| search_knowledge_hybrid, | |
| ) | |
| LlmCall = Callable[[str, bool, float], Awaitable[str]] | |
| Timer = Callable[[], float] | |
| BENCHMARK_MODE = "dual_score_fixture_benchmark" | |
| BENCHMARK_NOTES = ( | |
| "Structured cases score rank=1 via planner/query-builder path (LLM stub). " | |
| "Vector-RAG cases report two ranks: strict (exact source+filename+page) and " | |
| "semantic (any chunk from the right document containing expected keywords). " | |
| "TOC/figure-index chunks are excluded from rank computation." | |
| ) | |
| _FIGURE_PATTERN = _re.compile(r"figure\s+\d", _re.IGNORECASE) | |
| _DOT_LEADER_PATTERN = _re.compile(r"\.{4,}|\s(?:\.\s){3,}") | |
| async def _empty_llm( | |
| prompt: str, | |
| json_mode: bool = False, | |
| temperature: float = 0.0, | |
| ) -> str: | |
| # json_mode=True: entity extraction stub (returns parseable JSON) | |
| # json_mode=False: HyDE / free-text stub (returns empty string → forces raw query fallback) | |
| if json_mode: | |
| return json.dumps({"plant": None, "variety": None, "stage": None}) | |
| return "" | |
| def _chunk_matches_case(chunk: dict[str, Any], case: dict[str, Any]) -> bool: | |
| expected_source = case.get("expected_source") | |
| expected_filename = case.get("expected_filename") | |
| if expected_source is None or expected_filename is None: | |
| return False | |
| expected_page = case.get("expected_page") | |
| return ( | |
| chunk.get("source") == expected_source | |
| and chunk.get("filename") == expected_filename | |
| and (expected_page is None or chunk.get("page_number") == expected_page) | |
| ) | |
| def _chunk_matches_semantic(chunk: dict[str, Any], case: dict[str, Any]) -> bool: | |
| """Return True if chunk is from the right source/filename and contains at least one expected keyword.""" | |
| expected_source = case.get("expected_source") | |
| expected_filename = case.get("expected_filename") | |
| keywords: list[str] = case.get("expected_content_keywords") or [] | |
| if not expected_source or not expected_filename or not keywords: | |
| return False | |
| if chunk.get("source") != expected_source: | |
| return False | |
| if chunk.get("filename") != expected_filename: | |
| return False | |
| content_lower = (chunk.get("content") or "").lower() | |
| return any(kw.lower() in content_lower for kw in keywords) | |
| def _is_toc_chunk(chunk: dict[str, Any]) -> bool: | |
| """Return True if this chunk looks like a table-of-contents or figure-index page. | |
| Heuristic: flag if content has >=3 figure-reference lines OR >=3 dot-leader sequences. | |
| These pages match keyword queries but contain no real explanatory content. | |
| """ | |
| content = chunk.get("content") or "" | |
| figure_hits = len(_FIGURE_PATTERN.findall(content)) | |
| dot_hits = len(_DOT_LEADER_PATTERN.findall(content)) | |
| return (figure_hits >= 3 and dot_hits >= 1) or dot_hits >= 3 | |
| async def _evaluate_vector_case( | |
| case: dict[str, Any], | |
| query: str, | |
| retrieval_queries: list[Any], | |
| retrieval_mode: str = "dense", | |
| ) -> tuple[dict[str, Any], BenchmarkResult]: | |
| if retrieval_mode == "hybrid": | |
| chunks = await search_knowledge_hybrid( | |
| raw_query=query, | |
| dense_queries=retrieval_queries, | |
| ) | |
| else: | |
| dense_lists = [] | |
| for rq in retrieval_queries: | |
| dense_lists.append(await search_knowledge(query=rq.text, query_label=rq.label)) | |
| chunks = merge_knowledge_results(dense_lists) | |
| filtered_chunks = [c for c in chunks if not _is_toc_chunk(c)] | |
| matched_rank, matched_chunk = None, None | |
| semantic_rank = None | |
| for index, chunk in enumerate(filtered_chunks, start=1): | |
| if matched_rank is None and _chunk_matches_case(chunk, case): | |
| matched_rank = index | |
| matched_chunk = chunk | |
| if semantic_rank is None and _chunk_matches_semantic(chunk, case): | |
| semantic_rank = index | |
| query_label = "+".join(rq.label for rq in retrieval_queries) if retrieval_queries else "raw" | |
| result = BenchmarkResult( | |
| case_id=str(case["case_id"]), | |
| rank=matched_rank, | |
| semantic_rank=semantic_rank, | |
| expected_found=bool(case.get("expected_found")), | |
| verified_hit=bool(matched_chunk and matched_chunk.get("similarity", 0) >= VERIFIED_DENSE_THRESHOLD), | |
| latency_ms=0, | |
| ) | |
| case_report = { | |
| "case_id": case["case_id"], | |
| "case_group": case.get("case_group"), | |
| "expected_mode": case.get("expected_mode"), | |
| "query": case["query"], | |
| "retrieval_query_label": query_label, | |
| "used_structured_params": False, | |
| "used_vector_rag": True, | |
| "matched_source": matched_chunk.get("source") if matched_chunk else None, | |
| "matched_filename": matched_chunk.get("filename") if matched_chunk else None, | |
| "matched_page": matched_chunk.get("page_number") if matched_chunk else None, | |
| "rank": result.rank, | |
| "semantic_rank": result.semantic_rank, | |
| "verified_hit": result.verified_hit, | |
| "latency_ms": result.latency_ms, | |
| } | |
| return case_report, result | |
| async def _evaluate_case( | |
| case: dict[str, Any], | |
| llm_call: LlmCall, | |
| timer: Timer, | |
| retrieval_mode: str = "dense", | |
| ) -> tuple[dict[str, Any], BenchmarkResult]: | |
| query = str(case["query"]) | |
| query_type = str(case.get("query_type", "plant_specific")) | |
| expected_mode = case.get("expected_mode") | |
| started_at = timer() | |
| if expected_mode == "vector_rag": | |
| retrieval_queries = await build_dual_retrieval_queries( | |
| query=query, | |
| query_type=query_type, | |
| needs_vector_rag=True, | |
| llm_call=llm_call, | |
| ) | |
| case_report, benchmark_result = await _evaluate_vector_case( | |
| case=case, | |
| query=query, | |
| retrieval_queries=retrieval_queries, | |
| retrieval_mode=retrieval_mode, | |
| ) | |
| latency_ms = max(int(round((timer() - started_at) * 1000)), 0) | |
| case_report["latency_ms"] = latency_ms | |
| return case_report, BenchmarkResult( | |
| case_id=benchmark_result.case_id, | |
| rank=benchmark_result.rank, | |
| semantic_rank=benchmark_result.semantic_rank, | |
| expected_found=benchmark_result.expected_found, | |
| verified_hit=benchmark_result.verified_hit, | |
| latency_ms=latency_ms, | |
| ) | |
| expected_plant = case.get("expected_plant") | |
| expected_stage = case.get("expected_stage") | |
| expected_found = bool(case.get("expected_found")) | |
| plan = await plan_plant_retrieval( | |
| query=query, | |
| llm_call=llm_call, | |
| ) | |
| retrieval_query = await build_retrieval_query( | |
| query=query, | |
| query_type=query_type, | |
| needs_vector_rag=plan.use_vector_rag, | |
| llm_call=llm_call, | |
| ) | |
| structured_params = ( | |
| get_plant_parameters(plan.plant_name, plan.stage) | |
| if plan.plant_name and plan.use_structured_params | |
| else None | |
| ) | |
| latency_ms = max(int(round((timer() - started_at) * 1000)), 0) | |
| matched_expected_result = bool( | |
| expected_found | |
| and structured_params | |
| and plan.plant_name == expected_plant | |
| and (expected_stage is None or plan.stage == expected_stage) | |
| ) | |
| benchmark_result = BenchmarkResult( | |
| case_id=str(case["case_id"]), | |
| rank=1 if matched_expected_result else None, | |
| semantic_rank=None, | |
| expected_found=expected_found, | |
| verified_hit=bool(matched_expected_result and structured_params.get("from_local_db")), | |
| latency_ms=latency_ms, | |
| ) | |
| case_report = { | |
| "case_id": case["case_id"], | |
| "case_group": case.get("case_group"), | |
| "expected_mode": expected_mode, | |
| "query": query, | |
| "expected_plant": expected_plant, | |
| "expected_stage": expected_stage, | |
| "expected_found": expected_found, | |
| "actual_plant": plan.plant_name, | |
| "actual_stage": plan.stage, | |
| "retrieval_query_label": retrieval_query.label, | |
| "used_structured_params": plan.use_structured_params, | |
| "used_vector_rag": plan.use_vector_rag, | |
| "rank": benchmark_result.rank, | |
| "semantic_rank": benchmark_result.semantic_rank, | |
| "verified_hit": benchmark_result.verified_hit, | |
| "latency_ms": benchmark_result.latency_ms, | |
| } | |
| return case_report, benchmark_result | |
| async def run_benchmark( | |
| cases: list[dict[str, Any]], | |
| llm_call: LlmCall = _empty_llm, | |
| timer: Timer = time.perf_counter, | |
| retrieval_mode: str = "dense", | |
| ) -> dict[str, Any]: | |
| case_reports: list[dict[str, Any]] = [] | |
| benchmark_results: list[BenchmarkResult] = [] | |
| case_groups: dict[str, str] = {} | |
| for case in cases: | |
| case_groups[str(case["case_id"])] = str(case.get("case_group", "structured_routing")) | |
| case_report, benchmark_result = await _evaluate_case( | |
| case, llm_call=llm_call, timer=timer, retrieval_mode=retrieval_mode | |
| ) | |
| case_reports.append(case_report) | |
| benchmark_results.append(benchmark_result) | |
| metrics = score_benchmark_run(benchmark_results) | |
| metrics_by_group = score_benchmark_groups(benchmark_results, case_groups) | |
| return { | |
| "benchmark_mode": BENCHMARK_MODE, | |
| "retrieval_mode": retrieval_mode, | |
| "notes": BENCHMARK_NOTES, | |
| "summary": build_report_summary(metrics), | |
| "metrics": metrics, | |
| "metrics_by_group": metrics_by_group, | |
| "cases": case_reports, | |
| } | |
| def emit_report( | |
| report: dict[str, Any], | |
| stdout: TextIO = sys.stdout, | |
| stderr: TextIO = sys.stderr, | |
| ) -> None: | |
| print(report["summary"], file=stderr) | |
| print(report["notes"], file=stderr) | |
| print(json.dumps(report, indent=2), file=stdout) | |
| def main() -> int: | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--mode", choices=["dense", "hybrid"], default="dense", | |
| help="Retrieval mode: dense (default) or hybrid") | |
| args = parser.parse_args() | |
| cases = load_golden_retrieval_cases() | |
| benchmark_logs = io.StringIO() | |
| with redirect_stdout(benchmark_logs): | |
| report = asyncio.run(run_benchmark(cases, retrieval_mode=args.mode)) | |
| log_output = benchmark_logs.getvalue() | |
| if log_output: | |
| print(log_output, file=sys.stderr, end="") | |
| emit_report(report) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |