curriculum-cot-code / checkpoint_utils.py
Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
import os
import shutil
from typing import Any, Callable
import torch
from peft import get_peft_model_state_dict
from safetensors.torch import save_file as save_safetensors_file
FINAL_CHECKPOINT_DIRNAME = "final_checkpoint"
_WEIGHT_FILENAMES = (
"adapter_model.safetensors",
"adapter_model.bin",
"model.safetensors",
"pytorch_model.bin",
)
def ensure_final_checkpoint_dir(output_dir: str) -> str:
repo_root = os.path.dirname(os.path.abspath(__file__))
output_dir_abs = os.path.abspath(output_dir)
try:
rel_output_dir = os.path.relpath(output_dir_abs, repo_root)
except Exception:
rel_output_dir = os.path.basename(output_dir_abs.rstrip(os.sep))
rel_parts = [part for part in rel_output_dir.split(os.sep) if part not in ("", ".")]
if rel_parts and rel_parts[0] == FINAL_CHECKPOINT_DIRNAME:
rel_parts = rel_parts[1:]
if rel_parts and rel_parts[0] == "checkpoints":
rel_parts = rel_parts[1:]
if not rel_parts:
rel_parts = [os.path.basename(output_dir_abs.rstrip(os.sep)) or "run"]
final_dir = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *rel_parts)
os.makedirs(final_dir, exist_ok=True)
return final_dir
def final_checkpoint_root(*parts: str) -> str:
repo_root = os.path.dirname(os.path.abspath(__file__))
root = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *parts)
os.makedirs(root, exist_ok=True)
return root
def normalize_to_final_checkpoint_root(path: str, *default_parts: str) -> str:
raw = str(path or "").strip()
if not raw:
return final_checkpoint_root(*default_parts)
abs_path = os.path.abspath(raw)
repo_root = os.path.dirname(os.path.abspath(__file__))
rel_path = os.path.relpath(abs_path, repo_root)
rel_parts = [part for part in rel_path.split(os.sep) if part not in ("", ".")]
if rel_parts[:1] == [FINAL_CHECKPOINT_DIRNAME]:
return abs_path
if rel_parts[:1] == ["checkpoints"]:
rel_parts = rel_parts[1:]
return final_checkpoint_root(*rel_parts)
return abs_path
def _has_saved_weights(target_dir: str) -> bool:
return any(os.path.exists(os.path.join(target_dir, name)) for name in _WEIGHT_FILENAMES)
def _fallback_save_adapter_weights(model: Any, target_dir: str) -> None:
if _has_saved_weights(target_dir):
return
state = get_peft_model_state_dict(model)
cpu_state = {
key: value.detach().cpu().contiguous()
for key, value in state.items()
if torch.is_tensor(value)
}
if cpu_state:
save_safetensors_file(cpu_state, os.path.join(target_dir, "adapter_model.safetensors"))
def save_model_artifacts(
model: Any,
tokenizer: Any,
target_dir: str,
*,
extra_save_fn: Callable[[Any, str], None] | None = None,
) -> str:
os.makedirs(target_dir, exist_ok=True)
model.save_pretrained(target_dir)
if tokenizer is not None:
tokenizer.save_pretrained(target_dir)
_fallback_save_adapter_weights(model, target_dir)
if extra_save_fn is not None:
extra_save_fn(model, target_dir)
return target_dir
def _replace_dir_contents(src_dir: str, dst_dir: str) -> None:
os.makedirs(dst_dir, exist_ok=True)
src_dir_abs = os.path.abspath(src_dir)
for name in os.listdir(dst_dir):
path = os.path.join(dst_dir, name)
if os.path.abspath(path) == src_dir_abs:
continue
if os.path.isdir(path) and not os.path.islink(path):
shutil.rmtree(path)
else:
os.unlink(path)
for name in os.listdir(src_dir):
src_path = os.path.join(src_dir, name)
dst_path = os.path.join(dst_dir, name)
if os.path.isdir(src_path) and not os.path.islink(src_path):
shutil.copytree(src_path, dst_path)
else:
shutil.copy2(src_path, dst_path)
def save_checkpoint_and_update_final(
model: Any,
tokenizer: Any,
output_dir: str,
checkpoint_name: str,
*,
extra_save_fn: Callable[[Any, str], None] | None = None,
) -> str:
checkpoint_dir = os.path.join(output_dir, checkpoint_name)
save_model_artifacts(model, tokenizer, checkpoint_dir, extra_save_fn=extra_save_fn)
_replace_dir_contents(checkpoint_dir, ensure_final_checkpoint_dir(output_dir))
return checkpoint_dir