Inframat-x commited on
Commit
a414fb0
Β·
verified Β·
1 Parent(s): f740959

Update rag_eval_metrics.py

Browse files
Files changed (1) hide show
  1. rag_eval_metrics.py +184 -157
rag_eval_metrics.py CHANGED
@@ -1,57 +1,37 @@
1
- #!/usr/bin/env python3
2
- """
3
- rag_eval_metrics.py
4
-
5
- Evaluate RAG retrieval quality by comparing app logs (JSONL) with a gold file (CSV).
6
- """
7
-
8
- import argparse
9
  import json
10
- import os
11
  import sys
12
  from pathlib import Path
13
  from typing import Dict, List, Tuple, Any, Optional
14
 
15
- import pandas as pd
16
  import numpy as np
 
17
 
18
-
19
- # ----------------------------- Small Utils ----------------------------- #
 
 
 
20
 
21
  def filename_key(s: str) -> str:
22
  s = (s or "").strip().replace("\\", "/").split("/")[-1]
23
  return s.casefold()
24
 
25
-
26
  def re_split_sc(s: str) -> List[str]:
27
  import re
28
  return re.split(r"[;,]", s)
29
 
30
-
31
  def _pick_last_non_empty(hit_lists) -> List[dict]:
32
- """
33
- Robustly select the last non-empty hits list from a pandas Series or iterable.
34
-
35
- This fixes the KeyError that happens when using reversed() directly on a Series
36
- with a non-range index.
37
- """
38
- # Convert pandas Series or other iterables to a plain Python list
39
  try:
40
  values = list(hit_lists.tolist())
41
  except AttributeError:
42
  values = list(hit_lists)
43
-
44
- # Walk from last to first, return first non-empty list-like
45
  for lst in reversed(values):
46
  if isinstance(lst, (list, tuple)) and len(lst) > 0:
47
  return lst
48
-
49
- # If everything was empty / NaN
50
  return []
51
 
52
-
53
- # ----------------------------- IO Helpers ----------------------------- #
54
-
55
  def read_logs(jsonl_path: Path) -> pd.DataFrame:
56
  rows = []
57
  if (not jsonl_path.exists()) or jsonl_path.stat().st_size == 0:
@@ -67,67 +47,96 @@ def read_logs(jsonl_path: Path) -> pd.DataFrame:
67
  except Exception:
68
  continue
69
 
70
- # Extract question
71
- q = (((rec.get("inputs") or {}).get("question")) or "").strip()
 
 
 
 
 
 
72
 
73
- # Extract retrieval hits (if present)
74
  retr = (rec.get("retrieval") or {})
75
  hits = retr.get("hits", [])
 
76
  norm_hits = []
77
  for h in hits or []:
78
  doc = (h.get("doc") or "").strip()
79
  page = str(h.get("page") or "").strip()
80
-
81
- # Normalize page to int or None
82
  try:
83
  page_int = int(page)
84
  except Exception:
85
  page_int = None
86
-
87
  norm_hits.append({"doc": doc, "page": page_int})
88
 
89
- rows.append({"question": q, "hits": norm_hits})
 
 
 
 
 
 
 
 
 
 
90
 
91
  df = pd.DataFrame(rows)
92
  if df.empty:
93
  return pd.DataFrame(columns=["question", "hits"])
94
 
95
- # Group by normalized question text and keep last non-empty hits list per question
96
  df = (
97
- df.groupby(df["question"].astype(str).str.casefold().str.strip(), as_index=False)
98
- .agg({"question": "last", "hits": _pick_last_non_empty})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  )
100
  return df
101
 
102
-
103
  def read_gold(csv_path: Path) -> pd.DataFrame:
104
  df = pd.read_csv(csv_path)
105
  cols = {c.lower().strip(): c for c in df.columns}
106
 
107
- # --- question column ---
108
  q_col = None
109
  for cand in ["question", "query", "q"]:
110
  if cand in cols:
111
  q_col = cols[cand]
112
  break
113
  if q_col is None:
114
- raise ValueError("Gold CSV must contain a 'question' column (case-insensitive).")
115
 
116
- # --- possible relevant_docs (list-in-cell) column ---
117
  rel_list_col = None
118
  for cand in ["relevant_docs", "relevant", "docs"]:
119
  if cand in cols:
120
  rel_list_col = cols[cand]
121
  break
122
 
123
- # --- single-doc-per-row column ---
124
  doc_col = None
125
  for cand in ["doc", "document", "file", "doc_name"]:
126
  if cand in cols:
127
  doc_col = cols[cand]
128
  break
129
 
130
- # --- optional page column ---
131
  page_col = None
