| |
| """Upload the latest stable F5 checkpoint to Hugging Face. |
| |
| By default this publishes a reduced EMA-only safetensors artifact. The full |
| training checkpoint remains local for resume, while the Hub upload is much |
| smaller and compatible with F5-TTS inference loaders. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import re |
| import shutil |
| import sys |
| import time |
| from pathlib import Path |
|
|
|
|
| DEFAULT_REPO_ID = "outlawmold/sinhala-f5-tts" |
| DEFAULT_CKPT_DIR = Path(".venv310/Lib/ckpts/sinhala_tts_batch03") |
| DEFAULT_STAGE_DIR = Path(".hf_upload/sinhala-f5-tts-checkpoint") |
|
|
|
|
| def configure_env(args: argparse.Namespace) -> None: |
| |
| os.environ.pop("HF_XET_HIGH_PERFORMANCE", None) |
| if args.backend in {"hf-transfer", "lfs-transfer"}: |
| os.environ["HF_HUB_DISABLE_XET"] = "1" |
| if args.backend == "lfs-transfer": |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
| else: |
| os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None) |
| else: |
| os.environ.pop("HF_HUB_DISABLE_XET", None) |
| os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None) |
| os.environ.setdefault("HF_XET_CACHE", str(Path(args.xet_cache).resolve())) |
| os.environ.setdefault("HF_XET_CHUNK_CACHE_SIZE_BYTES", "0") |
| os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Resumably upload the latest stable checkpoint to HF Hub") |
| p.add_argument("--repo-id", default=DEFAULT_REPO_ID) |
| p.add_argument("--repo-type", default="model", choices=["model", "dataset", "space"]) |
| p.add_argument("--checkpoint-dir", default=str(DEFAULT_CKPT_DIR)) |
| p.add_argument("--checkpoint", default="", help="Specific checkpoint path. Defaults to newest model_<step>.pt") |
| p.add_argument("--stage-dir", default=str(DEFAULT_STAGE_DIR)) |
| p.add_argument( |
| "--publish-format", |
| default="ema-safetensors", |
| choices=["ema-safetensors", "ema-pt", "training-pt"], |
| help="ema-* uploads a reduced inference/finetune artifact; training-pt uploads the full resume checkpoint", |
| ) |
| p.add_argument("--xet-cache", default=r"C:\tmp\hf_xet_cache") |
| p.add_argument("--num-workers", type=int, default=1, help="Keep low on slow links for cleaner resume behavior") |
| p.add_argument("--report-every", type=int, default=30) |
| p.add_argument( |
| "--backend", |
| default="hf-transfer", |
| choices=["large-folder", "hf-transfer", "lfs-transfer"], |
| help="hf-transfer patches Hub LFS multipart uploads to use Rust parallel transfer", |
| ) |
| p.add_argument("--transfer-workers", type=int, default=8, help="Parallel hf_transfer part uploads") |
| p.add_argument("--dry-run", action="store_true") |
| return p.parse_args() |
|
|
|
|
| def latest_stable_checkpoint(ckpt_dir: Path) -> Path: |
| candidates: list[tuple[int, Path]] = [] |
| for path in ckpt_dir.glob("model_*.pt"): |
| match = re.fullmatch(r"model_(\d+)\.pt", path.name) |
| if match: |
| candidates.append((int(match.group(1)), path)) |
| if not candidates: |
| raise FileNotFoundError(f"No stable model_<step>.pt checkpoints found in {ckpt_dir}") |
| return max(candidates, key=lambda item: item[0])[1] |
|
|
|
|
| def replace_link_or_copy(src: Path, dst: Path) -> str: |
| if dst.exists() or dst.is_symlink(): |
| dst.unlink() |
| try: |
| os.link(src, dst) |
| return "hardlink" |
| except OSError: |
| shutil.copy2(src, dst) |
| return "copy" |
|
|
|
|
| def reduced_name(src: Path, suffix: str) -> str: |
| return f"{src.stem}_ema{suffix}" |
|
|
|
|
| def prune_to_ema(src: Path, dst: Path, safetensors: bool) -> None: |
| import torch |
|
|
| ckpt = torch.load(src, map_location="cpu", weights_only=True, mmap=True) |
| if "ema_model_state_dict" not in ckpt: |
| raise KeyError(f"ema_model_state_dict not found in {src}") |
| ema = ckpt["ema_model_state_dict"] |
| if safetensors: |
| from safetensors.torch import save_file |
|
|
| save_file(ema, str(dst), metadata={"format": "pt", "source_checkpoint": src.name}) |
| else: |
| torch.save({"ema_model_state_dict": ema, "source_checkpoint": src.name}, dst) |
|
|
|
|
| def patch_lfs_upload_with_hf_transfer(max_files: int) -> None: |
| """Patch Hub multipart LFS uploads to use hf_transfer concurrency. |
| |
| huggingface_hub 1.13 no longer honors HF_HUB_ENABLE_HF_TRANSFER and uploads |
| multipart LFS chunks serially. hf_transfer can upload the signed part URLs |
| concurrently and returns the response headers required by the completion |
| request. |
| """ |
|
|
| import hf_transfer |
| import huggingface_hub.lfs as lfs |
|
|
| original = lfs._upload_parts_iteratively |
|
|
| def upload_parts_hf_transfer(operation, sorted_parts_urls: list[str], chunk_size: int) -> list[dict]: |
| file_path = operation.path_or_fileobj |
| if not isinstance(file_path, (str, Path)): |
| return original(operation, sorted_parts_urls, chunk_size) |
|
|
| print( |
| f"[transfer] hf_transfer multipart: parts={len(sorted_parts_urls)} " |
| f"chunk_size={chunk_size} workers={max_files}", |
| flush=True, |
| ) |
| return hf_transfer.multipart_upload( |
| file_path=str(Path(file_path)), |
| parts_urls=sorted_parts_urls, |
| chunk_size=chunk_size, |
| max_files=max_files, |
| parallel_failures=3, |
| max_retries=5, |
| ) |
|
|
| lfs._upload_parts_iteratively = upload_parts_hf_transfer |
|
|
|
|
| def stage_checkpoint(src: Path, stage_dir: Path, publish_format: str) -> Path: |
| stage_dir.mkdir(parents=True, exist_ok=True) |
| if publish_format == "ema-safetensors": |
| remote_name = reduced_name(src, ".safetensors") |
| elif publish_format == "ema-pt": |
| remote_name = reduced_name(src, ".pt") |
| else: |
| remote_name = src.name |
| staged = stage_dir / remote_name |
|
|
| |
| for path in list(stage_dir.glob("model_*.pt")) + list(stage_dir.glob("model_*.safetensors")): |
| if path.name != remote_name: |
| path.unlink() |
|
|
| if publish_format == "training-pt": |
| mode = replace_link_or_copy(src, staged) |
| print(f"[stage] {mode}: {src} -> {staged}", flush=True) |
| else: |
| if not staged.exists() or staged.stat().st_mtime < src.stat().st_mtime: |
| start = time.time() |
| print(f"[prune] creating reduced EMA artifact: {staged}", flush=True) |
| prune_to_ema(src, staged, safetensors=publish_format == "ema-safetensors") |
| print(f"[prune] done in {time.time() - start:.1f}s", flush=True) |
| else: |
| print(f"[stage] reusing reduced artifact: {staged}", flush=True) |
| return staged |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| configure_env(args) |
|
|
| ckpt = Path(args.checkpoint) if args.checkpoint else latest_stable_checkpoint(Path(args.checkpoint_dir)) |
| ckpt = ckpt.resolve() |
| if not ckpt.exists(): |
| raise FileNotFoundError(ckpt) |
|
|
| stage_dir = Path(args.stage_dir).resolve() |
| staged = stage_checkpoint(ckpt, stage_dir, args.publish_format) |
|
|
| size_gb = ckpt.stat().st_size / (1024 ** 3) |
| staged_gb = staged.stat().st_size / (1024 ** 3) |
| print(f"[select] checkpoint={ckpt.name} size={size_gb:.2f} GiB", flush=True) |
| print(f"[select] artifact={staged.name} size={staged_gb:.2f} GiB", flush=True) |
| print(f"[config] publish_format={args.publish_format}", flush=True) |
| print(f"[config] repo={args.repo_id} workers={args.num_workers}", flush=True) |
| print(f"[config] backend={args.backend}", flush=True) |
| print(f"[config] transfer_workers={args.transfer_workers}", flush=True) |
| print(f"[config] stage={stage_dir}", flush=True) |
| print(f"[config] HF_XET_CACHE={os.environ.get('HF_XET_CACHE')}", flush=True) |
| print(f"[config] HF_XET_CHUNK_CACHE_SIZE_BYTES={os.environ.get('HF_XET_CHUNK_CACHE_SIZE_BYTES')}", flush=True) |
| print(f"[config] HF_HUB_DISABLE_XET={os.environ.get('HF_HUB_DISABLE_XET')}", flush=True) |
| print(f"[config] HF_HUB_ENABLE_HF_TRANSFER={os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}", flush=True) |
|
|
| if args.dry_run: |
| print(f"[dry-run] staged {staged.name}; upload not started", flush=True) |
| return 0 |
|
|
| from huggingface_hub import HfApi, upload_file |
|
|
| if args.backend == "hf-transfer": |
| patch_lfs_upload_with_hf_transfer(args.transfer_workers) |
|
|
| api = HfApi() |
| if api.file_exists(repo_id=args.repo_id, repo_type=args.repo_type, filename=staged.name): |
| print(f"[done] remote already has {staged.name}", flush=True) |
| return 0 |
|
|
| if args.backend in {"hf-transfer", "lfs-transfer"}: |
| upload_file( |
| path_or_fileobj=str(staged), |
| path_in_repo=staged.name, |
| repo_id=args.repo_id, |
| repo_type=args.repo_type, |
| commit_message=f"Upload {staged.name} via hf_transfer", |
| ) |
| else: |
| api.upload_large_folder( |
| repo_id=args.repo_id, |
| repo_type=args.repo_type, |
| folder_path=stage_dir, |
| allow_patterns=staged.name, |
| num_workers=args.num_workers, |
| print_report=True, |
| print_report_every=args.report_every, |
| ) |
|
|
| if not api.file_exists(repo_id=args.repo_id, repo_type=args.repo_type, filename=staged.name): |
| print(f"[error] upload finished but remote file was not found: {staged.name}", file=sys.stderr, flush=True) |
| return 1 |
|
|
| print(f"[done] uploaded and verified {staged.name}", flush=True) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|