BlackBox / src /evaluation /model_comparison.py
AbdullahKhanSherwani's picture
Final updates
4afbbc2 verified
"""
Model comparison: DeepSeek-V3.1 vs GPT-oss-120b
Configurations: section/hybrid and md_recursive/hybrid
With and without cross-encoder reranking
Queries: 3 hardest questions from ablation (ones that caused faithfulness < 1)
Runs 2 models × 2 strategies × 2 rerank modes × 3 queries = 24 evaluations.
Results saved to data/model_comparison.csv
"""
import os
import sys
import time
import pandas as pd
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from src.retrieval.query import load_model, init_pinecone, retrieve
from src.retrieval.hybrid import build_bm25_index, load_reranker, hybrid_retrieve
from src.generation.generate import generate_answer
from src.evaluation.evaluate import (
compute_faithfulness, compute_relevancy, build_faithfulness_context,
_load_cache, _save_cache, CACHE_PATH,
)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
OUTPUT_PATH = os.path.join(BASE_DIR, "data", "model_comparison.csv")
# 3 hardest questions — these caused faithfulness < 1 in the ablation study
HARD_QUERIES = [
"Based strictly on AAR-00/03 for TWA Flight 800, which debris field was the smallest and what were the exact fuselage station markers for the wreckage it contained?",
"Based strictly on AAR-00/01 for Korean Air Flight 801, what was the exact decision height for the ILS approach at Guam and at what altitude did the crew first receive a GPWS warning?",
"Based strictly on AAR-14/01 for Asiana Airlines Flight 214, state the exact CVR timestamp in hours, minutes, and seconds when the stick shaker first activated, the exact indicated airspeed at that moment, and the exact radio altitude recorded simultaneously by the FDR.",
]
STRATEGIES = ["section", "md_recursive"]
MODELS = ["gpt", "deepseek"]
RERANK_MODES = [True, False] # True = with cross-encoder, False = RRF only
TOP_K = 15
def run_one(query, strategy, llm_model, use_reranker,
jina_model, index, bm25_cache, reranker, cache):
t_total = time.perf_counter()
bm25, chunks = bm25_cache[strategy]
t_ret = time.perf_counter()
if use_reranker:
matches = hybrid_retrieve(
query, strategy, top_k=TOP_K,
bm25=bm25, chunks=chunks,
reranker=reranker,
model=jina_model, index=index,
)
else:
matches = hybrid_retrieve(
query, strategy, top_k=TOP_K,
bm25=bm25, chunks=chunks,
reranker=None,
model=jina_model, index=index,
)
retrieval_time = round(time.perf_counter() - t_ret, 3)
faithfulness_context = build_faithfulness_context(query, matches)
t_gen = time.perf_counter()
answer = generate_answer(query, matches, llm_provider=llm_model)
generation_time = round(time.perf_counter() - t_gen, 3)
total_time = round(time.perf_counter() - t_total, 3)
faith_score, _ = compute_faithfulness(answer, [faithfulness_context], query=query, cache=cache)
rel_score, _ = compute_relevancy(query, answer, jina_model, cache=cache)
return {
"query": query[:80] + "..." if len(query) > 80 else query,
"strategy": strategy,
"llm_model": llm_model,
"cross_encoder": use_reranker,
"faithfulness": round(faith_score, 3),
"relevancy": round(rel_score, 3),
"retrieval_time": retrieval_time,
"generation_time": generation_time,
"total_time": total_time,
"answer_snippet": answer[:300].replace("\n", " "),
}
def main():
print("Loading models...")
jina_model = load_model()
index = init_pinecone()
reranker = load_reranker()
cache = _load_cache(CACHE_PATH)
bm25_cache = {}
for s in STRATEGIES:
print(f"Building BM25 index for {s}...")
bm25_cache[s] = build_bm25_index(s)
results = []
total = len(HARD_QUERIES) * len(STRATEGIES) * len(MODELS) * len(RERANK_MODES)
done = 0
for query in HARD_QUERIES:
for strategy in STRATEGIES:
for llm_model in MODELS:
for use_reranker in RERANK_MODES:
done += 1
rerank_label = "w/ cross-encoder" if use_reranker else "RRF only"
print(f" [{done:>2}/{total}] {strategy} | {llm_model} | {rerank_label}")
try:
row = run_one(query, strategy, llm_model, use_reranker,
jina_model, index, bm25_cache, reranker, cache)
results.append(row)
except Exception as e:
print(f" ERROR: {e}")
_save_cache(cache, CACHE_PATH)
df = pd.DataFrame(results)
df.to_csv(OUTPUT_PATH, index=False)
print(f"\nResults saved to {OUTPUT_PATH}")
# Summary table
summary = df.groupby(["strategy", "llm_model", "cross_encoder"])[
["faithfulness", "relevancy", "retrieval_time", "generation_time"]
].mean().round(3)
print("\n=== MODEL COMPARISON SUMMARY ===")
print(summary.to_string())
if __name__ == "__main__":
main()