132
  for cand in ["page", "page_num", "page_number"]:
133
  if cand in cols:
@@ -136,71 +145,61 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
136
 
137
  rows = []
138
 
139
- # Case 1: relevant_docs list column (no explicit doc_col)
140
  if rel_list_col and doc_col is None:
141
  for _, r in df.iterrows():
142
  q_raw = str(r[q_col]).strip()
143
  q_norm = q_raw.casefold().strip()
144
-
145
  rel_val = str(r[rel_list_col]) if pd.notna(r[rel_list_col]) else ""
146
  if not rel_val:
147
- rows.append({
148
- "question_raw": q_raw,
149
- "question": q_norm,
150
- "doc": None,
151
- "page": np.nan
152
- })
 
 
153
  continue
154
-
155
  parts = [p.strip() for p in re_split_sc(rel_val)]
156
  for d in parts:
157
- rows.append({
158
- "question_raw": q_raw,
159
- "question": q_norm,
160
- "doc": filename_key(d),
161
- "page": np.nan
162
- })
163
-
164
- # Case 2: doc/page columns (one relevant doc per row)
165
  elif doc_col:
166
  for _, r in df.iterrows():
167
  q_raw = str(r[q_col]).strip()
168
  q_norm = q_raw.casefold().strip()
169
-
170
  d = str(r[doc_col]).strip() if pd.notna(r[doc_col]) else ""
171
  p = r[page_col] if (page_col and pd.notna(r[page_col])) else np.nan
172
-
173
  try:
174
  p = int(p)
175
  except Exception:
176
  p = np.nan
177
-
178
- rows.append({
179
- "question_raw": q_raw,
180
- "question": q_norm,
181
- "doc": filename_key(d),
182
- "page": p
183
- })
184
-
185
  else:
186
- raise ValueError("Gold CSV must contain either a 'doc' column or a 'relevant_docs' column.")
187
 
188
  gold = pd.DataFrame(rows)
189
-
190
- # Keep only rows with a valid doc (when docs exist)
191
  gold["has_doc"] = gold["doc"].apply(lambda x: isinstance(x, str) and len(x) > 0)
192
  if gold["has_doc"].any():
193
  gold = gold[gold["has_doc"]].copy()
194
  gold.drop(columns=["has_doc"], inplace=True, errors="ignore")
195
-
196
- # Remove duplicates
197
  gold = gold.drop_duplicates(subset=["question", "doc", "page"])
198
-
199
  return gold
200
 
201
-
202
- # ----------------------------- Metric Core ----------------------------- #
203
-
204
  def dcg_at_k(relevances: List[int]) -> float:
205
  dcg = 0.0
206
  for i, rel in enumerate(relevances, start=1):
@@ -208,7 +207,6 @@ def dcg_at_k(relevances: List[int]) -> float:
208
  dcg += 1.0 / np.log2(i + 1.0)
209
  return float(dcg)
210
 
211
-
212
  def ndcg_at_k(relevances: List[int]) -> float:
213
  dcg = dcg_at_k(relevances)
214
  ideal = sorted(relevances, reverse=True)
@@ -217,13 +215,11 @@ def ndcg_at_k(relevances: List[int]) -> float:
217
  return 0.0
218
  return float(dcg / idcg)
219
 
220
-
221
  def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
222
  top = hits[:k] if hits else []
223
  pred_docs = [filename_key(h.get("doc", "")) for h in top]
224
  pred_pairs = [(filename_key(h.get("doc", "")), h.get("page", None)) for h in top]
225
 
226
- # --- Doc-level metrics ---
227
  gold_doc_set = set([d for d in gold_docs if isinstance(d, str) and d])
228
 
229
  rel_bin_doc = [1 if d in gold_doc_set else 0 for d in pred_docs]
@@ -232,10 +228,11 @@ def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
232
  rec_doc = (sum(rel_bin_doc) / max(1, len(gold_doc_set))) if gold_doc_set else 0.0
233
  ndcg_doc = ndcg_at_k(rel_bin_doc)
234
 
235
- # --- Page-level metrics (only if gold has page labels) ---
236
  gold_pairs = set()
237
  for d, p in zip(gold_docs, gold_pages):
238
- if isinstance(d, str) and d and (p is not None) and (not (isinstance(p, float) and np.isnan(p))):
 
 
239
  try:
240
  p_int = int(p)
241
  except Exception:
@@ -249,7 +246,6 @@ def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
249
  rel_bin_page.append(0)
250
  else:
251
  rel_bin_page.append(1 if (d, p) in gold_pairs else 0)
252
-
253
  hitk_page = 1 if any(rel_bin_page) else 0
