Xsmos commited on
Commit
989ce21
·
verified ·
1 Parent(s): 18c6acf

Fix import error and add source_files to config

Browse files
Files changed (2) hide show
  1. foundation_bert.py +0 -4
  2. test_load.py +45 -0
foundation_bert.py CHANGED
@@ -143,17 +143,13 @@ class FoundationBert(BertModel):
143
  ):
144
  from huggingface_hub import hf_hub_download
145
 
146
- # 1. 如果是远程加载,pretrained_model_name_or_path 就是 REPO_ID
147
- # 我们显式地请求下载 train_config.yaml
148
  try:
149
- # 这一步会检查缓存,如果没有则从云端下载并返回本地绝对路径
150
  model_config = hf_hub_download(
151
  repo_id=pretrained_model_name_or_path,
152
  filename="train_config.yaml",
153
  revision=kwargs.get("revision", "main")
154
  )
155
  except Exception as e:
156
- # 备选方案:如果本地路径已存在(例如 Snigdaa 的用法)
157
  model_config = os.path.join(pretrained_model_name_or_path, "train_config.yaml")
158
 
159
  # print(f"✅ Successfully located config at: {model_config}")
 
143
  ):
144
  from huggingface_hub import hf_hub_download
145
 
 
 
146
  try:
 
147
  model_config = hf_hub_download(
148
  repo_id=pretrained_model_name_or_path,
149
  filename="train_config.yaml",
150
  revision=kwargs.get("revision", "main")
151
  )
152
  except Exception as e:
 
153
  model_config = os.path.join(pretrained_model_name_or_path, "train_config.yaml")
154
 
155
  # print(f"✅ Successfully located config at: {model_config}")
test_load.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from transformers import AutoModel, AutoConfig
4
+
5
+ # Repository ID on Hugging Face Hub
6
+ REPO_ID = "StarNetLaboratory/mosaic"
7
+
8
+ # Since custom modeling code (foundation_bert.py) is used,
9
+ # trust_remote_code must be set to True.
10
+ TRUST_CODE = True
11
+
12
+ print(f"--- 1. Attempting to load configuration ---")
13
+ try:
14
+ # Attempt to load config to verify config.json is present and readable
15
+ config = AutoConfig.from_pretrained(REPO_ID, trust_remote_code=TRUST_CODE)
16
+ print(f"✅ Config loaded successfully: {config.architectures}")
17
+ except Exception as e:
18
+ print(f"❌ Config loading failed: {e}")
19
+ sys.exit(1)
20
+
21
+ print(f"\n--- 2. Attempting to load model ---")
22
+ try:
23
+ # Detect device (GPU if available, else CPU)
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ # This call triggers transformers to download and execute foundation_bert.py
27
+ # and load weights from model.safetensors
28
+ model = AutoModel.from_pretrained(
29
+ REPO_ID,
30
+ config=config,
31
+ trust_remote_code=TRUST_CODE,
32
+ torch_dtype=torch.float32 # Match the dtype used during training/local testing
33
+ ).to(device)
34
+
35
+ model.eval()
36
+
37
+ # Calculate and print total parameters to verify the architecture
38
+ total_params = sum(p.numel() for p in model.parameters())
39
+ print(f"✅ Model loaded successfully! Total parameters: {total_params:,}")
40
+
41
+ except Exception as e:
42
+ print(f"❌ Model loading failed.")
43
+ print(f"Check file integrity and remote code logic (foundation_bert.py).")
44
+ print(f"Error details: {e}")
45
+ sys.exit(1)