amanwithaplan commited on
Commit
cddbc43
·
verified ·
1 Parent(s): e76af51

Replace correlation with proper ranking metrics (NDCG, MRR)

Browse files
Files changed (1) hide show
  1. train_reranker.py +135 -55
train_reranker.py CHANGED
@@ -7,6 +7,7 @@
7
  # "transformers>=4.48",
8
  # "trackio",
9
  # "scipy",
 
10
  # ]
11
  # ///
12
  """
@@ -18,9 +19,10 @@ Dataset format: {"query": "...", "text": "...", "score": 0.0-1.0}
18
 
19
  import logging
20
  import os
 
21
  from collections import defaultdict
22
  import trackio
23
- import torch
24
  from datasets import load_dataset
25
  from sentence_transformers.cross_encoder import (
26
  CrossEncoder,
@@ -28,7 +30,7 @@ from sentence_transformers.cross_encoder import (
28
  CrossEncoderTrainingArguments,
29
  )
30
  from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
31
- from scipy.stats import spearmanr, pearsonr
32
  from transformers import TrainerCallback
33
 
34
  logging.basicConfig(level=logging.INFO)
@@ -46,31 +48,107 @@ RUN_NAME = os.environ.get("RUN_NAME", "reranker-03130903")
46
  SPACE_ID = os.environ.get("TRACKIO_SPACE_ID", "amanwithaplan/trackio")
47
 
48
 
49
- def evaluate_correlation(model, eval_dataset):
50
- """Evaluate correlation between predicted scores and labels."""
51
- pairs = [(item["sentence1"], item["sentence2"]) for item in eval_dataset]
52
- labels = [item["label"] for item in eval_dataset]
 
 
 
 
53
 
54
- predictions = model.predict(pairs, show_progress_bar=True)
55
 
56
- spearman = spearmanr(predictions, labels).correlation
57
- pearson = pearsonr(predictions, labels).statistic
 
58
 
59
- # Mean absolute error
60
- mae = sum(abs(p - l) for p, l in zip(predictions, labels)) / len(labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  return {
63
- "spearman": spearman,
64
- "pearson": pearson,
65
- "mae": mae,
66
- "pred_mean": float(predictions.mean()),
67
- "pred_std": float(predictions.std()),
68
- "label_mean": sum(labels) / len(labels),
69
  }
70
 
71
 
72
  class DomainEvalCallback(TrainerCallback):
73
- """Callback to log our domain-specific correlation metrics during training."""
74
 
75
  def __init__(self, model, eval_dataset_full):
76
  self.model = model
@@ -78,51 +156,53 @@ class DomainEvalCallback(TrainerCallback):
78
 
79
  def on_evaluate(self, args, state, control, **kwargs):
80
  """Run after each evaluation step."""
81
- # Get correlation metrics
82
- pairs = [(item["sentence1"], item["sentence2"]) for item in self.eval_dataset_full]
83
- labels = [item["label"] for item in self.eval_dataset_full]
84
-
85
- predictions = self.model.predict(pairs, show_progress_bar=False)
86
-
87
- spearman = spearmanr(predictions, labels).correlation
88
- pearson_val = pearsonr(predictions, labels).statistic
89
- mae = sum(abs(p - l) for p, l in zip(predictions, labels)) / len(labels)
90
 
91
  # Log to trackio
92
  trackio.log({
93
- "domain/spearman": spearman,
94
- "domain/pearson": pearson_val,
95
- "domain/mae": float(mae),
96
- "domain/pred_mean": float(predictions.mean()),
97
- "domain/pred_std": float(predictions.std()),
98
  })
99
 
100
- logger.info(f"Domain eval - Spearman: {spearman:.4f}, Pearson: {pearson_val:.4f}, MAE: {mae:.4f}")
 
 
 
 
 
 
101
 
102
 
103
  def evaluate_by_type(model, eval_dataset, type_column="type"):
104
- """Evaluate correlation per content type."""
105
  if type_column not in eval_dataset.column_names:
106
  return {}
107
 
108
- # Group by type
109
  by_type = defaultdict(list)
110
  for item in eval_dataset:
111
  by_type[item[type_column]].append(item)
112
 
113
  results = {}
114
  for content_type, items in by_type.items():
115
- if len(items) < 5:
116
- continue
117
-
118
- pairs = [(item["sentence1"], item["sentence2"]) for item in items]
119
- labels = [item["label"] for item in items]
120
- predictions = model.predict(pairs)
121
-
122
- if len(set(labels)) > 1: # Need variance for correlation
123
- results[f"{content_type}_spearman"] = spearmanr(predictions, labels).correlation
124
- results[f"{content_type}_mae"] = sum(abs(p - l) for p, l in zip(predictions, labels)) / len(labels)
125
- results[f"{content_type}_n"] = len(items)
 
 
 
 
 
126
 
127
  return results
128
 
@@ -193,9 +273,9 @@ def main():
193
  })
194
  logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
195
 
196
- # Evaluate base model before training
197
  logger.info("Evaluating base model on eval set...")
198
- base_metrics = evaluate_correlation(model, eval_dataset)
199
  for key, value in base_metrics.items():
200
  trackio.log({f"base_model/{key}": value})
201
  logger.info(f"Base model metrics: {base_metrics}")
@@ -231,7 +311,7 @@ def main():
231
  run_name=RUN_NAME,
232
  )
233
 
234
- # Custom callback to log domain-specific metrics during training
235
  domain_callback = DomainEvalCallback(model, eval_dataset_full)
236
 
237
  trainer = CrossEncoderTrainer(
@@ -246,14 +326,14 @@ def main():
246
  logger.info("Starting training...")
247
  trainer.train()
248
 
249
- # Final evaluation on our eval set
250
- logger.info("Running final correlation evaluation...")
251
- final_metrics = evaluate_correlation(model, eval_dataset)
252
  for key, value in final_metrics.items():
253
  trackio.log({f"final/{key}": value})
254
  logger.info(f"Final metrics: {final_metrics}")
255
 
256
- # Per-type evaluation (use full eval dataset with type column)
257
  logger.info("Evaluating by content type...")
258
  type_metrics = evaluate_by_type(model, eval_dataset_full)
259
  for key, value in type_metrics.items():
@@ -262,8 +342,8 @@ def main():
262
 
263
  # Log improvement
264
  trackio.log({
265
- "improvement/spearman_delta": final_metrics["spearman"] - base_metrics["spearman"],
266
- "improvement/mae_delta": base_metrics["mae"] - final_metrics["mae"], # Lower is better
267
  })
268
 
269
  logger.info(f"Pushing final model to {HUB_MODEL_ID}")
 
7
  # "transformers>=4.48",
8
  # "trackio",
9
  # "scipy",
10
+ # "numpy",
11
  # ]
12
  # ///
13
  """
 
