File size: 3,952 Bytes
a3c924f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import json
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from src.pipeline.query_pipeline import QueryPipeline
from src.eval.retrieval_metrics import recall_at_k, mrr_score, precision_at_k
from src.eval.hallucination import HallucinationGrader
from src.eval.relevancy import RelevancyGrader

def main():
    print("Initializing Pipeline...")
    pipeline = QueryPipeline()
    grader = HallucinationGrader(pipeline.llm)
    relevancy_grader = RelevancyGrader(pipeline.llm)

    print("Loading Evaluation Data (WikiQA Test Split)...")
    # For meaningful evaluation, we need questions that actually have answers in our indexed subset.
    # Since we indexed the 'train' split of WikiQA (first 100), we should evaluate on that same subset 
    # to test "retrieval ability" (can it find what we vaguely know is there).
    # In a real scenario, you'd test on a hold-out set, but only if you indexed the whole knowledge base.
    try:
        ds = load_dataset("microsoft/wiki_qa", split="train[:20]", trust_remote_code=True)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    # Metrics
    recalls = []
    precisions = []
    mrrs = []
    hallucination_scores = []
    relevancy_scores = []
    
    print("Running Evaluation...")
    for i, row in tqdm(enumerate(ds), total=len(ds)):
        query = row['question']
        relevant_doc_content = row['answer'] # The correct sentence
        is_correct = row['label'] == 1




        if not is_correct:
            # If this row isn't a correct answer pair, skip for retrieval accuracy measurement 
            # (or treat as a negative, but for RAG recall we usually care about positive queries)
            continue
            
        result = pipeline.run(query, top_k_retrieval=10, top_k_rerank=3)
        
        # Retrieval Metrics
        retrieved_contents = [doc if isinstance(doc, str) else doc['content'] for doc, score in result['context']]
        


        
        # Check if relevant content is in retrieved
        # The ingestion pipeline might add metadata like "Source: ...". 
        # So we check if the relevant content SUBSTRING is in the retrieved chunks.
        is_hit = False
        for content in retrieved_contents:
            if relevant_doc_content in content:
                is_hit = True
                break
        
        recalls.append(1.0 if is_hit else 0.0)
        # Precision (strict: is the retrieved doc the specific sentence?)
        # Since we only retrieve 10 and usually there is only 1 relevant sentence in WikiQA:
        # Precision will be at best 0.1 if is_hit is true.
        precisions.append(1.0/10.0 if is_hit else 0.0)
        
        # MRR
        # Find rank
        rank = -1
        for idx, content in enumerate(retrieved_contents):
            if relevant_doc_content in content:
                rank = idx + 1
                break
        
        if rank > 0:
            mrrs.append(1.0 / rank)
        else:
            mrrs.append(0.0)
            
        # Generation / Hallucination Metric
        # We ask the LLM to grade if the answer supported by context
        grade = grader.grade(
            context="\n".join(retrieved_contents),
            answer=result['answer']
        )
        hallucination_scores.append(grade.get('score', 0.0))
        
        # New: Answer Relevancy
        rel_grade = relevancy_grader.grade(query=query, answer=result['answer'])
        relevancy_scores.append(rel_grade.get('score', 0.0))
        
    print("\nXXX Evaluation Results XXX")
    print(f"Average Recall@10: {np.mean(recalls):.4f}")
    print(f"Average Precision@10: {np.mean(precisions):.4f}")
    print(f"Average MRR: {np.mean(mrrs):.4f}")
    print(f"Average Factuality Score: {np.mean(hallucination_scores):.4f}")
    print(f"Average Answer Relevancy: {np.mean(relevancy_scores):.4f}")

if __name__ == "__main__":
    main()