Inframat-x commited on
Commit
eac1d4e
Β·
verified Β·
1 Parent(s): 810d884

Update rag_eval_metrics.py

Browse files
Files changed (1) hide show
  1. rag_eval_metrics.py +157 -190
rag_eval_metrics.py CHANGED
@@ -1,36 +1,57 @@
1
- # ===================== rag_eval_metrics.py =====================
 
 
 
 
 
 
 
2
  import json
 
 
3
  from pathlib import Path
4
  from typing import Dict, List, Tuple, Any, Optional
5
 
6
- import numpy as np
7
  import pandas as pd
 
8
 
9
- # ---- small CLI-like helpers (colors are just for console logs) ----
10
- COLOR_TITLE = "\033[94m"
11
- COLOR_TEXT = "\033[34m"
12
- COLOR_ACCENT = "\033[36m"
13
- COLOR_RESET = "\033[0m"
14
 
15
  def filename_key(s: str) -> str:
16
  s = (s or "").strip().replace("\\", "/").split("/")[-1]
17
  return s.casefold()
18
 
 
19
  def re_split_sc(s: str) -> List[str]:
20
  import re
21
  return re.split(r"[;,]", s)
22
 
 
23
  def _pick_last_non_empty(hit_lists) -> List[dict]:
 
 
 
 
 
 
 
24
  try:
25
  values = list(hit_lists.tolist())
26
  except AttributeError:
27
  values = list(hit_lists)
 
 
28
  for lst in reversed(values):
29
  if isinstance(lst, (list, tuple)) and len(lst) > 0:
30
  return lst
 
 
31
  return []
32
 
33
- # ----------------------------- IO: logs ----------------------------- #
 
 
34
  def read_logs(jsonl_path: Path) -> pd.DataFrame:
35
  rows = []
36
  if (not jsonl_path.exists()) or jsonl_path.stat().st_size == 0:
@@ -46,96 +67,67 @@ def read_logs(jsonl_path: Path) -> pd.DataFrame:
46
  except Exception:
47
  continue
48
 
49
- inputs = (rec.get("inputs") or {})
50
- q = (inputs.get("question") or "").strip()
51
-
52
- w_tfidf = float(inputs.get("w_tfidf", 0.0))
53
- w_bm25 = float(inputs.get("w_bm25", 0.0))
54
- w_emb = float(inputs.get("w_emb", 0.0))
55
- top_k = int(inputs.get("top_k", rec.get("k", 8)))
56
- cfg_id = inputs.get("config_id") or ""
57
 
 
58
  retr = (rec.get("retrieval") or {})
59
  hits = retr.get("hits", [])
60
-
61
  norm_hits = []
62
  for h in hits or []:
63
  doc = (h.get("doc") or "").strip()
64
  page = str(h.get("page") or "").strip()
 
 
65
  try:
66
  page_int = int(page)
67
  except Exception:
68
  page_int = None
 
69
  norm_hits.append({"doc": doc, "page": page_int})
70
 
71
- rows.append(
72
- {
73
- "question": q,
74
- "hits": norm_hits,
75
- "w_tfidf": w_tfidf,
76
- "w_bm25": w_bm25,
77
- "w_emb": w_emb,
78
- "top_k": top_k,
79
- "config_id": cfg_id,
80
- }
81
- )
82
 
83
  df = pd.DataFrame(rows)
84
  if df.empty:
85
  return pd.DataFrame(columns=["question", "hits"])
86
 
87
- # group by normalized question + weights + config
88
  df = (
89
- df.groupby(
90
- [
91
- df["question"].astype(str).str.casefold().str.strip(),
92
- df["w_tfidf"].round(3),
93
- df["w_bm25"].round(3),
94
- df["w_emb"].round(3),
95
- df["top_k"],
96
- df["config_id"],
97
- ],
98
- as_index=False,
99
- )
100
- .agg(
101
- {
102
- "question": "last",
103
- "hits": _pick_last_non_empty,
104
- "w_tfidf": "last",
105
- "w_bm25": "last",
106
- "w_emb": "last",
107
- "top_k": "last",
108
- "config_id": "last",
109
- }
110
- )
111
  )
112
  return df
113
 
114
- # ----------------------------- IO: gold ----------------------------- #
115
  def read_gold(csv_path: Path) -> pd.DataFrame:
116
  df = pd.read_csv(csv_path)
117
  cols = {c.lower().strip(): c for c in df.columns}
118
 
 
119
  q_col = None
120
  for cand in ["question", "query", "q"]:
121
  if cand in cols:
122
  q_col = cols[cand]
123
  break
124
  if q_col is None:
125
- raise ValueError("Gold CSV must contain a 'question' column.")
126
 
 
127
  rel_list_col = None
128
  for cand in ["relevant_docs", "relevant", "docs"]:
129
  if cand in cols:
130
  rel_list_col = cols[cand]
131
  break
132
 
 
133
  doc_col = None
134
  for cand in ["doc", "document", "file", "doc_name"]:
135
  if cand in cols:
136
  doc_col = cols[cand]
137
  break
138
 
 
139
  page_col = None
140
  for cand in ["page", "page_num", "page_number"]:
141
  if cand in cols:
@@ -144,61 +136,71 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
144
 
145
  rows = []
146
 
 
147
  if rel_list_col and doc_col is None:
148
  for _, r in df.iterrows():
149
  q_raw = str(r[q_col]).strip()
150
  q_norm = q_raw.casefold().strip()
 
151
  rel_val = str(r[rel_list_col]) if pd.notna(r[rel_list_col]) else ""
152
  if not rel_val:
153
- rows.append(
154
- {
155
- "question_raw": q_raw,
156
- "question": q_norm,
157
- "doc": None,
158
- "page": np.nan,
159
- }
160
- )
161
  continue
 
162
  parts = [p.strip() for p in re_split_sc(rel_val)]
163
  for d in parts:
164
- rows.append(
165
- {
166
- "question_raw": q_raw,
167
- "question": q_norm,
168
- "doc": filename_key(d),
169
- "page": np.nan,
170
- }
171
- )
172
  elif doc_col:
173
  for _, r in df.iterrows():
174
  q_raw = str(r[q_col]).strip()
175
  q_norm = q_raw.casefold().strip()
 
176
  d = str(r[doc_col]).strip() if pd.notna(r[doc_col]) else ""
177
  p = r[page_col] if (page_col and pd.notna(r[page_col])) else np.nan
 
178
  try:
179
  p = int(p)
180
  except Exception:
181
  p = np.nan
182
- rows.append(
183
- {
184
- "question_raw": q_raw,
185
- "question": q_norm,
186
- "doc": filename_key(d),
187
- "page": p,
188
- }
189
- )
190
  else:
191
- raise ValueError("Gold CSV must contain either a 'doc' or 'relevant_docs' column.")
192
 
193
  gold = pd.DataFrame(rows)
 
 
194
  gold["has_doc"] = gold["doc"].apply(lambda x: isinstance(x, str) and len(x) > 0)
195
  if gold["has_doc"].any():
196
  gold = gold[gold["has_doc"]].copy()
197
  gold.drop(columns=["has_doc"], inplace=True, errors="ignore")
 
 
198
  gold = gold.drop_duplicates(subset=["question", "doc", "page"])
 
199
  return gold
200
 
201
- # ----------------------------- metrics ----------------------------- #
 
 
202
  def dcg_at_k(relevances: List[int]) -> float:
203
  dcg = 0.0
204
  for i, rel in enumerate(relevances, start=1):
@@ -206,6 +208,7 @@ def dcg_at_k(relevances: List[int]) -> float:
206
  dcg += 1.0 / np.log2(i + 1.0)
207
  return float(dcg)
208
 
 
209
  def ndcg_at_k(relevances: List[int]) -> float:
210
  dcg = dcg_at_k(relevances)
211
  ideal = sorted(relevances, reverse=True)
@@ -214,11 +217,13 @@ def ndcg_at_k(relevances: List[int]) -> float:
214
  return 0.0
215
  return float(dcg / idcg)
216
 
 
217
  def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
218
  top = hits[:k] if hits else []
219
  pred_docs = [filename_key(h.get("doc", "")) for h in top]
220
  pred_pairs = [(filename_key(h.get("doc", "")), h.get("page", None)) for h in top]
221
 
 
222
  gold_doc_set = set([d for d in gold_docs if isinstance(d, str) and d])
223
 
224
  rel_bin_doc = [1 if d in gold_doc_set else 0 for d in pred_docs]
@@ -227,11 +232,10 @@ def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
227
  rec_doc = (sum(rel_bin_doc) / max(1, len(gold_doc_set))) if gold_doc_set else 0.0
228
  ndcg_doc = ndcg_at_k(rel_bin_doc)
229
 
 
230
  gold_pairs = set()
231
  for d, p in zip(gold_docs, gold_pages):
232
- if isinstance(d, str) and d and (p is not None) and not (
233
- isinstance(p, float) and np.isnan(p)
234
- ):
235
  try:
236
  p_int = int(p)
237
  except Exception:
@@ -245,6 +249,7 @@ def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
245
  rel_bin_page.append(0)
246
  else:
247
  rel_bin_page.append(1 if (d, p) in gold_pairs else 0)
 
248
  hitk_page = 1 if any(rel_bin_page) else 0
249
  prec_page = (sum(rel_bin_page) / max(1, len(pred_pairs))) if pred_pairs else 0.0
250
  rec_page = (sum(rel_bin_page) / max(1, len(gold_pairs))) if gold_pairs else 0.0
@@ -263,65 +268,69 @@ def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
263
  "ndcg@k_page": ndcg_page,
264
  "n_gold_docs": int(len(gold_doc_set)),
265
  "n_gold_doc_pages": int(len(gold_pairs)),
266
- "n_pred": int(len(pred_docs)),
267
  }
268
 
 
 
 
 
 
 
 
 
 
 
269
  def _fmt(x: Any) -> str:
270
  try:
271
  return f"{float(x):.3f}"
272
  except Exception:
273
  return "-"
274
 
275
- # ----------------------------- main evaluation ----------------------------- #
276
- def evaluate_rag(
277
- gold_csv: str,
278
- logs_jsonl: str,
279
- k: int = 8,
280
- out_dir: str = "rag_artifacts",
281
- group_by_weights: bool = True,
282
- ):
283
- """
284
- Main entry point used by app.py.
285
 
286
- - gold_csv: path to gold CSV
287
- - logs_jsonl: path to rag_logs.jsonl
288
- - k: cutoff (top-k)
289
- - out_dir: directory to write metrics files
290
- - group_by_weights: if True, also write metrics_by_weights.csv
291
- """
292
- out_dir = Path(out_dir)
 
 
293
  out_dir.mkdir(parents=True, exist_ok=True)
294
 
295
- gold_path = Path(gold_csv)
296
- logs_path = Path(logs_jsonl)
297
 
298
  if not gold_path.exists():
299
- print(f"{COLOR_TEXT}❌ gold.csv not found at {gold_path}{COLOR_RESET}")
300
- return
301
  if not logs_path.exists() or logs_path.stat().st_size == 0:
302
- print(f"{COLOR_TEXT}❌ logs JSONL not found or empty at {logs_path}{COLOR_RESET}")
303
- return
304
 
 
305
  try:
306
  gold = read_gold(gold_path)
307
  except Exception as e:
308
- print(f"{COLOR_TEXT}❌ Failed to read gold: {e}{COLOR_RESET}")
309
- return
310
 
 
311
  try:
312
  logs = read_logs(logs_path)
313
  except Exception as e:
314
- print(f"{COLOR_TEXT}❌ Failed to read logs: {e}{COLOR_RESET}")
315
- return
316
 
317
  if gold.empty:
318
- print(f"{COLOR_TEXT}❌ Gold file contains no usable rows.{COLOR_RESET}")
319
- return
320
  if logs.empty:
321
- print(f"{COLOR_TEXT}❌ Logs file contains no usable entries.{COLOR_RESET}")
322
- return
323
 
324
- # build gold dict
325
  gdict: Dict[str, List[Tuple[str, Optional[int]]]] = {}
326
  for _, r in gold.iterrows():
327
  q = str(r["question"]).strip()
