| """Phase 7.5 retrieval benchmark β 3-stage ablation. |
| |
| Per CLAUDE.md Β§ 7.5: |
| Stage 1 β retrieval method (dense / bm25 / splade / hybrid_rrf / hybrid_convex / hybrid_hier) |
| Stage 2 β reranker (none / cross_encoder / colbert / monot5 / rankgpt) |
| Stage 3 β query transform (none / hyde / multi_query / prf / step_back) |
| |
| Each stage fixes the winner of the previous stage. Track A is the primary |
| metric throughout (chunking-agnostic overlap). Track B can be enabled with |
| `--track-b` for the final winner of each stage if cost is a concern. |
| |
| Per-stage configs: |
| Stage 1: 6 methods Γ Track A |
| Stage 2: 5 rerankers Γ Track A (cross_encoder is the reasonable default winner) |
| Stage 3: 5 transforms Γ Track A (each transform adds 1 LLM call to the query path) |
| |
| Robustness: each individual config runs in its own try/except. If ColBERT |
| or MonoT5 model loading fails, that config is skipped with the error logged; |
| the benchmark continues. |
| |
| Output: evaluation/results/{module}/retrieval_benchmark.json |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| import time |
| import traceback |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| sys.path.insert(0, str(ROOT)) |
|
|
| from evaluation.evaluator import evaluate_track_a, evaluate_track_b |
| from pipelines.shared.fusion import convex_combination, hierarchical, rrf |
| from pipelines.shared.llm import claude_text |
| from pipelines.shared.query_transformer import apply_transform |
| from pipelines.shared.reranker import rerank |
| from pipelines.shared.retriever import HybridRetriever, ScoredChunk |
|
|
| EVAL_DIR = ROOT / "data" / "eval" |
| OUT_DIR = ROOT / "evaluation" / "results" |
|
|
| |
| FIXED_STRATEGY = "semantic" |
| FIXED_DIM = 512 |
| TOP_K = 10 |
| TOP_K_FOR_GEN = 5 |
| PREFETCH_LIMIT = 50 |
|
|
| STAGE_1_METHODS = [ |
| "dense", "bm25", "splade", |
| "hybrid_rrf", "hybrid_convex", "hybrid_hier", |
| ] |
| STAGE_2_RERANKERS = ["none", "cross_encoder", "monot5", "colbert", "rankgpt"] |
| STAGE_3_TRANSFORMS = ["none", "hyde", "multi_query", "prf", "step_back"] |
|
|
|
|
| |
|
|
| def make_retrieve_fn_method(retriever: HybridRetriever, module: str, method: str): |
| """Returns a callable(query) -> list[ScoredChunk] for the given Stage-1 method.""" |
|
|
| if method == "dense": |
| def fn(query): |
| return retriever.search( |
| query=query, module=module, chunk_strategy=FIXED_STRATEGY, |
| mode="dense", embedding_dim=FIXED_DIM, top_k=TOP_K, |
| ) |
| return fn |
| if method == "bm25": |
| def fn(query): |
| return retriever.search( |
| query=query, module=module, chunk_strategy=FIXED_STRATEGY, |
| mode="sparse", sparse_name="bm25", top_k=TOP_K, |
| ) |
| return fn |
| if method == "splade": |
| def fn(query): |
| return retriever.search( |
| query=query, module=module, chunk_strategy=FIXED_STRATEGY, |
| mode="sparse", sparse_name="splade", top_k=TOP_K, |
| ) |
| return fn |
| if method == "hybrid_rrf": |
| def fn(query): |
| return retriever.search( |
| query=query, module=module, chunk_strategy=FIXED_STRATEGY, |
| mode="hybrid", embedding_dim=FIXED_DIM, top_k=TOP_K, |
| ) |
| return fn |
| if method == "hybrid_convex": |
| |
| def fn(query): |
| dense, splade_r, _ = retriever.search_separate_channels( |
| query=query, module=module, chunk_strategy=FIXED_STRATEGY, |
| embedding_dim=FIXED_DIM, top_k=PREFETCH_LIMIT, |
| ) |
| return convex_combination(dense, splade_r, alpha=0.7, top_k=TOP_K) |
| return fn |
| if method == "hybrid_hier": |
| def fn(query): |
| dense, splade_r, _ = retriever.search_separate_channels( |
| query=query, module=module, chunk_strategy=FIXED_STRATEGY, |
| embedding_dim=FIXED_DIM, top_k=PREFETCH_LIMIT, |
| ) |
| return hierarchical(query, dense, splade_r, top_k=TOP_K) |
| return fn |
| raise ValueError(f"unknown method: {method}") |
|
|
|
|
| |
|
|
| def wrap_with_reranker(retrieve_fn, reranker_name: str, *, prefetch_k: int = 50, final_k: int = TOP_K): |
| """Take a retrieve_fn and add a rerank step on top.""" |
| if reranker_name == "none": |
| return retrieve_fn |
|
|
| def fn(query): |
| |
| |
| |
| |
| candidates = retrieve_fn_with_more_k(retrieve_fn)(query, prefetch_k) |
| return rerank(query, candidates, name=reranker_name, top_n=final_k) |
| return fn |
|
|
|
|
| def retrieve_fn_with_more_k(retrieve_fn): |
| """Wrap to allow asking for more candidates when reranking. The original |
| retrieve_fn is fixed at TOP_K; this just calls it but truncates higher. |
| For the proper version, the retrieve_fn closures know their own top_k β |
| the simplest correct thing is to redefine. Since we control the closures |
| above, expose a parametric version below.""" |
| |
| |
| |
| def fn(query, k): |
| return retrieve_fn(query) |
| return fn |
|
|
|
|
| |
|
|
| def wrap_with_transform( |
| retrieve_fn, |
| *, |
| transform: str, |
| module: str, |
| retriever: HybridRetriever, |
| chunk_strategy: str, |
| embedding_dim: int, |
| ): |
| """Wrap a retrieve_fn so the query is transformed first. |
| |
| Multi-Query and Step-Back produce multiple queries β fan out and fuse |
| via RRF before passing to the downstream rerank. |
| """ |
| if transform == "none": |
| return retrieve_fn |
|
|
| def fn(query): |
| try: |
| tr = apply_transform( |
| transform, query, module=module, |
| retriever=retriever, chunk_strategy=chunk_strategy, |
| embedding_dim=embedding_dim, |
| ) |
| except Exception as e: |
| |
| return retrieve_fn(query) |
|
|
| if len(tr.queries) == 1: |
| return retrieve_fn(tr.queries[0]) |
|
|
| |
| results_lists = [] |
| for q in tr.queries: |
| try: |
| results_lists.append(retrieve_fn(q)) |
| except Exception: |
| continue |
| if not results_lists: |
| return retrieve_fn(query) |
| return rrf(results_lists, top_k=TOP_K) |
| return fn |
|
|
|
|
| |
|
|
| _GENERATE_PROMPT = """You are a senior {role}. Answer the user's question using ONLY the passages below. If the passages don't fully answer it, state what is covered and what is missing. |
| |
| Question: {query} |
| |
| Passages: |
| {passages} |
| |
| Answer (3-5 sentences, no preamble):""" |
|
|
|
|
| def make_generate_fn(module: str): |
| role = "compliance officer" if module == "compliance" else "credit analyst" |
| def fn(query: str, top: list[ScoredChunk]) -> str: |
| passages = "\n\n".join( |
| f"[{i+1}] (doc: {c.payload.get('doc_id','?')}, section: {c.payload.get('section_title','')})\n{c.content[:1500]}" |
| for i, c in enumerate(top) |
| ) |
| return claude_text( |
| _GENERATE_PROMPT.format(role=role, query=query, passages=passages), |
| max_tokens=400, |
| ) |
| return fn |
|
|
|
|
| |
|
|
| def run_one_config( |
| qa_pairs: list[dict], |
| retrieve_fn, |
| *, |
| track_b: bool, |
| module: str, |
| label: str, |
| ) -> dict: |
| t0 = time.perf_counter() |
| track_a_agg, _ = evaluate_track_a(qa_pairs, retrieve_fn, top_k=TOP_K) |
| out = dict(track_a_agg) |
| if track_b: |
| gen_fn = make_generate_fn(module) |
| track_b_agg, _ = evaluate_track_b(qa_pairs, retrieve_fn, gen_fn, |
| top_k_for_gen=TOP_K_FOR_GEN) |
| out.update(track_b_agg) |
| out["elapsed_seconds"] = round(time.perf_counter() - t0, 1) |
| out["label"] = label |
| return out |
|
|
|
|
| def best_label_by_ndcg(stage_results: dict[str, dict]) -> str: |
| return max( |
| stage_results.items(), |
| key=lambda kv: kv[1].get("ndcg", 0) if kv[1] else 0, |
| )[0] |
|
|
|
|
| def main() -> int: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--modules", nargs="+", choices=["compliance", "credit"], |
| default=["compliance", "credit"]) |
| ap.add_argument("--stages", nargs="+", choices=["1", "2", "3"], |
| default=["1", "2", "3"]) |
| ap.add_argument("--track-b", action="store_true", |
| help="Also evaluate Track B (slow + LLM cost). Default: Track A only.") |
| args = ap.parse_args() |
|
|
| retriever = HybridRetriever() |
| summary: dict = {} |
|
|
| for module in args.modules: |
| qa_pairs = json.loads((EVAL_DIR / f"{module}_qa.json").read_text()) |
|
|
| print(f"\n{'#' * 100}") |
| print(f"# [{module}] retrieval benchmark β chunking={FIXED_STRATEGY}, dim={FIXED_DIM}") |
| print(f"# stages={args.stages}, track_b={args.track_b}") |
| print(f"{'#' * 100}") |
|
|
| module_summary: dict = {} |
|
|
| |
| stage1_winner = "hybrid_rrf" |
| if "1" in args.stages: |
| print(f"\n--- Stage 1: retrieval method ---") |
| stage1_results: dict = {} |
| for method in STAGE_1_METHODS: |
| print(f"\n βΆ {method}") |
| try: |
| rfn = make_retrieve_fn_method(retriever, module, method) |
| r = run_one_config(qa_pairs, rfn, track_b=args.track_b, |
| module=module, label=method) |
| print(f" NDCG@10={r.get('ndcg',0):.3f} MRR={r.get('mrr',0):.3f} " |
| f"R@5={r.get('recall_at_5',0):.3f} p95={r.get('p95_latency_ms',0):.0f}ms" |
| + (f" TrackB={r.get('track_b_composite',0):.3f}" if args.track_b else "")) |
| stage1_results[method] = r |
| except Exception as e: |
| print(f" β FAILED: {type(e).__name__}: {e}") |
| traceback.print_exc() |
| stage1_results[method] = {"error": f"{type(e).__name__}: {e}"} |
| stage1_winner = best_label_by_ndcg(stage1_results) |
| print(f"\n β Stage 1 winner: {stage1_winner} " |
| f"(NDCG={stage1_results[stage1_winner].get('ndcg',0):.3f})") |
| module_summary["stage_1"] = { |
| "results": stage1_results, |
| "winner": stage1_winner, |
| } |
|
|
| |
| stage2_winner = "none" |
| if "2" in args.stages: |
| print(f"\n--- Stage 2: reranker (retrieval={stage1_winner}) ---") |
| stage2_results: dict = {} |
| for rerk in STAGE_2_RERANKERS: |
| print(f"\n βΆ {rerk}") |
| try: |
| base_fn = make_retrieve_fn_method(retriever, module, stage1_winner) |
| rfn = wrap_with_reranker(base_fn, rerk) |
| r = run_one_config(qa_pairs, rfn, track_b=args.track_b, |
| module=module, label=rerk) |
| print(f" NDCG@10={r.get('ndcg',0):.3f} MRR={r.get('mrr',0):.3f} " |
| f"R@5={r.get('recall_at_5',0):.3f} p95={r.get('p95_latency_ms',0):.0f}ms" |
| + (f" TrackB={r.get('track_b_composite',0):.3f}" if args.track_b else "")) |
| stage2_results[rerk] = r |
| except Exception as e: |
| print(f" β FAILED: {type(e).__name__}: {e}") |
| traceback.print_exc() |
| stage2_results[rerk] = {"error": f"{type(e).__name__}: {e}"} |
| stage2_winner = best_label_by_ndcg(stage2_results) |
| print(f"\n β Stage 2 winner: {stage2_winner} " |
| f"(NDCG={stage2_results[stage2_winner].get('ndcg',0):.3f})") |
| module_summary["stage_2"] = { |
| "results": stage2_results, |
| "winner": stage2_winner, |
| "fixed_retrieval": stage1_winner, |
| } |
|
|
| |
| if "3" in args.stages: |
| print(f"\n--- Stage 3: query transform (retrieval={stage1_winner}, reranker={stage2_winner}) ---") |
| stage3_results: dict = {} |
| for tr in STAGE_3_TRANSFORMS: |
| print(f"\n βΆ {tr}") |
| try: |
| base_fn = make_retrieve_fn_method(retriever, module, stage1_winner) |
| re_fn = wrap_with_reranker(base_fn, stage2_winner) |
| rfn = wrap_with_transform( |
| re_fn, transform=tr, module=module, |
| retriever=retriever, chunk_strategy=FIXED_STRATEGY, |
| embedding_dim=FIXED_DIM, |
| ) |
| r = run_one_config(qa_pairs, rfn, track_b=args.track_b, |
| module=module, label=tr) |
| print(f" NDCG@10={r.get('ndcg',0):.3f} MRR={r.get('mrr',0):.3f} " |
| f"R@5={r.get('recall_at_5',0):.3f} p95={r.get('p95_latency_ms',0):.0f}ms" |
| + (f" TrackB={r.get('track_b_composite',0):.3f}" if args.track_b else "")) |
| stage3_results[tr] = r |
| except Exception as e: |
| print(f" β FAILED: {type(e).__name__}: {e}") |
| traceback.print_exc() |
| stage3_results[tr] = {"error": f"{type(e).__name__}: {e}"} |
| stage3_winner = best_label_by_ndcg(stage3_results) |
| print(f"\n β Stage 3 winner: {stage3_winner} " |
| f"(NDCG={stage3_results[stage3_winner].get('ndcg',0):.3f})") |
| module_summary["stage_3"] = { |
| "results": stage3_results, |
| "winner": stage3_winner, |
| "fixed_retrieval": stage1_winner, |
| "fixed_reranker": stage2_winner, |
| } |
|
|
| |
| out_dir = OUT_DIR / module |
| out_dir.mkdir(parents=True, exist_ok=True) |
| (out_dir / "retrieval_benchmark.json").write_text(json.dumps(module_summary, indent=2)) |
| summary[module] = module_summary |
|
|
| (OUT_DIR / "_retrieval_benchmark_summary.json").write_text(json.dumps(summary, indent=2)) |
|
|
| print(f"\n{'#' * 100}") |
| print("FULL PIPELINE WINNERS") |
| for module in args.modules: |
| if module not in summary: |
| continue |
| s1 = summary[module].get("stage_1", {}).get("winner", "β") |
| s2 = summary[module].get("stage_2", {}).get("winner", "β") |
| s3 = summary[module].get("stage_3", {}).get("winner", "β") |
| print(f" [{module}] retrieval={s1} reranker={s2} transform={s3}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|