ML-Chatbot / rag_eval_metrics.py
Inframat-x's picture
Update rag_eval_metrics.py
1697774 verified
#!/usr/bin/env python3
"""
rag_eval_metrics.py
Evaluate RAG retrieval quality by comparing app logs (JSONL) with a gold file (CSV).
Extended to also evaluate answer quality using:
- Lexical similarity: BLEU, ROUGE-1/2/L
- Semantic similarity: BERTScore (Recall, F1)
If nltk / rouge-score / bert-score are missing, the script still runs and
returns NaN for these metrics instead of crashing.
Also uses robust CSV reading to handle non-UTF8 encodings (cp1252/latin1).
"""
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
import pandas as pd
import numpy as np
# ----------------------------- Small Utils ----------------------------- #
def filename_key(s: str) -> str:
s = (s or "").strip().replace("\\", "/").split("/")[-1]
return s.casefold()
def re_split_sc(s: str) -> List[str]:
import re
return re.split(r"[;,]", s)
def _pick_last_non_empty(hit_lists) -> List[dict]:
"""
Robustly select the last non-empty hits list from a pandas Series or iterable.
This fixes the KeyError that happens when using reversed() directly on a Series
with a non-range index.
"""
# Convert pandas Series or other iterables to a plain Python list
try:
values = list(hit_lists.tolist())
except AttributeError:
values = list(hit_lists)
# Walk from last to first, return first non-empty list-like
for lst in reversed(values):
if isinstance(lst, (list, tuple)) and len(lst) > 0:
return lst
# If everything was empty / NaN
return []
def _read_csv_robust(path: Path) -> pd.DataFrame:
"""
Try multiple encodings so we don't crash on Windows-1252 / Latin-1 CSVs.
"""
encodings = ["utf-8", "utf-8-sig", "cp1252", "latin1"]
last_err = None
for enc in encodings:
try:
return pd.read_csv(path, encoding=enc)
except UnicodeDecodeError as e:
last_err = e
continue
# If all fail, re-raise the last error
raise last_err if last_err is not None else ValueError(
"Failed to read CSV with fallback encodings."
)
# ----------------------------- IO Helpers ----------------------------- #
def read_logs(jsonl_path: Path) -> pd.DataFrame:
"""
Read RAG JSONL logs and aggregate by question.
Returns a DataFrame with columns:
- question: original question text (last occurrence)
- hits: list of dicts {doc, page} for retrieval
- answer: final answer text logged for that question
"""
rows = []
if (not jsonl_path.exists()) or jsonl_path.stat().st_size == 0:
return pd.DataFrame(columns=["question", "hits", "answer"])
with open(jsonl_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
rec = json.loads(line)
except Exception:
continue
# Extract question
q = (((rec.get("inputs") or {}).get("question")) or "").strip()
# Extract retrieval hits (if present)
retr = (rec.get("retrieval") or {})
hits = retr.get("hits", [])
norm_hits = []
for h in hits or []:
doc = (h.get("doc") or "").strip()
page = str(h.get("page") or "").strip()
# Normalize page to int or None
try:
page_int = int(page)
except Exception:
page_int = None
norm_hits.append({"doc": doc, "page": page_int})
# Extract final answer text (if present)
out = (rec.get("output") or {})
ans = ((out.get("final_answer") or "")).strip()
rows.append({"question": q, "hits": norm_hits, "answer": ans})
df = pd.DataFrame(rows)
if df.empty:
return pd.DataFrame(columns=["question", "hits", "answer"])
# Group by normalized question text and keep last non-empty hits list and answer per question
df = (
df.groupby(df["question"].astype(str).str.casefold().str.strip(), as_index=False)
.agg({
"question": "last",
"hits": _pick_last_non_empty,
"answer": "last"
})
)
return df
def read_gold(csv_path: Path) -> Tuple[pd.DataFrame, Dict[str, str]]:
"""
Read gold CSV with retrieval labels and optional reference answers.
Returns:
- gold_df: rows with columns ['question', 'doc', 'page', 'answer', ...]
where 'question' is normalized (casefold+strip)
- gold_answers: dict mapping normalized question -> reference answer text
"""
df = _read_csv_robust(csv_path)
cols = {c.lower().strip(): c for c in df.columns}
# --- question column ---
q_col = None
for cand in ["question", "query", "q"]:
if cand in cols:
q_col = cols[cand]
break
if q_col is None:
raise ValueError("Gold CSV must contain a 'question' column (case-insensitive).")
# --- possible relevant_docs (list-in-cell) column ---
rel_list_col = None
for cand in ["relevant_docs", "relevant", "docs"]:
if cand in cols:
rel_list_col = cols[cand]
break
# --- single-doc-per-row column ---
doc_col = None
for cand in ["doc", "document", "file", "doc_name"]:
if cand in cols:
doc_col = cols[cand]
break
# --- optional page column ---
page_col = None
for cand in ["page", "page_num", "page_number"]:
if cand in cols:
page_col = cols[cand]
break
# --- optional answer column (for QA metrics) ---
ans_col = None
for cand in ["answer", "reference_answer", "gold_answer"]:
if cand in cols:
ans_col = cols[cand]
break
rows = []
# Case 1: relevant_docs list column (no explicit doc_col)
if rel_list_col and doc_col is None:
for _, r in df.iterrows():
q_raw = str(r[q_col]).strip()
q_norm = q_raw.casefold().strip()
ans_raw = str(r[ans_col]).strip() if (ans_col and pd.notna(r[ans_col])) else ""
rel_val = str(r[rel_list_col]) if pd.notna(r[rel_list_col]) else ""
if not rel_val:
rows.append({
"question_raw": q_raw,
"question": q_norm,
"doc": None,
"page": np.nan,
"answer": ans_raw
})
continue
parts = [p.strip() for p in re_split_sc(rel_val)]
for d in parts:
rows.append({
"question_raw": q_raw,
"question": q_norm,
"doc": filename_key(d),
"page": np.nan,
"answer": ans_raw
})
# Case 2: doc/page columns (one relevant doc per row)
elif doc_col:
for _, r in df.iterrows():
q_raw = str(r[q_col]).strip()
q_norm = q_raw.casefold().strip()
ans_raw = str(r[ans_col]).strip() if (ans_col and pd.notna(r[ans_col])) else ""
d = str(r[doc_col]).strip() if pd.notna(r[doc_col]) else ""
p = r[page_col] if (page_col and pd.notna(r[page_col])) else np.nan
try:
p = int(p)
except Exception:
p = np.nan
rows.append({
"question_raw": q_raw,
"question": q_norm,
"doc": filename_key(d),
"page": p,
"answer": ans_raw
})
else:
raise ValueError("Gold CSV must contain either a 'doc' column or a 'relevant_docs' column.")
gold = pd.DataFrame(rows)
# Keep only rows with a valid doc (when docs exist)
gold["has_doc"] = gold["doc"].apply(lambda x: isinstance(x, str) and len(x) > 0)
if gold["has_doc"].any():
gold = gold[gold["has_doc"]].copy()
gold.drop(columns=["has_doc"], inplace=True, errors="ignore")
# Remove duplicates
gold = gold.drop_duplicates(subset=["question", "doc", "page"])
# Build question -> gold_answer map (normalized questions)
gold_answers: Dict[str, str] = {}
if "answer" in gold.columns:
tmp = (
gold[["question", "answer"]]
.dropna(subset=["answer"])
.drop_duplicates(subset=["question"])
)
gold_answers = dict(zip(tmp["question"], tmp["answer"]))
return gold, gold_answers
# ----------------------------- Retrieval Metric Core ----------------------------- #
def dcg_at_k(relevances: List[int]) -> float:
dcg = 0.0
for i, rel in enumerate(relevances, start=1):
if rel > 0:
dcg += 1.0 / np.log2(i + 1.0)
return float(dcg)
def ndcg_at_k(relevances: List[int]) -> float:
dcg = dcg_at_k(relevances)
ideal = sorted(relevances, reverse=True)
idcg = dcg_at_k(ideal)
if idcg == 0.0:
return 0.0
return float(dcg / idcg)
def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
top = hits[:k] if hits else []
pred_docs = [filename_key(h.get("doc", "")) for h in top]
pred_pairs = [(filename_key(h.get("doc", "")), h.get("page", None)) for h in top]
# --- Doc-level metrics ---
gold_doc_set = set([d for d in gold_docs if isinstance(d, str) and d])
rel_bin_doc = [1 if d in gold_doc_set else 0 for d in pred_docs]
hitk_doc = 1 if any(rel_bin_doc) else 0
prec_doc = (sum(rel_bin_doc) / max(1, len(pred_docs))) if pred_docs else 0.0
rec_doc = (sum(rel_bin_doc) / max(1, len(gold_doc_set))) if gold_doc_set else 0.0
ndcg_doc = ndcg_at_k(rel_bin_doc)
# --- Page-level metrics (only if gold has page labels) ---
gold_pairs = set()
for d, p in zip(gold_docs, gold_pages):
if isinstance(d, str) and d and (p is not None) and (not (isinstance(p, float) and np.isnan(p))):
try:
p_int = int(p)
except Exception:
continue
gold_pairs.add((d, p_int))
if gold_pairs:
rel_bin_page = []
for (d, p) in pred_pairs:
if p is None or not isinstance(p, int):
rel_bin_page.append(0)
else:
rel_bin_page.append(1 if (d, p) in gold_pairs else 0)
hitk_page = 1 if any(rel_bin_page) else 0
prec_page = (sum(rel_bin_page) / max(1, len(pred_pairs))) if pred_pairs else 0.0
rec_page = (sum(rel_bin_page) / max(1, len(gold_pairs))) if gold_pairs else 0.0
ndcg_page = ndcg_at_k(rel_bin_page)
else:
hitk_page = prec_page = rec_page = ndcg_page = np.nan
return {
"hit@k_doc": hitk_doc,
"precision@k_doc": prec_doc,
"recall@k_doc": rec_doc,
"ndcg@k_doc": ndcg_doc,
"hit@k_page": hitk_page,
"precision@k_page": prec_page,
"recall@k_page": rec_page,
"ndcg@k_page": ndcg_page,
"n_gold_docs": int(len(gold_doc_set)),
"n_gold_doc_pages": int(len(gold_pairs)),
"n_pred": int(len(pred_docs))
}
# ---------------------- Answer Quality Metrics (with fallbacks) ---------------------- #
# Try to import optional libraries; if missing, we fall back to NaN metrics
try:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
HAVE_NLTK = True
except Exception:
sentence_bleu = None
SmoothingFunction = None
HAVE_NLTK = False
try:
from rouge_score import rouge_scorer
HAVE_ROUGE = True
except Exception:
rouge_scorer = None
HAVE_ROUGE = False
try:
from bert_score import score as bert_score
HAVE_BERT = True
except Exception:
bert_score = None
HAVE_BERT = False
if HAVE_NLTK:
_SMOOTH = SmoothingFunction().method1
else:
_SMOOTH = None
if HAVE_ROUGE:
_ROUGE_SCORER = rouge_scorer.RougeScorer(
["rouge1", "rouge2", "rougeL"], use_stemmer=True
)
else:
_ROUGE_SCORER = None
def _normalize_text_for_metrics(s: str) -> str:
import re
s = (s or "").strip().lower()
# remove simple markdown markers
s = re.sub(r"\*\*|\*", "", s)
# drop inline citations like (Doc.pdf, p.X)
s = re.sub(r"\([^)]*\)", " ", s)
s = re.sub(r"\s+", " ", s)
return s.strip()
def compute_text_metrics(pred: str, ref: str) -> Dict[str, float]:
"""
Compute lexical and semantic similarity metrics between prediction and reference:
- BLEU
- ROUGE-1/2/L (F-measure)
- BERTScore Recall, F1
If the required libraries (nltk, rouge-score, bert-score) are not installed,
returns NaN for all metrics.
"""
# If any of the libraries is missing, skip answer metrics
if not (HAVE_NLTK and HAVE_ROUGE and HAVE_BERT):
return {
"bleu": np.nan,
"rouge1": np.nan,
"rouge2": np.nan,
"rougeL": np.nan,
"bert_recall": np.nan,
"bert_f1": np.nan,
}
pred_n = _normalize_text_for_metrics(pred)
ref_n = _normalize_text_for_metrics(ref)
if not pred_n or not ref_n:
return {
"bleu": np.nan,
"rouge1": np.nan,
"rouge2": np.nan,
"rougeL": np.nan,
"bert_recall": np.nan,
"bert_f1": np.nan,
}
pred_tokens = pred_n.split()
ref_tokens = ref_n.split()
# BLEU (sentence-level with smoothing)
bleu = float(
sentence_bleu([ref_tokens], pred_tokens, smoothing_function=_SMOOTH)
)
# ROUGE via rouge-score (F-measure)
rs = _ROUGE_SCORER.score(ref_n, pred_n)
rouge1 = float(rs["rouge1"].fmeasure)
rouge2 = float(rs["rouge2"].fmeasure)
rougeL = float(rs["rougeL"].fmeasure)
# BERTScore (semantic similarity)
P, R, F1 = bert_score([pred_n], [ref_n], lang="en", rescale_with_baseline=True)
bert_recall = float(R.mean().item())
bert_f1 = float(F1.mean().item())
return {
"bleu": bleu,
"rouge1": rouge1,
"rouge2": rouge2,
"rougeL": rougeL,
"bert_recall": bert_recall,
"bert_f1": bert_f1,
}
# ----------------------------- Orchestration ----------------------------- #
# === Dark blue and accent colors ===
COLOR_TITLE = "\033[94m" # light blue for titles
COLOR_TEXT = "\033[34m" # dark blue
COLOR_ACCENT = "\033[36m" # cyan for metrics
COLOR_RESET = "\033[0m"
def _fmt(x: Any) -> str:
try:
return f"{float(x):.3f}"
except Exception:
return "-"
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--gold_csv", required=True, type=str)
ap.add_argument("--logs_jsonl", required=True, type=str)
ap.add_argument("--k", type=int, default=8)
ap.add_argument("--out_dir", type=str, default="rag_artifacts")
args = ap.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
gold_path = Path(args.gold_csv)
logs_path = Path(args.logs_jsonl)
if not gold_path.exists():
print(
f"{COLOR_TEXT}❌ gold.csv not found at {gold_path}{COLOR_RESET}",
file=sys.stderr,
)
sys.exit(0)
if not logs_path.exists() or logs_path.stat().st_size == 0:
print(
f"{COLOR_TEXT}❌ logs JSONL not found or empty at {logs_path}{COLOR_RESET}",
file=sys.stderr,
)
sys.exit(0)
# Read gold (retrieval + QA answers)
try:
gold, gold_answers = read_gold(gold_path)
except Exception as e:
print(
f"{COLOR_TEXT}❌ Failed to read gold: {e}{COLOR_RESET}",
file=sys.stderr,
)
sys.exit(0)
# Read logs (with robust aggregation)
try:
logs = read_logs(logs_path)
except Exception as e:
print(
f"{COLOR_TEXT}❌ Failed to read logs: {e}{COLOR_RESET}",
file=sys.stderr,
)
sys.exit(0)
if gold.empty:
print(
f"{COLOR_TEXT}❌ Gold file contains no usable rows.{COLOR_RESET}",
file=sys.stderr,
)
sys.exit(0)
if logs.empty:
print(
f"{COLOR_TEXT}❌ Logs file contains no usable entries.{COLOR_RESET}",
file=sys.stderr,
)
sys.exit(0)
# Build gold dict: normalized_question -> list of (doc, page)
gdict: Dict[str, List[Tuple[str, Optional[int]]]] = {}
for _, r in gold.iterrows():
q = str(r["question"]).strip() # already normalized in read_gold
d = r["doc"]
p = r["page"] if "page" in r else np.nan
gdict.setdefault(q, []).append((d, p))
# Normalize log questions for join
logs["q_norm"] = logs["question"].astype(str).str.casefold().str.strip()
perq_rows = []
not_in_logs, not_in_gold = [], []
# For each gold question, compute metrics using logs
for q_norm, pairs in gdict.items():
row = logs[logs["q_norm"] == q_norm]
gdocs = [d for (d, _) in pairs]
gpages = [p for (_, p) in pairs]
if row.empty:
# No logs for this gold question → zero retrieval and no answer metrics
not_in_logs.append(q_norm)
base_metrics = {
"hit@k_doc": 0,
"precision@k_doc": 0.0,
"recall@k_doc": 0.0,
"ndcg@k_doc": 0.0,
"hit@k_page": np.nan,
"precision@k_page": np.nan,
"recall@k_page": np.nan,
"ndcg@k_page": np.nan,
"n_gold_docs": int(len(set([d for d in gdocs if isinstance(d, str) and d]))),
"n_gold_doc_pages": int(
len(
[
(d, p)
for (d, p) in zip(gdocs, gpages)
if isinstance(d, str) and d and pd.notna(p)
]
)
),
"n_pred": 0,
}
txt_metrics = {
"bleu": np.nan,
"rouge1": np.nan,
"rouge2": np.nan,
"rougeL": np.nan,
"bert_recall": np.nan,
"bert_f1": np.nan,
}
perq_rows.append(
{
"question": q_norm,
"covered_in_logs": 0,
**base_metrics,
**txt_metrics,
}
)
continue
# Use aggregated hits from read_logs
hits = row.iloc[0]["hits"] or []
base_metrics = compute_metrics_for_question(gdocs, gpages, hits, args.k)
# Answer text: predicted vs. gold
pred_answer = str(row.iloc[0].get("answer", "")).strip()
gold_answer = str(gold_answers.get(q_norm, "")).strip()
if gold_answer and pred_answer:
txt_metrics = compute_text_metrics(pred_answer, gold_answer)
else:
txt_metrics = {
"bleu": np.nan,
"rouge1": np.nan,
"rouge2": np.nan,
"rougeL": np.nan,
"bert_recall": np.nan,
"bert_f1": np.nan,
}
perq_rows.append(
{
"question": q_norm,
"covered_in_logs": 1,
**base_metrics,
**txt_metrics,
}
)
# Any log questions not in gold
gold_qs = set(gdict.keys())
for qn in logs["q_norm"].tolist():
if qn not in gold_qs:
not_in_gold.append(qn)
perq = pd.DataFrame(perq_rows)
covered = perq[perq["covered_in_logs"] == 1].copy()
agg = {
"questions_total_gold": int(len(gdict)),
"questions_covered_in_logs": int(covered.shape[0]),
"questions_missing_in_logs": int(len(not_in_logs)),
"questions_in_logs_not_in_gold": int(len(set(not_in_gold))),
"k": int(args.k),
"mean_hit@k_doc": float(covered["hit@k_doc"].mean()) if not covered.empty else 0.0,
"mean_precision@k_doc": float(covered["precision@k_doc"].mean()) if not covered.empty else 0.0,
"mean_recall@k_doc": float(covered["recall@k_doc"].mean()) if not covered.empty else 0.0,
"mean_ndcg@k_doc": float(covered["ndcg@k_doc"].mean()) if not covered.empty else 0.0,
"mean_hit@k_page": float(covered["hit@k_page"].dropna().mean())
if covered["hit@k_page"].notna().any()
else None,
"mean_precision@k_page": float(covered["precision@k_page"].dropna().mean())
if covered["precision@k_page"].notna().any()
else None,
"mean_recall@k_page": float(covered["recall@k_page"].dropna().mean())
if covered["recall@k_page"].notna().any()
else None,
"mean_ndcg@k_page": float(covered["ndcg@k_page"].dropna().mean())
if covered["ndcg@k_page"].notna().any()
else None,
"avg_gold_docs_per_q": float(perq["n_gold_docs"].mean()) if not perq.empty else 0.0,
"avg_preds_per_q": float(perq["n_pred"].mean()) if not perq.empty else 0.0,
"examples_missing_in_logs": list(not_in_logs[:10]),
"examples_in_logs_not_in_gold": list(dict.fromkeys(not_in_gold))[:10],
}
# Aggregate answer-quality metrics (lexical + semantic)
if "bleu" in covered.columns:
agg["mean_bleu"] = float(covered["bleu"].mean(skipna=True))
agg["mean_rouge1"] = float(covered["rouge1"].mean(skipna=True))
agg["mean_rouge2"] = float(covered["rouge2"].mean(skipna=True))
agg["mean_rougeL"] = float(covered["rougeL"].mean(skipna=True))
agg["mean_bert_recall"] = float(covered["bert_recall"].mean(skipna=True))
agg["mean_bert_f1"] = float(covered["bert_f1"].mean(skipna=True))
perq_path = out_dir / "metrics_per_question.csv"
agg_path = out_dir / "metrics_aggregate.json"
perq.to_csv(perq_path, index=False)
with open(agg_path, "w", encoding="utf-8") as f:
json.dump(agg, f, ensure_ascii=False, indent=2)
# === Console summary with color ===
print(f"{COLOR_TITLE}RAG Evaluation Summary{COLOR_RESET}")
print(f"{COLOR_TITLE}----------------------{COLOR_RESET}")
print(f"{COLOR_TEXT}Gold questions: {COLOR_ACCENT}{agg['questions_total_gold']}{COLOR_RESET}")
print(f"{COLOR_TEXT}Covered in logs: {COLOR_ACCENT}{agg['questions_covered_in_logs']}{COLOR_RESET}")
print(f"{COLOR_TEXT}Missing in logs: {COLOR_ACCENT}{agg['questions_missing_in_logs']}{COLOR_RESET}")
print(
f"{COLOR_TEXT}In logs but not in gold: "
f"{COLOR_ACCENT}{agg['questions_in_logs_not_in_gold']}{COLOR_RESET}"
)
print(f"{COLOR_TEXT}k = {COLOR_ACCENT}{agg['k']}{COLOR_RESET}\n")
print(
f"{COLOR_TEXT}Doc-level:{COLOR_RESET} "
f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_doc'])} "
f"Precision@k={_fmt(agg['mean_precision@k_doc'])} "
f"Recall@k={_fmt(agg['mean_recall@k_doc'])} "
f"nDCG@k={_fmt(agg['mean_ndcg@k_doc'])}{COLOR_RESET}"
)
if agg.get("mean_hit@k_page") is not None:
print(
f"{COLOR_TEXT}Page-level:{COLOR_RESET} "
f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_page'])} "
f"Precision@k={_fmt(agg['mean_precision@k_page'])} "
f"Recall={_fmt(agg['mean_recall@k_page'])} "
f"nDCG@k={_fmt(agg['mean_ndcg@k_page'])}{COLOR_RESET}"
)
else:
print(f"{COLOR_TEXT}Page-level: (no page labels in gold){COLOR_RESET}")
# Lexical metrics summary
if "mean_bleu" in agg:
print(
f"{COLOR_TEXT}Lexical (answer quality):{COLOR_RESET} "
f"{COLOR_ACCENT}BLEU={_fmt(agg.get('mean_bleu'))} "
f"ROUGE-1={_fmt(agg.get('mean_rouge1'))} "
f"ROUGE-2={_fmt(agg.get('mean_rouge2'))} "
f"ROUGE-L={_fmt(agg.get('mean_rougeL'))}{COLOR_RESET}"
)
# Semantic metrics summary
if "mean_bert_f1" in agg:
print(
f"{COLOR_TEXT}Semantic (BERTScore):{COLOR_RESET} "
f"{COLOR_ACCENT}Recall={_fmt(agg.get('mean_bert_recall'))} "
f"F1={_fmt(agg.get('mean_bert_f1'))}{COLOR_RESET}"
)
print()
print(
f"{COLOR_TEXT}Wrote per-question CSV → "
f"{COLOR_ACCENT}{perq_path}{COLOR_RESET}"
)
print(
f"{COLOR_TEXT}Wrote aggregate JSON → "
f"{COLOR_ACCENT}{agg_path}{COLOR_RESET}"
)
if __name__ == "__main__":
main()