@@ -329,18 +338,20 @@ def evaluate_rag(
329
  p = r["page"] if "page" in r else np.nan
330
  gdict.setdefault(q, []).append((d, p))
331
 
332
- # normalize questions in logs
333
  logs["q_norm"] = logs["question"].astype(str).str.casefold().str.strip()
334
 
335
  perq_rows = []
336
  not_in_logs, not_in_gold = [], []
337
 
 
338
  for q_norm, pairs in gdict.items():
339
  row = logs[logs["q_norm"] == q_norm]
340
  gdocs = [d for (d, _) in pairs]
341
  gpages = [p for (_, p) in pairs]
342
 
343
  if row.empty:
 
344
  not_in_logs.append(q_norm)
345
  metrics = {
346
  "hit@k_doc": 0,
@@ -351,27 +362,30 @@ def evaluate_rag(
351
  "precision@k_page": np.nan,
352
  "recall@k_page": np.nan,
353
  "ndcg@k_page": np.nan,
354
- "n_gold_docs": int(
355
- len(set([d for d in gdocs if isinstance(d, str) and d]))
356
- ),
357
- "n_gold_doc_pages": int(
358
- len(
359
- [
360
- (d, p)
361
- for (d, p) in zip(gdocs, gpages)
362
- if isinstance(d, str) and d and pd.notna(p)
363
- ]
364
- )
365
- ),
366
- "n_pred": 0,
367
  }
368
- perq_rows.append({"question": q_norm, "covered_in_logs": 0, **metrics})
 
 
 
 
369
  continue
370
 
 
371
  hits = row.iloc[0]["hits"] or []
372
- metrics = compute_metrics_for_question(gdocs, gpages, hits, k)
373
- perq_rows.append({"question": q_norm, "covered_in_logs": 1, **metrics})
374
-
 
 
 
 
 
375
  gold_qs = set(gdict.keys())
376
  for qn in logs["q_norm"].tolist():
377
  if qn not in gold_qs:
@@ -385,23 +399,15 @@ def evaluate_rag(
385
  "questions_covered_in_logs": int(covered.shape[0]),
386
  "questions_missing_in_logs": int(len(not_in_logs)),
387
  "questions_in_logs_not_in_gold": int(len(set(not_in_gold))),
388
- "k": int(k),
389
  "mean_hit@k_doc": float(covered["hit@k_doc"].mean()) if not covered.empty else 0.0,
390
  "mean_precision@k_doc": float(covered["precision@k_doc"].mean()) if not covered.empty else 0.0,
391
  "mean_recall@k_doc": float(covered["recall@k_doc"].mean()) if not covered.empty else 0.0,
392
  "mean_ndcg@k_doc": float(covered["ndcg@k_doc"].mean()) if not covered.empty else 0.0,
393
- "mean_hit@k_page": float(covered["hit@k_page"].dropna().mean())
394
- if covered["hit@k_page"].notna().any()
395
- else None,
396
- "mean_precision@k_page": float(covered["precision@k_page"].dropna().mean())
397
- if covered["precision@k_page"].notna().any()
398
- else None,
399
- "mean_recall@k_page": float(covered["recall@k_page"].dropna().mean())
400
- if covered["recall@k_page"].notna().any()
401
- else None,
402
- "mean_ndcg@k_page": float(covered["ndcg@k_page"].dropna().mean())
403
- if covered["ndcg@k_page"].notna().any()
404
- else None,
405
  "avg_gold_docs_per_q": float(perq["n_gold_docs"].mean()) if not perq.empty else 0.0,
406
  "avg_preds_per_q": float(perq["n_pred"].mean()) if not perq.empty else 0.0,
407
  "examples_missing_in_logs": list(not_in_logs[:10]),
@@ -409,12 +415,13 @@ def evaluate_rag(
409
  }
410
 
411
  perq_path = out_dir / "metrics_per_question.csv"
412
- agg_path = out_dir / "metrics_aggregate.json"
413
 
414
  perq.to_csv(perq_path, index=False)
415
  with open(agg_path, "w", encoding="utf-8") as f:
416
  json.dump(agg, f, ensure_ascii=False, indent=2)
417
 
 
418
  print(f"{COLOR_TITLE}RAG Evaluation Summary{COLOR_RESET}")
419
  print(f"{COLOR_TITLE}----------------------{COLOR_RESET}")
420
  print(f"{COLOR_TEXT}Gold questions: {COLOR_ACCENT}{agg['questions_total_gold']}{COLOR_RESET}")
@@ -431,7 +438,7 @@ def evaluate_rag(
431
  f"nDCG@k={_fmt(agg['mean_ndcg@k_doc'])}{COLOR_RESET}"
432
  )
433
 
434
- if agg["mean_hit@k_page"] is not None:
435
  print(
436
  f"{COLOR_TEXT}Page-level:{COLOR_RESET} "
437
  f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_page'])} "
@@ -446,46 +453,6 @@ def evaluate_rag(
446
  print(f"{COLOR_TEXT}Wrote per-question CSV β†’ {COLOR_ACCENT}{perq_path}{COLOR_RESET}")
447
  print(f"{COLOR_TEXT}Wrote aggregate JSON β†’ {COLOR_ACCENT}{agg_path}{COLOR_RESET}")
448
 
449
- # --------- optional: configuration-wise metrics by weights ---------
450
- if group_by_weights:
451
- logs_short = logs[
452
- ["q_norm", "w_tfidf", "w_bm25", "w_emb", "top_k", "config_id"]
453
- ].drop_duplicates()
454
-
455
- perq_for_merge = perq.copy()
456
- perq_for_merge["q_norm"] = perq_for_merge["question"]
457
- perq_for_merge = perq_for_merge.merge(
458
- logs_short, on="q_norm", how="left"
459
- )
460
- perq_for_merge.drop(columns=["q_norm"], inplace=True, errors="ignore")
461
-
462
- gb_cols = ["w_tfidf", "w_bm25", "w_emb", "top_k", "config_id"]
463
- surf_rows = []
464
- for key, grp in perq_for_merge.groupby(gb_cols):
465
- wt, wb, we, tk, cid = key
466
- cov = grp[grp["covered_in_logs"] == 1]
467
- if cov.empty:
468
- continue
469
- surf_rows.append(
470
- {
471
- "w_tfidf": float(wt),
472
- "w_bm25": float(wb),
473
- "w_emb": float(we),
474
- "top_k": int(tk),
475
- "config_id": cid,
476
- "mean_hit@k_doc": float(cov["hit@k_doc"].mean()),
477
- "mean_precision@k_doc": float(cov["precision@k_doc"].mean()),
478
- "mean_recall@k_doc": float(cov["recall@k_doc"].mean()),
479
- "mean_ndcg@k_doc": float(cov["ndcg@k_doc"].mean()),
480
- }
481
- )
482
-
483
- surf_df = pd.DataFrame(surf_rows)
484
- surf_path = out_dir / "metrics_by_weights.csv"
485
- surf_df.to_csv(surf_path, index=False)
486
- print(
487
- f"{COLOR_TEXT}Wrote config surface metrics β†’ "
488
- f"{COLOR_ACCENT}{surf_path}{COLOR_RESET}"
489
- )
490
 
491
- print("βœ… rag_eval_metrics helpers loaded.")
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ rag_eval_metrics.py
4
+
5
+ Evaluate RAG retrieval quality by comparing app logs (JSONL) with a gold file (CSV).
6
+ """
7
+
8
+ import argparse
9
  import json
10
+ import os
11
+ import sys
12
  from pathlib import Path
13
  from typing import Dict, List, Tuple, Any, Optional
14
 
 
15
  import pandas as pd
16
+ import numpy as np
17
 
18
+
19
+ # ----------------------------- Small Utils ----------------------------- #
 
 
 
20
 
21
  def filename_key(s: str) -> str:
22
  s = (s or "").strip().replace("\\", "/").split("/")[-1]
23
  return s.casefold()
24
 
25
+
26
  def re_split_sc(s: str) -> List[str]:
27
  import re
28
  return re.split(r"[;,]", s)
29
 
30
+
31
  def _pick_last_non_empty(hit_lists) -> List[dict]:
32
+ """
33
+ Robustly select the last non-empty hits list from a pandas Series or iterable.
34
+
35
+ This fixes the KeyError that happens when using reversed() directly on a Series
36
+ with a non-range index.
37
+ """
38
+ # Convert pandas Series or other iterables to a plain Python list
39
  try:
40
  values = list(hit_lists.tolist())
41
  except AttributeError:
42
  values = list(hit_lists)
43
+
44
+ # Walk from last to first, return first non-empty list-like
45
  for lst in reversed(values):
46
  if isinstance(lst, (list, tuple)) and len(lst) > 0:
47
  return lst
48
+
49
+ # If everything was empty / NaN
50
  return []
51
 
52
+
53
+ # ----------------------------- IO Helpers ----------------------------- #
54
+
55
  def read_logs(jsonl_path: Path) -> pd.DataFrame:
56
  rows = []
57
  if (not jsonl_path.exists()) or jsonl_path.stat().st_size == 0:
 
67
  except Exception:
68
  continue
69
 
70
+ # Extract question
71
+ q = (((rec.get("inputs") or {}).get("question")) or "").strip()
 
 
 
 
 
 
72
 
73
+ # Extract retrieval hits (if present)
74
  retr = (rec.get("retrieval") or {})
75
  hits = retr.get("hits", [])
 
76
  norm_hits = []
77
  for h in hits or []:
78
  doc = (h.get("doc") or "").strip()
79
  page = str(h.get("page") or "").strip()
80
+
81
+ # Normalize page to int or None
82
  try:
83
  page_int = int(page)
84
  except Exception:
85
  page_int = None
86
+
87
  norm_hits.append({"doc": doc, "page": page_int})
88
 
89
+ rows.append({"question": q, "hits": norm_hits})
 
 
 
 
 
 
 
 
 
 
90
 
91
  df = pd.DataFrame(rows)
92
  if df.empty:
93
  return pd.DataFrame(columns=["question", "hits"])
94
 
95
+ # Group by normalized question text and keep last non-empty hits list per question
96
  df = (
97
+ df.groupby(df["question"].astype(str).str.casefold().str.strip(), as_index=False)
98
+ .agg({"question": "last", "hits": _pick_last_non_empty})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  )
100
  return df
101
 
102
+
103
  def read_gold(csv_path: Path) -> pd.DataFrame:
104
  df = pd.read_csv(csv_path)
105
  cols = {c.lower().strip(): c for c in df.columns}
106
 
107
+ # --- question column ---
108
  q_col = None
109
  for cand in ["question", "query", "q"]:
110
  if cand in cols:
111
  q_col = cols[cand]
112
  break
113
  if q_col is None:
114
+ raise ValueError("Gold CSV must contain a 'question' column (case-insensitive).")
115
 
116
+ # --- possible relevant_docs (list-in-cell) column ---
117
  rel_list_col = None
118
  for cand in ["relevant_docs", "relevant", "docs"]:
119
  if cand in cols:
120
  rel_list_col = cols[cand]
121
  break
122
 
123
+ # --- single-doc-per-row column ---
124
  doc_col = None
125
  for cand in ["doc", "document", "file", "doc_name"]:
126
  if cand in cols:
127
  doc_col = cols[cand]
128
  break
129
 
130
+ # --- optional page column ---
131
  page_col = None
132
  for cand in ["page", "page_num", "page_number"]:
133
  if cand in cols:
 
136
 
137
  rows = []
138
 
139
+ # Case 1: relevant_docs list column (no explicit doc_col)
140
  if rel_list_col and doc_col is None:
141
  for _, r in df.iterrows():
142
  q_raw = str(r[q_col]).strip()
143
  q_norm = q_raw.casefold().strip()
144
+
145
  rel_val = str(r[rel_list_col]) if pd.notna(r[rel_list_col]) else ""
146
  if not rel_val:
147
+ rows.append({
148
+ "question_raw": q_raw,
149
+ "question": q_norm,
150
+ "doc": None,
151
+ "page": np.nan
152
+ })
 
 
153
  continue
154
+
155
  parts = [p.strip() for p in re_split_sc(rel_val)]
156
  for d in parts:
157
+ rows.append({
158
+ "question_raw": q_raw,
159
+ "question": q_norm,
160
+ "doc": filename_key(d),
161
+ "page": np.nan
162
+ })
163
+
164
+ # Case 2: doc/page columns (one relevant doc per row)
165
  elif doc_col:
166
  for _, r in df.iterrows():
167
  q_raw = str(r[q_col]).strip()
168
  q_norm = q_raw.casefold().strip()
169
+
170
  d = str(r[doc_col]).strip() if pd.notna(r[doc_col]) else ""
171
  p = r[page_col] if (page_col and pd.notna(r[page_col])) else np.nan
172
+
173
  try:
174
  p = int(p)
175
  except Exception:
176
  p = np.nan
177
+
178
+ rows.append({
179
+ "question_raw": q_raw,
180
+ "question": q_norm,
181
+ "doc": filename_key(d),
182
+ "page": p
183
+ })
184
+
185
  else:
186
+ raise ValueError("Gold CSV must contain either a 'doc' column or a 'relevant_docs' column.")
187
 
188
  gold = pd.DataFrame(rows)
189
+
190
+ # Keep only rows with a valid doc (when docs exist)
191
  gold["has_doc"] = gold["doc"].apply(lambda x: isinstance(x, str) and len(x) > 0)
192
  if gold["has_doc"].any():
193
  gold = gold[gold["has_doc"]].copy()
194
  gold.drop(columns=["has_doc"], inplace=True, errors="ignore")
195
+
196
+ # Remove duplicates
197
  gold = gold.drop_duplicates(subset=["question", "doc", "page"])
198
+
199
  return gold
200
 
201
+
202
+ # ----------------------------- Metric Core ----------------------------- #
203
+
204
  def dcg_at_k(relevances: List[int]) -> float:
205
  dcg = 0.0
206
  for i, rel in enumerate(relevances, start=1):
 
208
  dcg += 1.0 / np.log2(i + 1.0)
209
  return float(dcg)
210
 
211
+
212
  def ndcg_at_k(relevances: List[int]) -> float:
213
  dcg = dcg_at_k(relevances)
214
  ideal = sorted(relevances, reverse=True)
 
217
  return 0.0
218
  return float(dcg / idcg)
219
 
220
+
221
  def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
222
  top = hits[:k] if hits else []
223
  pred_docs = [filename_key(h.get("doc", "")) for h in top]
224
  pred_pairs = [(filename_key(h.get("doc", "")), h.get("page", None)) for h in top]
225
 
226
+ # --- Doc-level metrics ---
227
  gold_doc_set = set([d for d in gold_docs if isinstance(d, str) and d])
228
 
229
  rel_bin_doc = [1 if d in gold_doc_set else 0 for d in pred_docs]
 
232
  rec_doc = (sum(rel_bin_doc) / max(1, len(gold_doc_set))) if gold_doc_set else 0.0
233
  ndcg_doc = ndcg_at_k(rel_bin_doc)
234
 
235
+ # --- Page-level metrics (only if gold has page labels) ---
236
  gold_pairs = set()
237
  for d, p in zip(gold_docs, gold_pages):
238
+ if isinstance(d, str) and d and (p is not None) and (not (isinstance(p, float) and np.isnan(p))):
 
 
239
  try:
240
  p_int = int(p)
241
  except Exception:
 
249
  rel_bin_page.append(0)
250
  else:
251
  rel_bin_page.append(1 if (d, p) in gold_pairs else 0)
252
+
253
  hitk_page = 1 if any(rel_bin_page) else 0
254
  prec_page = (sum(rel_bin_page) / max(1, len(pred_pairs))) if pred_pairs else 0.0
255
  rec_page = (sum(rel_bin_page) / max(1, len(gold_pairs))) if gold_pairs else 0.0
 
268
  "ndcg@k_page": ndcg_page,
269
  "n_gold_docs": int(len(gold_doc_set)),
270
  "n_gold_doc_pages": int(len(gold_pairs)),
271
+ "n_pred": int(len(pred_docs))
272
  }
273
 
274
+
275
+ # ----------------------------- Orchestration ----------------------------- #
276
+
277
+ # === Dark blue and accent colors ===
278
+ COLOR_TITLE = "\033[94m" # light blue for titles
279
+ COLOR_TEXT = "\033[34m" # dark blue
280
+ COLOR_ACCENT = "\033[36m" # cyan for metrics
281
+ COLOR_RESET = "\033[0m"
282
+
283
+
284
  def _fmt(x: Any) -> str:
285
  try:
286
  return f"{float(x):.3f}"
287
  except Exception:
288
  return "-"
289
 
 
 
 
 
 
 
 
 
 
 
290
 
291
+ def main():
292
+ ap = argparse.ArgumentParser()
293
+ ap.add_argument("--gold_csv", required=True, type=str)
294
+ ap.add_argument("--logs_jsonl", required=True, type=str)
295
+ ap.add_argument("--k", type=int, default=8)
296
+ ap.add_argument("--out_dir", type=str, default="rag_artifacts")
297
+ args = ap.parse_args()
298
+
299
+ out_dir = Path(args.out_dir)
300
  out_dir.mkdir(parents=True, exist_ok=True)
301
 
302
+ gold_path = Path(args.gold_csv)
303
+ logs_path = Path(args.logs_jsonl)
304
 
305
  if not gold_path.exists():
306
+ print(f"{COLOR_TEXT}❌ gold.csv not found at {gold_path}{COLOR_RESET}", file=sys.stderr)
307
+ sys.exit(0)
308
  if not logs_path.exists() or logs_path.stat().st_size == 0:
309
+ print(f"{COLOR_TEXT}❌ logs JSONL not found or empty at {logs_path}{COLOR_RESET}", file=sys.stderr)
310
+ sys.exit(0)
311
 
312
+ # Read gold
313
  try:
314
  gold = read_gold(gold_path)
315
  except Exception as e:
316
+ print(f"{COLOR_TEXT}❌ Failed to read gold: {e}{COLOR_RESET}", file=sys.stderr)
317
+ sys.exit(0)
318
 
319
+ # Read logs (with robust aggregation)
320
  try:
321
  logs = read_logs(logs_path)
322
  except Exception as e:
323
+ print(f"{COLOR_TEXT}❌ Failed to read logs: {e}{COLOR_RESET}", file=sys.stderr)
324
+ sys.exit(0)
325
 
326
  if gold.empty:
327
+ print(f"{COLOR_TEXT}❌ Gold file contains no usable rows.{COLOR_RESET}", file=sys.stderr)
328
+ sys.exit(0)
329
  if logs.empty:
330
+ print(f"{COLOR_TEXT}❌ Logs file contains no usable entries.{COLOR_RESET}", file=sys.stderr)
331
+ sys.exit(0)
332
 
333
+ # Build gold dict: normalized_question -> list of (doc, page)
334
  gdict: Dict[str, List[Tuple[str, Optional[int]]]] = {}
335
  for _, r in gold.iterrows():
336
  q = str(r["question"]).strip()
 
338
  p = r["page"] if "page" in r else np.nan
339
  gdict.setdefault(q, []).append((d, p))
340
 
341
+ # Normalize log questions for join
342
  logs["q_norm"] = logs["question"].astype(str).str.casefold().str.strip()
343
 
344
  perq_rows = []
345
  not_in_logs, not_in_gold = [], []
346
 
347
+ # For each gold question, compute metrics using logs
348
  for q_norm, pairs in gdict.items():
349
  row = logs[logs["q_norm"] == q_norm]
350
  gdocs = [d for (d, _) in pairs]
351
  gpages = [p for (_, p) in pairs]
352
 
353
  if row.empty:
354
+ # No logs for this gold question β†’ zero retrieval
355
  not_in_logs.append(q_norm)
356
  metrics = {
357
  "hit@k_doc": 0,
 
362
  "precision@k_page": np.nan,
363
  "recall@k_page": np.nan,
364
  "ndcg@k_page": np.nan,
365
+ "n_gold_docs": int(len(set([d for d in gdocs if isinstance(d, str) and d]))),
366
+ "n_gold_doc_pages": int(len([
367
+ (d, p) for (d, p) in zip(gdocs, gpages)
368
+ if isinstance(d, str) and d and pd.notna(p)
369
+ ])),
370
+ "n_pred": 0
 
 
 
 
 
 
 
371
  }
372
+ perq_rows.append({
373
+ "question": q_norm,
374
+ "covered_in_logs": 0,
375
+ **metrics
376
+ })
377
  continue
378
 
379
+ # Use aggregated hits from read_logs
380
  hits = row.iloc[0]["hits"] or []
381
+ metrics = compute_metrics_for_question(gdocs, gpages, hits, args.k)
382
+ perq_rows.append({
383
+ "question": q_norm,
384
+ "covered_in_logs": 1,
385
+ **metrics
386
+ })
387
+
388
+ # Any log questions not in gold
389
  gold_qs = set(gdict.keys())
390
  for qn in logs["q_norm"].tolist():
391
  if qn not in gold_qs:
 
399
  "questions_covered_in_logs": int(covered.shape[0]),
400
  "questions_missing_in_logs": int(len(not_in_logs)),
401
  "questions_in_logs_not_in_gold": int(len(set(not_in_gold))),
402
+ "k": int(args.k),
403
  "mean_hit@k_doc": float(covered["hit@k_doc"].mean()) if not covered.empty else 0.0,
404
  "mean_precision@k_doc": float(covered["precision@k_doc"].mean()) if not covered.empty else 0.0,
405
  "mean_recall@k_doc": float(covered["recall@k_doc"].mean()) if not covered.empty else 0.0,
406
  "mean_ndcg@k_doc": float(covered["ndcg@k_doc"].mean()) if not covered.empty else 0.0,
407
+ "mean_hit@k_page": float(covered["hit@k_page"].dropna().mean()) if covered["hit@k_page"].notna().any() else None,
408
+ "mean_precision@k_page": float(covered["precision@k_page"].dropna().mean()) if covered["precision@k_page"].notna().any() else None,
409
+ "mean_recall@k_page": float(covered["recall@k_page"].dropna().mean()) if covered["recall@k_page"].notna().any() else None,
410
+ "mean_ndcg@k_page": float(covered["ndcg@k_page"].dropna().mean()) if covered["ndcg@k_page"].notna().any() else None,
 
 
 
 
 
 
 
 
411
  "avg_gold_docs_per_q": float(perq["n_gold_docs"].mean()) if not perq.empty else 0.0,
412
  "avg_preds_per_q": float(perq["n_pred"].mean()) if not perq.empty else 0.0,
413
  "examples_missing_in_logs": list(not_in_logs[:10]),
 
415
  }
416
 
417
  perq_path = out_dir / "metrics_per_question.csv"
418
+ agg_path = out_dir / "metrics_aggregate.json"
419
 
420
  perq.to_csv(perq_path, index=False)
421
  with open(agg_path, "w", encoding="utf-8") as f:
422
  json.dump(agg, f, ensure_ascii=False, indent=2)
423
 
424
+ # === Console summary with color ===
425
  print(f"{COLOR_TITLE}RAG Evaluation Summary{COLOR_RESET}")
426
  print(f"{COLOR_TITLE}----------------------{COLOR_RESET}")
427
  print(f"{COLOR_TEXT}Gold questions: {COLOR_ACCENT}{agg['questions_total_gold']}{COLOR_RESET}")
 
438
  f"nDCG@k={_fmt(agg['mean_ndcg@k_doc'])}{COLOR_RESET}"
439
  )
440
 
441
+ if agg['mean_hit@k_page'] is not None:
442
  print(
443
  f"{COLOR_TEXT}Page-level:{COLOR_RESET} "
444
  f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_page'])} "
 
453
  print(f"{COLOR_TEXT}Wrote per-question CSV β†’ {COLOR_ACCENT}{perq_path}{COLOR_RESET}")
454
  print(f"{COLOR_TEXT}Wrote aggregate JSON β†’ {COLOR_ACCENT}{agg_path}{COLOR_RESET}")
455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
+ if __name__ == "__main__":
458
+ main()