rewrite / src /training /callbacks.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
Training callbacks for monitoring and checkpointing.
Integrates with Weights & Biases and TensorBoard.
"""
from transformers import TrainerCallback, TrainerState, TrainerControl, TrainingArguments
from loguru import logger
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
class StyleMetricsCallback(TrainerCallback):
"""Logs style similarity metrics during evaluation."""
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
metrics = kwargs.get("metrics", {})
if metrics:
logger.info(f"Evaluation metrics at step {state.global_step}:")
for key, value in metrics.items():
logger.info(f" {key}: {value:.4f}" if isinstance(value, float) else f" {key}: {value}")
# Log to W&B if available
if HAS_WANDB and wandb.run is not None:
wandb.log(
{f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))},
step=state.global_step,
)
class EarlyStoppingOnStyleDrift(TrainerCallback):
"""Stops training if style similarity drops below threshold."""
def __init__(self, min_style_similarity: float = 0.75):
self.min_style_similarity = min_style_similarity
self.best_style_sim = 0.0
self.patience_counter = 0
self.patience = 3 # Stop after 3 consecutive low evaluations
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
metrics = kwargs.get("metrics", {})
style_sim = metrics.get("eval_style_similarity", None)
if style_sim is not None:
if style_sim > self.best_style_sim:
self.best_style_sim = style_sim
self.patience_counter = 0
if style_sim < self.min_style_similarity:
self.patience_counter += 1
logger.warning(
f"Style similarity {style_sim:.4f} below threshold {self.min_style_similarity}. "
f"Patience: {self.patience_counter}/{self.patience}"
)
if self.patience_counter >= self.patience:
logger.error(
f"Early stopping: style similarity consistently below {self.min_style_similarity}"
)
control.should_training_stop = True
else:
self.patience_counter = 0