sinhala-tts / scripts /upload_latest_checkpoint_large_folder.py
outlawmold's picture
Use parallel hf_transfer for model uploads
3b37854
#!/usr/bin/env python3
"""Upload the latest stable F5 checkpoint to Hugging Face.
By default this publishes a reduced EMA-only safetensors artifact. The full
training checkpoint remains local for resume, while the Hub upload is much
smaller and compatible with F5-TTS inference loaders.
"""
from __future__ import annotations
import argparse
import os
import re
import shutil
import sys
import time
from pathlib import Path
DEFAULT_REPO_ID = "outlawmold/sinhala-f5-tts"
DEFAULT_CKPT_DIR = Path(".venv310/Lib/ckpts/sinhala_tts_batch03")
DEFAULT_STAGE_DIR = Path(".hf_upload/sinhala-f5-tts-checkpoint")
def configure_env(args: argparse.Namespace) -> None:
# Environment variables are read when huggingface_hub is imported.
os.environ.pop("HF_XET_HIGH_PERFORMANCE", None)
if args.backend in {"hf-transfer", "lfs-transfer"}:
os.environ["HF_HUB_DISABLE_XET"] = "1"
if args.backend == "lfs-transfer":
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
else:
os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
else:
os.environ.pop("HF_HUB_DISABLE_XET", None)
os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
os.environ.setdefault("HF_XET_CACHE", str(Path(args.xet_cache).resolve()))
os.environ.setdefault("HF_XET_CHUNK_CACHE_SIZE_BYTES", "0")
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Resumably upload the latest stable checkpoint to HF Hub")
p.add_argument("--repo-id", default=DEFAULT_REPO_ID)
p.add_argument("--repo-type", default="model", choices=["model", "dataset", "space"])
p.add_argument("--checkpoint-dir", default=str(DEFAULT_CKPT_DIR))
p.add_argument("--checkpoint", default="", help="Specific checkpoint path. Defaults to newest model_<step>.pt")
p.add_argument("--stage-dir", default=str(DEFAULT_STAGE_DIR))
p.add_argument(
"--publish-format",
default="ema-safetensors",
choices=["ema-safetensors", "ema-pt", "training-pt"],
help="ema-* uploads a reduced inference/finetune artifact; training-pt uploads the full resume checkpoint",
)
p.add_argument("--xet-cache", default=r"C:\tmp\hf_xet_cache")
p.add_argument("--num-workers", type=int, default=1, help="Keep low on slow links for cleaner resume behavior")
p.add_argument("--report-every", type=int, default=30)
p.add_argument(
"--backend",
default="hf-transfer",
choices=["large-folder", "hf-transfer", "lfs-transfer"],
help="hf-transfer patches Hub LFS multipart uploads to use Rust parallel transfer",
)
p.add_argument("--transfer-workers", type=int, default=8, help="Parallel hf_transfer part uploads")
p.add_argument("--dry-run", action="store_true")
return p.parse_args()
def latest_stable_checkpoint(ckpt_dir: Path) -> Path:
candidates: list[tuple[int, Path]] = []
for path in ckpt_dir.glob("model_*.pt"):
match = re.fullmatch(r"model_(\d+)\.pt", path.name)
if match:
candidates.append((int(match.group(1)), path))
if not candidates:
raise FileNotFoundError(f"No stable model_<step>.pt checkpoints found in {ckpt_dir}")
return max(candidates, key=lambda item: item[0])[1]
def replace_link_or_copy(src: Path, dst: Path) -> str:
if dst.exists() or dst.is_symlink():
dst.unlink()
try:
os.link(src, dst)
return "hardlink"
except OSError:
shutil.copy2(src, dst)
return "copy"
def reduced_name(src: Path, suffix: str) -> str:
return f"{src.stem}_ema{suffix}"
def prune_to_ema(src: Path, dst: Path, safetensors: bool) -> None:
import torch
ckpt = torch.load(src, map_location="cpu", weights_only=True, mmap=True)
if "ema_model_state_dict" not in ckpt:
raise KeyError(f"ema_model_state_dict not found in {src}")
ema = ckpt["ema_model_state_dict"]
if safetensors:
from safetensors.torch import save_file
save_file(ema, str(dst), metadata={"format": "pt", "source_checkpoint": src.name})
else:
torch.save({"ema_model_state_dict": ema, "source_checkpoint": src.name}, dst)
def patch_lfs_upload_with_hf_transfer(max_files: int) -> None:
"""Patch Hub multipart LFS uploads to use hf_transfer concurrency.
huggingface_hub 1.13 no longer honors HF_HUB_ENABLE_HF_TRANSFER and uploads
multipart LFS chunks serially. hf_transfer can upload the signed part URLs
concurrently and returns the response headers required by the completion
request.
"""
import hf_transfer
import huggingface_hub.lfs as lfs
original = lfs._upload_parts_iteratively
def upload_parts_hf_transfer(operation, sorted_parts_urls: list[str], chunk_size: int) -> list[dict]:
file_path = operation.path_or_fileobj
if not isinstance(file_path, (str, Path)):
return original(operation, sorted_parts_urls, chunk_size)
print(
f"[transfer] hf_transfer multipart: parts={len(sorted_parts_urls)} "
f"chunk_size={chunk_size} workers={max_files}",
flush=True,
)
return hf_transfer.multipart_upload(
file_path=str(Path(file_path)),
parts_urls=sorted_parts_urls,
chunk_size=chunk_size,
max_files=max_files,
parallel_failures=3,
max_retries=5,
)
lfs._upload_parts_iteratively = upload_parts_hf_transfer
def stage_checkpoint(src: Path, stage_dir: Path, publish_format: str) -> Path:
stage_dir.mkdir(parents=True, exist_ok=True)
if publish_format == "ema-safetensors":
remote_name = reduced_name(src, ".safetensors")
elif publish_format == "ema-pt":
remote_name = reduced_name(src, ".pt")
else:
remote_name = src.name
staged = stage_dir / remote_name
# Remove stale checkpoint links/copies, but preserve .cache for resumability.
for path in list(stage_dir.glob("model_*.pt")) + list(stage_dir.glob("model_*.safetensors")):
if path.name != remote_name:
path.unlink()
if publish_format == "training-pt":
mode = replace_link_or_copy(src, staged)
print(f"[stage] {mode}: {src} -> {staged}", flush=True)
else:
if not staged.exists() or staged.stat().st_mtime < src.stat().st_mtime:
start = time.time()
print(f"[prune] creating reduced EMA artifact: {staged}", flush=True)
prune_to_ema(src, staged, safetensors=publish_format == "ema-safetensors")
print(f"[prune] done in {time.time() - start:.1f}s", flush=True)
else:
print(f"[stage] reusing reduced artifact: {staged}", flush=True)
return staged
def main() -> int:
args = parse_args()
configure_env(args)
ckpt = Path(args.checkpoint) if args.checkpoint else latest_stable_checkpoint(Path(args.checkpoint_dir))
ckpt = ckpt.resolve()
if not ckpt.exists():
raise FileNotFoundError(ckpt)
stage_dir = Path(args.stage_dir).resolve()
staged = stage_checkpoint(ckpt, stage_dir, args.publish_format)
size_gb = ckpt.stat().st_size / (1024 ** 3)
staged_gb = staged.stat().st_size / (1024 ** 3)
print(f"[select] checkpoint={ckpt.name} size={size_gb:.2f} GiB", flush=True)
print(f"[select] artifact={staged.name} size={staged_gb:.2f} GiB", flush=True)
print(f"[config] publish_format={args.publish_format}", flush=True)
print(f"[config] repo={args.repo_id} workers={args.num_workers}", flush=True)
print(f"[config] backend={args.backend}", flush=True)
print(f"[config] transfer_workers={args.transfer_workers}", flush=True)
print(f"[config] stage={stage_dir}", flush=True)
print(f"[config] HF_XET_CACHE={os.environ.get('HF_XET_CACHE')}", flush=True)
print(f"[config] HF_XET_CHUNK_CACHE_SIZE_BYTES={os.environ.get('HF_XET_CHUNK_CACHE_SIZE_BYTES')}", flush=True)
print(f"[config] HF_HUB_DISABLE_XET={os.environ.get('HF_HUB_DISABLE_XET')}", flush=True)
print(f"[config] HF_HUB_ENABLE_HF_TRANSFER={os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}", flush=True)
if args.dry_run:
print(f"[dry-run] staged {staged.name}; upload not started", flush=True)
return 0
from huggingface_hub import HfApi, upload_file
if args.backend == "hf-transfer":
patch_lfs_upload_with_hf_transfer(args.transfer_workers)
api = HfApi()
if api.file_exists(repo_id=args.repo_id, repo_type=args.repo_type, filename=staged.name):
print(f"[done] remote already has {staged.name}", flush=True)
return 0
if args.backend in {"hf-transfer", "lfs-transfer"}:
upload_file(
path_or_fileobj=str(staged),
path_in_repo=staged.name,
repo_id=args.repo_id,
repo_type=args.repo_type,
commit_message=f"Upload {staged.name} via hf_transfer",
)
else:
api.upload_large_folder(
repo_id=args.repo_id,
repo_type=args.repo_type,
folder_path=stage_dir,
allow_patterns=staged.name,
num_workers=args.num_workers,
print_report=True,
print_report_every=args.report_every,
)
if not api.file_exists(repo_id=args.repo_id, repo_type=args.repo_type, filename=staged.name):
print(f"[error] upload finished but remote file was not found: {staged.name}", file=sys.stderr, flush=True)
return 1
print(f"[done] uploaded and verified {staged.name}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())