254
  prec_page = (sum(rel_bin_page) / max(1, len(pred_pairs))) if pred_pairs else 0.0
255
  rec_page = (sum(rel_bin_page) / max(1, len(gold_pairs))) if gold_pairs else 0.0
@@ -268,69 +264,56 @@ def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
268
  "ndcg@k_page": ndcg_page,
269
  "n_gold_docs": int(len(gold_doc_set)),
270
  "n_gold_doc_pages": int(len(gold_pairs)),
271
- "n_pred": int(len(pred_docs))
272
  }
273
 
274
-
275
- # ----------------------------- Orchestration ----------------------------- #
276
-
277
- # === Dark blue and accent colors ===
278
- COLOR_TITLE = "\033[94m" # light blue for titles
279
- COLOR_TEXT = "\033[34m" # dark blue
280
- COLOR_ACCENT = "\033[36m" # cyan for metrics
281
- COLOR_RESET = "\033[0m"
282
-
283
-
284
  def _fmt(x: Any) -> str:
285
  try:
286
  return f"{float(x):.3f}"
287
  except Exception:
288
  return "-"
289
 
290
-
291
- def main():
292
- ap = argparse.ArgumentParser()
293
- ap.add_argument("--gold_csv", required=True, type=str)
294
- ap.add_argument("--logs_jsonl", required=True, type=str)
295
- ap.add_argument("--k", type=int, default=8)
296
- ap.add_argument("--out_dir", type=str, default="rag_artifacts")
297
- args = ap.parse_args()
298
-
299
- out_dir = Path(args.out_dir)
300
  out_dir.mkdir(parents=True, exist_ok=True)
301
 
302
- gold_path = Path(args.gold_csv)
303
- logs_path = Path(args.logs_jsonl)
304
 
305
  if not gold_path.exists():
306
- print(f"{COLOR_TEXT}❌ gold.csv not found at {gold_path}{COLOR_RESET}", file=sys.stderr)
307
- sys.exit(0)
308
  if not logs_path.exists() or logs_path.stat().st_size == 0:
309
- print(f"{COLOR_TEXT}❌ logs JSONL not found or empty at {logs_path}{COLOR_RESET}", file=sys.stderr)
310
- sys.exit(0)
311
 
312
- # Read gold
313
  try:
314
  gold = read_gold(gold_path)
315
  except Exception as e:
316
- print(f"{COLOR_TEXT}❌ Failed to read gold: {e}{COLOR_RESET}", file=sys.stderr)
317
- sys.exit(0)
318
 
319
- # Read logs (with robust aggregation)
320
  try:
321
  logs = read_logs(logs_path)
322
  except Exception as e:
323
- print(f"{COLOR_TEXT}❌ Failed to read logs: {e}{COLOR_RESET}", file=sys.stderr)
324
- sys.exit(0)
325
 
326
  if gold.empty:
327
- print(f"{COLOR_TEXT}❌ Gold file contains no usable rows.{COLOR_RESET}", file=sys.stderr)
328
- sys.exit(0)
329
  if logs.empty:
330
- print(f"{COLOR_TEXT}❌ Logs file contains no usable entries.{COLOR_RESET}", file=sys.stderr)
331
- sys.exit(0)
332
 
333
- # Build gold dict: normalized_question -> list of (doc, page)
334
  gdict: Dict[str, List[Tuple[str, Optional[int]]]] = {}
335
  for _, r in gold.iterrows():
336
  q = str(r["question"]).strip()
@@ -338,20 +321,18 @@ def main():
338
  p = r["page"] if "page" in r else np.nan
339
  gdict.setdefault(q, []).append((d, p))
340
 
341
- # Normalize log questions for join
342
  logs["q_norm"] = logs["question"].astype(str).str.casefold().str.strip()
343
 
344
  perq_rows = []
345
  not_in_logs, not_in_gold = [], []
346
 
347
- # For each gold question, compute metrics using logs
348
  for q_norm, pairs in gdict.items():
349
  row = logs[logs["q_norm"] == q_norm]
350
  gdocs = [d for (d, _) in pairs]
351
  gpages = [p for (_, p) in pairs]
352
 
353
  if row.empty:
354
- # No logs for this gold question β†’ zero retrieval
355
  not_in_logs.append(q_norm)
356
  metrics = {
357
  "hit@k_doc": 0,
@@ -362,30 +343,27 @@ def main():
362
  "precision@k_page": np.nan,
363
  "recall@k_page": np.nan,
364
  "ndcg@k_page": np.nan,
365
- "n_gold_docs": int(len(set([d for d in gdocs if isinstance(d, str) and d]))),
366
- "n_gold_doc_pages": int(len([
367
- (d, p) for (d, p) in zip(gdocs, gpages)
368
- if isinstance(d, str) and d and pd.notna(p)
369
- ])),
370
- "n_pred": 0
 
 
 
 
 
 
 
371
  }
