Dheeraj-13's picture
Initial deployment of RAG Assistant
26fe9a7
import os
import sys
import json
import argparse
from typing import List, Dict
from dotenv import load_dotenv
load_dotenv()
# Add root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from services.rag.retrieve import get_retriever
from services.rag.rerank import get_reranker
from services.rag.generate import get_generator
from eval.metrics import calculate_recall, calculate_mrr
from eval.judge import Judge
def load_dataset(path: str) -> List[Dict]:
data = []
with open(path, 'r') as f:
for line in f:
if line.strip():
data.append(json.loads(line))
return data
def run_eval(data_path: str, report_dir: str):
print(f"Loading dataset from {data_path}...")
dataset = load_dataset(data_path)
# Init services
retriever = get_retriever() # Assumes default data/index
reranker = get_reranker()
generator = get_generator()
judge = Judge()
results = []
total_recall = 0
total_mrr = 0
total_grounding = 0
total_correctness = 0
count = 0
print(f"Running eval on {len(dataset)} examples...")
for item in dataset:
qid = item.get('id')
question = item['question']
gold_sources = item.get('gold_sources', [])
# 1. Retrieve
retrieved = retriever.retrieve(question, top_k=10)
retrieved_ids = [c['metadata']['doc_id'] for c in retrieved] # Check chunk_id or doc_id? metrics.py checks substring
retrieved_ids_full = [c['metadata']['chunk_id'] for c in retrieved]
# 2. Rerank
reranked = reranker.rerank(question, retrieved, top_k=5)
# 3. Generate
answer = generator.generate(question, reranked)
# 4. Compute Metrics
recall = calculate_recall(retrieved_ids_full, gold_sources)
mrr = calculate_mrr(retrieved_ids_full, gold_sources)
# 5. Judge
# Concatenate context for judge
context_text = "\n".join([c['content'] for c in reranked])
# Only run judge if we have an API Key, else skip
if os.getenv("OPENAI_API_KEY"):
eval_res = judge.evaluate(question, context_text, answer)
else:
eval_res = {"grounding": 0, "correctness": 0, "reasoning": "No API Key"}
result_entry = {
"id": qid,
"question": question,
"answer": answer,
"metrics": {
"recall@10": recall,
"mrr": mrr,
"grounding": eval_res.get('grounding'),
"correctness": eval_res.get('correctness')
},
"judge_reasoning": eval_res.get('reasoning')
}
results.append(result_entry)
total_recall += recall
total_mrr += mrr
total_grounding += eval_res.get('grounding', 0)
total_correctness += eval_res.get('correctness', 0)
count += 1
print(f"Eval {qid}: Recall={recall:.2f}, MRR={mrr:.2f}")
# Aggregate
if count > 0:
avg_results = {
"avg_recall@10": total_recall / count,
"avg_mrr": total_mrr / count,
"avg_grounding": total_grounding / count,
"avg_correctness": total_correctness / count
}
else:
avg_results = {}
print("\nResults:", avg_results)
# Save Report
os.makedirs(report_dir, exist_ok=True)
with open(os.path.join(report_dir, "eval_report.json"), 'w') as f:
json.dump({"summary": avg_results, "details": results}, f, indent=2)
with open(os.path.join(report_dir, "eval_report.md"), 'w') as f:
f.write("# Evaluation Report\n\n")
f.write("## Summary\n")
for k, v in avg_results.items():
f.write(f"- **{k}**: {v:.4f}\n")
f.write("\n## Details\n")
# Write top 5 failures?
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", required=True)
parser.add_argument("--report", default="reports")
args = parser.parse_args()
run_eval(args.data, args.report)