MediRAG-API / src /evaluate.py
joytheslothh's picture
deploy: clean build
b6f9fa8
"""
FR-22: src/evaluate.py — MediRAG Evaluation Orchestrator
=========================================================
Top-level entry point for the evaluation pipeline.
Runs all 4 evaluation modules + RAGAS + aggregator for a given
(question, answer, context_docs) triple, returning a fully structured
composite EvalResult.
Usage as a module:
from src.evaluate import run_evaluation
result = run_evaluation(question, answer, context_docs)
print(f"Score: {result.score:.3f} ({result.details['confidence_level']})")
Usage from CLI:
python -m src.evaluate \\
--question "What is the recommended dosage of Metformin for Type 2 Diabetes?" \\
--answer "Metformin is typically started at 500mg twice daily..." \\
--context-file data/processed/chunks.jsonl \\
--top-k 5
SRS reference: FR-22, Section 7 (Evaluation Pipeline Overview)
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
import time
from pathlib import Path
from typing import Optional
from src.modules.base import EvalResult
from src.modules.faithfulness import score_faithfulness
from src.modules.entity_verifier import verify_entities
from src.modules.source_credibility import score_source_credibility
from src.modules.contradiction import score_contradiction
from src.evaluation.ragas_eval import score_ragas
from src.evaluation.aggregator import aggregate
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Main evaluation function
# ---------------------------------------------------------------------------
def run_evaluation(
question: str,
answer: str,
context_chunks: list[dict],
rxnorm_cache_path: str = "data/rxnorm_cache.csv",
run_ragas: bool = True,
weights: Optional[dict[str, float]] = None,
config: Optional[dict] = None,
) -> EvalResult:
"""
Run the full MediRAG evaluation pipeline for a single QA pair.
Args:
question : Original user question.
answer : LLM-generated answer to evaluate.
context_chunks : List of retrieved chunk dicts (from retriever.retrieve()).
Each chunk must have at minimum {'text': str}.
rxnorm_cache_path : Path to rxnorm_cache.csv for entity verification.
run_ragas : Whether to run the RAGAS module (requires LLM backend).
weights : Override default aggregation weights (optional).
Returns:
EvalResult for the "aggregator" module containing:
.score → composite score in [0, 1]
.details → full breakdown per module
.latency_ms → total wall-clock time
"""
t_start = time.perf_counter()
logger.info("=== MediRAG Evaluation START ===")
logger.info("Question: %s", question[:120])
logger.info("Answer : %s", answer[:120])
logger.info("Chunks : %d context documents", len(context_chunks))
# Extract text and metadata for modules that need it
context_texts: list[str] = [c.get("text", "") for c in context_chunks]
chunk_ids: list[str] = [
c.get("chunk_id") or c.get("metadata", {}).get("chunk_id") or f"chunk_{i}"
for i, c in enumerate(context_chunks)
]
# -------------------------------------------------------------------------
# Retrieval Quality Gate
# If the retriever's absolute RRF score is too low, the chunks are likely
# unrelated to the question — evaluation against them produces false HRS spikes.
# Threshold: max raw RRF for top-1 in both sources = 2/(60+1) ≈ 0.0328
# We flag as insufficient if max_rrf < 0.008 (only very weak BM25 or FAISS match)
# -------------------------------------------------------------------------
RETRIEVAL_CONFIDENCE_THRESHOLD = 0.008
retrieval_confidence = context_chunks[0].get("_retrieval_confidence", 1.0) if context_chunks else 0.0
if context_chunks and retrieval_confidence < RETRIEVAL_CONFIDENCE_THRESHOLD:
logger.warning(
"Retrieval confidence %.6f below threshold %.3f — context likely irrelevant to question.",
retrieval_confidence, RETRIEVAL_CONFIDENCE_THRESHOLD,
)
total_ms = int((time.perf_counter() - t_start) * 1000)
return EvalResult(
module_name="aggregator",
score=0.5,
details={
"retrieval_insufficient": True,
"retrieval_confidence": retrieval_confidence,
"hrs": 50,
"risk_band": "MODERATE",
"confidence_level": "LOW",
"total_pipeline_ms": total_ms,
"module_results": {},
"warning": (
"Retrieved context has very low relevance to the question "
f"(retrieval_confidence={retrieval_confidence:.4f}). "
"Evaluation scores would be meaningless. "
"Consider rephrasing the question or expanding the index."
),
},
latency_ms=total_ms,
)
# -------------------------------------------------------------------------
# Module 1: Faithfulness (DeBERTa NLI)
# -------------------------------------------------------------------------
logger.info("--- Module 1: Faithfulness ---")
faith_result = score_faithfulness(
answer=answer,
context_docs=context_texts,
chunk_ids=chunk_ids,
config=config,
)
# -------------------------------------------------------------------------
# Module 2: Entity Verification (SciSpaCy + RxNorm)
# -------------------------------------------------------------------------
logger.info("--- Module 2: Entity Verification ---")
entity_result = verify_entities(
answer=answer,
question=question,
context_docs=context_texts,
rxnorm_cache_path=rxnorm_cache_path,
)
# -------------------------------------------------------------------------
# Module 3: Source Credibility (Evidence Tier)
# -------------------------------------------------------------------------
logger.info("--- Module 3: Source Credibility ---")
source_result = score_source_credibility(retrieved_chunks=context_chunks)
# -------------------------------------------------------------------------
# Module 4: Contradiction Detection (DeBERTa NLI cross-check)
# -------------------------------------------------------------------------
logger.info("--- Module 4: Contradiction Detection ---")
contra_result = score_contradiction(
answer=answer,
context_docs=context_texts,
)
# -------------------------------------------------------------------------
# RAGAS (optional — requires LLM backend)
# -------------------------------------------------------------------------
ragas_result: Optional[EvalResult] = None
if run_ragas:
logger.info("--- RAGAS Evaluation ---")
ragas_result = score_ragas(
question=question,
answer=answer,
context_docs=context_texts,
)
# -------------------------------------------------------------------------
# Aggregator: weighted composite
# -------------------------------------------------------------------------
logger.info("--- Aggregator ---")
agg_result = aggregate(
faithfulness_result=faith_result,
entity_result=entity_result,
source_result=source_result,
contradiction_result=contra_result,
ragas_result=ragas_result,
weights=weights,
)
total_ms = int((time.perf_counter() - t_start) * 1000)
agg_result.details["total_pipeline_ms"] = total_ms
# Attach per-module results for API/dashboard access
agg_result.details["module_results"] = {
"faithfulness": {"score": faith_result.score, "details": faith_result.details},
"entity_verifier": {"score": entity_result.score, "details": entity_result.details},
"source_credibility": {"score": source_result.score, "details": source_result.details},
"contradiction": {"score": contra_result.score, "details": contra_result.details},
"ragas": {"score": ragas_result.score, "details": ragas_result.details} if ragas_result else None,
}
logger.info(
"=== MediRAG Evaluation DONE: score=%.3f (%s) in %d ms ===",
agg_result.score,
agg_result.details.get("confidence_level", "?"),
total_ms,
)
return agg_result
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def _build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="MediRAG evaluation pipeline (FR-22)"
)
p.add_argument("--question", required=True, help="User question")
p.add_argument("--answer", required=True, help="LLM answer to evaluate")
p.add_argument("--context-file", default="data/processed/chunks.jsonl",
help="JSONL file of chunks (output of ingest.py)")
p.add_argument("--top-k", type=int, default=5,
help="Number of context chunks to use")
p.add_argument("--rxnorm-cache", default="data/rxnorm_cache.csv",
help="Path to rxnorm_cache.csv")
p.add_argument("--no-ragas", action="store_true",
help="Skip RAGAS evaluation (no LLM backend needed)")
p.add_argument("--json", action="store_true",
help="Output result as JSON")
return p
def _load_context_from_file(path: str, top_k: int) -> list[dict]:
"""Load top-k chunks from a JSONL file as simple dicts."""
chunks = []
try:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
chunks.append(json.loads(line))
if len(chunks) >= top_k:
break
except FileNotFoundError:
logger.error("Context file not found: %s", path)
sys.exit(1)
return chunks
if __name__ == "__main__":
import yaml
# Load config.yaml for logging setup
try:
cfg = yaml.safe_load(Path("config.yaml").read_text())
log_level = cfg.get("logging", {}).get("level", "INFO")
except Exception:
log_level = "INFO"
logging.basicConfig(
level=getattr(logging, log_level, logging.INFO),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
args = _build_parser().parse_args()
chunks = _load_context_from_file(args.context_file, args.top_k)
result = run_evaluation(
question=args.question,
answer=args.answer,
context_chunks=chunks,
rxnorm_cache_path=args.rxnorm_cache,
run_ragas=not args.no_ragas,
)
if args.json:
import dataclasses
print(json.dumps(dataclasses.asdict(result), indent=2))
else:
print(f"\n{'='*60}")
print(f" MediRAG Evaluation Result")
print(f"{'='*60}")
print(f" Score : {result.score:.3f}")
print(f" Confidence : {result.details.get('confidence_level', 'N/A')}")
print(f" Pipeline time : {result.details.get('total_pipeline_ms', 0)} ms")
print(f"\n Module Breakdown:")
for mod, res in (result.details.get("module_results") or {}).items():
if res:
print(f" {mod:22s}: {res['score']:.3f}")
print(f"{'='*60}\n")