File size: 4,360 Bytes
76de008 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | from __future__ import annotations
import os
import shutil
from typing import Any, Callable
import torch
from peft import get_peft_model_state_dict
from safetensors.torch import save_file as save_safetensors_file
FINAL_CHECKPOINT_DIRNAME = "final_checkpoint"
_WEIGHT_FILENAMES = (
"adapter_model.safetensors",
"adapter_model.bin",
"model.safetensors",
"pytorch_model.bin",
)
def ensure_final_checkpoint_dir(output_dir: str) -> str:
repo_root = os.path.dirname(os.path.abspath(__file__))
output_dir_abs = os.path.abspath(output_dir)
try:
rel_output_dir = os.path.relpath(output_dir_abs, repo_root)
except Exception:
rel_output_dir = os.path.basename(output_dir_abs.rstrip(os.sep))
rel_parts = [part for part in rel_output_dir.split(os.sep) if part not in ("", ".")]
if rel_parts and rel_parts[0] == FINAL_CHECKPOINT_DIRNAME:
rel_parts = rel_parts[1:]
if rel_parts and rel_parts[0] == "checkpoints":
rel_parts = rel_parts[1:]
if not rel_parts:
rel_parts = [os.path.basename(output_dir_abs.rstrip(os.sep)) or "run"]
final_dir = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *rel_parts)
os.makedirs(final_dir, exist_ok=True)
return final_dir
def final_checkpoint_root(*parts: str) -> str:
repo_root = os.path.dirname(os.path.abspath(__file__))
root = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *parts)
os.makedirs(root, exist_ok=True)
return root
def normalize_to_final_checkpoint_root(path: str, *default_parts: str) -> str:
raw = str(path or "").strip()
if not raw:
return final_checkpoint_root(*default_parts)
abs_path = os.path.abspath(raw)
repo_root = os.path.dirname(os.path.abspath(__file__))
rel_path = os.path.relpath(abs_path, repo_root)
rel_parts = [part for part in rel_path.split(os.sep) if part not in ("", ".")]
if rel_parts[:1] == [FINAL_CHECKPOINT_DIRNAME]:
return abs_path
if rel_parts[:1] == ["checkpoints"]:
rel_parts = rel_parts[1:]
return final_checkpoint_root(*rel_parts)
return abs_path
def _has_saved_weights(target_dir: str) -> bool:
return any(os.path.exists(os.path.join(target_dir, name)) for name in _WEIGHT_FILENAMES)
def _fallback_save_adapter_weights(model: Any, target_dir: str) -> None:
if _has_saved_weights(target_dir):
return
state = get_peft_model_state_dict(model)
cpu_state = {
key: value.detach().cpu().contiguous()
for key, value in state.items()
if torch.is_tensor(value)
}
if cpu_state:
save_safetensors_file(cpu_state, os.path.join(target_dir, "adapter_model.safetensors"))
def save_model_artifacts(
model: Any,
tokenizer: Any,
target_dir: str,
*,
extra_save_fn: Callable[[Any, str], None] | None = None,
) -> str:
os.makedirs(target_dir, exist_ok=True)
model.save_pretrained(target_dir)
if tokenizer is not None:
tokenizer.save_pretrained(target_dir)
_fallback_save_adapter_weights(model, target_dir)
if extra_save_fn is not None:
extra_save_fn(model, target_dir)
return target_dir
def _replace_dir_contents(src_dir: str, dst_dir: str) -> None:
os.makedirs(dst_dir, exist_ok=True)
src_dir_abs = os.path.abspath(src_dir)
for name in os.listdir(dst_dir):
path = os.path.join(dst_dir, name)
if os.path.abspath(path) == src_dir_abs:
continue
if os.path.isdir(path) and not os.path.islink(path):
shutil.rmtree(path)
else:
os.unlink(path)
for name in os.listdir(src_dir):
src_path = os.path.join(src_dir, name)
dst_path = os.path.join(dst_dir, name)
if os.path.isdir(src_path) and not os.path.islink(src_path):
shutil.copytree(src_path, dst_path)
else:
shutil.copy2(src_path, dst_path)
def save_checkpoint_and_update_final(
model: Any,
tokenizer: Any,
output_dir: str,
checkpoint_name: str,
*,
extra_save_fn: Callable[[Any, str], None] | None = None,
) -> str:
checkpoint_dir = os.path.join(output_dir, checkpoint_name)
save_model_artifacts(model, tokenizer, checkpoint_dir, extra_save_fn=extra_save_fn)
_replace_dir_contents(checkpoint_dir, ensure_final_checkpoint_dir(output_dir))
return checkpoint_dir
|