vijusudhi commited on
Commit
db67351
·
verified ·
1 Parent(s): 80e5611

Update gptx_tokenizer.py

Browse files
Files changed (1) hide show
  1. gptx_tokenizer.py +22 -22
gptx_tokenizer.py CHANGED
@@ -103,28 +103,28 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
103
  ValueError: If repo_id is not provided when model_file_or_name is not a file.
104
  OSError: If the model file cannot be loaded or downloaded.
105
  """
106
- # if not os.path.isfile(model_file_or_name):
107
- # if repo_id is None:
108
- # raise ValueError("repo_id must be provided if model_file_or_name is not a local file")
109
 
110
- # try:
111
- # # List all files in the repo
112
- # repo_files = list_repo_files(repo_id)
113
 
114
- # # Find the tokenizer model file
115
- # tokenizer_files = [f for f in repo_files if f.endswith('.model')]
116
- # if not tokenizer_files:
117
- # raise FileNotFoundError(f"No .model file found in repository {repo_id}")
118
 
119
- # # Use the first .model file found
120
- # model_file = tokenizer_files[0]
121
- # print(f"Found tokenizer model file: {model_file}")
122
 
123
- # # Download the file
124
- # model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
125
- # print(f"Downloaded tokenizer model to: {model_file_or_name}")
126
- # except Exception as e:
127
- # raise OSError(f"Failed to download tokenizer model: {str(e)}")
128
 
129
  try:
130
  return spm.SentencePieceProcessor(model_file=model_file_or_name)
@@ -182,10 +182,10 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
182
  if config_path is None:
183
  config_path = str(Path(cp_path) / TOKENIZER_CONFIG_FILE)
184
 
185
- # if os.path.isfile(config_path):
186
- self.tokenizer_config = self.load_json(Path(config_path))
187
- # else: # Load from repo
188
- # self.tokenizer_config = self.load_json(Path(self.find_tokenizer_config(Path(config_path), repo_id=REPO_ID)))
189
 
190
  @property
191
  def vocab_size(self) -> int:
 
103
  ValueError: If repo_id is not provided when model_file_or_name is not a file.
104
  OSError: If the model file cannot be loaded or downloaded.
105
  """
106
+ if not os.path.isfile(model_file_or_name):
107
+ if repo_id is None:
108
+ raise ValueError("repo_id must be provided if model_file_or_name is not a local file")
109
 
110
+ try:
111
+ # List all files in the repo
112
+ repo_files = list_repo_files(repo_id)
113
 
114
+ # Find the tokenizer model file
115
+ tokenizer_files = [f for f in repo_files if f.endswith('.model')]
116
+ if not tokenizer_files:
117
+ raise FileNotFoundError(f"No .model file found in repository {repo_id}")
118
 
119
+ # Use the first .model file found
120
+ model_file = tokenizer_files[0]
121
+ print(f"Found tokenizer model file: {model_file}")
122
 
123
+ # Download the file
124
+ model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
125
+ print(f"Downloaded tokenizer model to: {model_file_or_name}")
126
+ except Exception as e:
127
+ raise OSError(f"Failed to download tokenizer model: {str(e)}")
128
 
129
  try:
130
  return spm.SentencePieceProcessor(model_file=model_file_or_name)
 
182
  if config_path is None:
183
  config_path = str(Path(cp_path) / TOKENIZER_CONFIG_FILE)
184
 
185
+ if os.path.isfile(config_path):
186
+ self.tokenizer_config = self.load_json(Path(config_path))
187
+ else: # Load from repo
188
+ self.tokenizer_config = self.load_json(Path(self.find_tokenizer_config(Path(config_path), repo_id=REPO_ID)))
189
 
190
  @property
191
  def vocab_size(self) -> int: