""" Training callbacks for monitoring and checkpointing. Integrates with Weights & Biases and TensorBoard. """ from transformers import TrainerCallback, TrainerState, TrainerControl, TrainingArguments from loguru import logger try: import wandb HAS_WANDB = True except ImportError: HAS_WANDB = False class StyleMetricsCallback(TrainerCallback): """Logs style similarity metrics during evaluation.""" def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): metrics = kwargs.get("metrics", {}) if metrics: logger.info(f"Evaluation metrics at step {state.global_step}:") for key, value in metrics.items(): logger.info(f" {key}: {value:.4f}" if isinstance(value, float) else f" {key}: {value}") # Log to W&B if available if HAS_WANDB and wandb.run is not None: wandb.log( {f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}, step=state.global_step, ) class EarlyStoppingOnStyleDrift(TrainerCallback): """Stops training if style similarity drops below threshold.""" def __init__(self, min_style_similarity: float = 0.75): self.min_style_similarity = min_style_similarity self.best_style_sim = 0.0 self.patience_counter = 0 self.patience = 3 # Stop after 3 consecutive low evaluations def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): metrics = kwargs.get("metrics", {}) style_sim = metrics.get("eval_style_similarity", None) if style_sim is not None: if style_sim > self.best_style_sim: self.best_style_sim = style_sim self.patience_counter = 0 if style_sim < self.min_style_similarity: self.patience_counter += 1 logger.warning( f"Style similarity {style_sim:.4f} below threshold {self.min_style_similarity}. " f"Patience: {self.patience_counter}/{self.patience}" ) if self.patience_counter >= self.patience: logger.error( f"Early stopping: style similarity consistently below {self.min_style_similarity}" ) control.should_training_stop = True else: self.patience_counter = 0