amanwithaplan commited on
Commit
55c1fce
·
verified ·
1 Parent(s): 3517d13

Add trackio metrics and correlation evaluation

Browse files
Files changed (1) hide show
  1. train_reranker.py +110 -1
train_reranker.py CHANGED
@@ -6,6 +6,7 @@
6
  # "torch>=2.4",
7
  # "transformers>=4.48",
8
  # "trackio",
 
9
  # ]
10
  # ///
11
  """
@@ -18,6 +19,8 @@ Dataset format: {"query": "...", "text": "...", "score": 0.0-1.0}
18
  import logging
19
  import os
20
  from collections import defaultdict
 
 
21
  from datasets import load_dataset
22
  from sentence_transformers.cross_encoder import (
23
  CrossEncoder,
@@ -25,6 +28,7 @@ from sentence_transformers.cross_encoder import (
25
  CrossEncoderTrainingArguments,
26
  )
27
  from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
 
28
 
29
  logging.basicConfig(level=logging.INFO)
30
  logger = logging.getLogger(__name__)
@@ -38,14 +42,81 @@ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "16"))
38
  LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-5"))
39
  MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "512"))
40
  RUN_NAME = os.environ.get("RUN_NAME", "reranker-03130903")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  logger.info(f"Configuration:")
45
  logger.info(f" Dataset: {DATASET_NAME}")
46
  logger.info(f" Base model: {BASE_MODEL}")
47
  logger.info(f" Epochs: {NUM_EPOCHS}")
48
  logger.info(f" Run name: {RUN_NAME}")
 
49
 
50
  model = CrossEncoder(BASE_MODEL, max_length=MAX_SEQ_LENGTH)
51
 
@@ -53,12 +124,17 @@ def main():
53
  dataset = load_dataset(DATASET_NAME, split="train")
54
 
55
  # Log dataset composition
 
56
  if "type" in dataset.column_names:
57
- type_counts = defaultdict(int)
58
  for item in dataset:
59
  type_counts[item["type"]] += 1
60
  logger.info(f"Dataset composition: {dict(type_counts)}")
61
 
 
 
 
 
 
62
  logger.info(f"Total examples: {len(dataset)}")
63
 
64
  # Rename columns for CrossEncoderTrainer
@@ -74,8 +150,19 @@ def main():
74
  train_dataset = splits["train"]
75
  eval_dataset = splits["test"]
76
 
 
 
 
 
77
  logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
78
 
 
 
 
 
 
 
 
79
  # NanoBEIR for benchmark comparison
80
  evaluator = CrossEncoderNanoBEIREvaluator(
81
  dataset_names=["msmarco", "nfcorpus", "nq"],
@@ -118,8 +205,30 @@ def main():
118
  logger.info("Starting training...")
119
  trainer.train()
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  logger.info(f"Pushing final model to {HUB_MODEL_ID}")
122
  model.push_to_hub(HUB_MODEL_ID)
 
 
123
  logger.info("Done!")
124
 
125
 
 
6
  # "torch>=2.4",
7
  # "transformers>=4.48",
8
  # "trackio",
9
+ # "scipy",
10
  # ]
11
  # ///
12
  """
 
19
  import logging
20
  import os
21
  from collections import defaultdict
22
+ import trackio
23
+ import torch
24
  from datasets import load_dataset
25
  from sentence_transformers.cross_encoder import (
26
  CrossEncoder,
 
28
  CrossEncoderTrainingArguments,
29
  )
30
  from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
31
+ from scipy.stats import spearmanr, pearsonr
32
 
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger(__name__)
 
42
  LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-5"))
43
  MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "512"))
44
  RUN_NAME = os.environ.get("RUN_NAME", "reranker-03130903")
45
+ SPACE_ID = os.environ.get("TRACKIO_SPACE_ID", "amanwithaplan/trackio")
46
+
47
+
48
+ def evaluate_correlation(model, eval_dataset):
49
+ """Evaluate correlation between predicted scores and labels."""
50
+ pairs = [(item["sentence1"], item["sentence2"]) for item in eval_dataset]
51
+ labels = [item["label"] for item in eval_dataset]
52
+
53
+ predictions = model.predict(pairs, show_progress_bar=True)
54
+
55
+ spearman = spearmanr(predictions, labels).correlation
56
+ pearson = pearsonr(predictions, labels).statistic
57
+
58
+ # Mean absolute error
59
+ mae = sum(abs(p - l) for p, l in zip(predictions, labels)) / len(labels)
60
+
61
+ return {
62
+ "spearman": spearman,
63
+ "pearson": pearson,
64
+ "mae": mae,
65
+ "pred_mean": float(predictions.mean()),
66
+ "pred_std": float(predictions.std()),
67
+ "label_mean": sum(labels) / len(labels),
68
+ }
69
+
70
+
71
+ def evaluate_by_type(model, eval_dataset, type_column="type"):
72
+ """Evaluate correlation per content type."""
73
+ if type_column not in eval_dataset.column_names:
74
+ return {}
75
+
76
+ # Group by type
77
+ by_type = defaultdict(list)
78
+ for item in eval_dataset:
79
+ by_type[item[type_column]].append(item)
80
+
81
+ results = {}
82
+ for content_type, items in by_type.items():
83
+ if len(items) < 5:
84
+ continue
85
+
86
+ pairs = [(item["sentence1"], item["sentence2"]) for item in items]
87
+ labels = [item["label"] for item in items]
88
+ predictions = model.predict(pairs)
89
+
90
+ if len(set(labels)) > 1: # Need variance for correlation
91
+ results[f"{content_type}_spearman"] = spearmanr(predictions, labels).correlation
92
+ results[f"{content_type}_mae"] = sum(abs(p - l) for p, l in zip(predictions, labels)) / len(labels)
93
+ results[f"{content_type}_n"] = len(items)
94
+
95
+ return results
96
 
