# hub_trainer.py from __future__ import annotations import shutil from pathlib import Path from typing import Iterable, Optional from transformers import Trainer def _copy_path(src: Path, dst_dir: Path) -> None: dst = dst_dir / src.name if src.is_dir(): if dst.exists(): shutil.rmtree(dst) shutil.copytree(src, dst) else: shutil.copy2(src, dst) class HubReadyTrainer(Trainer): def __init__( self, *args, code_paths: Optional[Iterable[str]] = None, **kwargs, ): super().__init__(*args, **kwargs) self.code_paths = list(code_paths or []) def _checkpoint_dir(self) -> Path: return Path(self.args.output_dir) / f"checkpoint-{self.state.global_step}" def _save_extra_hub_artifacts(self, checkpoint_dir: Path) -> None: checkpoint_dir.mkdir(parents=True, exist_ok=True) # Save tokenizer / processor into this checkpoint processing_obj = getattr(self, "processing_class", None) if processing_obj is not None: processing_obj.save_pretrained(checkpoint_dir) # Copy custom code paths for src_str in self.code_paths: src = Path(src_str) if not src.exists(): raise FileNotFoundError(f"Custom code path not found: {src}") _copy_path(src, checkpoint_dir) def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): super().save_model(output_dir=output_dir, _internal_call=_internal_call) target_dir = Path(output_dir) if output_dir is not None else Path(self.args.output_dir) self._save_extra_hub_artifacts(target_dir) def _save_checkpoint(self, model, trial): super()._save_checkpoint(model, trial) checkpoint_dir = self._checkpoint_dir() self._save_extra_hub_artifacts(checkpoint_dir)