Inframat-x commited on
Commit
81fd4b5
·
verified ·
1 Parent(s): eed2512

Update rag_eval_metrics.py

Browse files
Files changed (1) hide show
  1. rag_eval_metrics.py +162 -49
rag_eval_metrics.py CHANGED
@@ -16,10 +16,47 @@ import pandas as pd
16
  import numpy as np
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # ----------------------------- IO Helpers ----------------------------- #
20
 
21
  def read_logs(jsonl_path: Path) -> pd.DataFrame:
22
  rows = []
 
 
 
23
  with open(jsonl_path, "r", encoding="utf-8") as f:
24
  for line in f:
25
  line = line.strip()
@@ -29,29 +66,35 @@ def read_logs(jsonl_path: Path) -> pd.DataFrame:
29
  rec = json.loads(line)
30
  except Exception:
31
  continue
 
 
32
  q = (((rec.get("inputs") or {}).get("question")) or "").strip()
 
 
33
  retr = (rec.get("retrieval") or {})
34
  hits = retr.get("hits", [])
35
  norm_hits = []
36
  for h in hits or []:
37
  doc = (h.get("doc") or "").strip()
38
  page = str(h.get("page") or "").strip()
 
 
39
  try:
40
  page_int = int(page)
41
  except Exception:
42
  page_int = None
 
43
  norm_hits.append({"doc": doc, "page": page_int})
 
44
  rows.append({"question": q, "hits": norm_hits})
 
45
  df = pd.DataFrame(rows)
46
  if df.empty:
47
  return pd.DataFrame(columns=["question", "hits"])
48
- def _pick_last_non_empty(hit_lists: List[List[dict]]) -> List[dict]:
49
- for lst in reversed(hit_lists):
50
- if lst:
51
- return lst
52
- return []
53
  df = (
54
- df.groupby(df["question"].str.casefold().str.strip(), as_index=False)
55
  .agg({"question": "last", "hits": _pick_last_non_empty})
56
  )
57
  return df
@@ -60,6 +103,8 @@ def read_logs(jsonl_path: Path) -> pd.DataFrame:
60
  def read_gold(csv_path: Path) -> pd.DataFrame:
61
  df = pd.read_csv(csv_path)
62
  cols = {c.lower().strip(): c for c in df.columns}
 
 
63
  q_col = None
64
  for cand in ["question", "query", "q"]:
65
  if cand in cols:
@@ -68,18 +113,21 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
68
  if q_col is None:
69
  raise ValueError("Gold CSV must contain a 'question' column (case-insensitive).")
70
 
 
71
  rel_list_col = None
72
  for cand in ["relevant_docs", "relevant", "docs"]:
73
  if cand in cols:
74
  rel_list_col = cols[cand]
75
  break
76
 
 
77
  doc_col = None
78
  for cand in ["doc", "document", "file", "doc_name"]:
79
  if cand in cols:
80
  doc_col = cols[cand]
81
  break
82
 
 
83
  page_col = None
84
  for cand in ["page", "page_num", "page_number"]:
85
  if cand in cols:
@@ -87,48 +135,68 @@ def read_gold(csv_path: Path) -> pd.DataFrame:
87
  break
88
 
89
  rows = []
 
 
90
  if rel_list_col and doc_col is None:
91
  for _, r in df.iterrows():
92
  q_raw = str(r[q_col]).strip()
93
  q_norm = q_raw.casefold().strip()
 
94
  rel_val = str(r[rel_list_col]) if pd.notna(r[rel_list_col]) else ""
95
  if not rel_val:
96
- rows.append({"question_raw": q_raw, "question": q_norm, "doc": None, "page": np.nan})
 
 
 
 
 
97
  continue
 
98
  parts = [p.strip() for p in re_split_sc(rel_val)]
99
  for d in parts:
100
- rows.append({"question_raw": q_raw, "question": q_norm, "doc": filename_key(d), "page": np.nan})
 
 
 
 
 
 
 
101
  elif doc_col:
102
  for _, r in df.iterrows():
103
  q_raw = str(r[q_col]).strip()
