Remove nested directory: BitTransformerLM/bit_transformer/hf_checkpoint.py
Browse files
BitTransformerLM/bit_transformer/hf_checkpoint.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import gzip
|
| 4 |
-
import os
|
| 5 |
-
import shutil
|
| 6 |
-
import tempfile
|
| 7 |
-
from typing import Optional
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from huggingface_hub import HfApi, hf_hub_download, login
|
| 11 |
-
|
| 12 |
-
REPO_ID = "architect/bittransformerlm"
|
| 13 |
-
FILENAME = "model.pt.gz"
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def hf_login(token: Optional[str] = None) -> None:
|
| 17 |
-
"""Authenticate with Hugging Face.
|
| 18 |
-
|
| 19 |
-
The ``token`` may be provided directly or via the ``HF_TOKEN`` environment
|
| 20 |
-
variable. If omitted entirely, the library will attempt an interactive login.
|
| 21 |
-
"""
|
| 22 |
-
login(token=token)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def save_checkpoint(
|
| 26 |
-
model: torch.nn.Module,
|
| 27 |
-
*,
|
| 28 |
-
repo_id: str = REPO_ID,
|
| 29 |
-
filename: str = FILENAME,
|
| 30 |
-
) -> None:
|
| 31 |
-
"""Upload the model weights to ``repo_id`` under ``filename``.
|
| 32 |
-
|
| 33 |
-
The file within the repository is overwritten each time to avoid
|
| 34 |
-
accumulating checkpoints.
|
| 35 |
-
"""
|
| 36 |
-
with tempfile.TemporaryDirectory() as tmp:
|
| 37 |
-
tmp_pt = os.path.join(tmp, "model.pt")
|
| 38 |
-
tmp_gz = os.path.join(tmp, filename)
|
| 39 |
-
torch.save(model.state_dict(), tmp_pt)
|
| 40 |
-
with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst:
|
| 41 |
-
dst.write(src.read())
|
| 42 |
-
HfApi().upload_file(
|
| 43 |
-
path_or_fileobj=tmp_gz,
|
| 44 |
-
path_in_repo=f"checkpoints/{filename}",
|
| 45 |
-
repo_id=repo_id,
|
| 46 |
-
repo_type="model",
|
| 47 |
-
overwrite=True,
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def download_checkpoint(
|
| 52 |
-
dest_path: str,
|
| 53 |
-
*,
|
| 54 |
-
repo_id: str = REPO_ID,
|
| 55 |
-
filename: str = FILENAME,
|
| 56 |
-
) -> bool:
|
| 57 |
-
"""Download the latest checkpoint to ``dest_path``.
|
| 58 |
-
|
| 59 |
-
Returns ``True`` if the checkpoint was successfully retrieved.
|
| 60 |
-
"""
|
| 61 |
-
try:
|
| 62 |
-
buf = hf_hub_download(
|
| 63 |
-
repo_id,
|
| 64 |
-
f"checkpoints/{filename}",
|
| 65 |
-
repo_type="model",
|
| 66 |
-
force_download=True,
|
| 67 |
-
)
|
| 68 |
-
except Exception as exc: # pragma: no cover - network errors
|
| 69 |
-
print("Failed to download checkpoint", exc)
|
| 70 |
-
return False
|
| 71 |
-
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
| 72 |
-
shutil.copyfile(buf, dest_path)
|
| 73 |
-
return True
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
__all__ = ["hf_login", "save_checkpoint", "download_checkpoint"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|