97
 
98
  def main():
99
+ # Initialize trackio with full config
100
+ trackio.init(
101
+ project="arcade-reranker",
102
+ name=RUN_NAME,
103
+ space_id=SPACE_ID,
104
+ config={
105
+ "model": BASE_MODEL,
106
+ "dataset": DATASET_NAME,
107
+ "learning_rate": LEARNING_RATE,
108
+ "num_epochs": NUM_EPOCHS,
109
+ "batch_size": BATCH_SIZE,
110
+ "max_seq_length": MAX_SEQ_LENGTH,
111
+ }
112
+ )
113
+
114
  logger.info(f"Configuration:")
115
  logger.info(f" Dataset: {DATASET_NAME}")
116
  logger.info(f" Base model: {BASE_MODEL}")
117
  logger.info(f" Epochs: {NUM_EPOCHS}")
118
  logger.info(f" Run name: {RUN_NAME}")
119
+ logger.info(f" Trackio space: {SPACE_ID}")
120
 
121
  model = CrossEncoder(BASE_MODEL, max_length=MAX_SEQ_LENGTH)
122
 
 
124
  dataset = load_dataset(DATASET_NAME, split="train")
125
 
126
  # Log dataset composition
127
+ type_counts = defaultdict(int)
128
  if "type" in dataset.column_names:
 
129
  for item in dataset:
130
  type_counts[item["type"]] += 1
131
  logger.info(f"Dataset composition: {dict(type_counts)}")
132
 
133
+ # Log to trackio
134
+ for content_type, count in type_counts.items():
135
+ trackio.log({f"data/{content_type}_count": count})
136
+
137
+ trackio.log({"data/total_examples": len(dataset)})
138
  logger.info(f"Total examples: {len(dataset)}")
139
 
140
  # Rename columns for CrossEncoderTrainer
 
150
  train_dataset = splits["train"]
151
  eval_dataset = splits["test"]
152
 
153
+ trackio.log({
154
+ "data/train_size": len(train_dataset),
155
+ "data/eval_size": len(eval_dataset),
156
+ })
157
  logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
158
 
159
+ # Evaluate base model before training
160
+ logger.info("Evaluating base model on eval set...")
161
+ base_metrics = evaluate_correlation(model, eval_dataset)
162
+ for key, value in base_metrics.items():
163
+ trackio.log({f"base_model/{key}": value})
164
+ logger.info(f"Base model metrics: {base_metrics}")
165
+
166
  # NanoBEIR for benchmark comparison
167
  evaluator = CrossEncoderNanoBEIREvaluator(
168
  dataset_names=["msmarco", "nfcorpus", "nq"],
 
205
  logger.info("Starting training...")
206
  trainer.train()
207
 
208
+ # Final evaluation on our eval set
209
+ logger.info("Running final correlation evaluation...")
210
+ final_metrics = evaluate_correlation(model, eval_dataset)
211
+ for key, value in final_metrics.items():
212
+ trackio.log({f"final/{key}": value})
213
+ logger.info(f"Final metrics: {final_metrics}")
214
+
215
+ # Per-type evaluation
216
+ logger.info("Evaluating by content type...")
217
+ type_metrics = evaluate_by_type(model, eval_dataset)
218
+ for key, value in type_metrics.items():
219
+ trackio.log({f"final/by_type/{key}": value})
220
+ logger.info(f"Per-type metrics: {type_metrics}")
221
+
222
+ # Log improvement
223
+ trackio.log({
224
+ "improvement/spearman_delta": final_metrics["spearman"] - base_metrics["spearman"],
225
+ "improvement/mae_delta": base_metrics["mae"] - final_metrics["mae"], # Lower is better
226
+ })
227
+
228
  logger.info(f"Pushing final model to {HUB_MODEL_ID}")
229
  model.push_to_hub(HUB_MODEL_ID)
230
+
231
+ trackio.finish()
232
  logger.info("Done!")
233
 
234