amanwithaplan commited on
Commit
e76af51
·
verified ·
1 Parent(s): a6f0d21

Add domain metrics callback + fix push_to_hub exist_ok

Browse files
Files changed (1) hide show
  1. train_reranker.py +37 -1
train_reranker.py CHANGED
@@ -29,6 +29,7 @@ from sentence_transformers.cross_encoder import (
29
  )
30
  from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
31
  from scipy.stats import spearmanr, pearsonr
 
32
 
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger(__name__)
@@ -68,6 +69,37 @@ def evaluate_correlation(model, eval_dataset):
68
  }
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def evaluate_by_type(model, eval_dataset, type_column="type"):
72
  """Evaluate correlation per content type."""
73
  if type_column not in eval_dataset.column_names:
@@ -199,12 +231,16 @@ def main():
199
  run_name=RUN_NAME,
200
  )
201
 
 
 
 
202
  trainer = CrossEncoderTrainer(
203
  model=model,
204
  args=args,
205
  train_dataset=train_dataset,
206
  eval_dataset=eval_dataset,
207
  evaluator=evaluator,
 
208
  )
209
 
210
  logger.info("Starting training...")
@@ -231,7 +267,7 @@ def main():
231
  })
232
 
233
  logger.info(f"Pushing final model to {HUB_MODEL_ID}")
234
- model.push_to_hub(HUB_MODEL_ID)
235
 
236
  trackio.finish()
237
  logger.info("Done!")
 
29
  )
30
  from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
31
  from scipy.stats import spearmanr, pearsonr
32
+ from transformers import TrainerCallback
33
 
34
  logging.basicConfig(level=logging.INFO)
35
  logger = logging.getLogger(__name__)
 
69
  }
70
 
71
 
72
+ class DomainEvalCallback(TrainerCallback):
73
+ """Callback to log our domain-specific correlation metrics during training."""
74
+
75
+ def __init__(self, model, eval_dataset_full):
76
+ self.model = model
77
+ self.eval_dataset_full = eval_dataset_full
78
+
79
+ def on_evaluate(self, args, state, control, **kwargs):
80
+ """Run after each evaluation step."""
81
+ # Get correlation metrics
82
+ pairs = [(item["sentence1"], item["sentence2"]) for item in self.eval_dataset_full]
83
+ labels = [item["label"] for item in self.eval_dataset_full]
84
+
85
+ predictions = self.model.predict(pairs, show_progress_bar=False)
86
+
87
+ spearman = spearmanr(predictions, labels).correlation
88
+ pearson_val = pearsonr(predictions, labels).statistic
89
+ mae = sum(abs(p - l) for p, l in zip(predictions, labels)) / len(labels)
90
+
91
+ # Log to trackio
92
+ trackio.log({
93
+ "domain/spearman": spearman,
94
+ "domain/pearson": pearson_val,
95
+ "domain/mae": float(mae),
96
+ "domain/pred_mean": float(predictions.mean()),
97
+ "domain/pred_std": float(predictions.std()),
98
+ })
99
+
100
+ logger.info(f"Domain eval - Spearman: {spearman:.4f}, Pearson: {pearson_val:.4f}, MAE: {mae:.4f}")
101
+
102
+
103
  def evaluate_by_type(model, eval_dataset, type_column="type"):
104
  """Evaluate correlation per content type."""
105
  if type_column not in eval_dataset.column_names:
 
231
  run_name=RUN_NAME,
232
  )
233
 
234
+ # Custom callback to log domain-specific metrics during training
235
+ domain_callback = DomainEvalCallback(model, eval_dataset_full)
236
+
237
  trainer = CrossEncoderTrainer(
238
  model=model,
239
  args=args,
240
  train_dataset=train_dataset,
241
  eval_dataset=eval_dataset,
242
  evaluator=evaluator,
243
+ callbacks=[domain_callback],
244
  )
245
 
246
  logger.info("Starting training...")
 
267
  })
268
 
269
  logger.info(f"Pushing final model to {HUB_MODEL_ID}")
270
+ model.push_to_hub(HUB_MODEL_ID, exist_ok=True)
271
 
272
  trackio.finish()
273
  logger.info("Done!")