19
 
20
  import logging
21
  import os
22
+ import math
23
  from collections import defaultdict
24
  import trackio
25
+ import numpy as np
26
  from datasets import load_dataset
27
  from sentence_transformers.cross_encoder import (
28
  CrossEncoder,
 
30
  CrossEncoderTrainingArguments,
31
  )
32
  from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
33
+ from scipy.stats import spearmanr
34
  from transformers import TrainerCallback
35
 
36
  logging.basicConfig(level=logging.INFO)
 
48
  SPACE_ID = os.environ.get("TRACKIO_SPACE_ID", "amanwithaplan/trackio")
49
 
50
 
51
+ def dcg_at_k(relevances, k):
52
+ """Compute DCG@k."""
53
+ relevances = np.array(relevances)[:k]
54
+ if len(relevances) == 0:
55
+ return 0.0
56
+ # DCG = sum of rel_i / log2(i+2) for i in 0..k-1
57
+ discounts = np.log2(np.arange(len(relevances)) + 2)
58
+ return np.sum(relevances / discounts)
59
 
 
60
 
61
+ def ndcg_at_k(predicted_order, true_relevances, k):
62
+ """
63
+ Compute NDCG@k.
64
 
65
+ predicted_order: indices of docs sorted by model score (descending)
66
+ true_relevances: ground truth relevance scores for each doc
67
+ """
68
+ # Get relevances in predicted order
69
+ predicted_relevances = [true_relevances[i] for i in predicted_order]
70
+
71
+ # Ideal order: sort by true relevance descending
72
+ ideal_relevances = sorted(true_relevances, reverse=True)
73
+
74
+ dcg = dcg_at_k(predicted_relevances, k)
75
+ idcg = dcg_at_k(ideal_relevances, k)
76
+
77
+ if idcg == 0:
78
+ return 0.0
79
+ return dcg / idcg
80
+
81
+
82
+ def mrr(predicted_order, true_relevances, threshold=0.5):
83
+ """
84
+ Compute MRR (Mean Reciprocal Rank).
85
+
86
+ Returns 1/rank of first relevant doc (relevance > threshold).
87
+ """
88
+ for rank, idx in enumerate(predicted_order, start=1):
89
+ if true_relevances[idx] > threshold:
90
+ return 1.0 / rank
91
+ return 0.0
92
+
93
+
94
+ def evaluate_ranking(model, eval_dataset):
95
+ """
96
+ Proper ranking evaluation: group by query, compute NDCG and MRR.
97
+
98
+ This measures what we actually care about:
99
+ "Given a query with multiple docs, does the model rank them correctly?"
100
+ """
101
+ # Group samples by query
102
+ query_groups = defaultdict(list)
103
+ for item in eval_dataset:
104
+ query_groups[item["sentence1"]].append({
105
+ "text": item["sentence2"],
106
+ "label": item["label"]
107
+ })
108
+
109
+ # Filter to queries with multiple docs (need at least 2 to rank)
110
+ query_groups = {q: docs for q, docs in query_groups.items() if len(docs) >= 2}
111
+
112
+ if not query_groups:
113
+ return {"ndcg@3": 0.0, "ndcg@5": 0.0, "mrr": 0.0, "n_queries": 0}
114
+
115
+ ndcg_3_scores = []
116
+ ndcg_5_scores = []
117
+ mrr_scores = []
118
+ rank_correlations = []
119
+
120
+ for query, docs in query_groups.items():
121
+ # Get model predictions for this query's docs
122
+ pairs = [(query, d["text"]) for d in docs]
123
+ predictions = model.predict(pairs, show_progress_bar=False)
124
+
125
+ true_relevances = [d["label"] for d in docs]
126
+
127
+ # Get predicted order: indices sorted by prediction descending
128
+ predicted_order = np.argsort(predictions)[::-1].tolist()
129
+
130
+ # Compute metrics
131
+ ndcg_3_scores.append(ndcg_at_k(predicted_order, true_relevances, k=3))
132
+ ndcg_5_scores.append(ndcg_at_k(predicted_order, true_relevances, k=5))
133
+ mrr_scores.append(mrr(predicted_order, true_relevances, threshold=0.5))
134
+
135
+ # Rank correlation within this query
136
+ if len(set(true_relevances)) > 1: # Need variance
137
+ corr = spearmanr(predictions, true_relevances).correlation
138
+ if not math.isnan(corr):
139
+ rank_correlations.append(corr)
140
 
