Spaces:
Sleeping
Sleeping
XQ commited on
Commit ·
eef89e0
1
Parent(s): cc7b6b4
Add eval script
Browse files- scripts/evaluate.py +214 -0
scripts/evaluate.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAGAS evaluation script: runs a fixed QA test set through the RAG pipeline
|
| 2 |
+
and reports retrieval + generation quality metrics.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python -m scripts.evaluate [--top-k 5] [--retrieval-only]
|
| 6 |
+
|
| 7 |
+
Output:
|
| 8 |
+
A table of RAGAS scores printed to stdout.
|
| 9 |
+
- Full mode: faithfulness, answer_relevancy, context_precision, context_recall
|
| 10 |
+
- Retrieval-only: context_precision, context_recall
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import sys
|
| 18 |
+
import tempfile
|
| 19 |
+
|
| 20 |
+
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 21 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 22 |
+
|
| 23 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 24 |
+
|
| 25 |
+
from src.config import load_settings
|
| 26 |
+
from src.evaluation.evaluator import RAGEvaluator
|
| 27 |
+
from src.models import ChunkStrategy
|
| 28 |
+
from src.provider import create_embeddings, create_llm, create_reranker
|
| 29 |
+
from src.ingestion.pipeline import IngestionPipeline
|
| 30 |
+
from src.retrieval.embedder import Embedder
|
| 31 |
+
from src.retrieval.vector_store import VectorStore
|
| 32 |
+
from src.retrieval.bm25_search import BM25Search
|
| 33 |
+
from src.retrieval.hybrid import HybridRetriever
|
| 34 |
+
from src.retrieval.reranker import Reranker
|
| 35 |
+
from src.agent.intent_classifier import IntentClassifier
|
| 36 |
+
from src.agent.router import QueryRouter
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
DOCS_DIR = os.path.join(PROJECT_ROOT, "docs")
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Test set: (question, ground_truth)
|
| 44 |
+
# Questions are in Danish to match the document language.
|
| 45 |
+
# Ground truths are short reference answers used by RAGAS context_recall.
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
TEST_SET: list[tuple[str, str]] = [
|
| 48 |
+
(
|
| 49 |
+
"Hvad er reglerne for brug af AI på KU?",
|
| 50 |
+
"KU har retningslinjer for ansvarlig brug af AI-værktøjer, "
|
| 51 |
+
"herunder krav om gennemsigtighed og akademisk integritet.",
|
| 52 |
+
),
|
| 53 |
+
(
|
| 54 |
+
"Hvilke regler gælder for behandling af personoplysninger?",
|
| 55 |
+
"Behandling af personoplysninger på KU skal ske i overensstemmelse "
|
| 56 |
+
"med GDPR og universitetets databeskyttelsespolitik.",
|
| 57 |
+
),
|
| 58 |
+
(
|
| 59 |
+
"Hvad er KUs politik for informationssikkerhed?",
|
| 60 |
+
"KU kræver, at medarbejdere følger informationssikkerhedspolitikken, "
|
| 61 |
+
"herunder adgangskontrol og beskyttelse af følsomme data.",
|
| 62 |
+
),
|
| 63 |
+
(
|
| 64 |
+
"Hvordan håndteres brud på datasikkerheden?",
|
| 65 |
+
"Sikkerhedsbrud skal indberettes til IT-sikkerhedsteamet inden for "
|
| 66 |
+
"72 timer i overensstemmelse med GDPR-kravene.",
|
| 67 |
+
),
|
| 68 |
+
(
|
| 69 |
+
"Hvad er reglerne for eksamen og snyd?",
|
| 70 |
+
"KU har regler for eksamenssnyd, herunder konsekvenser som bortvisning "
|
| 71 |
+
"fra eksamen og i alvorlige tilfælde bortvisning fra universitetet.",
|
| 72 |
+
),
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def parse_args() -> argparse.Namespace:
|
| 77 |
+
"""Parse command-line arguments.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Parsed argument namespace.
|
| 81 |
+
"""
|
| 82 |
+
parser = argparse.ArgumentParser(
|
| 83 |
+
description="Run RAGAS evaluation over a fixed QA test set.",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--top-k",
|
| 87 |
+
type=int,
|
| 88 |
+
default=5,
|
| 89 |
+
help="Number of retrieved chunks per query (default: 5).",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--retrieval-only",
|
| 93 |
+
action="store_true",
|
| 94 |
+
help="Only measure context_precision and context_recall (no generation).",
|
| 95 |
+
)
|
| 96 |
+
return parser.parse_args()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main() -> None:
|
| 100 |
+
"""Build the RAG pipeline, run the test set, and print RAGAS scores."""
|
| 101 |
+
args = parse_args()
|
| 102 |
+
settings = load_settings()
|
| 103 |
+
|
| 104 |
+
logging.basicConfig(
|
| 105 |
+
level=getattr(logging, settings.log_level.upper(), logging.INFO),
|
| 106 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
logger.info("=== RAGAS Evaluation Start ===")
|
| 110 |
+
logger.info("Test set size: %d | top_k: %d | retrieval_only: %s",
|
| 111 |
+
len(TEST_SET), args.top_k, args.retrieval_only)
|
| 112 |
+
|
| 113 |
+
qdrant_tmp = tempfile.mkdtemp(prefix="eval_qdrant_")
|
| 114 |
+
logger.info("Qdrant temp path: %s", qdrant_tmp)
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# --- 1) Create providers ---
|
| 118 |
+
llm = create_llm(settings)
|
| 119 |
+
embeddings = create_embeddings(settings)
|
| 120 |
+
|
| 121 |
+
# --- 2) Ingest docs ---
|
| 122 |
+
logger.info("Ingesting PDFs from %s ...", DOCS_DIR)
|
| 123 |
+
pipeline = IngestionPipeline(
|
| 124 |
+
strategy=ChunkStrategy.RECURSIVE,
|
| 125 |
+
chunk_size=settings.chunk_size,
|
| 126 |
+
chunk_overlap=settings.chunk_overlap,
|
| 127 |
+
)
|
| 128 |
+
chunks = pipeline.ingest_directory(DOCS_DIR)
|
| 129 |
+
logger.info("Total chunks: %d", len(chunks))
|
| 130 |
+
|
| 131 |
+
if not chunks:
|
| 132 |
+
logger.error("No chunks produced. Place PDFs in docs/ and retry.")
|
| 133 |
+
sys.exit(1)
|
| 134 |
+
|
| 135 |
+
# --- 3) Embed and index ---
|
| 136 |
+
embedder = Embedder(embeddings)
|
| 137 |
+
vectors = embedder.embed_batch([c.text for c in chunks])
|
| 138 |
+
|
| 139 |
+
vector_store = VectorStore(
|
| 140 |
+
path=qdrant_tmp,
|
| 141 |
+
collection_name="eval",
|
| 142 |
+
dimension=settings.embedding_dimension,
|
| 143 |
+
)
|
| 144 |
+
vector_store.add_chunks(chunks, vectors)
|
| 145 |
+
|
| 146 |
+
bm25 = BM25Search()
|
| 147 |
+
bm25.index(chunks)
|
| 148 |
+
|
| 149 |
+
# --- 4) Build router ---
|
| 150 |
+
hybrid = HybridRetriever(
|
| 151 |
+
vector_store=vector_store,
|
| 152 |
+
bm25_search=bm25,
|
| 153 |
+
embedder=embedder,
|
| 154 |
+
dense_weight=settings.dense_weight,
|
| 155 |
+
bm25_weight=settings.bm25_weight,
|
| 156 |
+
)
|
| 157 |
+
reranker = Reranker(model=create_reranker(settings.reranker_model))
|
| 158 |
+
classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
|
| 159 |
+
generator = llm | StrOutputParser()
|
| 160 |
+
router = QueryRouter(
|
| 161 |
+
intent_classifier=classifier,
|
| 162 |
+
hybrid_retriever=hybrid,
|
| 163 |
+
reranker=reranker,
|
| 164 |
+
generator=generator,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# --- 5) Run test set ---
|
| 168 |
+
questions: list[str] = []
|
| 169 |
+
answers: list[str] = []
|
| 170 |
+
contexts: list[list[str]] = []
|
| 171 |
+
ground_truths: list[str] = []
|
| 172 |
+
|
| 173 |
+
for question, ground_truth in TEST_SET:
|
| 174 |
+
logger.info("Running query: %s", question)
|
| 175 |
+
response = router.route(query=question, top_k=args.top_k)
|
| 176 |
+
questions.append(question)
|
| 177 |
+
answers.append(response.answer)
|
| 178 |
+
contexts.append([r.chunk.text for r in response.sources])
|
| 179 |
+
ground_truths.append(ground_truth)
|
| 180 |
+
|
| 181 |
+
# --- 6) Evaluate ---
|
| 182 |
+
evaluator = RAGEvaluator(llm=llm)
|
| 183 |
+
|
| 184 |
+
if args.retrieval_only:
|
| 185 |
+
scores = evaluator.evaluate_retrieval(
|
| 186 |
+
questions=questions,
|
| 187 |
+
contexts=contexts,
|
| 188 |
+
ground_truths=ground_truths,
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
scores = evaluator.evaluate(
|
| 192 |
+
questions=questions,
|
| 193 |
+
answers=answers,
|
| 194 |
+
contexts=contexts,
|
| 195 |
+
ground_truths=ground_truths,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# --- 7) Print results ---
|
| 199 |
+
print("\n" + "=" * 50)
|
| 200 |
+
print("RAGAS EVALUATION RESULTS")
|
| 201 |
+
print("=" * 50)
|
| 202 |
+
for metric, score in scores.items():
|
| 203 |
+
print(f" {metric:<30} {score:.4f}")
|
| 204 |
+
print("=" * 50)
|
| 205 |
+
|
| 206 |
+
logger.info("=== RAGAS Evaluation Complete ===")
|
| 207 |
+
|
| 208 |
+
finally:
|
| 209 |
+
shutil.rmtree(qdrant_tmp, ignore_errors=True)
|
| 210 |
+
logger.info("Cleaned up temp Qdrant at %s", qdrant_tmp)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
main()
|