substrate / eval /eval_retrieval.py
Syed Taha
refactor: update eval_retrieval.py to support chunking strategy argument
eb951a8
"""
eval/eval_retrieval.py
Evaluates retrieval quality WITHOUT any LLM involvement.
Scoring approach (v2 - keyword-in-context):
- Retrieve top-K chunks for each query
- Concatenate retrieved chunk text into a context string
- Check what % of context_keywords appear in that context
- Pass if keyword coverage >= threshold (default 0.5)
- must_retrieve kept as optional strict secondary check
This measures "does the retrieved context contain the right concepts?"
rather than "did we find these exact function names?" - a much more
meaningful signal for a corpus of 81k chunks.
Usage:
python eval/eval_retrieval.py # all queries, all methods
python eval/eval_retrieval.py --strategy fixed # test different chunking strategies
python eval/eval_retrieval.py --method hybrid # one method only
python eval/eval_retrieval.py --tier 2 # cross-repo queries only
python eval/eval_retrieval.py --query T2-001 # single query
python eval/eval_retrieval.py --verbose # show retrieved chunks
"""
import argparse
import json
import logging
import pickle
import re
import sys
import time
from pathlib import Path
import numpy as np
import yaml
from sentence_transformers import SentenceTransformer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)
# Config
def load_config() -> dict:
with open("config.yaml") as f:
return yaml.safe_load(f)
def load_test_queries(path: str = "eval/test_queries.yaml") -> list[dict]:
with open(path) as f:
data = yaml.safe_load(f)
return data["queries"]
# Load indexes
def load_bm25(cfg: dict, strategy: str = "function") -> tuple:
path = Path(cfg["bm25"]["index_path"].format(chunking=strategy))
log.info("Loading BM25 index from %s...", path)
with path.open("rb") as f:
payload = pickle.load(f)
log.info(" BM25 index ready (%d documents)", len(payload["chunks"]))
return payload["bm25"], payload["chunks"]
def load_chroma(cfg: dict, strategy: str = "function"):
import chromadb
persist_dir = cfg["vector_store"]["chroma"]["persist_directory"]
collection_name = cfg["vector_store"]["chroma"]["collection_name"].format(
chunking=strategy
)
client = chromadb.PersistentClient(path=persist_dir)
collection = client.get_collection(collection_name)
log.info("ChromaDB collection '%s' loaded (%d vectors)",
collection_name, collection.count())
return collection
def load_embed_model(cfg: dict) -> SentenceTransformer:
model_name = cfg["embedding"]["model"]
log.info("Loading embedding model: %s", model_name)
return SentenceTransformer(model_name)
# Tokenizer (must match build_bm25.py)
def tokenize(text: str) -> list[str]:
tokens = re.split(r"[\s\(\)\[\]\{\}\.,;:\"'=\+\-\*/<>!@#\$%\^&\|\\`~]+", text.lower())
return [t for t in tokens if len(t) > 1]
# Retrieval methods
def retrieve_bm25(
query: str,
bm25,
chunks: list[dict],
top_k: int = 10,
) -> list[dict]:
tokens = tokenize(query)
scores = bm25.get_scores(tokens)
top_idx = scores.argsort()[-top_k:][::-1]
results = []
for idx in top_idx:
chunk = chunks[idx].copy()
chunk["_score"] = float(scores[idx])
chunk["_method"] = "bm25"
results.append(chunk)
return results
def retrieve_dense(
query: str,
collection,
model: SentenceTransformer,
top_k: int = 10,
) -> list[dict]:
query_embedding = model.encode(
query, normalize_embeddings=True, show_progress_bar=False, convert_to_numpy=True
).tolist()
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["metadatas", "distances", "documents"],
)
chunks = []
for i, meta in enumerate(results["metadatas"][0]):
chunk = dict(meta)
# ChromaDB stores the ID separately - inject it back so rrf_fusion can key on it
chunk["chunk_id"] = results["ids"][0][i]
# ChromaDB returns distance (lower = better), convert to similarity
chunk["_score"] = 1.0 - results["distances"][0][i]
chunk["_method"] = "dense"
chunk["_text"] = results["documents"][0][i]
chunks.append(chunk)
return chunks
def get_id(chunk: dict) -> str:
return chunk.get("chunk_id") or (
f"{chunk.get('repo','')}::{chunk.get('filepath','')}::"
f"{chunk.get('function_name','')}::{chunk.get('line_start','')}"
)
def rrf_fusion(
bm25_results: list[dict],
dense_results: list[dict],
k: int = 60,
) -> list[dict]:
"""
Reciprocal Rank Fusion - combines two ranked lists.
score(d) = sum(1 / (k + rank_i(d))) for each list i
Uses chunk_id as the deduplication key.
Falls back to function_name::filepath if chunk_id missing.
"""
scores: dict[str, float] = {}
chunk_map: dict[str, dict] = {}
for rank, chunk in enumerate(bm25_results, 1):
cid = get_id(chunk)
scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank)
chunk_map[cid] = chunk
for rank, chunk in enumerate(dense_results, 1):
cid = get_id(chunk)
scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank)
chunk_map[cid] = chunk
sorted_ids = sorted(scores, key=lambda x: scores[x], reverse=True)
fused = []
for cid in sorted_ids:
chunk = chunk_map[cid].copy()
chunk["_score"] = scores[cid]
chunk["_method"] = "hybrid_rrf"
fused.append(chunk)
return fused
def retrieve_hybrid(
query: str,
bm25,
bm25_chunks: list[dict],
collection,
model: SentenceTransformer,
cfg: dict,
top_k: int = 10,
) -> list[dict]:
ret_cfg = cfg["retrieval"]["hybrid"]
bm25_results = retrieve_bm25(query, bm25, bm25_chunks, top_k=ret_cfg["bm25_top_k"])
dense_results = retrieve_dense(query, collection, model, top_k=ret_cfg["dense_top_k"])
return rrf_fusion(bm25_results, dense_results, k=ret_cfg["rrf_k"])[:top_k]
# Evaluation logic
def build_context(chunks: list[dict]) -> str:
"""Concatenate all retrieved chunk text into one searchable string."""
parts = []
for c in chunks:
parts.append(f"{c.get('function_name','')} "
f"{c.get('docstring','')} "
f"{c.get('raw_code', c.get('_text',''))}")
return " ".join(parts).lower()
def evaluate_query(
query_obj: dict,
chunks: list[dict],
top_k: int,
kw_threshold: float = 0.5,
) -> dict:
"""
Primary metric - keyword-in-context:
Does the retrieved context contain the query's expected concepts?
context_keywords field in test_queries.yaml defines what to look for.
Falls back to keywords field if context_keywords not present.
Score = found / total. Pass if score >= kw_threshold.
Secondary metric - must_retrieve (strict):
Are specific function names present in retrieved function names?
Reported but NOT the pass/fail criterion.
"""
context = build_context(chunks[:top_k])
retrieved_fns = [c.get("function_name", "") for c in chunks[:top_k]]
retrieved_repos = [c.get("repo", "") for c in chunks[:top_k]]
# Primary: keyword-in-context
# Use context_keywords if defined, fall back to keywords
kws = query_obj.get("context_keywords") or query_obj.get("keywords", [])
kw_found = [kw for kw in kws if kw.lower() in context]
kw_missed = [kw for kw in kws if kw.lower() not in context]
if kws:
kw_score = len(kw_found) / len(kws)
kw_passed = kw_score >= kw_threshold
else:
kw_score = None
kw_passed = None
# Secondary: must_retrieve
must = query_obj.get("must_retrieve", [])
mr_hits = [fn for fn in must if fn in retrieved_fns]
mr_misses = [fn for fn in must if fn not in retrieved_fns]
mr_score = len(mr_hits) / len(must) if must else None
# Anti-keywords (hallucination check, used in LLM eval later)
anti = query_obj.get("anti_keywords", [])
anti_hits = [kw for kw in anti if kw.lower() in context]
return {
"query_id": query_obj["id"],
"tier": query_obj["tier"],
"repos": query_obj["repos"],
"query": query_obj["query"],
# Primary
"kw_score": kw_score,
"kw_passed": kw_passed,
"kw_found": kw_found,
"kw_missed": kw_missed,
# Secondary
"mr_score": mr_score,
"mr_hits": mr_hits,
"mr_misses": mr_misses,
# Metadata
"anti_hits": anti_hits,
"retrieved_functions": retrieved_fns,
"retrieved_repos": retrieved_repos,
}
def print_report(method, query_results, verbose=False, retrieval_details=None) -> dict:
log.info("")
log.info("=" * 70)
log.info("Retrieval Evaluation - Method: %s (keyword-in-context)", method.upper())
log.info("=" * 70)
verifiable = [r for r in query_results if r["kw_passed"] is not None]
unverifiable = [r for r in query_results if r["kw_passed"] is None]
passed = [r for r in verifiable if r["kw_passed"]]
failed = [r for r in verifiable if not r["kw_passed"]]
for tier in sorted(set(r["tier"] for r in query_results)):
tv = [r for r in verifiable if r["tier"] == tier]
if not tv:
continue
tp = sum(1 for r in tv if r["kw_passed"])
avg = np.mean([r["kw_score"] for r in tv])
log.info(" Tier %d: %d/%d passed (%.0f%%) avg coverage %.2f",
tier, tp, len(tv), 100 * tp / len(tv), avg)
log.info("")
log.info(" Verifiable : %d", len(verifiable))
log.info(" Passed : %d (%.1f%%)",
len(passed), 100 * len(passed) / len(verifiable) if verifiable else 0)
log.info(" Failed : %d", len(failed))
log.info(" Unverifiable: %d (no keywords - needs LLM judge)", len(unverifiable))
if failed:
log.info("")
log.info(" Failed:")
for r in failed:
log.info(" [%s] T%d score=%.2f missed=%s",
r["query_id"], r["tier"], r["kw_score"] or 0, r["kw_missed"])
if verbose and retrieval_details:
log.info("")
log.info(" Per-query chunks:")
for qid, chunks in retrieval_details.items():
r = next(x for x in query_results if x["query_id"] == qid)
log.info(" %s kw=%.2f found=%s",
qid, r["kw_score"] or 0, r["kw_found"])
for i, c in enumerate(chunks[:5], 1):
log.info(" %d. [%.4f] %s::%s::%s",
i, c.get("_score", 0),
c.get("repo",""), c.get("filepath","")[:35],
c.get("function_name",""))
avg_kw = np.mean([r["kw_score"] for r in verifiable
if r["kw_score"] is not None]) if verifiable else 0.0
mr_v = [r for r in query_results if r["mr_score"] is not None]
avg_mr = np.mean([r["mr_score"] for r in mr_v]) if mr_v else 0.0
log.info("")
log.info(" Avg keyword coverage : %.3f (primary)", avg_kw)
log.info(" Avg must_retrieve : %.3f (secondary)", avg_mr)
log.info(" Pass rate : %.1f%%",
100 * len(passed) / len(verifiable) if verifiable else 0)
log.info("=" * 70)
return {
"method": method,
"total": len(query_results),
"verifiable": len(verifiable),
"passed": len(passed),
"failed": len(failed),
"unverifiable": len(unverifiable),
"pass_rate": len(passed) / len(verifiable) if verifiable else 0.0,
"avg_kw_score": float(avg_kw),
"avg_mr_score": float(avg_mr),
}
def save_results(results, output_dir, strategy: str = "function"):
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"retrieval_eval_{strategy}.json"
with path.open("w") as f:
json.dump(results, f, indent=2, default=str)
log.info("Saved to %s", path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--strategy",
type=str,
choices=["function", "fixed", "recursive"],
default="function",
help="Chunking strategy (default: function)",
)
parser.add_argument("--method", choices=["bm25","dense","hybrid","all"], default="all")
parser.add_argument("--tier", type=int, default=None)
parser.add_argument("--query", type=str, default=None)
parser.add_argument("--top-k", type=int, default=10)
parser.add_argument("--kw-threshold", type=float, default=0.5)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
from dotenv import load_dotenv
load_dotenv()
cfg = load_config()
queries = load_test_queries(cfg["evaluation"]["test_queries_path"])
if args.tier:
queries = [q for q in queries if q["tier"] == args.tier]
if args.query:
queries = [q for q in queries if q["id"] == args.query]
if not queries:
log.error("No queries matched.")
sys.exit(1)
log.info("Eval: %d queries | strategy=%s | top_k=%d | kw_threshold=%.1f",
len(queries), args.strategy, args.top_k, args.kw_threshold)
bm25, bm25_chunks = load_bm25(cfg, args.strategy)
collection = load_chroma(cfg, args.strategy)
model = load_embed_model(cfg)
methods = ["bm25","dense","hybrid"] if args.method == "all" else [args.method]
summaries = []
output = {}
for method in methods:
log.info("Running: %s", method)
results = []
details = {}
t0 = time.time()
for q in queries:
if method == "bm25":
chunks = retrieve_bm25(q["query"], bm25, bm25_chunks, args.top_k)
elif method == "dense":
chunks = retrieve_dense(q["query"], collection, model, args.top_k)
else:
chunks = retrieve_hybrid(q["query"], bm25, bm25_chunks,
collection, model, cfg, args.top_k)
results.append(evaluate_query(q, chunks, args.top_k, args.kw_threshold))
details[q["id"]] = chunks
dur = time.time() - t0
log.info(" %.1fs (%.2fs/query)", dur, dur / len(queries))
s = print_report(method, results, verbose=args.verbose,
retrieval_details=details if args.verbose else None)
summaries.append(s)
output[method] = {"summary": s, "per_query": results}
if len(summaries) > 1:
log.info("")
log.info("=" * 70)
log.info("Comparison")
log.info("=" * 70)
log.info(" %-10s %8s %10s %12s %10s",
"Method","Passed","Pass Rate","KW Coverage","MR Score")
log.info(" " + "-"*55)
for s in summaries:
log.info(" %-10s %8d %9.1f%% %11.3f %9.3f",
s["method"], s["passed"],
100*s["pass_rate"], s["avg_kw_score"], s["avg_mr_score"])
log.info("=" * 70)
save_results(output, Path(cfg["evaluation"]["results_dir"]), args.strategy)
if __name__ == "__main__":
main()