Update modeling_hubert_kmeans.py
Browse files
modeling_hubert_kmeans.py
CHANGED
|
@@ -64,7 +64,7 @@ class HubertKmeansModel(PreTrainedModel):
|
|
| 64 |
def _load_kmeans_from_repo(self):
|
| 65 |
"""Load k-means centers from km.pt in the repo root."""
|
| 66 |
import os
|
| 67 |
-
from huggingface_hub import
|
| 68 |
|
| 69 |
try:
|
| 70 |
# First try local file (for local testing)
|
|
@@ -75,8 +75,7 @@ class HubertKmeansModel(PreTrainedModel):
|
|
| 75 |
# Try to download from HF hub
|
| 76 |
repo_id = getattr(self.config, "_name_or_path", None)
|
| 77 |
if repo_id and "/" in repo_id: # Looks like a HF repo
|
| 78 |
-
|
| 79 |
-
local_path = cached_download(url)
|
| 80 |
centers = torch.load(local_path, map_location='cpu')
|
| 81 |
else:
|
| 82 |
raise FileNotFoundError("Could not find km.pt")
|
|
|
|
| 64 |
def _load_kmeans_from_repo(self):
|
| 65 |
"""Load k-means centers from km.pt in the repo root."""
|
| 66 |
import os
|
| 67 |
+
from huggingface_hub import hf_hub_download
|
| 68 |
|
| 69 |
try:
|
| 70 |
# First try local file (for local testing)
|
|
|
|
| 75 |
# Try to download from HF hub
|
| 76 |
repo_id = getattr(self.config, "_name_or_path", None)
|
| 77 |
if repo_id and "/" in repo_id: # Looks like a HF repo
|
| 78 |
+
local_path = hf_hub_download(repo_id=repo_id, filename="km.pt")
|
|
|
|
| 79 |
centers = torch.load(local_path, map_location='cpu')
|
| 80 |
else:
|
| 81 |
raise FileNotFoundError("Could not find km.pt")
|