Update gptx_tokenizer.py
Browse files- gptx_tokenizer.py +6 -9
gptx_tokenizer.py
CHANGED
|
@@ -66,9 +66,9 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
| 66 |
|
| 67 |
def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Optional[Path]:
|
| 68 |
if not os.path.isfile(config_path):
|
| 69 |
-
config_path = try_to_load_from_cache(repo_id=
|
| 70 |
if not config_path:
|
| 71 |
-
config_path = self._download_config_from_hub()
|
| 72 |
|
| 73 |
return config_path
|
| 74 |
|
|
@@ -89,19 +89,16 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
| 89 |
OSError: If the model file cannot be loaded or downloaded.
|
| 90 |
"""
|
| 91 |
if not os.path.isfile(model_file_or_name):
|
| 92 |
-
model_file_or_name = try_to_load_from_cache(repo_id=
|
| 93 |
if not model_file_or_name:
|
| 94 |
-
model_file_or_name = self._download_model_from_hub()
|
| 95 |
|
| 96 |
try:
|
| 97 |
return spm.SentencePieceProcessor(model_file=model_file_or_name)
|
| 98 |
except Exception as e:
|
| 99 |
raise OSError(f"Failed to load tokenizer model: {str(e)}")
|
| 100 |
|
| 101 |
-
def _download_model_from_hub(self) -> Optional[str]:
|
| 102 |
-
if repo_id is None:
|
| 103 |
-
raise ValueError("repo_id must be provided if model_file_or_name is not a local file")
|
| 104 |
-
|
| 105 |
try:
|
| 106 |
# List all files in the repo
|
| 107 |
repo_files = list_repo_files(repo_id)
|
|
@@ -123,7 +120,7 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
| 123 |
|
| 124 |
return model_file_or_name
|
| 125 |
|
| 126 |
-
def _download_config_from_hub(self):
|
| 127 |
if repo_id is None:
|
| 128 |
raise ValueError("repo_id must be provided if config_path is not a local file")
|
| 129 |
|
|
|
|
| 66 |
|
| 67 |
def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Optional[Path]:
|
| 68 |
if not os.path.isfile(config_path):
|
| 69 |
+
config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
|
| 70 |
if not config_path:
|
| 71 |
+
config_path = self._download_config_from_hub(repo_id=repo_id)
|
| 72 |
|
| 73 |
return config_path
|
| 74 |
|
|
|
|
| 89 |
OSError: If the model file cannot be loaded or downloaded.
|
| 90 |
"""
|
| 91 |
if not os.path.isfile(model_file_or_name):
|
| 92 |
+
model_file_or_name = try_to_load_from_cache(repo_id=repo_id, filename=Path(model_file_or_name).name)
|
| 93 |
if not model_file_or_name:
|
| 94 |
+
model_file_or_name = self._download_model_from_hub(repo_id=repo_id)
|
| 95 |
|
| 96 |
try:
|
| 97 |
return spm.SentencePieceProcessor(model_file=model_file_or_name)
|
| 98 |
except Exception as e:
|
| 99 |
raise OSError(f"Failed to load tokenizer model: {str(e)}")
|
| 100 |
|
| 101 |
+
def _download_model_from_hub(self, repo_id: str) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
| 102 |
try:
|
| 103 |
# List all files in the repo
|
| 104 |
repo_files = list_repo_files(repo_id)
|
|
|
|
| 120 |
|
| 121 |
return model_file_or_name
|
| 122 |
|
| 123 |
+
def _download_config_from_hub(self, repo_id: str):
|
| 124 |
if repo_id is None:
|
| 125 |
raise ValueError("repo_id must be provided if config_path is not a local file")
|
| 126 |
|