141
  return {
142
+ "ndcg@3": np.mean(ndcg_3_scores),
143
+ "ndcg@5": np.mean(ndcg_5_scores),
144
+ "mrr": np.mean(mrr_scores),
145
+ "rank_corr": np.mean(rank_correlations) if rank_correlations else 0.0,
146
+ "n_queries": len(query_groups),
 
147
  }
148
 
149
 
150
  class DomainEvalCallback(TrainerCallback):
151
+ """Callback to log proper ranking metrics during training."""
152
 
153
  def __init__(self, model, eval_dataset_full):
154
  self.model = model
 
156
 
157
  def on_evaluate(self, args, state, control, **kwargs):
158
  """Run after each evaluation step."""
159
+ metrics = evaluate_ranking(self.model, self.eval_dataset_full)
 
 
 
 
 
 
 
 
160
 
161
  # Log to trackio
162
  trackio.log({
163
+ "domain/ndcg@3": metrics["ndcg@3"],
164
+ "domain/ndcg@5": metrics["ndcg@5"],
165
+ "domain/mrr": metrics["mrr"],
166
+ "domain/rank_corr": metrics["rank_corr"],
 
167
  })
168
 
169
+ logger.info(
170
+ f"Domain eval - NDCG@3: {metrics['ndcg@3']:.4f}, "
171
+ f"NDCG@5: {metrics['ndcg@5']:.4f}, "
172
+ f"MRR: {metrics['mrr']:.4f}, "
173
+ f"RankCorr: {metrics['rank_corr']:.4f} "
174
+ f"(n={metrics['n_queries']} queries)"
175
+ )
176
 
