File size: 1,898 Bytes
c2e579a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# hub_trainer.py

from __future__ import annotations

import shutil
from pathlib import Path
from typing import Iterable, Optional

from transformers import Trainer


def _copy_path(src: Path, dst_dir: Path) -> None:
    dst = dst_dir / src.name
    if src.is_dir():
        if dst.exists():
            shutil.rmtree(dst)
        shutil.copytree(src, dst)
    else:
        shutil.copy2(src, dst)


class HubReadyTrainer(Trainer):
    def __init__(
        self,
        *args,
        code_paths: Optional[Iterable[str]] = None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.code_paths = list(code_paths or [])

    def _checkpoint_dir(self) -> Path:
        return Path(self.args.output_dir) / f"checkpoint-{self.state.global_step}"

    def _save_extra_hub_artifacts(self, checkpoint_dir: Path) -> None:
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        # Save tokenizer / processor into this checkpoint
        processing_obj = getattr(self, "processing_class", None)
        if processing_obj is not None:
            processing_obj.save_pretrained(checkpoint_dir)

        # Copy custom code paths
        for src_str in self.code_paths:
            src = Path(src_str)
            if not src.exists():
                raise FileNotFoundError(f"Custom code path not found: {src}")
            _copy_path(src, checkpoint_dir)

    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
        super().save_model(output_dir=output_dir, _internal_call=_internal_call)
        target_dir = Path(output_dir) if output_dir is not None else Path(self.args.output_dir)
        self._save_extra_hub_artifacts(target_dir)

    def _save_checkpoint(self, model, trial):
        super()._save_checkpoint(model, trial)
        checkpoint_dir = self._checkpoint_dir()
        self._save_extra_hub_artifacts(checkpoint_dir)