Xsmos commited on
Commit
fd499f6
·
verified ·
1 Parent(s): 520f1a9

Fix import error and add source_files to config

Browse files
Files changed (1) hide show
  1. foundation_bert.py +16 -7
foundation_bert.py CHANGED
@@ -141,13 +141,22 @@ class FoundationBert(BertModel):
141
  use_safetensors: bool = None,
142
  **kwargs,
143
  ):
144
- """
145
- Modification to correctly handle loading extraneous parameters for GBert
146
- """
147
-
148
- current_dir = os.path.dirname(os.path.abspath(__file__))
149
- model_config = os.path.join(current_dir, 'train_config.yaml')
150
- print(f"🆘 DEBUG: Attempting to load config from: {model_config}")
 
 
 
 
 
 
 
 
 
151
  with open(model_config, 'r') as f:
152
  config = yaml.load(f, Loader=MyLoader)
153
 
 
141
  use_safetensors: bool = None,
142
  **kwargs,
143
  ):
144
+ from huggingface_hub import hf_hub_download
145
+
146
+ # 1. 如果是远程加载,pretrained_model_name_or_path 就是 REPO_ID
147
+ # 我们显式地请求下载 train_config.yaml
148
+ try:
149
+ # 这一步会检查缓存,如果没有则从云端下载并返回本地绝对路径
150
+ config_file_path = hf_hub_download(
151
+ repo_id=pretrained_model_name_or_path,
152
+ filename="train_config.yaml",
153
+ revision=kwargs.get("revision", "main")
154
+ )
155
+ except Exception as e:
156
+ # 备选方案:如果本地路径已存在(例如 Snigdaa 的用法)
157
+ model_config = os.path.join(pretrained_model_name_or_path, "train_config.yaml")
158
+
159
+ print(f"✅ Successfully located config at: {config_file_path}")
160
  with open(model_config, 'r') as f:
161
  config = yaml.load(f, Loader=MyLoader)
162