#!/usr/bin/env python3 """Upload the latest stable F5 checkpoint to Hugging Face. By default this publishes a reduced EMA-only safetensors artifact. The full training checkpoint remains local for resume, while the Hub upload is much smaller and compatible with F5-TTS inference loaders. """ from __future__ import annotations import argparse import os import re import shutil import sys import time from pathlib import Path DEFAULT_REPO_ID = "outlawmold/sinhala-f5-tts" DEFAULT_CKPT_DIR = Path(".venv310/Lib/ckpts/sinhala_tts_batch03") DEFAULT_STAGE_DIR = Path(".hf_upload/sinhala-f5-tts-checkpoint") def configure_env(args: argparse.Namespace) -> None: # Environment variables are read when huggingface_hub is imported. os.environ.pop("HF_XET_HIGH_PERFORMANCE", None) if args.backend in {"hf-transfer", "lfs-transfer"}: os.environ["HF_HUB_DISABLE_XET"] = "1" if args.backend == "lfs-transfer": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" else: os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None) else: os.environ.pop("HF_HUB_DISABLE_XET", None) os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None) os.environ.setdefault("HF_XET_CACHE", str(Path(args.xet_cache).resolve())) os.environ.setdefault("HF_XET_CHUNK_CACHE_SIZE_BYTES", "0") os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Resumably upload the latest stable checkpoint to HF Hub") p.add_argument("--repo-id", default=DEFAULT_REPO_ID) p.add_argument("--repo-type", default="model", choices=["model", "dataset", "space"]) p.add_argument("--checkpoint-dir", default=str(DEFAULT_CKPT_DIR)) p.add_argument("--checkpoint", default="", help="Specific checkpoint path. Defaults to newest model_.pt") p.add_argument("--stage-dir", default=str(DEFAULT_STAGE_DIR)) p.add_argument( "--publish-format", default="ema-safetensors", choices=["ema-safetensors", "ema-pt", "training-pt"], help="ema-* uploads a reduced inference/finetune artifact; training-pt uploads the full resume checkpoint", ) p.add_argument("--xet-cache", default=r"C:\tmp\hf_xet_cache") p.add_argument("--num-workers", type=int, default=1, help="Keep low on slow links for cleaner resume behavior") p.add_argument("--report-every", type=int, default=30) p.add_argument( "--backend", default="hf-transfer", choices=["large-folder", "hf-transfer", "lfs-transfer"], help="hf-transfer patches Hub LFS multipart uploads to use Rust parallel transfer", ) p.add_argument("--transfer-workers", type=int, default=8, help="Parallel hf_transfer part uploads") p.add_argument("--dry-run", action="store_true") return p.parse_args() def latest_stable_checkpoint(ckpt_dir: Path) -> Path: candidates: list[tuple[int, Path]] = [] for path in ckpt_dir.glob("model_*.pt"): match = re.fullmatch(r"model_(\d+)\.pt", path.name) if match: candidates.append((int(match.group(1)), path)) if not candidates: raise FileNotFoundError(f"No stable model_.pt checkpoints found in {ckpt_dir}") return max(candidates, key=lambda item: item[0])[1] def replace_link_or_copy(src: Path, dst: Path) -> str: if dst.exists() or dst.is_symlink(): dst.unlink() try: os.link(src, dst) return "hardlink" except OSError: shutil.copy2(src, dst) return "copy" def reduced_name(src: Path, suffix: str) -> str: return f"{src.stem}_ema{suffix}" def prune_to_ema(src: Path, dst: Path, safetensors: bool) -> None: import torch ckpt = torch.load(src, map_location="cpu", weights_only=True, mmap=True) if "ema_model_state_dict" not in ckpt: raise KeyError(f"ema_model_state_dict not found in {src}") ema = ckpt["ema_model_state_dict"] if safetensors: from safetensors.torch import save_file save_file(ema, str(dst), metadata={"format": "pt", "source_checkpoint": src.name}) else: torch.save({"ema_model_state_dict": ema, "source_checkpoint": src.name}, dst) def patch_lfs_upload_with_hf_transfer(max_files: int) -> None: """Patch Hub multipart LFS uploads to use hf_transfer concurrency. huggingface_hub 1.13 no longer honors HF_HUB_ENABLE_HF_TRANSFER and uploads multipart LFS chunks serially. hf_transfer can upload the signed part URLs concurrently and returns the response headers required by the completion request. """ import hf_transfer import huggingface_hub.lfs as lfs original = lfs._upload_parts_iteratively def upload_parts_hf_transfer(operation, sorted_parts_urls: list[str], chunk_size: int) -> list[dict]: file_path = operation.path_or_fileobj if not isinstance(file_path, (str, Path)): return original(operation, sorted_parts_urls, chunk_size) print( f"[transfer] hf_transfer multipart: parts={len(sorted_parts_urls)} " f"chunk_size={chunk_size} workers={max_files}", flush=True, ) return hf_transfer.multipart_upload( file_path=str(Path(file_path)), parts_urls=sorted_parts_urls, chunk_size=chunk_size, max_files=max_files, parallel_failures=3, max_retries=5, ) lfs._upload_parts_iteratively = upload_parts_hf_transfer def stage_checkpoint(src: Path, stage_dir: Path, publish_format: str) -> Path: stage_dir.mkdir(parents=True, exist_ok=True) if publish_format == "ema-safetensors": remote_name = reduced_name(src, ".safetensors") elif publish_format == "ema-pt": remote_name = reduced_name(src, ".pt") else: remote_name = src.name staged = stage_dir / remote_name # Remove stale checkpoint links/copies, but preserve .cache for resumability. for path in list(stage_dir.glob("model_*.pt")) + list(stage_dir.glob("model_*.safetensors")): if path.name != remote_name: path.unlink() if publish_format == "training-pt": mode = replace_link_or_copy(src, staged) print(f"[stage] {mode}: {src} -> {staged}", flush=True) else: if not staged.exists() or staged.stat().st_mtime < src.stat().st_mtime: start = time.time() print(f"[prune] creating reduced EMA artifact: {staged}", flush=True) prune_to_ema(src, staged, safetensors=publish_format == "ema-safetensors") print(f"[prune] done in {time.time() - start:.1f}s", flush=True) else: print(f"[stage] reusing reduced artifact: {staged}", flush=True) return staged def main() -> int: args = parse_args() configure_env(args) ckpt = Path(args.checkpoint) if args.checkpoint else latest_stable_checkpoint(Path(args.checkpoint_dir)) ckpt = ckpt.resolve() if not ckpt.exists(): raise FileNotFoundError(ckpt) stage_dir = Path(args.stage_dir).resolve() staged = stage_checkpoint(ckpt, stage_dir, args.publish_format) size_gb = ckpt.stat().st_size / (1024 ** 3) staged_gb = staged.stat().st_size / (1024 ** 3) print(f"[select] checkpoint={ckpt.name} size={size_gb:.2f} GiB", flush=True) print(f"[select] artifact={staged.name} size={staged_gb:.2f} GiB", flush=True) print(f"[config] publish_format={args.publish_format}", flush=True) print(f"[config] repo={args.repo_id} workers={args.num_workers}", flush=True) print(f"[config] backend={args.backend}", flush=True) print(f"[config] transfer_workers={args.transfer_workers}", flush=True) print(f"[config] stage={stage_dir}", flush=True) print(f"[config] HF_XET_CACHE={os.environ.get('HF_XET_CACHE')}", flush=True) print(f"[config] HF_XET_CHUNK_CACHE_SIZE_BYTES={os.environ.get('HF_XET_CHUNK_CACHE_SIZE_BYTES')}", flush=True) print(f"[config] HF_HUB_DISABLE_XET={os.environ.get('HF_HUB_DISABLE_XET')}", flush=True) print(f"[config] HF_HUB_ENABLE_HF_TRANSFER={os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}", flush=True) if args.dry_run: print(f"[dry-run] staged {staged.name}; upload not started", flush=True) return 0 from huggingface_hub import HfApi, upload_file if args.backend == "hf-transfer": patch_lfs_upload_with_hf_transfer(args.transfer_workers) api = HfApi() if api.file_exists(repo_id=args.repo_id, repo_type=args.repo_type, filename=staged.name): print(f"[done] remote already has {staged.name}", flush=True) return 0 if args.backend in {"hf-transfer", "lfs-transfer"}: upload_file( path_or_fileobj=str(staged), path_in_repo=staged.name, repo_id=args.repo_id, repo_type=args.repo_type, commit_message=f"Upload {staged.name} via hf_transfer", ) else: api.upload_large_folder( repo_id=args.repo_id, repo_type=args.repo_type, folder_path=stage_dir, allow_patterns=staged.name, num_workers=args.num_workers, print_report=True, print_report_every=args.report_every, ) if not api.file_exists(repo_id=args.repo_id, repo_type=args.repo_type, filename=staged.name): print(f"[error] upload finished but remote file was not found: {staged.name}", file=sys.stderr, flush=True) return 1 print(f"[done] uploaded and verified {staged.name}", flush=True) return 0 if __name__ == "__main__": raise SystemExit(main())