Spaces:
Sleeping
Sleeping
Update rag_eval_metrics.py
Browse files- rag_eval_metrics.py +189 -29
rag_eval_metrics.py
CHANGED
|
@@ -3,6 +3,10 @@
|
|
| 3 |
rag_eval_metrics.py
|
| 4 |
|
| 5 |
Evaluate RAG retrieval quality by comparing app logs (JSONL) with a gold file (CSV).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import argparse
|
|
@@ -15,19 +19,16 @@ from typing import Dict, List, Tuple, Any, Optional
|
|
| 15 |
import pandas as pd
|
| 16 |
import numpy as np
|
| 17 |
|
| 18 |
-
|
| 19 |
# ----------------------------- Small Utils ----------------------------- #
|
| 20 |
|
| 21 |
def filename_key(s: str) -> str:
|
| 22 |
s = (s or "").strip().replace("\\", "/").split("/")[-1]
|
| 23 |
return s.casefold()
|
| 24 |
|
| 25 |
-
|
| 26 |
def re_split_sc(s: str) -> List[str]:
|
| 27 |
import re
|
| 28 |
return re.split(r"[;,]", s)
|
| 29 |
|
| 30 |
-
|
| 31 |
def _pick_last_non_empty(hit_lists) -> List[dict]:
|
| 32 |
"""
|
| 33 |
Robustly select the last non-empty hits list from a pandas Series or iterable.
|
|
@@ -49,13 +50,20 @@ def _pick_last_non_empty(hit_lists) -> List[dict]:
|
|
| 49 |
# If everything was empty / NaN
|
| 50 |
return []
|
| 51 |
|
| 52 |
-
|
| 53 |
# ----------------------------- IO Helpers ----------------------------- #
|
| 54 |
|
| 55 |
def read_logs(jsonl_path: Path) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
rows = []
|
| 57 |
if (not jsonl_path.exists()) or jsonl_path.stat().st_size == 0:
|
| 58 |
-
return pd.DataFrame(columns=["question", "hits"])
|
| 59 |
|
| 60 |
with open(jsonl_path, "r", encoding="utf-8") as f:
|
| 61 |
for line in f:
|
|
@@ -86,21 +94,36 @@ def read_logs(jsonl_path: Path) -> pd.DataFrame:
|
|
| 86 |
|
| 87 |
norm_hits.append({"doc": doc, "page": page_int})
|
| 88 |
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
df = pd.DataFrame(rows)
|
| 92 |
if df.empty:
|
| 93 |
-
return pd.DataFrame(columns=["question", "hits"])
|
| 94 |
|
| 95 |
-
# Group by normalized question text and keep last non-empty hits list per question
|
| 96 |
df = (
|
| 97 |
df.groupby(df["question"].astype(str).str.casefold().str.strip(), as_index=False)
|
| 98 |
-
.agg({
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
)
|
| 100 |
return df
|
| 101 |
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
df = pd.read_csv(csv_path)
|
| 105 |
cols = {c.lower().strip(): c for c in df.columns}
|
| 106 |
|
|
@@ -134,6 +157,13 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
|
|
| 134 |
page_col = cols[cand]
|
| 135 |
break
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
rows = []
|
| 138 |
|
| 139 |
# Case 1: relevant_docs list column (no explicit doc_col)
|
|
@@ -141,6 +171,7 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
|
|
| 141 |
for _, r in df.iterrows():
|
| 142 |
q_raw = str(r[q_col]).strip()
|
| 143 |
q_norm = q_raw.casefold().strip()
|
|
|
|
| 144 |
|
| 145 |
rel_val = str(r[rel_list_col]) if pd.notna(r[rel_list_col]) else ""
|
| 146 |
if not rel_val:
|
|
@@ -148,7 +179,8 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
|
|
| 148 |
"question_raw": q_raw,
|
| 149 |
"question": q_norm,
|
| 150 |
"doc": None,
|
| 151 |
-
"page": np.nan
|
|
|
|
| 152 |
})
|
| 153 |
continue
|
| 154 |
|
|
@@ -158,7 +190,8 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
|
|
| 158 |
"question_raw": q_raw,
|
| 159 |
"question": q_norm,
|
| 160 |
"doc": filename_key(d),
|
| 161 |
-
"page": np.nan
|
|
|
|
| 162 |
})
|
| 163 |
|
| 164 |
# Case 2: doc/page columns (one relevant doc per row)
|
|
@@ -166,6 +199,7 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
|
|
| 166 |
for _, r in df.iterrows():
|
| 167 |
q_raw = str(r[q_col]).strip()
|
| 168 |
q_norm = q_raw.casefold().strip()
|
|
|
|
| 169 |
|
| 170 |
d = str(r[doc_col]).strip() if pd.notna(r[doc_col]) else ""
|
| 171 |
p = r[page_col] if (page_col and pd.notna(r[page_col])) else np.nan
|
|
@@ -179,7 +213,8 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
|
|
| 179 |
"question_raw": q_raw,
|
| 180 |
"question": q_norm,
|
| 181 |
"doc": filename_key(d),
|
| 182 |
-
"page": p
|
|
|
|
| 183 |
})
|
| 184 |
|
| 185 |
else:
|
|
@@ -196,8 +231,17 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
|
|
| 196 |
# Remove duplicates
|
| 197 |
gold = gold.drop_duplicates(subset=["question", "doc", "page"])
|
| 198 |
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
|
|
|
| 201 |
|
| 202 |
# ----------------------------- Metric Core ----------------------------- #
|
| 203 |
|
|
@@ -208,7 +252,6 @@ def dcg_at_k(relevances: List[int]) -> float:
|
|
| 208 |
dcg += 1.0 / np.log2(i + 1.0)
|
| 209 |
return float(dcg)
|
| 210 |
|
| 211 |
-
|
| 212 |
def ndcg_at_k(relevances: List[int]) -> float:
|
| 213 |
dcg = dcg_at_k(relevances)
|
| 214 |
ideal = sorted(relevances, reverse=True)
|
|
@@ -217,7 +260,6 @@ def ndcg_at_k(relevances: List[int]) -> float:
|
|
| 217 |
return 0.0
|
| 218 |
return float(dcg / idcg)
|
| 219 |
|
| 220 |
-
|
| 221 |
def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
|
| 222 |
top = hits[:k] if hits else []
|
| 223 |
pred_docs = [filename_key(h.get("doc", "")) for h in top]
|
|
@@ -271,6 +313,70 @@ def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
|
|
| 271 |
"n_pred": int(len(pred_docs))
|
| 272 |
}
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
# ----------------------------- Orchestration ----------------------------- #
|
| 276 |
|
|
@@ -280,14 +386,12 @@ COLOR_TEXT = "\033[34m" # dark blue
|
|
| 280 |
COLOR_ACCENT = "\033[36m" # cyan for metrics
|
| 281 |
COLOR_RESET = "\033[0m"
|
| 282 |
|
| 283 |
-
|
| 284 |
def _fmt(x: Any) -> str:
|
| 285 |
try:
|
| 286 |
return f"{float(x):.3f}"
|
| 287 |
except Exception:
|
| 288 |
return "-"
|
| 289 |
|
| 290 |
-
|
| 291 |
def main():
|
| 292 |
ap = argparse.ArgumentParser()
|
| 293 |
ap.add_argument("--gold_csv", required=True, type=str)
|
|
@@ -309,9 +413,9 @@ def main():
|
|
| 309 |
print(f"{COLOR_TEXT}❌ logs JSONL not found or empty at {logs_path}{COLOR_RESET}", file=sys.stderr)
|
| 310 |
sys.exit(0)
|
| 311 |
|
| 312 |
-
# Read gold
|
| 313 |
try:
|
| 314 |
-
gold = read_gold(gold_path)
|
| 315 |
except Exception as e:
|
| 316 |
print(f"{COLOR_TEXT}❌ Failed to read gold: {e}{COLOR_RESET}", file=sys.stderr)
|
| 317 |
sys.exit(0)
|
|
@@ -333,7 +437,7 @@ def main():
|
|
| 333 |
# Build gold dict: normalized_question -> list of (doc, page)
|
| 334 |
gdict: Dict[str, List[Tuple[str, Optional[int]]]] = {}
|
| 335 |
for _, r in gold.iterrows():
|
| 336 |
-
q = str(r["question"]).strip()
|
| 337 |
d = r["doc"]
|
| 338 |
p = r["page"] if "page" in r else np.nan
|
| 339 |
gdict.setdefault(q, []).append((d, p))
|
|
@@ -351,9 +455,9 @@ def main():
|
|
| 351 |
gpages = [p for (_, p) in pairs]
|
| 352 |
|
| 353 |
if row.empty:
|
| 354 |
-
# No logs for this gold question → zero retrieval
|
| 355 |
not_in_logs.append(q_norm)
|
| 356 |
-
|
| 357 |
"hit@k_doc": 0,
|
| 358 |
"precision@k_doc": 0.0,
|
| 359 |
"recall@k_doc": 0.0,
|
|
@@ -369,20 +473,49 @@ def main():
|
|
| 369 |
])),
|
| 370 |
"n_pred": 0
|
| 371 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
perq_rows.append({
|
| 373 |
"question": q_norm,
|
| 374 |
"covered_in_logs": 0,
|
| 375 |
-
**
|
|
|
|
| 376 |
})
|
| 377 |
continue
|
| 378 |
|
| 379 |
# Use aggregated hits from read_logs
|
| 380 |
hits = row.iloc[0]["hits"] or []
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
perq_rows.append({
|
| 383 |
"question": q_norm,
|
| 384 |
"covered_in_logs": 1,
|
| 385 |
-
**
|
|
|
|
| 386 |
})
|
| 387 |
|
| 388 |
# Any log questions not in gold
|
|
@@ -414,6 +547,15 @@ def main():
|
|
| 414 |
"examples_in_logs_not_in_gold": list(dict.fromkeys(not_in_gold))[:10],
|
| 415 |
}
|
| 416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
perq_path = out_dir / "metrics_per_question.csv"
|
| 418 |
agg_path = out_dir / "metrics_aggregate.json"
|
| 419 |
|
|
@@ -438,21 +580,39 @@ def main():
|
|
| 438 |
f"nDCG@k={_fmt(agg['mean_ndcg@k_doc'])}{COLOR_RESET}"
|
| 439 |
)
|
| 440 |
|
| 441 |
-
if agg
|
| 442 |
print(
|
| 443 |
f"{COLOR_TEXT}Page-level:{COLOR_RESET} "
|
| 444 |
f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_page'])} "
|
| 445 |
f"Precision@k={_fmt(agg['mean_precision@k_page'])} "
|
| 446 |
-
f"Recall
|
| 447 |
f"nDCG@k={_fmt(agg['mean_ndcg@k_page'])}{COLOR_RESET}"
|
| 448 |
)
|
| 449 |
else:
|
| 450 |
print(f"{COLOR_TEXT}Page-level: (no page labels in gold){COLOR_RESET}")
|
| 451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
print()
|
| 453 |
print(f"{COLOR_TEXT}Wrote per-question CSV → {COLOR_ACCENT}{perq_path}{COLOR_RESET}")
|
| 454 |
print(f"{COLOR_TEXT}Wrote aggregate JSON → {COLOR_ACCENT}{agg_path}{COLOR_RESET}")
|
| 455 |
|
| 456 |
-
|
| 457 |
if __name__ == "__main__":
|
| 458 |
main()
|
|
|
|
|
|
| 3 |
rag_eval_metrics.py
|
| 4 |
|
| 5 |
Evaluate RAG retrieval quality by comparing app logs (JSONL) with a gold file (CSV).
|
| 6 |
+
|
| 7 |
+
Now extended to also evaluate answer quality using:
|
| 8 |
+
- Lexical similarity: BLEU, ROUGE-1/2/L
|
| 9 |
+
- Semantic similarity: BERTScore (Recall, F1)
|
| 10 |
"""
|
| 11 |
|
| 12 |
import argparse
|
|
|
|
| 19 |
import pandas as pd
|
| 20 |
import numpy as np
|
| 21 |
|
|
|
|
| 22 |
# ----------------------------- Small Utils ----------------------------- #
|
| 23 |
|
| 24 |
def filename_key(s: str) -> str:
|
| 25 |
s = (s or "").strip().replace("\\", "/").split("/")[-1]
|
| 26 |
return s.casefold()
|
| 27 |
|
|
|
|
| 28 |
def re_split_sc(s: str) -> List[str]:
|
| 29 |
import re
|
| 30 |
return re.split(r"[;,]", s)
|
| 31 |
|
|
|
|
| 32 |
def _pick_last_non_empty(hit_lists) -> List[dict]:
|
| 33 |
"""
|
| 34 |
Robustly select the last non-empty hits list from a pandas Series or iterable.
|
|
|
|
| 50 |
# If everything was empty / NaN
|
| 51 |
return []
|
| 52 |
|
|
|
|
| 53 |
# ----------------------------- IO Helpers ----------------------------- #
|
| 54 |
|
| 55 |
def read_logs(jsonl_path: Path) -> pd.DataFrame:
|
| 56 |
+
"""
|
| 57 |
+
Read RAG JSONL logs and aggregate by question.
|
| 58 |
+
|
| 59 |
+
Returns a DataFrame with columns:
|
| 60 |
+
- question: original question text (last occurrence)
|
| 61 |
+
- hits: list of dicts {doc, page} for retrieval
|
| 62 |
+
- answer: final answer text logged for that question
|
| 63 |
+
"""
|
| 64 |
rows = []
|
| 65 |
if (not jsonl_path.exists()) or jsonl_path.stat().st_size == 0:
|
| 66 |
+
return pd.DataFrame(columns=["question", "hits", "answer"])
|
| 67 |
|
| 68 |
with open(jsonl_path, "r", encoding="utf-8") as f:
|
| 69 |
for line in f:
|
|
|
|
| 94 |
|
| 95 |
norm_hits.append({"doc": doc, "page": page_int})
|
| 96 |
|
| 97 |
+
# Extract final answer text (if present)
|
| 98 |
+
out = (rec.get("output") or {})
|
| 99 |
+
ans = ((out.get("final_answer") or "")).strip()
|
| 100 |
+
|
| 101 |
+
rows.append({"question": q, "hits": norm_hits, "answer": ans})
|
| 102 |
|
| 103 |
df = pd.DataFrame(rows)
|
| 104 |
if df.empty:
|
| 105 |
+
return pd.DataFrame(columns=["question", "hits", "answer"])
|
| 106 |
|
| 107 |
+
# Group by normalized question text and keep last non-empty hits list and answer per question
|
| 108 |
df = (
|
| 109 |
df.groupby(df["question"].astype(str).str.casefold().str.strip(), as_index=False)
|
| 110 |
+
.agg({
|
| 111 |
+
"question": "last",
|
| 112 |
+
"hits": _pick_last_non_empty,
|
| 113 |
+
"answer": "last"
|
| 114 |
+
})
|
| 115 |
)
|
| 116 |
return df
|
| 117 |
|
| 118 |
+
def read_gold(csv_path: Path) -> Tuple[pd.DataFrame, Dict[str, str]]:
|
| 119 |
+
"""
|
| 120 |
+
Read gold CSV with retrieval labels and optional reference answers.
|
| 121 |
|
| 122 |
+
Returns:
|
| 123 |
+
- gold_df: rows with columns ['question', 'doc', 'page', 'answer', ...]
|
| 124 |
+
where 'question' is normalized (casefold+strip)
|
| 125 |
+
- gold_answers: dict mapping normalized question -> reference answer text
|
| 126 |
+
"""
|
| 127 |
df = pd.read_csv(csv_path)
|
| 128 |
cols = {c.lower().strip(): c for c in df.columns}
|
| 129 |
|
|
|
|
| 157 |
page_col = cols[cand]
|
| 158 |
break
|
| 159 |
|
| 160 |
+
# --- optional answer column (for QA metrics) ---
|
| 161 |
+
ans_col = None
|
| 162 |
+
for cand in ["answer", "reference_answer", "gold_answer"]:
|
| 163 |
+
if cand in cols:
|
| 164 |
+
ans_col = cols[cand]
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
rows = []
|
| 168 |
|
| 169 |
# Case 1: relevant_docs list column (no explicit doc_col)
|
|
|
|
| 171 |
for _, r in df.iterrows():
|
| 172 |
q_raw = str(r[q_col]).strip()
|
| 173 |
q_norm = q_raw.casefold().strip()
|
| 174 |
+
ans_raw = str(r[ans_col]).strip() if (ans_col and pd.notna(r[ans_col])) else ""
|
| 175 |
|
| 176 |
rel_val = str(r[rel_list_col]) if pd.notna(r[rel_list_col]) else ""
|
| 177 |
if not rel_val:
|
|
|
|
| 179 |
"question_raw": q_raw,
|
| 180 |
"question": q_norm,
|
| 181 |
"doc": None,
|
| 182 |
+
"page": np.nan,
|
| 183 |
+
"answer": ans_raw
|
| 184 |
})
|
| 185 |
continue
|
| 186 |
|
|
|
|
| 190 |
"question_raw": q_raw,
|
| 191 |
"question": q_norm,
|
| 192 |
"doc": filename_key(d),
|
| 193 |
+
"page": np.nan,
|
| 194 |
+
"answer": ans_raw
|
| 195 |
})
|
| 196 |
|
| 197 |
# Case 2: doc/page columns (one relevant doc per row)
|
|
|
|
| 199 |
for _, r in df.iterrows():
|
| 200 |
q_raw = str(r[q_col]).strip()
|
| 201 |
q_norm = q_raw.casefold().strip()
|
| 202 |
+
ans_raw = str(r[ans_col]).strip() if (ans_col and pd.notna(r[ans_col])) else ""
|
| 203 |
|
| 204 |
d = str(r[doc_col]).strip() if pd.notna(r[doc_col]) else ""
|
| 205 |
p = r[page_col] if (page_col and pd.notna(r[page_col])) else np.nan
|
|
|
|
| 213 |
"question_raw": q_raw,
|
| 214 |
"question": q_norm,
|
| 215 |
"doc": filename_key(d),
|
| 216 |
+
"page": p,
|
| 217 |
+
"answer": ans_raw
|
| 218 |
})
|
| 219 |
|
| 220 |
else:
|
|
|
|
| 231 |
# Remove duplicates
|
| 232 |
gold = gold.drop_duplicates(subset=["question", "doc", "page"])
|
| 233 |
|
| 234 |
+
# Build question -> gold_answer map (normalized questions)
|
| 235 |
+
gold_answers: Dict[str, str] = {}
|
| 236 |
+
if "answer" in gold.columns:
|
| 237 |
+
tmp = (
|
| 238 |
+
gold[["question", "answer"]]
|
| 239 |
+
.dropna(subset=["answer"])
|
| 240 |
+
.drop_duplicates(subset=["question"])
|
| 241 |
+
)
|
| 242 |
+
gold_answers = dict(zip(tmp["question"], tmp["answer"]))
|
| 243 |
|
| 244 |
+
return gold, gold_answers
|
| 245 |
|
| 246 |
# ----------------------------- Metric Core ----------------------------- #
|
| 247 |
|
|
|
|
| 252 |
dcg += 1.0 / np.log2(i + 1.0)
|
| 253 |
return float(dcg)
|
| 254 |
|
|
|
|
| 255 |
def ndcg_at_k(relevances: List[int]) -> float:
|
| 256 |
dcg = dcg_at_k(relevances)
|
| 257 |
ideal = sorted(relevances, reverse=True)
|
|
|
|
| 260 |
return 0.0
|
| 261 |
return float(dcg / idcg)
|
| 262 |
|
|
|
|
| 263 |
def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
|
| 264 |
top = hits[:k] if hits else []
|
| 265 |
pred_docs = [filename_key(h.get("doc", "")) for h in top]
|
|
|
|
| 313 |
"n_pred": int(len(pred_docs))
|
| 314 |
}
|
| 315 |
|
| 316 |
+
# ---------------------- Answer Quality Metrics ---------------------- #
|
| 317 |
+
|
| 318 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 319 |
+
from rouge_score import rouge_scorer
|
| 320 |
+
from bert_score import score as bert_score
|
| 321 |
+
|
| 322 |
+
_SMOOTH = SmoothingFunction().method1
|
| 323 |
+
_ROUGE_SCORER = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
| 324 |
+
|
| 325 |
+
def _normalize_text_for_metrics(s: str) -> str:
|
| 326 |
+
import re
|
| 327 |
+
s = (s or "").strip().lower()
|
| 328 |
+
# remove simple markdown markers
|
| 329 |
+
s = re.sub(r"\*\*|\*", "", s)
|
| 330 |
+
# drop inline citations like (Doc.pdf, p.X)
|
| 331 |
+
s = re.sub(r"\([^)]*\)", " ", s)
|
| 332 |
+
s = re.sub(r"\s+", " ", s)
|
| 333 |
+
return s.strip()
|
| 334 |
+
|
| 335 |
+
def compute_text_metrics(pred: str, ref: str) -> Dict[str, float]:
|
| 336 |
+
"""
|
| 337 |
+
Compute lexical and semantic similarity metrics between prediction and reference:
|
| 338 |
+
- BLEU
|
| 339 |
+
- ROUGE-1/2/L (F-measure)
|
| 340 |
+
- BERTScore Recall, F1
|
| 341 |
+
"""
|
| 342 |
+
pred_n = _normalize_text_for_metrics(pred)
|
| 343 |
+
ref_n = _normalize_text_for_metrics(ref)
|
| 344 |
+
|
| 345 |
+
if not pred_n or not ref_n:
|
| 346 |
+
return {
|
| 347 |
+
"bleu": np.nan,
|
| 348 |
+
"rouge1": np.nan,
|
| 349 |
+
"rouge2": np.nan,
|
| 350 |
+
"rougeL": np.nan,
|
| 351 |
+
"bert_recall": np.nan,
|
| 352 |
+
"bert_f1": np.nan,
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
pred_tokens = pred_n.split()
|
| 356 |
+
ref_tokens = ref_n.split()
|
| 357 |
+
|
| 358 |
+
# BLEU (sentence-level with smoothing)
|
| 359 |
+
bleu = float(sentence_bleu([ref_tokens], pred_tokens, smoothing_function=_SMOOTH))
|
| 360 |
+
|
| 361 |
+
# ROUGE via rouge-score (F-measure)
|
| 362 |
+
rs = _ROUGE_SCORER.score(ref_n, pred_n)
|
| 363 |
+
rouge1 = float(rs["rouge1"].fmeasure)
|
| 364 |
+
rouge2 = float(rs["rouge2"].fmeasure)
|
| 365 |
+
rougeL = float(rs["rougeL"].fmeasure)
|
| 366 |
+
|
| 367 |
+
# BERTScore (semantic similarity)
|
| 368 |
+
P, R, F1 = bert_score([pred_n], [ref_n], lang="en", rescale_with_baseline=True)
|
| 369 |
+
bert_recall = float(R.mean().item())
|
| 370 |
+
bert_f1 = float(F1.mean().item())
|
| 371 |
+
|
| 372 |
+
return {
|
| 373 |
+
"bleu": bleu,
|
| 374 |
+
"rouge1": rouge1,
|
| 375 |
+
"rouge2": rouge2,
|
| 376 |
+
"rougeL": rougeL,
|
| 377 |
+
"bert_recall": bert_recall,
|
| 378 |
+
"bert_f1": bert_f1,
|
| 379 |
+
}
|
| 380 |
|
| 381 |
# ----------------------------- Orchestration ----------------------------- #
|
| 382 |
|
|
|
|
| 386 |
COLOR_ACCENT = "\033[36m" # cyan for metrics
|
| 387 |
COLOR_RESET = "\033[0m"
|
| 388 |
|
|
|
|
| 389 |
def _fmt(x: Any) -> str:
|
| 390 |
try:
|
| 391 |
return f"{float(x):.3f}"
|
| 392 |
except Exception:
|
| 393 |
return "-"
|
| 394 |
|
|
|
|
| 395 |
def main():
|
| 396 |
ap = argparse.ArgumentParser()
|
| 397 |
ap.add_argument("--gold_csv", required=True, type=str)
|
|
|
|
| 413 |
print(f"{COLOR_TEXT}❌ logs JSONL not found or empty at {logs_path}{COLOR_RESET}", file=sys.stderr)
|
| 414 |
sys.exit(0)
|
| 415 |
|
| 416 |
+
# Read gold (retrieval + QA answers)
|
| 417 |
try:
|
| 418 |
+
gold, gold_answers = read_gold(gold_path)
|
| 419 |
except Exception as e:
|
| 420 |
print(f"{COLOR_TEXT}❌ Failed to read gold: {e}{COLOR_RESET}", file=sys.stderr)
|
| 421 |
sys.exit(0)
|
|
|
|
| 437 |
# Build gold dict: normalized_question -> list of (doc, page)
|
| 438 |
gdict: Dict[str, List[Tuple[str, Optional[int]]]] = {}
|
| 439 |
for _, r in gold.iterrows():
|
| 440 |
+
q = str(r["question"]).strip() # already normalized in read_gold
|
| 441 |
d = r["doc"]
|
| 442 |
p = r["page"] if "page" in r else np.nan
|
| 443 |
gdict.setdefault(q, []).append((d, p))
|
|
|
|
| 455 |
gpages = [p for (_, p) in pairs]
|
| 456 |
|
| 457 |
if row.empty:
|
| 458 |
+
# No logs for this gold question → zero retrieval and no answer metrics
|
| 459 |
not_in_logs.append(q_norm)
|
| 460 |
+
base_metrics = {
|
| 461 |
"hit@k_doc": 0,
|
| 462 |
"precision@k_doc": 0.0,
|
| 463 |
"recall@k_doc": 0.0,
|
|
|
|
| 473 |
])),
|
| 474 |
"n_pred": 0
|
| 475 |
}
|
| 476 |
+
|
| 477 |
+
txt_metrics = {
|
| 478 |
+
"bleu": np.nan,
|
| 479 |
+
"rouge1": np.nan,
|
| 480 |
+
"rouge2": np.nan,
|
| 481 |
+
"rougeL": np.nan,
|
| 482 |
+
"bert_recall": np.nan,
|
| 483 |
+
"bert_f1": np.nan,
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
perq_rows.append({
|
| 487 |
"question": q_norm,
|
| 488 |
"covered_in_logs": 0,
|
| 489 |
+
**base_metrics,
|
| 490 |
+
**txt_metrics,
|
| 491 |
})
|
| 492 |
continue
|
| 493 |
|
| 494 |
# Use aggregated hits from read_logs
|
| 495 |
hits = row.iloc[0]["hits"] or []
|
| 496 |
+
base_metrics = compute_metrics_for_question(gdocs, gpages, hits, args.k)
|
| 497 |
+
|
| 498 |
+
# Answer text: predicted vs. gold
|
| 499 |
+
pred_answer = str(row.iloc[0].get("answer", "")).strip()
|
| 500 |
+
gold_answer = str(gold_answers.get(q_norm, "")).strip()
|
| 501 |
+
|
| 502 |
+
if gold_answer and pred_answer:
|
| 503 |
+
txt_metrics = compute_text_metrics(pred_answer, gold_answer)
|
| 504 |
+
else:
|
| 505 |
+
txt_metrics = {
|
| 506 |
+
"bleu": np.nan,
|
| 507 |
+
"rouge1": np.nan,
|
| 508 |
+
"rouge2": np.nan,
|
| 509 |
+
"rougeL": np.nan,
|
| 510 |
+
"bert_recall": np.nan,
|
| 511 |
+
"bert_f1": np.nan,
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
perq_rows.append({
|
| 515 |
"question": q_norm,
|
| 516 |
"covered_in_logs": 1,
|
| 517 |
+
**base_metrics,
|
| 518 |
+
**txt_metrics,
|
| 519 |
})
|
| 520 |
|
| 521 |
# Any log questions not in gold
|
|
|
|
| 547 |
"examples_in_logs_not_in_gold": list(dict.fromkeys(not_in_gold))[:10],
|
| 548 |
}
|
| 549 |
|
| 550 |
+
# Aggregate answer-quality metrics (lexical + semantic)
|
| 551 |
+
if "bleu" in covered.columns:
|
| 552 |
+
agg["mean_bleu"] = float(covered["bleu"].mean(skipna=True))
|
| 553 |
+
agg["mean_rouge1"] = float(covered["rouge1"].mean(skipna=True))
|
| 554 |
+
agg["mean_rouge2"] = float(covered["rouge2"].mean(skipna=True))
|
| 555 |
+
agg["mean_rougeL"] = float(covered["rougeL"].mean(skipna=True))
|
| 556 |
+
agg["mean_bert_recall"] = float(covered["bert_recall"].mean(skipna=True))
|
| 557 |
+
agg["mean_bert_f1"] = float(covered["bert_f1"].mean(skipna=True))
|
| 558 |
+
|
| 559 |
perq_path = out_dir / "metrics_per_question.csv"
|
| 560 |
agg_path = out_dir / "metrics_aggregate.json"
|
| 561 |
|
|
|
|
| 580 |
f"nDCG@k={_fmt(agg['mean_ndcg@k_doc'])}{COLOR_RESET}"
|
| 581 |
)
|
| 582 |
|
| 583 |
+
if agg.get("mean_hit@k_page") is not None:
|
| 584 |
print(
|
| 585 |
f"{COLOR_TEXT}Page-level:{COLOR_RESET} "
|
| 586 |
f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_page'])} "
|
| 587 |
f"Precision@k={_fmt(agg['mean_precision@k_page'])} "
|
| 588 |
+
f"Recall={_fmt(agg['mean_recall@k_page'])} "
|
| 589 |
f"nDCG@k={_fmt(agg['mean_ndcg@k_page'])}{COLOR_RESET}"
|
| 590 |
)
|
| 591 |
else:
|
| 592 |
print(f"{COLOR_TEXT}Page-level: (no page labels in gold){COLOR_RESET}")
|
| 593 |
|
| 594 |
+
# Lexical metrics summary
|
| 595 |
+
if "mean_bleu" in agg:
|
| 596 |
+
print(
|
| 597 |
+
f"{COLOR_TEXT}Lexical (answer quality):{COLOR_RESET} "
|
| 598 |
+
f"{COLOR_ACCENT}BLEU={_fmt(agg['mean_bleu'])} "
|
| 599 |
+
f"ROUGE-1={_fmt(agg['mean_rouge1'])} "
|
| 600 |
+
f"ROUGE-2={_fmt(agg['mean_rouge2'])} "
|
| 601 |
+
f"ROUGE-L={_fmt(agg['mean_rougeL'])}{COLOR_RESET}"
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
# Semantic metrics summary
|
| 605 |
+
if "mean_bert_f1" in agg:
|
| 606 |
+
print(
|
| 607 |
+
f"{COLOR_TEXT}Semantic (BERTScore):{COLOR_RESET} "
|
| 608 |
+
f"{COLOR_ACCENT}Recall={_fmt(agg['mean_bert_recall'])} "
|
| 609 |
+
f"F1={_fmt(agg['mean_bert_f1'])}{COLOR_RESET}"
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
print()
|
| 613 |
print(f"{COLOR_TEXT}Wrote per-question CSV → {COLOR_ACCENT}{perq_path}{COLOR_RESET}")
|
| 614 |
print(f"{COLOR_TEXT}Wrote aggregate JSON → {COLOR_ACCENT}{agg_path}{COLOR_RESET}")
|
| 615 |
|
|
|
|
| 616 |
if __name__ == "__main__":
|
| 617 |
main()
|
| 618 |
+
|