WCNegentropy commited on
Commit
73b0816
·
verified ·
1 Parent(s): b19af44

Remove nested directory: BitTransformerLM/bit_transformer/hf_checkpoint.py

Browse files
BitTransformerLM/bit_transformer/hf_checkpoint.py DELETED
@@ -1,76 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import gzip
4
- import os
5
- import shutil
6
- import tempfile
7
- from typing import Optional
8
-
9
- import torch
10
- from huggingface_hub import HfApi, hf_hub_download, login
11
-
12
- REPO_ID = "architect/bittransformerlm"
13
- FILENAME = "model.pt.gz"
14
-
15
-
16
- def hf_login(token: Optional[str] = None) -> None:
17
- """Authenticate with Hugging Face.
18
-
19
- The ``token`` may be provided directly or via the ``HF_TOKEN`` environment
20
- variable. If omitted entirely, the library will attempt an interactive login.
21
- """
22
- login(token=token)
23
-
24
-
25
- def save_checkpoint(
26
- model: torch.nn.Module,
27
- *,
28
- repo_id: str = REPO_ID,
29
- filename: str = FILENAME,
30
- ) -> None:
31
- """Upload the model weights to ``repo_id`` under ``filename``.
32
-
33
- The file within the repository is overwritten each time to avoid
34
- accumulating checkpoints.
35
- """
36
- with tempfile.TemporaryDirectory() as tmp:
37
- tmp_pt = os.path.join(tmp, "model.pt")
38
- tmp_gz = os.path.join(tmp, filename)
39
- torch.save(model.state_dict(), tmp_pt)
40
- with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst:
41
- dst.write(src.read())
42
- HfApi().upload_file(
43
- path_or_fileobj=tmp_gz,
44
- path_in_repo=f"checkpoints/{filename}",
45
- repo_id=repo_id,
46
- repo_type="model",
47
- overwrite=True,
48
- )
49
-
50
-
51
- def download_checkpoint(
52
- dest_path: str,
53
- *,
54
- repo_id: str = REPO_ID,
55
- filename: str = FILENAME,
56
- ) -> bool:
57
- """Download the latest checkpoint to ``dest_path``.
58
-
59
- Returns ``True`` if the checkpoint was successfully retrieved.
60
- """
61
- try:
62
- buf = hf_hub_download(
63
- repo_id,
64
- f"checkpoints/{filename}",
65
- repo_type="model",
66
- force_download=True,
67
- )
68
- except Exception as exc: # pragma: no cover - network errors
69
- print("Failed to download checkpoint", exc)
70
- return False
71
- os.makedirs(os.path.dirname(dest_path), exist_ok=True)
72
- shutil.copyfile(buf, dest_path)
73
- return True
74
-
75
-
76
- __all__ = ["hf_login", "save_checkpoint", "download_checkpoint"]