Xsmos commited on
Commit
86ec786
·
verified ·
1 Parent(s): 20aa86a

Fix remote train_config loading

Browse files
Files changed (1) hide show
  1. foundation_bert.py +16 -4
foundation_bert.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
  # from ..utils.yaml_util import MyLoader
8
  from dataclasses import dataclass
9
  from transformers import ModernBertModel, ModernBertConfig, PretrainedConfig
 
10
  from typing import Optional, Union
11
 
12
  # import yaml
@@ -176,11 +177,22 @@ class FoundationBert(ModernBertModel):
176
  """
177
  Modification to correctly handle loading extraneous parameters for GBert
178
  """
179
- if 'checkpoint' in pretrained_model_name_or_path:
180
- model_config = Path(pretrained_model_name_or_path).parent / 'train_config.yaml'
 
 
 
181
  else:
182
- model_config = Path(pretrained_model_name_or_path) / 'train_config.yaml'
183
-
 
 
 
 
 
 
 
 
184
  with open(model_config, 'r') as f:
185
  config = yaml.load(f, Loader=MyLoader)
186
  kwargs['modalities'] = config['modalities']
 
7
  # from ..utils.yaml_util import MyLoader
8
  from dataclasses import dataclass
9
  from transformers import ModernBertModel, ModernBertConfig, PretrainedConfig
10
+ from transformers.utils import cached_file
11
  from typing import Optional, Union
12
 
13
  # import yaml
 
177
  """
178
  Modification to correctly handle loading extraneous parameters for GBert
179
  """
180
+ path = Path(pretrained_model_name_or_path)
181
+ if 'checkpoint' in str(pretrained_model_name_or_path):
182
+ model_config = path.parent / 'train_config.yaml'
183
+ elif path.is_dir():
184
+ model_config = path / 'train_config.yaml'
185
  else:
186
+ model_config = cached_file(
187
+ pretrained_model_name_or_path,
188
+ 'train_config.yaml',
189
+ cache_dir=cache_dir,
190
+ force_download=force_download,
191
+ local_files_only=local_files_only,
192
+ token=token,
193
+ revision=revision,
194
+ )
195
+
196
  with open(model_config, 'r') as f:
197
  config = yaml.load(f, Loader=MyLoader)
198
  kwargs['modalities'] = config['modalities']