File size: 2,525 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Training callbacks for monitoring and checkpointing.
Integrates with Weights & Biases and TensorBoard.
"""

from transformers import TrainerCallback, TrainerState, TrainerControl, TrainingArguments
from loguru import logger

try:
    import wandb
    HAS_WANDB = True
except ImportError:
    HAS_WANDB = False


class StyleMetricsCallback(TrainerCallback):
    """Logs style similarity metrics during evaluation."""

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        metrics = kwargs.get("metrics", {})
        if metrics:
            logger.info(f"Evaluation metrics at step {state.global_step}:")
            for key, value in metrics.items():
                logger.info(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")

            # Log to W&B if available
            if HAS_WANDB and wandb.run is not None:
                wandb.log(
                    {f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))},
                    step=state.global_step,
                )


class EarlyStoppingOnStyleDrift(TrainerCallback):
    """Stops training if style similarity drops below threshold."""

    def __init__(self, min_style_similarity: float = 0.75):
        self.min_style_similarity = min_style_similarity
        self.best_style_sim = 0.0
        self.patience_counter = 0
        self.patience = 3  # Stop after 3 consecutive low evaluations

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        metrics = kwargs.get("metrics", {})
        style_sim = metrics.get("eval_style_similarity", None)

        if style_sim is not None:
            if style_sim > self.best_style_sim:
                self.best_style_sim = style_sim
                self.patience_counter = 0

            if style_sim < self.min_style_similarity:
                self.patience_counter += 1
                logger.warning(
                    f"Style similarity {style_sim:.4f} below threshold {self.min_style_similarity}. "
                    f"Patience: {self.patience_counter}/{self.patience}"
                )
                if self.patience_counter >= self.patience:
                    logger.error(
                        f"Early stopping: style similarity consistently below {self.min_style_similarity}"
                    )
                    control.should_training_stop = True
            else:
                self.patience_counter = 0