177
 
178
  def evaluate_by_type(model, eval_dataset, type_column="type"):
179
+ """Evaluate ranking metrics per content type."""
180
  if type_column not in eval_dataset.column_names:
181
  return {}
182
 
183
+ # Group by type first
184
  by_type = defaultdict(list)
185
  for item in eval_dataset:
186
  by_type[item[type_column]].append(item)
187
 
188
  results = {}
189
  for content_type, items in by_type.items():
190
+ # Create a mini dataset for this type
191
+ class TypeDataset:
192
+ def __init__(self, items):
193
+ self.items = items
194
+ def __iter__(self):
195
+ return iter(self.items)
196
+ @property
197
+ def column_names(self):
198
+ return ["sentence1", "sentence2", "label"]
199
+
200
+ type_metrics = evaluate_ranking(model, TypeDataset(items))
201
+
202
+ if type_metrics["n_queries"] >= 2:
203
+ results[f"{content_type}_ndcg@5"] = type_metrics["ndcg@5"]
204
+ results[f"{content_type}_mrr"] = type_metrics["mrr"]
205
+ results[f"{content_type}_n_queries"] = type_metrics["n_queries"]
206
 
207
  return results
208
 
 
273
  })
274
  logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
275
 
276
+ # Evaluate base model before training with proper ranking metrics
277
  logger.info("Evaluating base model on eval set...")
278
+ base_metrics = evaluate_ranking(model, eval_dataset_full)
279
  for key, value in base_metrics.items():
280
  trackio.log({f"base_model/{key}": value})
281
  logger.info(f"Base model metrics: {base_metrics}")
 
311
  run_name=RUN_NAME,
312
  )
313
 
314
+ # Custom callback to log domain-specific ranking metrics during training
315
  domain_callback = DomainEvalCallback(model, eval_dataset_full)
316
 
317
  trainer = CrossEncoderTrainer(
 
326
  logger.info("Starting training...")
327
  trainer.train()
328
 
329
+ # Final evaluation with proper ranking metrics
330
+ logger.info("Running final ranking evaluation...")
331
+ final_metrics = evaluate_ranking(model, eval_dataset_full)
332
  for key, value in final_metrics.items():
333
  trackio.log({f"final/{key}": value})
334
  logger.info(f"Final metrics: {final_metrics}")
335
 
336
+ # Per-type evaluation
337
  logger.info("Evaluating by content type...")
338
  type_metrics = evaluate_by_type(model, eval_dataset_full)
339
  for key, value in type_metrics.items():
 
342
 
343
  # Log improvement
344
  trackio.log({
345
+ "improvement/ndcg5_delta": final_metrics["ndcg@5"] - base_metrics["ndcg@5"],
346
+ "improvement/mrr_delta": final_metrics["mrr"] - base_metrics["mrr"],
347
  })
348
 
349
  logger.info(f"Pushing final model to {HUB_MODEL_ID}")