372
- perq_rows.append({
373
- "question": q_norm,
374
- "covered_in_logs": 0,
375
- **metrics
376
- })
377
  continue
378
 
379
- # Use aggregated hits from read_logs
380
  hits = row.iloc[0]["hits"] or []
381
- metrics = compute_metrics_for_question(gdocs, gpages, hits, args.k)
382
- perq_rows.append({
383
- "question": q_norm,
384
- "covered_in_logs": 1,
385
- **metrics
386
- })
387
-
388
- # Any log questions not in gold
389
  gold_qs = set(gdict.keys())
390
  for qn in logs["q_norm"].tolist():
391
  if qn not in gold_qs:
@@ -399,15 +377,23 @@ def main():
399
  "questions_covered_in_logs": int(covered.shape[0]),
400
  "questions_missing_in_logs": int(len(not_in_logs)),
401
  "questions_in_logs_not_in_gold": int(len(set(not_in_gold))),
402
- "k": int(args.k),
403
  "mean_hit@k_doc": float(covered["hit@k_doc"].mean()) if not covered.empty else 0.0,
404
  "mean_precision@k_doc": float(covered["precision@k_doc"].mean()) if not covered.empty else 0.0,
405
  "mean_recall@k_doc": float(covered["recall@k_doc"].mean()) if not covered.empty else 0.0,
406
  "mean_ndcg@k_doc": float(covered["ndcg@k_doc"].mean()) if not covered.empty else 0.0,
407
- "mean_hit@k_page": float(covered["hit@k_page"].dropna().mean()) if covered["hit@k_page"].notna().any() else None,
408
- "mean_precision@k_page": float(covered["precision@k_page"].dropna().mean()) if covered["precision@k_page"].notna().any() else None,
409
- "mean_recall@k_page": float(covered["recall@k_page"].dropna().mean()) if covered["recall@k_page"].notna().any() else None,
410
- "mean_ndcg@k_page": float(covered["ndcg@k_page"].dropna().mean()) if covered["ndcg@k_page"].notna().any() else None,
 
 
 
 
 
 
 
 
411
  "avg_gold_docs_per_q": float(perq["n_gold_docs"].mean()) if not perq.empty else 0.0,
412
  "avg_preds_per_q": float(perq["n_pred"].mean()) if not perq.empty else 0.0,
413
  "examples_missing_in_logs": list(not_in_logs[:10]),
@@ -415,13 +401,12 @@ def main():
415
  }
416
 
417
  perq_path = out_dir / "metrics_per_question.csv"
418
- agg_path = out_dir / "metrics_aggregate.json"
419
 
420
  perq.to_csv(perq_path, index=False)
421
  with open(agg_path, "w", encoding="utf-8") as f:
422
  json.dump(agg, f, ensure_ascii=False, indent=2)
423
 
424
- # === Console summary with color ===
425
  print(f"{COLOR_TITLE}RAG Evaluation Summary{COLOR_RESET}")
426
  print(f"{COLOR_TITLE}----------------------{COLOR_RESET}")
427
  print(f"{COLOR_TEXT}Gold questions: {COLOR_ACCENT}{agg['questions_total_gold']}{COLOR_RESET}")
@@ -438,7 +423,7 @@ def main():
438
  f"nDCG@k={_fmt(agg['mean_ndcg@k_doc'])}{COLOR_RESET}"
439
  )
440
 
