| |
|
|
| """ |
| Author : Fabien FURFARO |
| """ |
|
|
| from typing import Optional, Union |
|
|
| from transformers import PreTrainedModel, TrainerCallback |
|
|
| from .modeling_tptt import LiZAttention |
|
|
|
|
| class LiZACallback(TrainerCallback): |
| """ |
| TrainerCallback to schedule mag_weight or enable/disable linear attention during training. |
| |
| Modes: |
| - "gradual": linear interpolation from initial_weight to final_weight. |
| - "cyclic": alternate between values in weight_list at each step. |
| - "switch": alternately enable/disable linear attention at each step. |
| """ |
|
|
| def __init__( |
| self, |
| model: PreTrainedModel, |
| mode: str = "gradual", |
| initial_weight: float = 0.0, |
| final_weight: float = 0.5, |
| transition_step: Union[int, tuple, list] = 100, |
| weight_list: Optional[list] = None, |
| switch_period: int = 1, |
| ): |
| self.model = model |
| self.mode = mode |
|
|
| |
| if isinstance(initial_weight, (tuple, list)): |
| initial_weight = initial_weight[0] |
| if isinstance(final_weight, (tuple, list)): |
| final_weight = final_weight[0] |
| self.initial_weight = float(initial_weight) |
| self.final_weight = float(final_weight) |
|
|
| |
| self.transition_step = ensure_int(transition_step) |
| if self.mode == "constant": |
| |
| self.initial_weight = self.final_weight |
| |
| if weight_list is not None: |
| self.weight_list = [ |
| float(w[0]) if isinstance(w, (tuple, list)) else float(w) |
| for w in weight_list |
| ] |
| else: |
| self.weight_list = [self.initial_weight, self.final_weight] |
|
|
| |
| self.switch_period = int(switch_period) |
|
|
| def on_step_end(self, args, state, control, **kwargs): |
| current_step = state.global_step |
| transition_step = self.transition_step |
|
|
| |
| current_step = ensure_int(current_step) |
| transition_step = ensure_int(transition_step) |
|
|
| |
| if self.mode == "constant": |
| |
| weight = self.final_weight |
| for _, module in self.model.named_modules(): |
| if isinstance(module, LiZAttention): |
| module.mag_weight = weight |
|
|
| elif self.mode == "gradual": |
| if current_step <= transition_step: |
| weight = self.initial_weight + ( |
| self.final_weight - self.initial_weight |
| ) * (current_step / transition_step) |
| else: |
| weight = self.final_weight |
| for _, module in self.model.named_modules(): |
| if isinstance(module, LiZAttention): |
| module.mag_weight = weight |
|
|
| elif self.mode == "cyclic": |
| idx = current_step % len(self.weight_list) |
| weight = self.weight_list[idx] |
| for _, module in self.model.named_modules(): |
| if isinstance(module, LiZAttention): |
| module.mag_weight = weight |
|
|
| elif self.mode == "switch": |
| |
| disable = (current_step // self.switch_period) % 2 == 0 |
| for _, module in self.model.named_modules(): |
| if isinstance(module, LiZAttention): |
| module.disable_linear_attn = disable |
|
|
| else: |
| raise ValueError(f"Unknown mode: {self.mode}") |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs): |
| mag_weight = None |
| disable_linear_attn = None |
| |
| for _, module in self.model.named_modules(): |
| if isinstance(module, LiZAttention): |
| mag_weight = getattr(module, "mag_weight", None) |
| disable_linear_attn = getattr(module, "disable_linear_attn", None) |
| break |
| if mag_weight is not None and logs is not None: |
| logs["mag_weight"] = float(mag_weight) |
| if disable_linear_attn is not None and logs is not None: |
| logs["disable_linear_attn"] = not bool(disable_linear_attn) |
|
|
|
|
| def ensure_int(value: Union[int, tuple, list]) -> int: |
| """Ensure the value is a plain integer.""" |
| if isinstance(value, (tuple, list)): |
| value = int(value[0]) |
| if hasattr(value, "item"): |
| value = int(value.item()) |
| return value |
|
|
|
|
| class SaveBestModelCallback(TrainerCallback): |
| """TrainerCallback to save the best model based on evaluation loss.""" |
|
|
| def __init__(self): |
| self.best_metric = float("inf") |
|
|
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): |
| if metrics is not None and "eval_loss" in metrics: |
| if metrics["eval_loss"] < self.best_metric: |
| self.best_metric = metrics["eval_loss"] |
| control.should_save = True |
| else: |
| control.should_save = False |
|
|