File size: 4,182 Bytes
26fe9a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import sys
import json
import argparse
from typing import List, Dict
from dotenv import load_dotenv

load_dotenv()

# Add root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from services.rag.retrieve import get_retriever
from services.rag.rerank import get_reranker
from services.rag.generate import get_generator
from eval.metrics import calculate_recall, calculate_mrr
from eval.judge import Judge

def load_dataset(path: str) -> List[Dict]:
    data = []
    with open(path, 'r') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data

def run_eval(data_path: str, report_dir: str):
    print(f"Loading dataset from {data_path}...")
    dataset = load_dataset(data_path)
    
    # Init services
    retriever = get_retriever() # Assumes default data/index
    reranker = get_reranker()
    generator = get_generator()
    judge = Judge()
    
    results = []
    
    total_recall = 0
    total_mrr = 0
    total_grounding = 0
    total_correctness = 0
    count = 0
    
    print(f"Running eval on {len(dataset)} examples...")
    
    for item in dataset:
        qid = item.get('id')
        question = item['question']
        gold_sources = item.get('gold_sources', [])
        
        # 1. Retrieve
        retrieved = retriever.retrieve(question, top_k=10)
        retrieved_ids = [c['metadata']['doc_id'] for c in retrieved] # Check chunk_id or doc_id? metrics.py checks substring
        retrieved_ids_full = [c['metadata']['chunk_id'] for c in retrieved]

        # 2. Rerank
        reranked = reranker.rerank(question, retrieved, top_k=5)
        
        # 3. Generate
        answer = generator.generate(question, reranked)
        
        # 4. Compute Metrics
        recall = calculate_recall(retrieved_ids_full, gold_sources)
        mrr = calculate_mrr(retrieved_ids_full, gold_sources)
        
        # 5. Judge
        # Concatenate context for judge
        context_text = "\n".join([c['content'] for c in reranked])
        # Only run judge if we have an API Key, else skip
        if os.getenv("OPENAI_API_KEY"):
            eval_res = judge.evaluate(question, context_text, answer)
        else:
            eval_res = {"grounding": 0, "correctness": 0, "reasoning": "No API Key"}
            
        result_entry = {
            "id": qid,
            "question": question,
            "answer": answer,
            "metrics": {
                "recall@10": recall,
                "mrr": mrr,
                "grounding": eval_res.get('grounding'),
                "correctness": eval_res.get('correctness')
            },
            "judge_reasoning": eval_res.get('reasoning')
        }
        results.append(result_entry)
        
        total_recall += recall
        total_mrr += mrr
        total_grounding += eval_res.get('grounding', 0)
        total_correctness += eval_res.get('correctness', 0)
        count += 1
        
        print(f"Eval {qid}: Recall={recall:.2f}, MRR={mrr:.2f}")

    # Aggregate
    if count > 0:
        avg_results = {
            "avg_recall@10": total_recall / count,
            "avg_mrr": total_mrr / count,
            "avg_grounding": total_grounding / count,
            "avg_correctness": total_correctness / count
        }
    else:
        avg_results = {}
        
    print("\nResults:", avg_results)
    
    # Save Report
    os.makedirs(report_dir, exist_ok=True)
    with open(os.path.join(report_dir, "eval_report.json"), 'w') as f:
        json.dump({"summary": avg_results, "details": results}, f, indent=2)
        
    with open(os.path.join(report_dir, "eval_report.md"), 'w') as f:
        f.write("# Evaluation Report\n\n")
        f.write("## Summary\n")
        for k, v in avg_results.items():
            f.write(f"- **{k}**: {v:.4f}\n")
        f.write("\n## Details\n")
        # Write top 5 failures? 
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", required=True)
    parser.add_argument("--report", default="reports")
    args = parser.parse_args()
    
    run_eval(args.data, args.report)