441
- if agg['mean_hit@k_page'] is not None:
442
  print(
443
  f"{COLOR_TEXT}Page-level:{COLOR_RESET} "
444
  f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_page'])} "
@@ -453,6 +438,48 @@ def main():
453
  print(f"{COLOR_TEXT}Wrote per-question CSV β†’ {COLOR_ACCENT}{perq_path}{COLOR_RESET}")
454
  print(f"{COLOR_TEXT}Wrote aggregate JSON β†’ {COLOR_ACCENT}{agg_path}{COLOR_RESET}")
455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
- if __name__ == "__main__":
458
- main()
 
 
1
+ # ===================== rag_eval_metrics.py (notebook version) =====================
 
 
 
 
 
 
 
2
  import json
 
3
  import sys
4
  from pathlib import Path
5
  from typing import Dict, List, Tuple, Any, Optional
6
 
 
7
  import numpy as np
8
+ import pandas as pd
9
 
10
+ # ---- small CLI-like helpers (you can wrap in main() if you want) ----
11
+ COLOR_TITLE = "\033[94m"
12
+ COLOR_TEXT = "\033[34m"
13
+ COLOR_ACCENT = "\033[36m"
14
+ COLOR_RESET = "\033[0m"
15
 
16
  def filename_key(s: str) -> str:
17
  s = (s or "").strip().replace("\\", "/").split("/")[-1]
18
  return s.casefold()
19
 
 
20
  def re_split_sc(s: str) -> List[str]:
21
  import re
22
  return re.split(r"[;,]", s)
23
 
 
24
  def _pick_last_non_empty(hit_lists) -> List[dict]:
 
 
 
 
 
 
 
25
  try:
26
  values = list(hit_lists.tolist())
27
  except AttributeError:
28
  values = list(hit_lists)
 
 
29
  for lst in reversed(values):
30
  if isinstance(lst, (list, tuple)) and len(lst) > 0:
31
  return lst
 
 
32
  return []
33
 
34
+ # ----------------------------- IO: logs ----------------------------- #
 
 
35
  def read_logs(jsonl_path: Path) -> pd.DataFrame:
36
  rows = []
37
  if (not jsonl_path.exists()) or jsonl_path.stat().st_size == 0:
 
47
  except Exception:
48
  continue
49
 
50
+ inputs = (rec.get("inputs") or {})
51
+ q = (inputs.get("question") or "").strip()
52
+
53
+ w_tfidf = float(inputs.get("w_tfidf", 0.0))
54
+ w_bm25 = float(inputs.get("w_bm25", 0.0))
55
+ w_emb = float(inputs.get("w_emb", 0.0))
56
+ top_k = int(inputs.get("top_k", rec.get("k", 8)))
57
+ cfg_id = inputs.get("config_id") or ""
58
 
 
59
  retr = (rec.get("retrieval") or {})
60
  hits = retr.get("hits", [])
61
+
62
  norm_hits = []
63
  for h in hits or []:
64
  doc = (h.get("doc") or "").strip()
65
  page = str(h.get("page") or "").strip()
 
 
66
  try:
67
  page_int = int(page)
68
  except Exception:
69
  page_int = None
 
70
  norm_hits.append({"doc": doc, "page": page_int})
71
 
72
+ rows.append(
73
+ {
74
+ "question": q,
75
+ "hits": norm_hits,
76
+ "w_tfidf": w_tfidf,
77
+ "w_bm25": w_bm25,
78
+ "w_emb": w_emb,
79
+ "top_k": top_k,
80
+ "config_id": cfg_id,
81
+ }
82
+ )
83
 
84
  df = pd.DataFrame(rows)
85
  if df.empty:
86
  return pd.DataFrame(columns=["question", "hits"])
87
 
88
+ # group by normalized question + weights + config
89
  df = (
90
+ df.groupby(
91
+ [
92
+ df["question"].astype(str).str.casefold().str.strip(),
93
+ df["w_tfidf"].round(3),
94
+ df["w_bm25"].round(3),
95
+ df["w_emb"].round(3),
96
+ df["top_k"],
97
+ df["config_id"],
98
+ ],
99
+ as_index=False,
100
+ )
101
+ .agg(
102
+ {
103
+ "question": "last",
104
+ "hits": _pick_last_non_empty,
105
+ "w_tfidf": "last",
106
+ "w_bm25": "last",
107
+ "w_emb": "last",
108
+ "top_k": "last",
109
+ "config_id": "last",
110
+ }
111
+ )
112
  )
113
  return df
114
 
115
+ # ----------------------------- IO: gold ----------------------------- #
116
  def read_gold(csv_path: Path) -> pd.DataFrame:
117
  df = pd.read_csv(csv_path)
118
  cols = {c.lower().strip(): c for c in df.columns}
119
 
 
120
  q_col = None
121
  for cand in ["question", "query", "q"]:
122
  if cand in cols:
123
  q_col = cols[cand]
124
  break
125
  if q_col is None:
126
+ raise ValueError("Gold CSV must contain a 'question' column.")
127
 
 
128
  rel_list_col = None
129
  for cand in ["relevant_docs", "relevant", "docs"]:
130
  if cand in cols:
131
  rel_list_col = cols[cand]
132
  break
133
 
 
134
  doc_col = None
135
  for cand in ["doc", "document", "file", "doc_name"]:
136
  if cand in cols:
137
  doc_col = cols[cand]
138
  break
139
 
 
140
  page_col = None
141
  for cand in ["page", "page_num", "page_number"]:
142
  if cand in cols:
 
145
 
146
  rows = []
147
 
 
148
  if rel_list_col and doc_col is None:
149
  for _, r in df.iterrows():
150
  q_raw = str(r[q_col]).strip()
151
  q_norm = q_raw.casefold().strip()
 
152
  rel_val = str(r[rel_list_col]) if pd.notna(r[rel_list_col]) else ""
153
  if not rel_val:
154
+ rows.append(
155
+ {
156
+ "question_raw": q_raw,
157
+ "question": q_norm,
158
+ "doc": None,
159
+ "page": np.nan,
160
+ }
161
+ )
162
  continue
 
163
  parts = [p.strip() for p in re_split_sc(rel_val)]
164
  for d in parts:
165
+ rows.append(
166
+ {
167
+ "question_raw": q_raw,
168
+ "question": q_norm,
169
+ "doc": filename_key(d),
170
+ "page": np.nan,
171
+ }
172
+ )
173
  elif doc_col:
174
  for _, r in df.iterrows():
175
  q_raw = str(r[q_col]).strip()
176
  q_norm = q_raw.casefold().strip()
 
177
  d = str(r[doc_col]).strip() if pd.notna(r[doc_col]) else ""
178
  p = r[page_col] if (page_col and pd.notna(r[page_col])) else np.nan
 
179
  try:
180
  p = int(p)
181
  except Exception:
182
  p = np.nan
183
+ rows.append(
184
+ {
185
+ "question_raw": q_raw,
186
+ "question": q_norm,
187
+ "doc": filename_key(d),
188
+ "page": p,
189
+ }
190
+ )
191
  else:
192
+ raise ValueError("Gold CSV must contain either a 'doc' or 'relevant_docs' column.")
193
 
194
  gold = pd.DataFrame(rows)
 
 
195
  gold["has_doc"] = gold["doc"].apply(lambda x: isinstance(x, str) and len(x) > 0)
196
  if gold["has_doc"].any():
197
  gold = gold[gold["has_doc"]].copy()
198
  gold.drop(columns=["has_doc"], inplace=True, errors="ignore")
 
 
199
  gold = gold.drop_duplicates(subset=["question", "doc", "page"])
 
200
  return gold
201
 
202
+ # ----------------------------- metrics ----------------------------- #
 
 
203
  def dcg_at_k(relevances: List[int]) -> float:
204
  dcg = 0.0
205
  for i, rel in enumerate(relevances, start=1):
 
207
  dcg += 1.0 / np.log2(i + 1.0)
208
  return float(dcg)
209
 
 
210
  def ndcg_at_k(relevances: List[int]) -> float:
211
  dcg = dcg_at_k(relevances)
212
  ideal = sorted(relevances, reverse=True)
 
215
  return 0.0
216
  return float(dcg / idcg)
217
 
 
218
  def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
219
  top = hits[:k] if hits else []
220
  pred_docs = [filename_key(h.get("doc", "")) for h in top]
221
  pred_pairs = [(filename_key(h.get("doc", "")), h.get("page", None)) for h in top]
222
 
 
223
  gold_doc_set = set([d for d in gold_docs if isinstance(d, str) and d])
224
 
225
  rel_bin_doc = [1 if d in gold_doc_set else 0 for d in pred_docs]
 
228
  rec_doc = (sum(rel_bin_doc) / max(1, len(gold_doc_set))) if gold_doc_set else 0.0
229
  ndcg_doc = ndcg_at_k(rel_bin_doc)
230
 
 
231
  gold_pairs = set()
232
  for d, p in zip(gold_docs, gold_pages):
233
+ if isinstance(d, str) and d and (p is not None) and not (
234
+ isinstance(p, float) and np.isnan(p)
235
+ ):
236
  try:
237
  p_int = int(p)
238
  except Exception:
 
246
  rel_bin_page.append(0)
247
  else:
248
  rel_bin_page.append(1 if (d, p) in gold_pairs else 0)
 
249
  hitk_page = 1 if any(rel_bin_page) else 0
250
  prec_page = (sum(rel_bin_page) / max(1, len(pred_pairs))) if pred_pairs else 0.0
251
  rec_page = (sum(rel_bin_page) / max(1, len(gold_pairs))) if gold_pairs else 0.0
 
264
  "ndcg@k_page": ndcg_page,
265
  "n_gold_docs": int(len(gold_doc_set)),
266
  "n_gold_doc_pages": int(len(gold_pairs)),
267
+ "n_pred": int(len(pred_docs)),
268
  }
