Add domain metrics callback + fix push_to_hub exist_ok
Browse files- train_reranker.py +37 -1
train_reranker.py
CHANGED
|
@@ -29,6 +29,7 @@ from sentence_transformers.cross_encoder import (
|
|
| 29 |
)
|
| 30 |
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
|
| 31 |
from scipy.stats import spearmanr, pearsonr
|
|
|
|
| 32 |
|
| 33 |
logging.basicConfig(level=logging.INFO)
|
| 34 |
logger = logging.getLogger(__name__)
|
|
@@ -68,6 +69,37 @@ def evaluate_correlation(model, eval_dataset):
|
|
| 68 |
}
|
| 69 |
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def evaluate_by_type(model, eval_dataset, type_column="type"):
|
| 72 |
"""Evaluate correlation per content type."""
|
| 73 |
if type_column not in eval_dataset.column_names:
|
|
@@ -199,12 +231,16 @@ def main():
|
|
| 199 |
run_name=RUN_NAME,
|
| 200 |
)
|
| 201 |
|
|
|
|
|
|
|
|
|
|
| 202 |
trainer = CrossEncoderTrainer(
|
| 203 |
model=model,
|
| 204 |
args=args,
|
| 205 |
train_dataset=train_dataset,
|
| 206 |
eval_dataset=eval_dataset,
|
| 207 |
evaluator=evaluator,
|
|
|
|
| 208 |
)
|
| 209 |
|
| 210 |
logger.info("Starting training...")
|
|
@@ -231,7 +267,7 @@ def main():
|
|
| 231 |
})
|
| 232 |
|
| 233 |
logger.info(f"Pushing final model to {HUB_MODEL_ID}")
|
| 234 |
-
model.push_to_hub(HUB_MODEL_ID)
|
| 235 |
|
| 236 |
trackio.finish()
|
| 237 |
logger.info("Done!")
|
|
|
|
| 29 |
)
|
| 30 |
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
|
| 31 |
from scipy.stats import spearmanr, pearsonr
|
| 32 |
+
from transformers import TrainerCallback
|
| 33 |
|
| 34 |
logging.basicConfig(level=logging.INFO)
|
| 35 |
logger = logging.getLogger(__name__)
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
|
| 72 |
+
class DomainEvalCallback(TrainerCallback):
|
| 73 |
+
"""Callback to log our domain-specific correlation metrics during training."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, model, eval_dataset_full):
|
| 76 |
+
self.model = model
|
| 77 |
+
self.eval_dataset_full = eval_dataset_full
|
| 78 |
+
|
| 79 |
+
def on_evaluate(self, args, state, control, **kwargs):
|
| 80 |
+
"""Run after each evaluation step."""
|
| 81 |
+
# Get correlation metrics
|
| 82 |
+
pairs = [(item["sentence1"], item["sentence2"]) for item in self.eval_dataset_full]
|
| 83 |
+
labels = [item["label"] for item in self.eval_dataset_full]
|
| 84 |
+
|
| 85 |
+
predictions = self.model.predict(pairs, show_progress_bar=False)
|
| 86 |
+
|
| 87 |
+
spearman = spearmanr(predictions, labels).correlation
|
| 88 |
+
pearson_val = pearsonr(predictions, labels).statistic
|
| 89 |
+
mae = sum(abs(p - l) for p, l in zip(predictions, labels)) / len(labels)
|
| 90 |
+
|
| 91 |
+
# Log to trackio
|
| 92 |
+
trackio.log({
|
| 93 |
+
"domain/spearman": spearman,
|
| 94 |
+
"domain/pearson": pearson_val,
|
| 95 |
+
"domain/mae": float(mae),
|
| 96 |
+
"domain/pred_mean": float(predictions.mean()),
|
| 97 |
+
"domain/pred_std": float(predictions.std()),
|
| 98 |
+
})
|
| 99 |
+
|
| 100 |
+
logger.info(f"Domain eval - Spearman: {spearman:.4f}, Pearson: {pearson_val:.4f}, MAE: {mae:.4f}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
def evaluate_by_type(model, eval_dataset, type_column="type"):
|
| 104 |
"""Evaluate correlation per content type."""
|
| 105 |
if type_column not in eval_dataset.column_names:
|
|
|
|
| 231 |
run_name=RUN_NAME,
|
| 232 |
)
|
| 233 |
|
| 234 |
+
# Custom callback to log domain-specific metrics during training
|
| 235 |
+
domain_callback = DomainEvalCallback(model, eval_dataset_full)
|
| 236 |
+
|
| 237 |
trainer = CrossEncoderTrainer(
|
| 238 |
model=model,
|
| 239 |
args=args,
|
| 240 |
train_dataset=train_dataset,
|
| 241 |
eval_dataset=eval_dataset,
|
| 242 |
evaluator=evaluator,
|
| 243 |
+
callbacks=[domain_callback],
|
| 244 |
)
|
| 245 |
|
| 246 |
logger.info("Starting training...")
|
|
|
|
| 267 |
})
|
| 268 |
|
| 269 |
logger.info(f"Pushing final model to {HUB_MODEL_ID}")
|
| 270 |
+
model.push_to_hub(HUB_MODEL_ID, exist_ok=True)
|
| 271 |
|
| 272 |
trackio.finish()
|
| 273 |
logger.info("Done!")
|