raghub-sevima-test / scripts /rag_eval_loop.py
lifedebugger's picture
Deploy files from GitHub repository with LFS
9b788cc
#!/usr/bin/env python3
"""
Small, repeatable RAG end-to-end evaluation loop.
This script mirrors real student usage through HTTP APIs:
1. Register/login lecturer and student users
2. Upload one or more course documents as lecturer
3. Poll until ingestion status is ready/failed
4. Create student chat session
5. Send student questions and score answers
Usage:
python scripts/rag_eval_loop.py --eval-file scripts/rag_eval_sample.json
Expected eval file shape:
{
"course_id": "course-demo-001",
"documents": [
{"path": "doc/samples/intro.txt"}
],
"questions": [
{
"question": "What is ...?",
"expected_all": ["..."],
"expected_any": ["...", "..."],
"forbidden": ["..."],
"expect_abstain": false
}
]
}
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sys
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
from requests.exceptions import ReadTimeout
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
DEFAULT_TIMEOUT_SECONDS = 300
DEFAULT_POLL_INTERVAL_SECONDS = 2
DEFAULT_ABSTAIN_PHRASE = "I don't know based on the available materials"
@dataclass
class EvalCaseResult:
index: int
question: str
passed: bool
answer_passed: bool
retrieval_checked: bool
retrieval_passed: bool
faithfulness_passed: bool
reasons: List[str]
answer_text: str
sources: List[Dict[str, Any]]
class EvalRunner:
def __init__(self, base_url: str, api_prefix: str, request_timeout: int) -> None:
self.base_url = base_url.rstrip("/")
self.api_prefix = api_prefix.rstrip("/")
self.request_timeout = request_timeout
def _url(self, path: str) -> str:
return f"{self.base_url}{self.api_prefix}{path}"
def _request(
self,
method: str,
path: str,
*,
token: Optional[str] = None,
expected: Optional[List[int]] = None,
timeout_seconds: Optional[int] = None,
**kwargs: Any,
) -> requests.Response:
headers = kwargs.pop("headers", {})
if token:
headers["Authorization"] = f"Bearer {token}"
response = requests.request(
method,
self._url(path),
headers=headers,
timeout=timeout_seconds if timeout_seconds is not None else self.request_timeout,
**kwargs,
)
if expected is not None and response.status_code not in expected:
raise RuntimeError(
f"{method} {path} failed: status={response.status_code} body={response.text}"
)
return response
def ensure_user(self, *, role: str, email: str, password: str, name: str, identity_number: str) -> str:
payload = {
"email": email,
"password": password,
"name": name,
"role": role,
"identity_number": identity_number,
}
reg = self._request("POST", "/auth/register", json=payload, expected=[201, 400, 422])
if reg.status_code == 201:
print(f"[INFO] Registered {role}: {email}")
self._activate_registered_user(email)
elif reg.status_code == 400:
print(f"[INFO] Using existing {role}: {email}")
else:
raise RuntimeError(
f"POST /auth/register failed: status={reg.status_code} body={reg.text}"
)
login = self._request(
"POST",
"/auth/login",
json={"email": email, "password": password},
expected=[200],
)
token = login.json()["token"]
return token
@staticmethod
def _activate_registered_user(email: str) -> None:
try:
from sqlmodel import Session, create_engine, select
from app.core.config import configs
from app.model.user import User
except Exception as exc:
print(f"[WARN] Could not import DB activation dependencies: {exc}")
return
engine_options = {}
if configs.DATABASE_URI.startswith("sqlite"):
engine_options["connect_args"] = {"check_same_thread": False}
engine = create_engine(configs.DATABASE_URI, **engine_options)
with Session(engine) as session:
user = session.exec(select(User).where(User.email == email)).first()
if not user:
print(f"[WARN] Registered user not found for activation: {email}")
return
if user.is_active:
return
user.is_active = True
session.add(user)
session.commit()
print(f"[INFO] Activated eval user locally: {email}")
def upload_documents(
self,
*,
token: str,
course_id: str,
document_paths: List[Path],
) -> List[str]:
uploaded_doc_ids: List[str] = []
for doc_path in document_paths:
with doc_path.open("rb") as fh:
files = {"file": (doc_path.name, fh)}
data = {"course_id": course_id}
resp = self._request(
"POST",
"/documents/upload",
token=token,
files=files,
data=data,
expected=[202],
)
payload = resp.json()
uploaded_doc_ids.append(payload["document_id"])
print(f"[INFO] Uploaded {doc_path} -> document_id={payload['document_id']}")
return uploaded_doc_ids
def wait_for_ingestion(
self,
*,
token: str,
course_id: str,
uploaded_doc_ids: List[str],
timeout_seconds: int,
poll_interval_seconds: int,
) -> Dict[str, Dict[str, Any]]:
deadline = time.time() + timeout_seconds
target = set(uploaded_doc_ids)
while time.time() < deadline:
resp = self._request(
"GET",
f"/courses/{course_id}/documents",
token=token,
expected=[200],
)
items = resp.json().get("data", [])
state: Dict[str, Dict[str, Any]] = {}
for item in items:
doc_id = item.get("id")
if doc_id not in target:
continue
state[doc_id] = {
"status": item.get("status"),
"error": item.get("error"),
"filename": item.get("filename"),
"chunk_count": item.get("chunk_count"),
}
ready_or_failed = {
doc_id: details
for doc_id, details in state.items()
if details.get("status") in {"ready", "failed"}
}
if len(ready_or_failed) == len(target):
return ready_or_failed
time.sleep(poll_interval_seconds)
raise TimeoutError(
f"Timed out waiting for ingestion statuses for document_ids={sorted(target)}"
)
def create_student_session(self, *, token: str, course_id: str, title: str) -> str:
resp = self._request(
"POST",
"/chats/sessions",
token=token,
json={"course_id": course_id, "title": title},
expected=[201],
)
payload = resp.json()
session_id = payload.get("id") or payload.get("uuid_id")
if not session_id:
raise RuntimeError(
"Could not parse chat session id from response payload. "
f"Expected key 'id' or 'uuid_id', got: {payload}"
)
print(f"[INFO] Created student session id={session_id}")
return session_id
def ask_question(
self,
*,
token: str,
session_id: str,
question: str,
question_timeout_seconds: int,
) -> Dict[str, Any]:
try:
resp = self._request(
"POST",
f"/chats/sessions/{session_id}/messages",
token=token,
json={"content": question},
expected=[200],
timeout_seconds=question_timeout_seconds,
)
except ReadTimeout as exc:
raise RuntimeError(
"Timed out waiting for chat answer. "
f"Increase --question-timeout (current={question_timeout_seconds}s) "
"or warm up model/downloads before running eval."
) from exc
payload = resp.json()
return {
"content": payload.get("content", ""),
"sources": payload.get("sources", []),
}
def normalize(text: str) -> str:
return " ".join((text or "").lower().split())
def _tokens(text: str) -> List[str]:
return re.findall(r"[a-z0-9]+", normalize(text))
def _phrase_match(answer_norm: str, phrase_norm: str) -> bool:
if phrase_norm in answer_norm:
return True
phrase_tokens = _tokens(phrase_norm)
if not phrase_tokens:
return True
answer_tokens = set(_tokens(answer_norm))
return all(token in answer_tokens for token in phrase_tokens)
def _source_text(sources: List[Dict[str, Any]]) -> str:
parts: List[str] = []
for source in sources:
parts.extend(
str(source.get(key, ""))
for key in ("document_id", "filename", "excerpt")
if source.get(key)
)
return "\n".join(parts)
def _content_tokens(text: str) -> List[str]:
stopwords = {
"a", "an", "and", "are", "as", "at", "based", "be", "by", "for", "from",
"in", "is", "it", "of", "on", "or", "that", "the", "this", "to", "with",
"yang", "dan", "di", "ke", "dari", "ini", "itu", "adalah", "untuk",
"pada", "dengan", "atau", "sebagai", "dalam",
}
return [
token
for token in _tokens(text)
if len(token) > 2 and token not in stopwords
]
def _strip_list_markers(text: str) -> str:
lines = (text or "").splitlines()
cleaned_lines: List[str] = []
for idx, line in enumerate(lines):
stripped = line.strip()
marker_match = re.fullmatch(r"(\d{1,2})[.)]", stripped)
if marker_match:
marker = int(marker_match.group(1))
has_following_content = any(next_line.strip() for next_line in lines[idx + 1 :])
if 1 <= marker <= 20 and has_following_content:
continue
line = re.sub(r"^\s*[-*+]\s+", "", line)
line = re.sub(r"^\s*(?:[1-9]|1\d|20)[.)]\s+", "", line)
cleaned_lines.append(line)
return "\n".join(cleaned_lines)
def _sentences(text: str) -> List[str]:
cleaned = _strip_list_markers(text)
return [
sentence.strip()
for sentence in re.split(r"(?<=[.!?])\s+|\n+", cleaned)
if sentence.strip()
]
def _normalize_number_claim(value: str) -> str:
if value.endswith("%"):
return f"{_normalize_number_claim(value[:-1])}%"
if re.fullmatch(r"0,\d+", value):
return value.replace(",", ".")
return value
def _number_claims(text: str) -> set[str]:
return {
_normalize_number_claim(match.group(0))
for match in re.finditer(r"(?<!\d)\d+(?:[.,:]\d+)*%?", normalize(text))
}
def score_answer_correctness(
answer: str,
case: Dict[str, Any],
abstain_phrase: str,
) -> List[str]:
reasons: List[str] = []
answer_norm = normalize(answer)
expect_abstain = bool(case.get("expect_abstain", False))
expected_all = [normalize(x) for x in case.get("expected_all", []) if x]
expected_any = [normalize(x) for x in case.get("expected_any", []) if x]
forbidden = [normalize(x) for x in case.get("forbidden", []) if x]
if expect_abstain:
if abstain_phrase not in answer_norm:
reasons.append("expected abstention phrase was not found")
else:
if abstain_phrase in answer_norm:
reasons.append("unexpected abstention for an in-scope question")
for phrase in expected_all:
if not _phrase_match(answer_norm, phrase):
reasons.append(f"missing expected_all phrase: {phrase}")
if expected_any and not any(_phrase_match(answer_norm, phrase) for phrase in expected_any):
reasons.append("none of expected_any phrases were found")
for phrase in forbidden:
if phrase in answer_norm:
reasons.append(f"forbidden phrase found: {phrase}")
return reasons
def score_retrieval_recall(
sources: List[Dict[str, Any]],
case: Dict[str, Any],
) -> tuple[bool, bool, List[str]]:
if case.get("expect_abstain") and not (
case.get("retrieval_expected_all") or case.get("retrieval_expected_any")
):
return False, True, []
expected_all = [
normalize(x)
for x in case.get("retrieval_expected_all", case.get("expected_all", []))
if x
]
expected_any = [
normalize(x)
for x in case.get("retrieval_expected_any", case.get("expected_any", []))
if x
]
if not expected_all and not expected_any:
return False, True, []
retrieval_norm = normalize(_source_text(sources))
reasons: List[str] = []
if not retrieval_norm:
return True, False, ["no retrieved sources returned by chat response"]
for phrase in expected_all:
if not _phrase_match(retrieval_norm, phrase):
reasons.append(f"retrieval missing expected_all phrase: {phrase}")
if expected_any and not any(_phrase_match(retrieval_norm, phrase) for phrase in expected_any):
reasons.append("retrieval found none of expected_any phrases")
return True, len(reasons) == 0, reasons
def score_faithfulness(
answer: str,
sources: List[Dict[str, Any]],
abstain_phrase: str,
min_overlap: float,
) -> tuple[bool, List[str]]:
answer_norm = normalize(answer)
if not answer_norm or abstain_phrase in answer_norm:
return True, []
source_norm = normalize(_source_text(sources))
if not source_norm:
return False, ["faithfulness cannot be checked because no sources were returned"]
source_tokens = set(_content_tokens(source_norm))
source_numbers = _number_claims(source_norm)
reasons: List[str] = []
for sentence in _sentences(answer):
sentence_tokens = _content_tokens(sentence)
sentence_numbers = _number_claims(sentence)
if len(sentence_tokens) < 4 and not sentence_numbers:
continue
missing_numbers = sorted(sentence_numbers - source_numbers)
if missing_numbers:
reasons.append(
f"faithfulness: sentence has numbers not found in sources: {', '.join(missing_numbers)}"
)
continue
overlap = sum(1 for token in sentence_tokens if token in source_tokens)
ratio = overlap / max(len(sentence_tokens), 1)
if ratio < min_overlap:
reasons.append(
f"faithfulness: low source overlap ({ratio:.2f}) for sentence: {sentence}"
)
return len(reasons) == 0, reasons
def score_case(
answer: str,
sources: List[Dict[str, Any]],
case: Dict[str, Any],
abstain_phrase: str,
idx: int,
faithfulness_min_overlap: float,
) -> EvalCaseResult:
answer_reasons = score_answer_correctness(answer, case, abstain_phrase)
retrieval_checked, retrieval_passed, retrieval_reasons = score_retrieval_recall(sources, case)
faithfulness_passed, faithfulness_reasons = score_faithfulness(
answer,
sources,
abstain_phrase,
faithfulness_min_overlap,
)
reasons = [
*(f"answer: {reason}" for reason in answer_reasons),
*(f"retrieval: {reason}" for reason in retrieval_reasons),
*(f"faithfulness: {reason}" for reason in faithfulness_reasons),
]
answer_passed = len(answer_reasons) == 0
return EvalCaseResult(
index=idx,
question=case.get("question", ""),
passed=answer_passed and retrieval_passed and faithfulness_passed,
answer_passed=answer_passed,
retrieval_checked=retrieval_checked,
retrieval_passed=retrieval_passed,
faithfulness_passed=faithfulness_passed,
reasons=reasons,
answer_text=(answer or ""),
sources=sources,
)
def load_eval_file(eval_file: Path) -> Dict[str, Any]:
if not eval_file.exists():
raise FileNotFoundError(f"Eval file not found: {eval_file}")
# utf-8-sig handles both plain UTF-8 and UTF-8 files that include BOM.
with eval_file.open("r", encoding="utf-8-sig") as fh:
payload = json.load(fh)
if "course_id" not in payload:
raise ValueError("eval file must include course_id")
if "documents" not in payload or not payload["documents"]:
raise ValueError("eval file must include at least one document entry")
if "questions" not in payload or not payload["questions"]:
raise ValueError("eval file must include at least one question")
return payload
def resolve_doc_paths(base_dir: Path, doc_entries: List[Dict[str, Any]]) -> List[Path]:
result: List[Path] = []
for entry in doc_entries:
rel = entry.get("path")
if not rel:
raise ValueError("each document entry must include path")
path = (base_dir / rel).resolve()
if not path.exists():
raise FileNotFoundError(f"document file not found: {path}")
result.append(path)
return result
def main() -> int:
parser = argparse.ArgumentParser(description="Run repeatable RAG end-to-end evaluation loop")
parser.add_argument("--eval-file", required=True, help="Path to evaluation JSON file")
parser.add_argument("--base-url", default="http://127.0.0.1:8000", help="Server base URL")
parser.add_argument("--api-prefix", default="/api/v1", help="API prefix")
parser.add_argument("--password", default="Pass1234!", help="Password for generated test users")
parser.add_argument("--request-timeout", type=int, default=60, help="Per-request timeout seconds")
parser.add_argument(
"--question-timeout",
type=int,
default=240,
help="Timeout seconds for each chat answer request",
)
parser.add_argument("--ingest-timeout", type=int, default=DEFAULT_TIMEOUT_SECONDS, help="Ingestion wait timeout seconds")
parser.add_argument("--poll-interval", type=int, default=DEFAULT_POLL_INTERVAL_SECONDS, help="Polling interval seconds")
parser.add_argument("--abstain-phrase", default=DEFAULT_ABSTAIN_PHRASE, help="Expected abstention phrase")
parser.add_argument(
"--faithfulness-min-overlap",
type=float,
default=0.35,
help=(
"Minimum content-token overlap between each answer sentence and retrieved sources "
"for the heuristic faithfulness check."
),
)
parser.add_argument(
"--max-preview-chars",
type=int,
default=240,
help="Max characters shown per answer when --show-full-answers is not set",
)
parser.add_argument(
"--show-full-answers",
action="store_true",
help="Print full answer text for every case",
)
parser.add_argument(
"--reuse-course-id",
action="store_true",
help=(
"Use the exact course_id from the eval file. By default, the runner "
"appends the run timestamp so repeated eval runs do not reuse old "
"indexed chunks from the same course."
),
)
parser.add_argument(
"--email-domain",
default="sevima.co.id",
help=(
"Domain for temporary eval users. The backend rejects placeholder "
"domains such as example.com."
),
)
args = parser.parse_args()
eval_file = Path(args.eval_file).resolve()
payload = load_eval_file(eval_file)
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S")
email_domain = args.email_domain.strip().lower().lstrip("@")
lecturer_email = f"rag_eval_lecturer_{ts}@{email_domain}"
student_email = f"rag_eval_student_{ts}@{email_domain}"
runner = EvalRunner(
base_url=args.base_url,
api_prefix=args.api_prefix,
request_timeout=args.request_timeout,
)
print("[INFO] Starting RAG E2E evaluation loop")
print(f"[INFO] Base URL: {args.base_url}{args.api_prefix}")
lecturer_token = runner.ensure_user(
role="lecturer",
email=lecturer_email,
password=args.password,
name="RAG Eval Lecturer",
identity_number=f"198501{ts[-12:]}",
)
student_token = runner.ensure_user(
role="student",
email=student_email,
password=args.password,
name="RAG Eval Student",
identity_number=f"{ts[-10:]}",
)
eval_course_id = payload["course_id"]
course_id = eval_course_id if args.reuse_course_id else f"{eval_course_id}-{ts}"
print(f"[INFO] Eval course_id: {eval_course_id}")
print(f"[INFO] Run course_id: {course_id}")
docs = resolve_doc_paths(REPO_ROOT, payload["documents"])
uploaded_doc_ids = runner.upload_documents(
token=lecturer_token,
course_id=course_id,
document_paths=docs,
)
statuses = runner.wait_for_ingestion(
token=student_token,
course_id=course_id,
uploaded_doc_ids=uploaded_doc_ids,
timeout_seconds=args.ingest_timeout,
poll_interval_seconds=args.poll_interval,
)
failed = [doc_id for doc_id, details in statuses.items() if details.get("status") == "failed"]
if failed:
print(f"[ERROR] Ingestion failed for document_ids: {failed}")
for doc_id in failed:
details = statuses.get(doc_id, {})
filename = details.get("filename") or "<unknown>"
error = details.get("error") or "<no error message provided by backend>"
print(f" - doc_id={doc_id} filename={filename}")
print(f" error={error}")
print("[HINT] Common causes: missing optional RAG deps, invalid API keys/env vars, parser/runtime errors.")
print("[HINT] If needed, install optional deps: uv pip install -r requirements.txt -r requirements-rag.txt")
return 2
print("[INFO] All uploaded documents are ready")
session_id = runner.create_student_session(
token=student_token,
course_id=course_id,
title=f"RAG Eval Session {ts}",
)
abstain_phrase_norm = normalize(args.abstain_phrase)
results: List[EvalCaseResult] = []
for idx, case in enumerate(payload["questions"], start=1):
question = case.get("question", "")
answer_payload = runner.ask_question(
token=student_token,
session_id=session_id,
question=question,
question_timeout_seconds=args.question_timeout,
)
answer = answer_payload["content"]
sources = answer_payload["sources"]
case_result = score_case(
answer,
sources,
case,
abstain_phrase_norm,
idx,
args.faithfulness_min_overlap,
)
results.append(case_result)
verdict = "PASS" if case_result.passed else "FAIL"
print(f"[{verdict}] Q{idx}: {question}")
print(
" metrics: "
f"answer={'PASS' if case_result.answer_passed else 'FAIL'}, "
f"retrieval={'PASS' if case_result.retrieval_passed else 'FAIL'}"
f"{'' if case_result.retrieval_checked else ' (not checked)'}, "
f"faithfulness={'PASS' if case_result.faithfulness_passed else 'FAIL'}, "
f"sources={len(case_result.sources)}"
)
if not case_result.passed:
for reason in case_result.reasons:
print(f" - {reason}")
if args.show_full_answers:
shown_answer = case_result.answer_text
else:
shown_answer = case_result.answer_text[: args.max_preview_chars]
print(f" answer: {shown_answer}")
total = len(results)
passed = sum(1 for r in results if r.passed)
failed_count = total - passed
score = (passed / total) * 100 if total else 0.0
answer_passed = sum(1 for r in results if r.answer_passed)
retrieval_checked = sum(1 for r in results if r.retrieval_checked)
retrieval_passed = sum(1 for r in results if r.retrieval_checked and r.retrieval_passed)
faithfulness_passed = sum(1 for r in results if r.faithfulness_passed)
answer_score = (answer_passed / total) * 100 if total else 0.0
retrieval_recall = (retrieval_passed / retrieval_checked) * 100 if retrieval_checked else 0.0
faithfulness_score = (faithfulness_passed / total) * 100 if total else 0.0
print("\n=== RAG E2E SUMMARY ===")
print(f"total_cases: {total}")
print(f"overall_passed: {passed}")
print(f"overall_failed: {failed_count}")
print(f"overall_score_percent: {score:.1f}")
print(f"answer_correctness_passed: {answer_passed}")
print(f"answer_correctness_percent: {answer_score:.1f}")
print(f"retrieval_recall_checked_cases: {retrieval_checked}")
print(f"retrieval_recall_passed: {retrieval_passed}")
print(f"retrieval_recall_percent: {retrieval_recall:.1f}")
print(f"faithfulness_passed: {faithfulness_passed}")
print(f"faithfulness_percent: {faithfulness_score:.1f}")
return 0 if failed_count == 0 else 1
if __name__ == "__main__":
try:
raise SystemExit(main())
except KeyboardInterrupt:
print("\n[INFO] Interrupted by user")
raise SystemExit(130)