269
 
 
 
 
 
 
 
 
 
 
 
270
  def _fmt(x: Any) -> str:
271
  try:
272
  return f"{float(x):.3f}"
273
  except Exception:
274
  return "-"
275
 
276
+ # ----------------------------- main evaluation ----------------------------- #
277
+ def evaluate_rag(
278
+ gold_csv: str,
279
+ logs_jsonl: str,
280
+ k: int = 8,
281
+ out_dir: str = "rag_artifacts",
282
+ group_by_weights: bool = True,
283
+ ):
284
+ out_dir = Path(out_dir)
 
285
  out_dir.mkdir(parents=True, exist_ok=True)
286
 
287
+ gold_path = Path(gold_csv)
288
+ logs_path = Path(logs_jsonl)
289
 
290
  if not gold_path.exists():
291
+ print(f"{COLOR_TEXT}❌ gold.csv not found at {gold_path}{COLOR_RESET}")
292
+ return
293
  if not logs_path.exists() or logs_path.stat().st_size == 0:
294
+ print(f"{COLOR_TEXT}❌ logs JSONL not found or empty at {logs_path}{COLOR_RESET}")
295
+ return
296
 
 
297
  try:
298
  gold = read_gold(gold_path)
299
  except Exception as e:
300
+ print(f"{COLOR_TEXT}❌ Failed to read gold: {e}{COLOR_RESET}")
301
+ return
302
 
 
303
  try:
