Insurance_Pilot / scripts /run_ragas_eval.py
Shoaib-33's picture
added
8058e7e
Raw
History Blame Contribute Delete
5.05 kB
import argparse
import asyncio
import json
import sys
import time
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from datasets import Dataset
from langchain_groq import ChatGroq
from ragas import RunConfig, evaluate
from ragas.embeddings import LangchainEmbeddingsWrapper
from ragas.llms import LangchainLLMWrapper
from ragas.metrics import (
answer_relevancy,
context_precision,
context_recall,
faithfulness,
)
from app.core.config import settings
from app.db.sqlite import init_db
from app.rag.embeddings import get_embedding_model
from app.rag.graph import ClaimsRAGGraph
from app.rag.ingestion import DocumentIngestionService
from app.rag.qdrant_store import QdrantVectorStore
def load_jsonl(path: Path) -> list[dict[str, Any]]:
rows = []
for line in path.read_text(encoding="utf-8").splitlines():
if line.strip():
rows.append(json.loads(line))
return rows
def default_reference(case: dict[str, Any]) -> str:
if case.get("reference"):
return str(case["reference"])
decision = case["expected_decision"]
return (
f"Decision: {decision}. The answer should use the retrieved insurance claim "
"guidance to explain the coverage triage, identify missing evidence, and "
"recommend the next action without inventing unsupported policy terms."
)
def build_eval_rows(dataset_path: Path, user_id: str, limit: int | None) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
init_db()
QdrantVectorStore().ensure_collections()
DocumentIngestionService().ingest_pdf_directory()
graph = ClaimsRAGGraph()
cases = load_jsonl(dataset_path)
if limit:
cases = cases[:limit]
ragas_rows = []
run_rows = []
for case in cases:
started = time.perf_counter()
state = graph.run(case["query"], user_id=user_id, use_cache=False)
latency_ms = round((time.perf_counter() - started) * 1000, 2)
sources = state.get("reranked_sources") or state.get("sources", [])
contexts = [str(source.get("text", "")) for source in sources if source.get("text")]
ragas_rows.append(
{
"user_input": case["query"],
"response": state.get("answer", ""),
"retrieved_contexts": contexts,
"reference": default_reference(case),
}
)
run_rows.append(
{
"id": case["id"],
"expected_decision": case["expected_decision"],
"sources": len(contexts),
"latency_ms": latency_ms,
}
)
return ragas_rows, run_rows
async def run_ragas(dataset_path: Path, user_id: str, limit: int | None) -> dict[str, Any]:
if not settings.groq_api_key:
raise RuntimeError("GROQ_API_KEY is required for RAGAS LLM-judge metrics.")
ragas_rows, run_rows = build_eval_rows(dataset_path, user_id, limit)
dataset = Dataset.from_list(ragas_rows)
judge_llm = ChatGroq(
model=settings.groq_model,
temperature=0,
max_retries=2,
api_key=settings.groq_api_key,
)
ragas_llm = LangchainLLMWrapper(judge_llm)
ragas_embeddings = LangchainEmbeddingsWrapper(get_embedding_model().model)
answer_relevancy.strictness = 1
metrics = [faithfulness, answer_relevancy, context_precision, context_recall]
result = evaluate(
dataset=dataset,
metrics=metrics,
llm=ragas_llm,
embeddings=ragas_embeddings,
run_config=RunConfig(timeout=180, max_workers=2, max_retries=2),
)
scores = result.to_pandas().to_dict(orient="records")
rows = []
for run_row, score_row in zip(run_rows, scores, strict=False):
rows.append({**run_row, "ragas": score_row})
metric_names = ["faithfulness", "answer_relevancy", "context_precision", "context_recall"]
summary = {
"total": len(rows),
"metrics": {},
"results": rows,
}
for metric in metric_names:
values = [
float(row["ragas"][metric])
for row in rows
if row["ragas"].get(metric) is not None and str(row["ragas"][metric]).lower() != "nan"
]
summary["metrics"][metric] = round(sum(values) / len(values), 3) if values else None
return summary
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="data/eval/golden_claim_scenarios.jsonl")
parser.add_argument("--user-id", default="ragas_eval_user")
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--output", default=None)
args = parser.parse_args()
summary = asyncio.run(run_ragas(Path(args.dataset), args.user_id, args.limit))
text = json.dumps(summary, indent=2)
print(text)
if args.output:
Path(args.output).write_text(text + "\n", encoding="utf-8")
if __name__ == "__main__":
main()