Kh0128 commited on
Commit
8e76b9f
·
verified ·
1 Parent(s): 9973cbf

Update output.py

Browse files
Files changed (1) hide show
  1. output.py +6 -16
output.py CHANGED
@@ -288,22 +288,12 @@ class AphasiaInferenceSystem:
288
 
289
  self.model = StableAphasiaClassifier(self.config, self.num_labels)
290
  self.model.bert.resize_token_embeddings(len(self.tokenizer))
291
-
292
- sft_path = os.path.join(self.model_dir, "model.safetensors")
293
- bin_path = os.path.join(self.model_dir, "pytorch_model.bin")
294
-
295
- if os.path.exists(sft_path):
296
- # safetensors 不使用 pickle,較安全
297
- state = load_file(sft_path) # 讀成 state_dict(CPU)
298
- missing, unexpected = self.model.load_state_dict(state, strict=False)
299
- elif os.path.exists(bin_path):
300
- state = torch.load(bin_path, map_location="cpu") # 先載到 CPU,再搬到裝置
301
- missing, unexpected = self.model.load_state_dict(state, strict=False)
302
- else:
303
- raise FileNotFoundError(f"找不到模型權重:{sft_path} 或 {bin_path}")
304
-
305
- if missing or unexpected:
306
- print(f"[load_state_dict] missing={missing}, unexpected={unexpected}")
307
 
308
  self.model.to(self.device)
309
  self.model.eval()
 
288
 
289
  self.model = StableAphasiaClassifier(self.config, self.num_labels)
290
  self.model.bert.resize_token_embeddings(len(self.tokenizer))
291
+
292
+ model_path = os.path.join(self.model_dir, "pytorch_model.safetensors")
293
+ if not os.path.exists(model_path):
294
+ raise FileNotFoundError(f"模型權重文件不存在: {model_path}")
295
+ state = torch.load(model_path, map_location=self.device)
296
+ self.model.load_state_dict(state) # (once)
 
 
 
 
 
 
 
 
 
 
297
 
298
  self.model.to(self.device)
299
  self.model.eval()