Kh0128 commited on
Commit
3dd57da
·
verified ·
1 Parent(s): 8bd4a37

Update output.py

Browse files
Files changed (1) hide show
  1. output.py +16 -6
output.py CHANGED
@@ -288,12 +288,22 @@ class AphasiaInferenceSystem:
288
 
289
  self.model = StableAphasiaClassifier(self.config, self.num_labels)
290
  self.model.bert.resize_token_embeddings(len(self.tokenizer))
291
-
292
- model_path = os.path.join(self.model_dir, "pytorch_model.bin")
293
- if not os.path.exists(model_path):
294
- raise FileNotFoundError(f"模型權重文件不存在: {model_path}")
295
- state = torch.load(model_path, map_location=self.device)
296
- self.model.load_state_dict(state) # (once)
 
 
 
 
 
 
 
 
 
 
297
 
298
  self.model.to(self.device)
299
  self.model.eval()
 
288
 
289
  self.model = StableAphasiaClassifier(self.config, self.num_labels)
290
  self.model.bert.resize_token_embeddings(len(self.tokenizer))
291
+
292
+ sft_path = os.path.join(self.model_dir, "model.safetensors")
293
+ bin_path = os.path.join(self.model_dir, "pytorch_model.bin")
294
+
295
+ if os.path.exists(sft_path):
296
+ # safetensors 不使用 pickle,較安全
297
+ state = load_file(sft_path) # 讀成 state_dict(CPU)
298
+ missing, unexpected = self.model.load_state_dict(state, strict=False)
299
+ elif os.path.exists(bin_path):
300
+ state = torch.load(bin_path, map_location="cpu") # 先載到 CPU,再搬到裝置
301
+ missing, unexpected = self.model.load_state_dict(state, strict=False)
302
+ else:
303
+ raise FileNotFoundError(f"找不到模型權重:{sft_path} 或 {bin_path}")
304
+
305
+ if missing or unexpected:
306
+ print(f"[load_state_dict] missing={missing}, unexpected={unexpected}")
307
 
308
  self.model.to(self.device)
309
  self.model.eval()