File size: 1,660 Bytes
dbc69f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Save full model+tokenizer copies outside the rotating HF checkpoint window."""

from __future__ import annotations

from pathlib import Path

from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from transformers.utils import logging

logger = logging.get_logger(__name__)


class PermanentCheckpointCallback(TrainerCallback):
    """On each multiple of ``every_n_steps``, call ``trainer.save_model`` into a fixed subfolder."""

    def __init__(self, every_n_steps: int, subdir: str = "checkpoints/permanent") -> None:
        self.every_n_steps = int(every_n_steps)
        self.subdir = subdir
        self._trainer = None

    def bind_trainer(self, trainer) -> None:
        self._trainer = trainer

    def on_step_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ) -> TrainerControl:
        if self.every_n_steps <= 0 or not state.is_world_process_zero:
            return control
        if state.global_step <= 0:
            return control
        if state.global_step % self.every_n_steps != 0:
            return control
        tr = self._trainer
        if tr is None:
            logger.warning("PermanentCheckpointCallback: trainer not bound; skipping save at step %s", state.global_step)
            return control
        out = Path(args.output_dir) / self.subdir / f"checkpoint-{state.global_step}"
        if out.exists():
            return control
        out.parent.mkdir(parents=True, exist_ok=True)
        tr.save_model(str(out))
        logger.info("Permanent checkpoint: %s", out)
        return control