104
  q_norm = q_raw.casefold().strip()
 
105
  d = str(r[doc_col]).strip() if pd.notna(r[doc_col]) else ""
106
- p = r[page_col] if page_col and pd.notna(r[page_col]) else np.nan
 
107
  try:
108
  p = int(p)
109
  except Exception:
110
  p = np.nan
111
- rows.append({"question_raw": q_raw, "question": q_norm, "doc": filename_key(d), "page": p})
 
 
 
 
 
 
 
112
  else:
113
  raise ValueError("Gold CSV must contain either a 'doc' column or a 'relevant_docs' column.")
114
 
115
  gold = pd.DataFrame(rows)
 
 
116
  gold["has_doc"] = gold["doc"].apply(lambda x: isinstance(x, str) and len(x) > 0)
117
  if gold["has_doc"].any():
118
  gold = gold[gold["has_doc"]].copy()
119
  gold.drop(columns=["has_doc"], inplace=True, errors="ignore")
120
- gold = gold.drop_duplicates(subset=["question", "doc", "page"])
121
- return gold
122
-
123
-
124
- def filename_key(s: str) -> str:
125
- s = (s or "").strip().replace("\\", "/").split("/")[-1]
126
- return s.casefold()
127
 
 
 
128
 
129
- def re_split_sc(s: str) -> List[str]:
130
- import re
131
- return re.split(r"[;,]", s)
132
 
133
 
134
  # ----------------------------- Metric Core ----------------------------- #
@@ -155,13 +223,16 @@ def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
155
  pred_docs = [filename_key(h.get("doc", "")) for h in top]
156
  pred_pairs = [(filename_key(h.get("doc", "")), h.get("page", None)) for h in top]
157
 
 
158
  gold_doc_set = set([d for d in gold_docs if isinstance(d, str) and d])
 
159
  rel_bin_doc = [1 if d in gold_doc_set else 0 for d in pred_docs]
160
  hitk_doc = 1 if any(rel_bin_doc) else 0
161
  prec_doc = (sum(rel_bin_doc) / max(1, len(pred_docs))) if pred_docs else 0.0
162
  rec_doc = (sum(rel_bin_doc) / max(1, len(gold_doc_set))) if gold_doc_set else 0.0
163
  ndcg_doc = ndcg_at_k(rel_bin_doc)
164
 
 
165
  gold_pairs = set()
166
  for d, p in zip(gold_docs, gold_pages):
167
  if isinstance(d, str) and d and (p is not None) and (not (isinstance(p, float) and np.isnan(p))):
@@ -172,8 +243,13 @@ def compute_metrics_for_question(gold_docs, gold_pages, hits, k):
172
  gold_pairs.add((d, p_int))
173
 
174
  if gold_pairs:
175
- rel_bin_page = [1 if ((d, (p if p is not None else -1)) in gold_pairs) else 0
176
- for (d, p) in [(d, (p if isinstance(p, int) else -1)) for (d, p) in pred_pairs]]
 
 
 
 
 
177
  hitk_page = 1 if any(rel_bin_page) else 0
178
  prec_page = (sum(rel_bin_page) / max(1, len(pred_pairs))) if pred_pairs else 0.0
179
  rec_page = (sum(rel_bin_page) / max(1, len(gold_pairs))) if gold_pairs else 0.0
@@ -204,6 +280,14 @@ COLOR_TEXT = "\033[34m" # dark blue
204
  COLOR_ACCENT = "\033[36m" # cyan for metrics
205
  COLOR_RESET = "\033[0m"
206
 
 
 
 
 
 
 
 
 
207
  def main():
208
  ap = argparse.ArgumentParser()
209
  ap.add_argument("--gold_csv", required=True, type=str)
@@ -225,12 +309,19 @@ def main():
225
  print(f"{COLOR_TEXT}❌ logs JSONL not found or empty at {logs_path}{COLOR_RESET}", file=sys.stderr)
226
  sys.exit(0)
227
 
 
228
  try:
229
  gold = read_gold(gold_path)
230
  except Exception as e:
