"""Save full model+tokenizer copies outside the rotating HF checkpoint window.""" from __future__ import annotations from pathlib import Path from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments from transformers.utils import logging logger = logging.get_logger(__name__) class PermanentCheckpointCallback(TrainerCallback): """On each multiple of ``every_n_steps``, call ``trainer.save_model`` into a fixed subfolder.""" def __init__(self, every_n_steps: int, subdir: str = "checkpoints/permanent") -> None: self.every_n_steps = int(every_n_steps) self.subdir = subdir self._trainer = None def bind_trainer(self, trainer) -> None: self._trainer = trainer def on_step_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ) -> TrainerControl: if self.every_n_steps <= 0 or not state.is_world_process_zero: return control if state.global_step <= 0: return control if state.global_step % self.every_n_steps != 0: return control tr = self._trainer if tr is None: logger.warning("PermanentCheckpointCallback: trainer not bound; skipping save at step %s", state.global_step) return control out = Path(args.output_dir) / self.subdir / f"checkpoint-{state.global_step}" if out.exists(): return control out.parent.mkdir(parents=True, exist_ok=True) tr.save_model(str(out)) logger.info("Permanent checkpoint: %s", out) return control