304
  logs = read_logs(logs_path)
305
  except Exception as e:
306
+ print(f"{COLOR_TEXT}❌ Failed to read logs: {e}{COLOR_RESET}")
307
+ return
308
 
309
  if gold.empty:
310
+ print(f"{COLOR_TEXT}❌ Gold file contains no usable rows.{COLOR_RESET}")
311
+ return
312
  if logs.empty:
313
+ print(f"{COLOR_TEXT}❌ Logs file contains no usable entries.{COLOR_RESET}")
314
+ return
315
 
316
+ # build gold dict
317
  gdict: Dict[str, List[Tuple[str, Optional[int]]]] = {}
318
  for _, r in gold.iterrows():
319
  q = str(r["question"]).strip()
 
321
  p = r["page"] if "page" in r else np.nan
322
  gdict.setdefault(q, []).append((d, p))
323
 
324
+ # normalize questions in logs
325
  logs["q_norm"] = logs["question"].astype(str).str.casefold().str.strip()
326
 
327
  perq_rows = []
328
  not_in_logs, not_in_gold = [], []
329
 
 
330
  for q_norm, pairs in gdict.items():
331
  row = logs[logs["q_norm"] == q_norm]
332
  gdocs = [d for (d, _) in pairs]
333
  gpages = [p for (_, p) in pairs]
334
 
335
  if row.empty:
 
336
  not_in_logs.append(q_norm)
337
  metrics = {
338
  "hit@k_doc": 0,
 
343
  "precision@k_page": np.nan,
344
  "recall@k_page": np.nan,
345
  "ndcg@k_page": np.nan,
346
+ "n_gold_docs": int(
347
+ len(set([d for d in gdocs if isinstance(d, str) and d]))
348
+ ),
349
+ "n_gold_doc_pages": int(
350
+ len(
351
+ [
352
+ (d, p)
353
+ for (d, p) in zip(gdocs, gpages)
354
+ if isinstance(d, str) and d and pd.notna(p)
355
+ ]
356
+ )
357
+ ),
358
+ "n_pred": 0,
359
  }
360
+ perq_rows.append({"question": q_norm, "covered_in_logs": 0, **metrics})
 
 
 
 
361
  continue
362
 
 
363
  hits = row.iloc[0]["hits"] or []
364
+ metrics = compute_metrics_for_question(gdocs, gpages, hits, k)
365
+ perq_rows.append({"question": q_norm, "covered_in_logs": 1, **metrics})
366
+
 
 
 
 
 
367
  gold_qs = set(gdict.keys())
368
  for qn in logs["q_norm"].tolist():
369
  if qn not in gold_qs:
 
377
  "questions_covered_in_logs": int(covered.shape[0]),
378
  "questions_missing_in_logs": int(len(not_in_logs)),
379
  "questions_in_logs_not_in_gold": int(len(set(not_in_gold))),
380
+ "k": int(k),
381
  "mean_hit@k_doc": float(covered["hit@k_doc"].mean()) if not covered.empty else 0.0,
382
  "mean_precision@k_doc": float(covered["precision@k_doc"].mean()) if not covered.empty else 0.0,
383
  "mean_recall@k_doc": float(covered["recall@k_doc"].mean()) if not covered.empty else 0.0,
384
  "mean_ndcg@k_doc": float(covered["ndcg@k_doc"].mean()) if not covered.empty else 0.0,