231
  print(f"{COLOR_TEXT}❌ Failed to read gold: {e}{COLOR_RESET}", file=sys.stderr)
232
  sys.exit(0)
233
- logs = read_logs(logs_path)
 
 
 
 
 
 
234
 
235
  if gold.empty:
236
  print(f"{COLOR_TEXT}❌ Gold file contains no usable rows.{COLOR_RESET}", file=sys.stderr)
@@ -239,6 +330,7 @@ def main():
239
  print(f"{COLOR_TEXT}❌ Logs file contains no usable entries.{COLOR_RESET}", file=sys.stderr)
240
  sys.exit(0)
241
 
 
242
  gdict: Dict[str, List[Tuple[str, Optional[int]]]] = {}
243
  for _, r in gold.iterrows():
244
  q = str(r["question"]).strip()
@@ -246,32 +338,54 @@ def main():
246
  p = r["page"] if "page" in r else np.nan
247
  gdict.setdefault(q, []).append((d, p))
248
 
 
249
  logs["q_norm"] = logs["question"].astype(str).str.casefold().str.strip()
 
250
  perq_rows = []
251
  not_in_logs, not_in_gold = [], []
252
 
 
253
  for q_norm, pairs in gdict.items():
254
  row = logs[logs["q_norm"] == q_norm]
 
 
 
255
  if row.empty:
 
256
  not_in_logs.append(q_norm)
257
- gdocs = [d for (d, _) in pairs]
258
- gpages = [p for (_, p) in pairs]
259
  metrics = {
260
- "hit@k_doc": 0, "precision@k_doc": 0.0, "recall@k_doc": 0.0, "ndcg@k_doc": 0.0,
261
- "hit@k_page": np.nan, "precision@k_page": np.nan, "recall@k_page": np.nan, "ndcg@k_page": np.nan,
 
 
 
 
 
 
262
  "n_gold_docs": int(len(set([d for d in gdocs if isinstance(d, str) and d]))),
263
- "n_gold_doc_pages": int(len([(d, p) for (d, p) in zip(gdocs, gpages) if isinstance(d, str) and d and pd.notna(p)])),
 
 
 
264
  "n_pred": 0
265
  }
266
- perq_rows.append({"question": q_norm, "covered_in_logs": 0, **metrics})
 
 
 
 
267
  continue
268
 
 
269
  hits = row.iloc[0]["hits"] or []
270
- gdocs = [d for (d, _) in pairs]
271
- gpages = [p for (_, p) in pairs]
272
  metrics = compute_metrics_for_question(gdocs, gpages, hits, args.k)
273
- perq_rows.append({"question": q_norm, "covered_in_logs": 1, **metrics})
 
 
 
 
274
 
 
275
  gold_qs = set(gdict.keys())
276
  for qn in logs["q_norm"].tolist():
277
  if qn not in gold_qs:
@@ -279,6 +393,7 @@ def main():
279
 
280
  perq = pd.DataFrame(perq_rows)
281
  covered = perq[perq["covered_in_logs"] == 1].copy()
 
