File size: 2,525 Bytes
12fd5f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | """
Training callbacks for monitoring and checkpointing.
Integrates with Weights & Biases and TensorBoard.
"""
from transformers import TrainerCallback, TrainerState, TrainerControl, TrainingArguments
from loguru import logger
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
class StyleMetricsCallback(TrainerCallback):
"""Logs style similarity metrics during evaluation."""
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
metrics = kwargs.get("metrics", {})
if metrics:
logger.info(f"Evaluation metrics at step {state.global_step}:")
for key, value in metrics.items():
logger.info(f" {key}: {value:.4f}" if isinstance(value, float) else f" {key}: {value}")
# Log to W&B if available
if HAS_WANDB and wandb.run is not None:
wandb.log(
{f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))},
step=state.global_step,
)
class EarlyStoppingOnStyleDrift(TrainerCallback):
"""Stops training if style similarity drops below threshold."""
def __init__(self, min_style_similarity: float = 0.75):
self.min_style_similarity = min_style_similarity
self.best_style_sim = 0.0
self.patience_counter = 0
self.patience = 3 # Stop after 3 consecutive low evaluations
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
metrics = kwargs.get("metrics", {})
style_sim = metrics.get("eval_style_similarity", None)
if style_sim is not None:
if style_sim > self.best_style_sim:
self.best_style_sim = style_sim
self.patience_counter = 0
if style_sim < self.min_style_similarity:
self.patience_counter += 1
logger.warning(
f"Style similarity {style_sim:.4f} below threshold {self.min_style_similarity}. "
f"Patience: {self.patience_counter}/{self.patience}"
)
if self.patience_counter >= self.patience:
logger.error(
f"Early stopping: style similarity consistently below {self.min_style_similarity}"
)
control.should_training_stop = True
else:
self.patience_counter = 0
|