zombiee / scripts /check_hub_checkpoints.py
EeshanSingh's picture
commit file change
636ec72
#!/usr/bin/env python
"""Inspect / recover GRPO training checkpoints from the Hugging Face Hub.
Use this to answer "did Colab/Kaggle save anything before it died?" and to
pull the latest checkpoint locally for manual inspection or DGX resume.
Examples
--------
List what's on the Hub:
python scripts/check_hub_checkpoints.py --hub-model-id noanya/zombiee
Show training progress (step, loss, learning rate from trainer_state.json):
python scripts/check_hub_checkpoints.py --hub-model-id noanya/zombiee --info
Download the latest checkpoint to ./recovered/:
python scripts/check_hub_checkpoints.py --hub-model-id noanya/zombiee \\
--download ./recovered
Then resume training from it:
python -m training.train \\
--resume-from-checkpoint ./recovered \\
--push-to-hub --hub-model-id noanya/zombiee \\
--max-steps 4000 --output-dir ./lora_v1
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from datetime import datetime, timezone
def parse_args():
p = argparse.ArgumentParser(
description="List / download GRPO training checkpoints from HF Hub.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
p.add_argument(
"--hub-model-id", default=os.environ.get("HUB_MODEL_ID", "noanya/zombiee"),
help="HF Hub repo id, e.g. 'noanya/zombiee' (default: $HUB_MODEL_ID or noanya/zombiee).",
)
p.add_argument(
"--info", action="store_true",
help="Read trainer_state.json from the latest checkpoint and print training progress.",
)
p.add_argument(
"--download", metavar="DIR", default=None,
help="Download the latest checkpoint to this directory.",
)
p.add_argument(
"--checkpoint", metavar="N", type=int, default=None,
help="Operate on checkpoint-N specifically instead of the latest.",
)
p.add_argument(
"--token", default=os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HF_TOKEN"),
help="HF token (default: $HUGGINGFACE_TOKEN / $HF_TOKEN). Required for private repos.",
)
return p.parse_args()
def list_checkpoints(api, repo_id, token):
"""Return (sorted list of checkpoint step numbers, list of root files)."""
try:
files = api.list_repo_files(repo_id, token=token)
except Exception as e:
print(f"ERROR: could not list {repo_id}: {e}", file=sys.stderr)
print(
" If the repo is private, set HUGGINGFACE_TOKEN. If it doesn't exist yet,\n"
" no training run has pushed to it.",
file=sys.stderr,
)
sys.exit(1)
steps = set()
root_files = []
for f in files:
if f.startswith("checkpoint-"):
try:
steps.add(int(f.split("/", 1)[0].split("-", 1)[1]))
except ValueError:
pass
elif "/" not in f:
root_files.append(f)
return sorted(steps), sorted(root_files)
def fetch_trainer_state(api, repo_id, step, token, work_dir):
"""Download trainer_state.json from checkpoint-step and return parsed dict."""
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id=repo_id,
filename=f"checkpoint-{step}/trainer_state.json",
local_dir=work_dir,
token=token,
)
with open(path) as f:
return json.load(f)
def fmt_age(iso_or_dt):
"""Render 'X minutes/hours/days ago' from an HF datetime."""
if isinstance(iso_or_dt, str):
try:
dt = datetime.fromisoformat(iso_or_dt.replace("Z", "+00:00"))
except ValueError:
return iso_or_dt
else:
dt = iso_or_dt
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
delta = datetime.now(timezone.utc) - dt
s = int(delta.total_seconds())
if s < 60:
return f"{s}s ago"
if s < 3600:
return f"{s // 60}m ago"
if s < 86400:
return f"{s // 3600}h {(s % 3600) // 60}m ago"
return f"{s // 86400}d {(s % 86400) // 3600}h ago"
def cmd_list(api, repo_id, token):
steps, root_files = list_checkpoints(api, repo_id, token)
print(f"Repo: https://huggingface.co/{repo_id}")
try:
info = api.repo_info(repo_id, token=token)
print(f"Last commit: {info.sha[:8]} ({fmt_age(info.lastModified)})")
except Exception:
pass
print()
if not steps:
print("No checkpoint-* directories found.")
if root_files:
print(f"Root files present: {', '.join(root_files)}")
print("(Looks like only a final-model push, no intermediate checkpoints.)")
else:
print("Repo is empty — training never reached the first save.")
return
print(f"Found {len(steps)} checkpoint(s): {', '.join(f'checkpoint-{s}' for s in steps)}")
print(f"Latest: checkpoint-{steps[-1]}")
if root_files:
print(f"Root files: {', '.join(root_files)}")
def cmd_info(api, repo_id, token, step):
steps, _ = list_checkpoints(api, repo_id, token)
if not steps:
print("No checkpoints to inspect.", file=sys.stderr)
sys.exit(1)
target = step if step is not None else steps[-1]
if target not in steps:
print(f"checkpoint-{target} not on hub. Available: {steps}", file=sys.stderr)
sys.exit(1)
print(f"Inspecting checkpoint-{target}...")
state = fetch_trainer_state(api, repo_id, target, token, "/tmp/_hub_inspect")
print()
print(f" global_step : {state.get('global_step')}")
print(f" epoch : {state.get('epoch'):.4f}" if state.get("epoch") is not None else " epoch : ?")
print(f" max_steps : {state.get('max_steps')}")
print(f" best_metric : {state.get('best_metric')}")
print(f" total_flos : {state.get('total_flos')}")
log_history = state.get("log_history", [])
if log_history:
print(f" log entries : {len(log_history)}")
last = log_history[-1]
print()
print(" Most recent log entry:")
for k in ("loss", "learning_rate", "grad_norm", "reward", "kl", "step"):
if k in last:
v = last[k]
if isinstance(v, float):
print(f" {k:18}: {v:.6f}")
else:
print(f" {k:18}: {v}")
pct = (target / state["max_steps"] * 100) if state.get("max_steps") else None
if pct is not None:
print()
print(f"Progress: {target} / {state['max_steps']} steps ({pct:.1f}% done)")
def cmd_download(api, repo_id, token, target_dir, step):
from huggingface_hub import snapshot_download
steps, _ = list_checkpoints(api, repo_id, token)
if not steps:
print("Nothing to download — no checkpoints on hub.", file=sys.stderr)
sys.exit(1)
chosen = step if step is not None else steps[-1]
if chosen not in steps:
print(f"checkpoint-{chosen} not on hub. Available: {steps}", file=sys.stderr)
sys.exit(1)
os.makedirs(target_dir, exist_ok=True)
print(f"Downloading checkpoint-{chosen} from {repo_id} -> {target_dir}/")
local = snapshot_download(
repo_id=repo_id,
allow_patterns=[f"checkpoint-{chosen}/*"],
local_dir=target_dir,
token=token,
)
final = os.path.join(local, f"checkpoint-{chosen}")
print()
print(f"Done. Local path: {final}")
print()
print("To resume training from this checkpoint:")
print(f" python -m training.train \\")
print(f" --resume-from-checkpoint {final} \\")
print(f" --push-to-hub --hub-model-id {repo_id} \\")
print(f" --output-dir ./lora_v1")
def main():
args = parse_args()
try:
from huggingface_hub import HfApi
except ImportError:
print("pip install huggingface_hub", file=sys.stderr)
sys.exit(1)
api = HfApi()
if args.download:
cmd_download(api, args.hub_model_id, args.token, args.download, args.checkpoint)
elif args.info:
cmd_info(api, args.hub_model_id, args.token, args.checkpoint)
else:
cmd_list(api, args.hub_model_id, args.token)
if __name__ == "__main__":
main()