385
+ "mean_hit@k_page": float(covered["hit@k_page"].dropna().mean())
386
+ if covered["hit@k_page"].notna().any()
387
+ else None,
388
+ "mean_precision@k_page": float(covered["precision@k_page"].dropna().mean())
389
+ if covered["precision@k_page"].notna().any()
390
+ else None,
391
+ "mean_recall@k_page": float(covered["recall@k_page"].dropna().mean())
392
+ if covered["recall@k_page"].notna().any()
393
+ else None,
394
+ "mean_ndcg@k_page": float(covered["ndcg@k_page"].dropna().mean())
395
+ if covered["ndcg@k_page"].notna().any()
396
+ else None,
397
  "avg_gold_docs_per_q": float(perq["n_gold_docs"].mean()) if not perq.empty else 0.0,
398
  "avg_preds_per_q": float(perq["n_pred"].mean()) if not perq.empty else 0.0,
399
  "examples_missing_in_logs": list(not_in_logs[:10]),
 
401
  }
402
 
403
  perq_path = out_dir / "metrics_per_question.csv"
404
+ agg_path = out_dir / "metrics_aggregate.json"
405
 
406
  perq.to_csv(perq_path, index=False)
407
  with open(agg_path, "w", encoding="utf-8") as f:
408
  json.dump(agg, f, ensure_ascii=False, indent=2)
409
 
 
410
  print(f"{COLOR_TITLE}RAG Evaluation Summary{COLOR_RESET}")
411
  print(f"{COLOR_TITLE}----------------------{COLOR_RESET}")
412
  print(f"{COLOR_TEXT}Gold questions: {COLOR_ACCENT}{agg['questions_total_gold']}{COLOR_RESET}")
 
423
  f"nDCG@k={_fmt(agg['mean_ndcg@k_doc'])}{COLOR_RESET}"
424
  )
425
 
426
+ if agg["mean_hit@k_page"] is not None:
427
  print(
428
  f"{COLOR_TEXT}Page-level:{COLOR_RESET} "
429
  f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_page'])} "
 
438
  print(f"{COLOR_TEXT}Wrote per-question CSV β†’ {COLOR_ACCENT}{perq_path}{COLOR_RESET}")
439
  print(f"{COLOR_TEXT}Wrote aggregate JSON β†’ {COLOR_ACCENT}{agg_path}{COLOR_RESET}")
440
 
441
+ # --------- NEW: configuration-wise surface metrics ---------
442
+ if group_by_weights:
443
+ logs_short = logs[
444
+ ["q_norm", "w_tfidf", "w_bm25", "w_emb", "top_k", "config_id"]
445
+ ].drop_duplicates()
446
+
447
+ perq_for_merge = perq.copy()
448
+ perq_for_merge["q_norm"] = perq_for_merge["question"]
449
+ perq_for_merge = perq_for_merge.merge(
450
+ logs_short, on="q_norm", how="left"
451
+ )
452
+ perq_for_merge.drop(columns=["q_norm"], inplace=True, errors="ignore")
453
+
454
+ gb_cols = ["w_tfidf", "w_bm25", "w_emb", "top_k", "config_id"]
455
+ surf_rows = []
456
+ for key, grp in perq_for_merge.groupby(gb_cols):
457
+ wt, wb, we, tk, cid = key
458
+ cov = grp[grp["covered_in_logs"] == 1]
459
+ if cov.empty:
460
+ continue
461
+ surf_rows.append(
462
+ {
463
+ "w_tfidf": float(wt),
464
+ "w_bm25": float(wb),
465
+ "w_emb": float(we),
466
+ "top_k": int(tk),
467
+ "config_id": cid,
468
+ "mean_hit@k_doc": float(cov["hit@k_doc"].mean()),
469
+ "mean_precision@k_doc": float(cov["precision@k_doc"].mean()),
470
+ "mean_recall@k_doc": float(cov["recall@k_doc"].mean()),
471
+ "mean_ndcg@k_doc": float(cov["ndcg@k_doc"].mean()),
472
+ }
473
+ )
474
+
475
+ surf_df = pd.DataFrame(surf_rows)
476
+ surf_path = out_dir / "metrics_by_weights.csv"
477
+ surf_df.to_csv(surf_path, index=False)
478
+ print(
479
+ f"{COLOR_TEXT}Wrote config surface metrics β†’ "
480
+ f"{COLOR_ACCENT}{surf_path}{COLOR_RESET}"
481
+ )
482
 
483
+ # Example call from notebook:
484
+ # evaluate_rag("gold.csv", "rag_artifacts/rag_logs.jsonl", k=8, out_dir="rag_artifacts")
485
+ print("βœ… rag_eval_metrics helpers loaded.")