neuralese_temp / src /hackable /permanent_checkpoint_callback.py
psidharth567's picture
Export neuralese codebase (cache and .env excluded).
dbc69f3
"""Save full model+tokenizer copies outside the rotating HF checkpoint window."""
from __future__ import annotations
from pathlib import Path
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from transformers.utils import logging
logger = logging.get_logger(__name__)
class PermanentCheckpointCallback(TrainerCallback):
"""On each multiple of ``every_n_steps``, call ``trainer.save_model`` into a fixed subfolder."""
def __init__(self, every_n_steps: int, subdir: str = "checkpoints/permanent") -> None:
self.every_n_steps = int(every_n_steps)
self.subdir = subdir
self._trainer = None
def bind_trainer(self, trainer) -> None:
self._trainer = trainer
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> TrainerControl:
if self.every_n_steps <= 0 or not state.is_world_process_zero:
return control
if state.global_step <= 0:
return control
if state.global_step % self.every_n_steps != 0:
return control
tr = self._trainer
if tr is None:
logger.warning("PermanentCheckpointCallback: trainer not bound; skipping save at step %s", state.global_step)
return control
out = Path(args.output_dir) / self.subdir / f"checkpoint-{state.global_step}"
if out.exists():
return control
out.parent.mkdir(parents=True, exist_ok=True)
tr.save_model(str(out))
logger.info("Permanent checkpoint: %s", out)
return control