from __future__ import annotations import os import shutil from typing import Any, Callable import torch from peft import get_peft_model_state_dict from safetensors.torch import save_file as save_safetensors_file FINAL_CHECKPOINT_DIRNAME = "final_checkpoint" _WEIGHT_FILENAMES = ( "adapter_model.safetensors", "adapter_model.bin", "model.safetensors", "pytorch_model.bin", ) def ensure_final_checkpoint_dir(output_dir: str) -> str: repo_root = os.path.dirname(os.path.abspath(__file__)) output_dir_abs = os.path.abspath(output_dir) try: rel_output_dir = os.path.relpath(output_dir_abs, repo_root) except Exception: rel_output_dir = os.path.basename(output_dir_abs.rstrip(os.sep)) rel_parts = [part for part in rel_output_dir.split(os.sep) if part not in ("", ".")] if rel_parts and rel_parts[0] == FINAL_CHECKPOINT_DIRNAME: rel_parts = rel_parts[1:] if rel_parts and rel_parts[0] == "checkpoints": rel_parts = rel_parts[1:] if not rel_parts: rel_parts = [os.path.basename(output_dir_abs.rstrip(os.sep)) or "run"] final_dir = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *rel_parts) os.makedirs(final_dir, exist_ok=True) return final_dir def final_checkpoint_root(*parts: str) -> str: repo_root = os.path.dirname(os.path.abspath(__file__)) root = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *parts) os.makedirs(root, exist_ok=True) return root def normalize_to_final_checkpoint_root(path: str, *default_parts: str) -> str: raw = str(path or "").strip() if not raw: return final_checkpoint_root(*default_parts) abs_path = os.path.abspath(raw) repo_root = os.path.dirname(os.path.abspath(__file__)) rel_path = os.path.relpath(abs_path, repo_root) rel_parts = [part for part in rel_path.split(os.sep) if part not in ("", ".")] if rel_parts[:1] == [FINAL_CHECKPOINT_DIRNAME]: return abs_path if rel_parts[:1] == ["checkpoints"]: rel_parts = rel_parts[1:] return final_checkpoint_root(*rel_parts) return abs_path def _has_saved_weights(target_dir: str) -> bool: return any(os.path.exists(os.path.join(target_dir, name)) for name in _WEIGHT_FILENAMES) def _fallback_save_adapter_weights(model: Any, target_dir: str) -> None: if _has_saved_weights(target_dir): return state = get_peft_model_state_dict(model) cpu_state = { key: value.detach().cpu().contiguous() for key, value in state.items() if torch.is_tensor(value) } if cpu_state: save_safetensors_file(cpu_state, os.path.join(target_dir, "adapter_model.safetensors")) def save_model_artifacts( model: Any, tokenizer: Any, target_dir: str, *, extra_save_fn: Callable[[Any, str], None] | None = None, ) -> str: os.makedirs(target_dir, exist_ok=True) model.save_pretrained(target_dir) if tokenizer is not None: tokenizer.save_pretrained(target_dir) _fallback_save_adapter_weights(model, target_dir) if extra_save_fn is not None: extra_save_fn(model, target_dir) return target_dir def _replace_dir_contents(src_dir: str, dst_dir: str) -> None: os.makedirs(dst_dir, exist_ok=True) src_dir_abs = os.path.abspath(src_dir) for name in os.listdir(dst_dir): path = os.path.join(dst_dir, name) if os.path.abspath(path) == src_dir_abs: continue if os.path.isdir(path) and not os.path.islink(path): shutil.rmtree(path) else: os.unlink(path) for name in os.listdir(src_dir): src_path = os.path.join(src_dir, name) dst_path = os.path.join(dst_dir, name) if os.path.isdir(src_path) and not os.path.islink(src_path): shutil.copytree(src_path, dst_path) else: shutil.copy2(src_path, dst_path) def save_checkpoint_and_update_final( model: Any, tokenizer: Any, output_dir: str, checkpoint_name: str, *, extra_save_fn: Callable[[Any, str], None] | None = None, ) -> str: checkpoint_dir = os.path.join(output_dir, checkpoint_name) save_model_artifacts(model, tokenizer, checkpoint_dir, extra_save_fn=extra_save_fn) _replace_dir_contents(checkpoint_dir, ensure_final_checkpoint_dir(output_dir)) return checkpoint_dir