mumble-cleanup / scripts /push_to_hub.py
adikuma's picture
initial upload: cleanup code and 688-pair seed dataset
fd0b01f verified
Raw
History Blame Contribute Delete
3.66 kB
# push the trained cleanup model to the huggingface hub.
# uploads the merged transformers model + fp32 onnx + int8 onnx + model card
# in one commit. needs HF_TOKEN in .env.local (gitignored).
#
# usage: uv run python scripts/push_to_hub.py --run-id r1
# by default the repo is private. pass --public to flip it.
import argparse
import os
import sys
from pathlib import Path
from huggingface_hub import CommitOperationAdd, HfApi
# files copied from runs/<run-id>/merged into the repo root.
MODEL_FILES = [
"config.json",
"model.safetensors",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
]
def _load_dotenv(path: Path) -> None:
if not path.exists():
return
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = value
def main() -> None:
parser = argparse.ArgumentParser(description="publish cleanup model to the hf hub")
parser.add_argument("--runs-dir", default="runs")
parser.add_argument("--run-id", required=True)
parser.add_argument("--repo", default="adikuma/mumble-cleanup")
parser.add_argument("--card", default="docs/model_card.md")
parser.add_argument("--public", action="store_true", help="make the repo public")
args = parser.parse_args()
_load_dotenv(Path(".env.local"))
token = os.environ.get("HF_TOKEN")
if not token:
print("error: HF_TOKEN not set. add it to .env.local", file=sys.stderr)
sys.exit(1)
run_dir = Path(args.runs_dir) / args.run_id
merged_dir = run_dir / "merged"
fp32_onnx = run_dir / "onnx" / "model.onnx"
int8_onnx = run_dir / "onnx" / "int8" / "model.onnx"
card_path = Path(args.card)
missing = [str(merged_dir / n) for n in MODEL_FILES if not (merged_dir / n).exists()]
if not fp32_onnx.exists():
missing.append(str(fp32_onnx))
if not int8_onnx.exists():
# int8 is optional but warn loudly
print(f"warn: no int8 onnx at {int8_onnx}; uploading fp32 only")
if not card_path.exists():
missing.append(str(card_path))
if missing:
print("error: missing files:", file=sys.stderr)
for path in missing:
print(f" {path}", file=sys.stderr)
sys.exit(1)
api = HfApi(token=token)
private = not args.public
print(f"[hub] creating repo {args.repo} (private={private})")
api.create_repo(args.repo, repo_type="model", private=private, exist_ok=True)
uploads = {"README.md": card_path}
for name in MODEL_FILES:
local = merged_dir / name
if local.exists():
uploads[name] = local
uploads["onnx/model.onnx"] = fp32_onnx
if int8_onnx.exists():
uploads["onnx/int8/model.onnx"] = int8_onnx
total_mb = sum(p.stat().st_size for p in uploads.values()) / 1e6
print(f"[hub] uploading {len(uploads)} files, {total_mb:.0f} MB")
operations = [
CommitOperationAdd(path_in_repo=repo_path, path_or_fileobj=str(local_path))
for repo_path, local_path in uploads.items()
]
api.create_commit(
repo_id=args.repo,
repo_type="model",
operations=operations,
commit_message="add mumble cleanup model, fp32 + int8 onnx, model card",
)
print()
print(f"[hub] done: https://huggingface.co/{args.repo}")
if __name__ == "__main__":
main()