| """Save full model+tokenizer copies outside the rotating HF checkpoint window.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class PermanentCheckpointCallback(TrainerCallback): |
| """On each multiple of ``every_n_steps``, call ``trainer.save_model`` into a fixed subfolder.""" |
|
|
| def __init__(self, every_n_steps: int, subdir: str = "checkpoints/permanent") -> None: |
| self.every_n_steps = int(every_n_steps) |
| self.subdir = subdir |
| self._trainer = None |
|
|
| def bind_trainer(self, trainer) -> None: |
| self._trainer = trainer |
|
|
| def on_step_end( |
| self, |
| args: TrainingArguments, |
| state: TrainerState, |
| control: TrainerControl, |
| **kwargs, |
| ) -> TrainerControl: |
| if self.every_n_steps <= 0 or not state.is_world_process_zero: |
| return control |
| if state.global_step <= 0: |
| return control |
| if state.global_step % self.every_n_steps != 0: |
| return control |
| tr = self._trainer |
| if tr is None: |
| logger.warning("PermanentCheckpointCallback: trainer not bound; skipping save at step %s", state.global_step) |
| return control |
| out = Path(args.output_dir) / self.subdir / f"checkpoint-{state.global_step}" |
| if out.exists(): |
| return control |
| out.parent.mkdir(parents=True, exist_ok=True) |
| tr.save_model(str(out)) |
| logger.info("Permanent checkpoint: %s", out) |
| return control |
|
|