| |
|
|
| from __future__ import annotations |
|
|
| import shutil |
| from pathlib import Path |
| from typing import Iterable, Optional |
|
|
| from transformers import Trainer |
|
|
|
|
| def _copy_path(src: Path, dst_dir: Path) -> None: |
| dst = dst_dir / src.name |
| if src.is_dir(): |
| if dst.exists(): |
| shutil.rmtree(dst) |
| shutil.copytree(src, dst) |
| else: |
| shutil.copy2(src, dst) |
|
|
|
|
| class HubReadyTrainer(Trainer): |
| def __init__( |
| self, |
| *args, |
| code_paths: Optional[Iterable[str]] = None, |
| **kwargs, |
| ): |
| super().__init__(*args, **kwargs) |
| self.code_paths = list(code_paths or []) |
|
|
| def _checkpoint_dir(self) -> Path: |
| return Path(self.args.output_dir) / f"checkpoint-{self.state.global_step}" |
|
|
| def _save_extra_hub_artifacts(self, checkpoint_dir: Path) -> None: |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| processing_obj = getattr(self, "processing_class", None) |
| if processing_obj is not None: |
| processing_obj.save_pretrained(checkpoint_dir) |
|
|
| |
| for src_str in self.code_paths: |
| src = Path(src_str) |
| if not src.exists(): |
| raise FileNotFoundError(f"Custom code path not found: {src}") |
| _copy_path(src, checkpoint_dir) |
|
|
| def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): |
| super().save_model(output_dir=output_dir, _internal_call=_internal_call) |
| target_dir = Path(output_dir) if output_dir is not None else Path(self.args.output_dir) |
| self._save_extra_hub_artifacts(target_dir) |
|
|
| def _save_checkpoint(self, model, trial): |
| super()._save_checkpoint(model, trial) |
| checkpoint_dir = self._checkpoint_dir() |
| self._save_extra_hub_artifacts(checkpoint_dir) |