282
  agg = {
283
  "questions_total_gold": int(len(gdict)),
284
  "questions_covered_in_logs": int(covered.shape[0]),
@@ -301,6 +416,7 @@ def main():
301
 
302
  perq_path = out_dir / "metrics_per_question.csv"
303
  agg_path = out_dir / "metrics_aggregate.json"
 
304
  perq.to_csv(perq_path, index=False)
305
  with open(agg_path, "w", encoding="utf-8") as f:
306
  json.dump(agg, f, ensure_ascii=False, indent=2)
@@ -314,18 +430,22 @@ def main():
314
  print(f"{COLOR_TEXT}In logs but not in gold: {COLOR_ACCENT}{agg['questions_in_logs_not_in_gold']}{COLOR_RESET}")
315
  print(f"{COLOR_TEXT}k = {COLOR_ACCENT}{agg['k']}{COLOR_RESET}\n")
316
 
317
- print(f"{COLOR_TEXT}Doc-level:{COLOR_RESET} "
318
- f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_doc'])} "
319
- f"Precision@k={_fmt(agg['mean_precision@k_doc'])} "
320
- f"Recall@k={_fmt(agg['mean_recall@k_doc'])} "
321
- f"nDCG@k={_fmt(agg['mean_ndcg@k_doc'])}{COLOR_RESET}")
 
 
322
 
323
  if agg['mean_hit@k_page'] is not None:
324
- print(f"{COLOR_TEXT}Page-level:{COLOR_RESET} "
325
- f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_page'])} "
326
- f"Precision@k={_fmt(agg['mean_precision@k_page'])} "
327
- f"Recall@k={_fmt(agg['mean_recall@k_page'])} "
328
- f"nDCG@k={_fmt(agg['mean_ndcg@k_page'])}{COLOR_RESET}")
 
 
329
  else:
330
  print(f"{COLOR_TEXT}Page-level: (no page labels in gold){COLOR_RESET}")
331
 
@@ -334,12 +454,5 @@ def main():
334
  print(f"{COLOR_TEXT}Wrote aggregate JSON → {COLOR_ACCENT}{agg_path}{COLOR_RESET}")
335
 
336
 
337
- def _fmt(x: Any) -> str:
338
- try:
339
- return f"{float(x):.3f}"
340
- except Exception:
341
- return "-"
342
-
343
-
344
  if __name__ == "__main__":
345
  main()
 
16
  import numpy as np
17
 
18
 
19
+ # ----------------------------- Small Utils ----------------------------- #
20
+
21
+ def filename_key(s: str) -> str:
22
+ s = (s or "").strip().replace("\\", "/").split("/")[-1]
23
+ return s.casefold()
24
+
25
+
26
+ def re_split_sc(s: str) -> List[str]:
27
+ import re
28
+ return re.split(r"[;,]", s)
29
+
30
+
31
+ def _pick_last_non_empty(hit_lists) -> List[dict]:
32
+ """
33
+ Robustly select the last non-empty hits list from a pandas Series or iterable.
34
+
35
+ This fixes the KeyError that happens when using reversed() directly on a Series
36
+ with a non-range index.
37
+ """
38
+ # Convert pandas Series or other iterables to a plain Python list
39
+ try:
40
+ values = list(hit_lists.tolist())
41
+ except AttributeError:
42
+ values = list(hit_lists)
43
+
44
+ # Walk from last to first, return first non-empty list-like
45
+ for lst in reversed(values):
46
+ if isinstance(lst, (list, tuple)) and len(lst) > 0:
47
+ return lst
48
+
49
+ # If everything was empty / NaN
50
+ return []
51
+
52
+
53
  # ----------------------------- IO Helpers ----------------------------- #
54
 
55
  def read_logs(jsonl_path: Path) -> pd.DataFrame:
56
  rows = []
57
+ if (not jsonl_path.exists()) or jsonl_path.stat().st_size == 0:
58
+ return pd.DataFrame(columns=["question", "hits"])
59
+
60
  with open(jsonl_path, "r", encoding="utf-8") as f:
61
  for line in f:
62
  line = line.strip()
 
66
  rec = json.loads(line)
67
  except Exception:
68
  continue
69
+
70
+ # Extract question
71
  q = (((rec.get("inputs") or {}).get("question")) or "").strip()
72
+
73
+ # Extract retrieval hits (if present)
74
  retr = (rec.get("retrieval") or {})
75
  hits = retr.get("hits", [])
76
  norm_hits = []
77
  for h in hits or []:
78
  doc = (h.get("doc") or "").strip()
79
  page = str(h.get("page") or "").strip()
80
+
81
+ # Normalize page to int or None
82
  try:
83
  page_int = int(page)
84
  except Exception:
85
  page_int = None
86
+
87
  norm_hits.append({"doc": doc, "page": page_int})
88
+
89
  rows.append({"question": q, "hits": norm_hits})
90
+
91
  df = pd.DataFrame(rows)
92
  if df.empty:
93
  return pd.DataFrame(columns=["question", "hits"])
94
+
95
+ # Group by normalized question text and keep last non-empty hits list per question
 
 
 
96
  df = (
97
+ df.groupby(df["question"].astype(str).str.casefold().str.strip(), as_index=False)
98
  .agg({"question": "last", "hits": _pick_last_non_empty})
99
  )
100
  return df
 
103
  def read_gold(csv_path: Path) -> pd.DataFrame:
104
  df = pd.read_csv(csv_path)
105
  cols = {c.lower().strip(): c for c in df.columns}
106
+
107
+ # --- question column ---
108
  q_col = None
109
  for cand in ["question", "query", "q"]:
110
  if cand in cols:
 
113
  if q_col is None:
114
  raise ValueError("Gold CSV must contain a 'question' column (case-insensitive).")
115
 
116
+ # --- possible relevant_docs (list-in-cell) column ---
117
  rel_list_col = None
118
  for cand in ["relevant_docs", "relevant", "docs"]:
119
  if cand in cols:
120
  rel_list_col = cols[cand]
121
  break
122
 
123
+ # --- single-doc-per-row column ---
124
  doc_col = None
125
  for cand in ["doc", "document", "file", "doc_name"]:
126
  if cand in cols:
127
  doc_col = cols[cand]
128
  break
129
 
130
+ # --- optional page column ---
131
  page_col = None
132
  for cand in ["page", "page_num", "page_number"]:
133
  if cand in cols:
 
135
  break
136
 
137
  rows = []
138
+
139
+ # Case 1: relevant_docs list column (no explicit doc_col)
140
  if rel_list_col and doc_col is None:
141
  for _, r in df.iterrows():
142
  q_raw = str(r[q_col]).strip()
143
  q_norm = q_raw.casefold().strip()
144
+
145
  rel_val = str(r[rel_list_col]) if pd.notna(r[rel_list_col]) else ""
146
  if not rel_val:
147
+ rows.append({
148
+ "question_raw": q_raw,
149
+ "question": q_norm,
150
+ "doc": None,
151
+ "page": np.nan
152
+ })
153
  continue
154
+
155
  parts = [p.strip() for p in re_split_sc(rel_val)]
156
  for d in parts:
157
+ rows.append({
158
+ "question_raw": q_raw,
159
+ "question": q_norm,
160
+ "doc": filename_key(d),
161
+ "page": np.nan
162
+ })
163
+
164
+ # Case 2: doc/page columns (one relevant doc per row)
165
  elif doc_col:
166
  for _, r in df.iterrows():
167
  q_raw = str(r[q_col]).strip()
168
  q_norm = q_raw.casefold().strip()
169
+
170
  d = str(r[doc_col]).strip() if pd.notna(r[doc_col]) else ""
171
+ p = r[page_col] if (page_col and pd.notna(r[page_col])) else np.nan
172
+
173
  try:
174
  p = int(p)
175
  except Exception:
176
  p = np.nan
177
+
178
+ rows.append({
179
+ "question_raw": q_raw,
180
+ "question": q_norm,
181
+ "doc": filename_key(d),
182
+ "page": p
183
+ })
184
+
185
  else:
186
  raise ValueError("Gold CSV must contain either a 'doc' column or a 'relevant_docs' column.")
187
 
188
  gold = pd.DataFrame(rows)
189
+
190
+ # Keep only rows with a valid doc (when docs exist)
191
  gold["has_doc"] = gold["doc"].apply(lambda x: isinstance(x, str) and len(x) > 0)
192
  if gold["has_doc"].any():
193
  gold = gold[gold["has_doc"]].copy()
194
  gold.drop(columns=["has_doc"], inplace=True, errors="ignore")
 
 
 
 
 
 
 
195
 
196
+ # Remove duplicates
197
+ gold = gold.drop_duplicates(subset=["question", "doc", "page"])
198
 
199
+ return gold
 
 
200
 
201
 
202
  # ----------------------------- Metric Core ----------------------------- #
 
223
  pred_docs = [filename_key(h.get("doc", "")) for h in top]
224
  pred_pairs = [(filename_key(h.get("doc", "")), h.get("page", None)) for h in top]
225
 
226
+ # --- Doc-level metrics ---
227
  gold_doc_set = set([d for d in gold_docs if isinstance(d, str) and d])
228
+
229
  rel_bin_doc = [1 if d in gold_doc_set else 0 for d in pred_docs]
230
  hitk_doc = 1 if any(rel_bin_doc) else 0
231
  prec_doc = (sum(rel_bin_doc) / max(1, len(pred_docs))) if pred_docs else 0.0
232
  rec_doc = (sum(rel_bin_doc) / max(1, len(gold_doc_set))) if gold_doc_set else 0.0
233
  ndcg_doc = ndcg_at_k(rel_bin_doc)
234
 
235
+ # --- Page-level metrics (only if gold has page labels) ---
236
  gold_pairs = set()
237
  for d, p in zip(gold_docs, gold_pages):
238
  if isinstance(d, str) and d and (p is not None) and (not (isinstance(p, float) and np.isnan(p))):
 
243
  gold_pairs.add((d, p_int))
244
 
245
  if gold_pairs:
246
+ rel_bin_page = []
247
+ for (d, p) in pred_pairs:
248
+ if p is None or not isinstance(p, int):
249
+ rel_bin_page.append(0)
250
+ else:
251
+ rel_bin_page.append(1 if (d, p) in gold_pairs else 0)
252
+
253
  hitk_page = 1 if any(rel_bin_page) else 0
254
  prec_page = (sum(rel_bin_page) / max(1, len(pred_pairs))) if pred_pairs else 0.0
255
  rec_page = (sum(rel_bin_page) / max(1, len(gold_pairs))) if gold_pairs else 0.0
 
280
  COLOR_ACCENT = "\033[36m" # cyan for metrics
281
  COLOR_RESET = "\033[0m"
282
 
283
+
284
+ def _fmt(x: Any) -> str:
285
+ try:
286
+ return f"{float(x):.3f}"
287
+ except Exception:
288
+ return "-"
289
+
290
+
291
  def main():
292
  ap = argparse.ArgumentParser()
293
  ap.add_argument("--gold_csv", required=True, type=str)
 
309
  print(f"{COLOR_TEXT}❌ logs JSONL not found or empty at {logs_path}{COLOR_RESET}", file=sys.stderr)
310
  sys.exit(0)
311
 
312
+ # Read gold
313
  try:
314
  gold = read_gold(gold_path)
315
  except Exception as e:
316
  print(f"{COLOR_TEXT}❌ Failed to read gold: {e}{COLOR_RESET}", file=sys.stderr)
317
  sys.exit(0)
318
+
319
+ # Read logs (with robust aggregation)
320
+ try:
321
+ logs = read_logs(logs_path)
322
+ except Exception as e:
323
+ print(f"{COLOR_TEXT}❌ Failed to read logs: {e}{COLOR_RESET}", file=sys.stderr)
324
+ sys.exit(0)
325
 
326
  if gold.empty:
327
  print(f"{COLOR_TEXT}❌ Gold file contains no usable rows.{COLOR_RESET}", file=sys.stderr)
 
330
  print(f"{COLOR_TEXT}❌ Logs file contains no usable entries.{COLOR_RESET}", file=sys.stderr)
331
  sys.exit(0)
332
 
333
+ # Build gold dict: normalized_question -> list of (doc, page)
334
  gdict: Dict[str, List[Tuple[str, Optional[int]]]] = {}
335
  for _, r in gold.iterrows():
336
  q = str(r["question"]).strip()
 
338
  p = r["page"] if "page" in r else np.nan
339
  gdict.setdefault(q, []).append((d, p))
340
 
341
+ # Normalize log questions for join
342
  logs["q_norm"] = logs["question"].astype(str).str.casefold().str.strip()
343
+
344
  perq_rows = []
345
  not_in_logs, not_in_gold = [], []
346
 
347
+ # For each gold question, compute metrics using logs
348
  for q_norm, pairs in gdict.items():
349
  row = logs[logs["q_norm"] == q_norm]
350
+ gdocs = [d for (d, _) in pairs]
351
+ gpages = [p for (_, p) in pairs]
352
+
353
  if row.empty:
354
+ # No logs for this gold question → zero retrieval
355
  not_in_logs.append(q_norm)
 
 
356
  metrics = {
357
+ "hit@k_doc": 0,
358
+ "precision@k_doc": 0.0,
359
+ "recall@k_doc": 0.0,
360
+ "ndcg@k_doc": 0.0,
361
+ "hit@k_page": np.nan,
362
+ "precision@k_page": np.nan,
363
+ "recall@k_page": np.nan,
364
+ "ndcg@k_page": np.nan,
365
  "n_gold_docs": int(len(set([d for d in gdocs if isinstance(d, str) and d]))),
366
+ "n_gold_doc_pages": int(len([
367
+ (d, p) for (d, p) in zip(gdocs, gpages)
368
+ if isinstance(d, str) and d and pd.notna(p)
369
+ ])),
370
  "n_pred": 0
371
  }
372
+ perq_rows.append({
373
+ "question": q_norm,
374
+ "covered_in_logs": 0,
375
+ **metrics
376
+ })
377
  continue
378
 
379
+ # Use aggregated hits from read_logs
380
  hits = row.iloc[0]["hits"] or []
 
 
381
  metrics = compute_metrics_for_question(gdocs, gpages, hits, args.k)
382
+ perq_rows.append({
383
+ "question": q_norm,
384
+ "covered_in_logs": 1,
385
+ **metrics
386
+ })
387
 
388
+ # Any log questions not in gold
389
  gold_qs = set(gdict.keys())
390
  for qn in logs["q_norm"].tolist():
391
  if qn not in gold_qs:
 
393
 
394
  perq = pd.DataFrame(perq_rows)
395
  covered = perq[perq["covered_in_logs"] == 1].copy()
396
+
397
  agg = {
398
  "questions_total_gold": int(len(gdict)),
399
  "questions_covered_in_logs": int(covered.shape[0]),
 
416
 
417
  perq_path = out_dir / "metrics_per_question.csv"
418
  agg_path = out_dir / "metrics_aggregate.json"
419
+
420
  perq.to_csv(perq_path, index=False)
421
  with open(agg_path, "w", encoding="utf-8") as f:
422
  json.dump(agg, f, ensure_ascii=False, indent=2)
 
430
  print(f"{COLOR_TEXT}In logs but not in gold: {COLOR_ACCENT}{agg['questions_in_logs_not_in_gold']}{COLOR_RESET}")
431
  print(f"{COLOR_TEXT}k = {COLOR_ACCENT}{agg['k']}{COLOR_RESET}\n")
432
 
433
+ print(
434
+ f"{COLOR_TEXT}Doc-level:{COLOR_RESET} "
435
+ f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_doc'])} "
436
+ f"Precision@k={_fmt(agg['mean_precision@k_doc'])} "
437
+ f"Recall@k={_fmt(agg['mean_recall@k_doc'])} "
438
+ f"nDCG@k={_fmt(agg['mean_ndcg@k_doc'])}{COLOR_RESET}"
439
+ )
440
 
441
  if agg['mean_hit@k_page'] is not None:
442
+ print(
443
+ f"{COLOR_TEXT}Page-level:{COLOR_RESET} "
444
+ f"{COLOR_ACCENT}Hit@k={_fmt(agg['mean_hit@k_page'])} "
445
+ f"Precision@k={_fmt(agg['mean_precision@k_page'])} "
446
+ f"Recall@k={_fmt(agg['mean_recall@k_page'])} "
447
+ f"nDCG@k={_fmt(agg['mean_ndcg@k_page'])}{COLOR_RESET}"
448
+ )
449
  else:
450
  print(f"{COLOR_TEXT}Page-level: (no page labels in gold){COLOR_RESET}")
451
 
 
454
  print(f"{COLOR_TEXT}Wrote aggregate JSON → {COLOR_ACCENT}{agg_path}{COLOR_RESET}")
455
 
456
 
 
 
 
 
 
 
 
457
  if __name__ == "__main__":
458
  main()