File size: 1,660 Bytes
dbc69f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | """Save full model+tokenizer copies outside the rotating HF checkpoint window."""
from __future__ import annotations
from pathlib import Path
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from transformers.utils import logging
logger = logging.get_logger(__name__)
class PermanentCheckpointCallback(TrainerCallback):
"""On each multiple of ``every_n_steps``, call ``trainer.save_model`` into a fixed subfolder."""
def __init__(self, every_n_steps: int, subdir: str = "checkpoints/permanent") -> None:
self.every_n_steps = int(every_n_steps)
self.subdir = subdir
self._trainer = None
def bind_trainer(self, trainer) -> None:
self._trainer = trainer
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> TrainerControl:
if self.every_n_steps <= 0 or not state.is_world_process_zero:
return control
if state.global_step <= 0:
return control
if state.global_step % self.every_n_steps != 0:
return control
tr = self._trainer
if tr is None:
logger.warning("PermanentCheckpointCallback: trainer not bound; skipping save at step %s", state.global_step)
return control
out = Path(args.output_dir) / self.subdir / f"checkpoint-{state.global_step}"
if out.exists():
return control
out.parent.mkdir(parents=True, exist_ok=True)
tr.save_model(str(out))
logger.info("Permanent checkpoint: %s", out)
return control
|