XQ commited on
Commit
eef89e0
·
1 Parent(s): cc7b6b4

Add eval script

Browse files
Files changed (1) hide show
  1. scripts/evaluate.py +214 -0
scripts/evaluate.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAGAS evaluation script: runs a fixed QA test set through the RAG pipeline
2
+ and reports retrieval + generation quality metrics.
3
+
4
+ Usage:
5
+ python -m scripts.evaluate [--top-k 5] [--retrieval-only]
6
+
7
+ Output:
8
+ A table of RAGAS scores printed to stdout.
9
+ - Full mode: faithfulness, answer_relevancy, context_precision, context_recall
10
+ - Retrieval-only: context_precision, context_recall
11
+ """
12
+
13
+ import argparse
14
+ import logging
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import tempfile
19
+
20
+ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
21
+ sys.path.insert(0, PROJECT_ROOT)
22
+
23
+ from langchain_core.output_parsers import StrOutputParser
24
+
25
+ from src.config import load_settings
26
+ from src.evaluation.evaluator import RAGEvaluator
27
+ from src.models import ChunkStrategy
28
+ from src.provider import create_embeddings, create_llm, create_reranker
29
+ from src.ingestion.pipeline import IngestionPipeline
30
+ from src.retrieval.embedder import Embedder
31
+ from src.retrieval.vector_store import VectorStore
32
+ from src.retrieval.bm25_search import BM25Search
33
+ from src.retrieval.hybrid import HybridRetriever
34
+ from src.retrieval.reranker import Reranker
35
+ from src.agent.intent_classifier import IntentClassifier
36
+ from src.agent.router import QueryRouter
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ DOCS_DIR = os.path.join(PROJECT_ROOT, "docs")
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Test set: (question, ground_truth)
44
+ # Questions are in Danish to match the document language.
45
+ # Ground truths are short reference answers used by RAGAS context_recall.
46
+ # ---------------------------------------------------------------------------
47
+ TEST_SET: list[tuple[str, str]] = [
48
+ (
49
+ "Hvad er reglerne for brug af AI på KU?",
50
+ "KU har retningslinjer for ansvarlig brug af AI-værktøjer, "
51
+ "herunder krav om gennemsigtighed og akademisk integritet.",
52
+ ),
53
+ (
54
+ "Hvilke regler gælder for behandling af personoplysninger?",
55
+ "Behandling af personoplysninger på KU skal ske i overensstemmelse "
56
+ "med GDPR og universitetets databeskyttelsespolitik.",
57
+ ),
58
+ (
59
+ "Hvad er KUs politik for informationssikkerhed?",
60
+ "KU kræver, at medarbejdere følger informationssikkerhedspolitikken, "
61
+ "herunder adgangskontrol og beskyttelse af følsomme data.",
62
+ ),
63
+ (
64
+ "Hvordan håndteres brud på datasikkerheden?",
65
+ "Sikkerhedsbrud skal indberettes til IT-sikkerhedsteamet inden for "
66
+ "72 timer i overensstemmelse med GDPR-kravene.",
67
+ ),
68
+ (
69
+ "Hvad er reglerne for eksamen og snyd?",
70
+ "KU har regler for eksamenssnyd, herunder konsekvenser som bortvisning "
71
+ "fra eksamen og i alvorlige tilfælde bortvisning fra universitetet.",
72
+ ),
73
+ ]
74
+
75
+
76
+ def parse_args() -> argparse.Namespace:
77
+ """Parse command-line arguments.
78
+
79
+ Returns:
80
+ Parsed argument namespace.
81
+ """
82
+ parser = argparse.ArgumentParser(
83
+ description="Run RAGAS evaluation over a fixed QA test set.",
84
+ )
85
+ parser.add_argument(
86
+ "--top-k",
87
+ type=int,
88
+ default=5,
89
+ help="Number of retrieved chunks per query (default: 5).",
90
+ )
91
+ parser.add_argument(
92
+ "--retrieval-only",
93
+ action="store_true",
94
+ help="Only measure context_precision and context_recall (no generation).",
95
+ )
96
+ return parser.parse_args()
97
+
98
+
99
+ def main() -> None:
100
+ """Build the RAG pipeline, run the test set, and print RAGAS scores."""
101
+ args = parse_args()
102
+ settings = load_settings()
103
+
104
+ logging.basicConfig(
105
+ level=getattr(logging, settings.log_level.upper(), logging.INFO),
106
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
107
+ )
108
+
109
+ logger.info("=== RAGAS Evaluation Start ===")
110
+ logger.info("Test set size: %d | top_k: %d | retrieval_only: %s",
111
+ len(TEST_SET), args.top_k, args.retrieval_only)
112
+
113
+ qdrant_tmp = tempfile.mkdtemp(prefix="eval_qdrant_")
114
+ logger.info("Qdrant temp path: %s", qdrant_tmp)
115
+
116
+ try:
117
+ # --- 1) Create providers ---
118
+ llm = create_llm(settings)
119
+ embeddings = create_embeddings(settings)
120
+
121
+ # --- 2) Ingest docs ---
122
+ logger.info("Ingesting PDFs from %s ...", DOCS_DIR)
123
+ pipeline = IngestionPipeline(
124
+ strategy=ChunkStrategy.RECURSIVE,
125
+ chunk_size=settings.chunk_size,
126
+ chunk_overlap=settings.chunk_overlap,
127
+ )
128
+ chunks = pipeline.ingest_directory(DOCS_DIR)
129
+ logger.info("Total chunks: %d", len(chunks))
130
+
131
+ if not chunks:
132
+ logger.error("No chunks produced. Place PDFs in docs/ and retry.")
133
+ sys.exit(1)
134
+
135
+ # --- 3) Embed and index ---
136
+ embedder = Embedder(embeddings)
137
+ vectors = embedder.embed_batch([c.text for c in chunks])
138
+
139
+ vector_store = VectorStore(
140
+ path=qdrant_tmp,
141
+ collection_name="eval",
142
+ dimension=settings.embedding_dimension,
143
+ )
144
+ vector_store.add_chunks(chunks, vectors)
145
+
146
+ bm25 = BM25Search()
147
+ bm25.index(chunks)
148
+
149
+ # --- 4) Build router ---
150
+ hybrid = HybridRetriever(
151
+ vector_store=vector_store,
152
+ bm25_search=bm25,
153
+ embedder=embedder,
154
+ dense_weight=settings.dense_weight,
155
+ bm25_weight=settings.bm25_weight,
156
+ )
157
+ reranker = Reranker(model=create_reranker(settings.reranker_model))
158
+ classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
159
+ generator = llm | StrOutputParser()
160
+ router = QueryRouter(
161
+ intent_classifier=classifier,
162
+ hybrid_retriever=hybrid,
163
+ reranker=reranker,
164
+ generator=generator,
165
+ )
166
+
167
+ # --- 5) Run test set ---
168
+ questions: list[str] = []
169
+ answers: list[str] = []
170
+ contexts: list[list[str]] = []
171
+ ground_truths: list[str] = []
172
+
173
+ for question, ground_truth in TEST_SET:
174
+ logger.info("Running query: %s", question)
175
+ response = router.route(query=question, top_k=args.top_k)
176
+ questions.append(question)
177
+ answers.append(response.answer)
178
+ contexts.append([r.chunk.text for r in response.sources])
179
+ ground_truths.append(ground_truth)
180
+
181
+ # --- 6) Evaluate ---
182
+ evaluator = RAGEvaluator(llm=llm)
183
+
184
+ if args.retrieval_only:
185
+ scores = evaluator.evaluate_retrieval(
186
+ questions=questions,
187
+ contexts=contexts,
188
+ ground_truths=ground_truths,
189
+ )
190
+ else:
191
+ scores = evaluator.evaluate(
192
+ questions=questions,
193
+ answers=answers,
194
+ contexts=contexts,
195
+ ground_truths=ground_truths,
196
+ )
197
+
198
+ # --- 7) Print results ---
199
+ print("\n" + "=" * 50)
200
+ print("RAGAS EVALUATION RESULTS")
201
+ print("=" * 50)
202
+ for metric, score in scores.items():
203
+ print(f" {metric:<30} {score:.4f}")
204
+ print("=" * 50)
205
+
206
+ logger.info("=== RAGAS Evaluation Complete ===")
207
+
208
+ finally:
209
+ shutil.rmtree(qdrant_tmp, ignore_errors=True)
210
+ logger.info("Cleaned up temp Qdrant at %s", qdrant_tmp)
211
+
212
+
213
+ if __name__ == "__main__":
214
+ main()