| |
| """Inspect / recover GRPO training checkpoints from the Hugging Face Hub. |
| |
| Use this to answer "did Colab/Kaggle save anything before it died?" and to |
| pull the latest checkpoint locally for manual inspection or DGX resume. |
| |
| Examples |
| -------- |
| List what's on the Hub: |
| python scripts/check_hub_checkpoints.py --hub-model-id noanya/zombiee |
| |
| Show training progress (step, loss, learning rate from trainer_state.json): |
| python scripts/check_hub_checkpoints.py --hub-model-id noanya/zombiee --info |
| |
| Download the latest checkpoint to ./recovered/: |
| python scripts/check_hub_checkpoints.py --hub-model-id noanya/zombiee \\ |
| --download ./recovered |
| |
| Then resume training from it: |
| python -m training.train \\ |
| --resume-from-checkpoint ./recovered \\ |
| --push-to-hub --hub-model-id noanya/zombiee \\ |
| --max-steps 4000 --output-dir ./lora_v1 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from datetime import datetime, timezone |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser( |
| description="List / download GRPO training checkpoints from HF Hub.", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=__doc__, |
| ) |
| p.add_argument( |
| "--hub-model-id", default=os.environ.get("HUB_MODEL_ID", "noanya/zombiee"), |
| help="HF Hub repo id, e.g. 'noanya/zombiee' (default: $HUB_MODEL_ID or noanya/zombiee).", |
| ) |
| p.add_argument( |
| "--info", action="store_true", |
| help="Read trainer_state.json from the latest checkpoint and print training progress.", |
| ) |
| p.add_argument( |
| "--download", metavar="DIR", default=None, |
| help="Download the latest checkpoint to this directory.", |
| ) |
| p.add_argument( |
| "--checkpoint", metavar="N", type=int, default=None, |
| help="Operate on checkpoint-N specifically instead of the latest.", |
| ) |
| p.add_argument( |
| "--token", default=os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HF_TOKEN"), |
| help="HF token (default: $HUGGINGFACE_TOKEN / $HF_TOKEN). Required for private repos.", |
| ) |
| return p.parse_args() |
|
|
|
|
| def list_checkpoints(api, repo_id, token): |
| """Return (sorted list of checkpoint step numbers, list of root files).""" |
| try: |
| files = api.list_repo_files(repo_id, token=token) |
| except Exception as e: |
| print(f"ERROR: could not list {repo_id}: {e}", file=sys.stderr) |
| print( |
| " If the repo is private, set HUGGINGFACE_TOKEN. If it doesn't exist yet,\n" |
| " no training run has pushed to it.", |
| file=sys.stderr, |
| ) |
| sys.exit(1) |
|
|
| steps = set() |
| root_files = [] |
| for f in files: |
| if f.startswith("checkpoint-"): |
| try: |
| steps.add(int(f.split("/", 1)[0].split("-", 1)[1])) |
| except ValueError: |
| pass |
| elif "/" not in f: |
| root_files.append(f) |
| return sorted(steps), sorted(root_files) |
|
|
|
|
| def fetch_trainer_state(api, repo_id, step, token, work_dir): |
| """Download trainer_state.json from checkpoint-step and return parsed dict.""" |
| from huggingface_hub import hf_hub_download |
|
|
| path = hf_hub_download( |
| repo_id=repo_id, |
| filename=f"checkpoint-{step}/trainer_state.json", |
| local_dir=work_dir, |
| token=token, |
| ) |
| with open(path) as f: |
| return json.load(f) |
|
|
|
|
| def fmt_age(iso_or_dt): |
| """Render 'X minutes/hours/days ago' from an HF datetime.""" |
| if isinstance(iso_or_dt, str): |
| try: |
| dt = datetime.fromisoformat(iso_or_dt.replace("Z", "+00:00")) |
| except ValueError: |
| return iso_or_dt |
| else: |
| dt = iso_or_dt |
| if dt.tzinfo is None: |
| dt = dt.replace(tzinfo=timezone.utc) |
| delta = datetime.now(timezone.utc) - dt |
| s = int(delta.total_seconds()) |
| if s < 60: |
| return f"{s}s ago" |
| if s < 3600: |
| return f"{s // 60}m ago" |
| if s < 86400: |
| return f"{s // 3600}h {(s % 3600) // 60}m ago" |
| return f"{s // 86400}d {(s % 86400) // 3600}h ago" |
|
|
|
|
| def cmd_list(api, repo_id, token): |
| steps, root_files = list_checkpoints(api, repo_id, token) |
|
|
| print(f"Repo: https://huggingface.co/{repo_id}") |
| try: |
| info = api.repo_info(repo_id, token=token) |
| print(f"Last commit: {info.sha[:8]} ({fmt_age(info.lastModified)})") |
| except Exception: |
| pass |
|
|
| print() |
| if not steps: |
| print("No checkpoint-* directories found.") |
| if root_files: |
| print(f"Root files present: {', '.join(root_files)}") |
| print("(Looks like only a final-model push, no intermediate checkpoints.)") |
| else: |
| print("Repo is empty — training never reached the first save.") |
| return |
|
|
| print(f"Found {len(steps)} checkpoint(s): {', '.join(f'checkpoint-{s}' for s in steps)}") |
| print(f"Latest: checkpoint-{steps[-1]}") |
| if root_files: |
| print(f"Root files: {', '.join(root_files)}") |
|
|
|
|
| def cmd_info(api, repo_id, token, step): |
| steps, _ = list_checkpoints(api, repo_id, token) |
| if not steps: |
| print("No checkpoints to inspect.", file=sys.stderr) |
| sys.exit(1) |
| target = step if step is not None else steps[-1] |
| if target not in steps: |
| print(f"checkpoint-{target} not on hub. Available: {steps}", file=sys.stderr) |
| sys.exit(1) |
|
|
| print(f"Inspecting checkpoint-{target}...") |
| state = fetch_trainer_state(api, repo_id, target, token, "/tmp/_hub_inspect") |
|
|
| print() |
| print(f" global_step : {state.get('global_step')}") |
| print(f" epoch : {state.get('epoch'):.4f}" if state.get("epoch") is not None else " epoch : ?") |
| print(f" max_steps : {state.get('max_steps')}") |
| print(f" best_metric : {state.get('best_metric')}") |
| print(f" total_flos : {state.get('total_flos')}") |
|
|
| log_history = state.get("log_history", []) |
| if log_history: |
| print(f" log entries : {len(log_history)}") |
| last = log_history[-1] |
| print() |
| print(" Most recent log entry:") |
| for k in ("loss", "learning_rate", "grad_norm", "reward", "kl", "step"): |
| if k in last: |
| v = last[k] |
| if isinstance(v, float): |
| print(f" {k:18}: {v:.6f}") |
| else: |
| print(f" {k:18}: {v}") |
|
|
| pct = (target / state["max_steps"] * 100) if state.get("max_steps") else None |
| if pct is not None: |
| print() |
| print(f"Progress: {target} / {state['max_steps']} steps ({pct:.1f}% done)") |
|
|
|
|
| def cmd_download(api, repo_id, token, target_dir, step): |
| from huggingface_hub import snapshot_download |
|
|
| steps, _ = list_checkpoints(api, repo_id, token) |
| if not steps: |
| print("Nothing to download — no checkpoints on hub.", file=sys.stderr) |
| sys.exit(1) |
| chosen = step if step is not None else steps[-1] |
| if chosen not in steps: |
| print(f"checkpoint-{chosen} not on hub. Available: {steps}", file=sys.stderr) |
| sys.exit(1) |
|
|
| os.makedirs(target_dir, exist_ok=True) |
| print(f"Downloading checkpoint-{chosen} from {repo_id} -> {target_dir}/") |
| local = snapshot_download( |
| repo_id=repo_id, |
| allow_patterns=[f"checkpoint-{chosen}/*"], |
| local_dir=target_dir, |
| token=token, |
| ) |
| final = os.path.join(local, f"checkpoint-{chosen}") |
| print() |
| print(f"Done. Local path: {final}") |
| print() |
| print("To resume training from this checkpoint:") |
| print(f" python -m training.train \\") |
| print(f" --resume-from-checkpoint {final} \\") |
| print(f" --push-to-hub --hub-model-id {repo_id} \\") |
| print(f" --output-dir ./lora_v1") |
|
|
|
|
| def main(): |
| args = parse_args() |
| try: |
| from huggingface_hub import HfApi |
| except ImportError: |
| print("pip install huggingface_hub", file=sys.stderr) |
| sys.exit(1) |
|
|
| api = HfApi() |
|
|
| if args.download: |
| cmd_download(api, args.hub_model_id, args.token, args.download, args.checkpoint) |
| elif args.info: |
| cmd_info(api, args.hub_model_id, args.token, args.checkpoint) |
| else: |
| cmd_list(api, args.hub_model_id, args.token) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|