Spaces:
Sleeping
Sleeping
Update output.py
Browse files
output.py
CHANGED
|
@@ -288,12 +288,22 @@ class AphasiaInferenceSystem:
|
|
| 288 |
|
| 289 |
self.model = StableAphasiaClassifier(self.config, self.num_labels)
|
| 290 |
self.model.bert.resize_token_embeddings(len(self.tokenizer))
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
self.model.to(self.device)
|
| 299 |
self.model.eval()
|
|
|
|
| 288 |
|
| 289 |
self.model = StableAphasiaClassifier(self.config, self.num_labels)
|
| 290 |
self.model.bert.resize_token_embeddings(len(self.tokenizer))
|
| 291 |
+
|
| 292 |
+
sft_path = os.path.join(self.model_dir, "model.safetensors")
|
| 293 |
+
bin_path = os.path.join(self.model_dir, "pytorch_model.bin")
|
| 294 |
+
|
| 295 |
+
if os.path.exists(sft_path):
|
| 296 |
+
# safetensors 不使用 pickle,較安全
|
| 297 |
+
state = load_file(sft_path) # 讀成 state_dict(CPU)
|
| 298 |
+
missing, unexpected = self.model.load_state_dict(state, strict=False)
|
| 299 |
+
elif os.path.exists(bin_path):
|
| 300 |
+
state = torch.load(bin_path, map_location="cpu") # 先載到 CPU,再搬到裝置
|
| 301 |
+
missing, unexpected = self.model.load_state_dict(state, strict=False)
|
| 302 |
+
else:
|
| 303 |
+
raise FileNotFoundError(f"找不到模型權重:{sft_path} 或 {bin_path}")
|
| 304 |
+
|
| 305 |
+
if missing or unexpected:
|
| 306 |
+
print(f"[load_state_dict] missing={missing}, unexpected={unexpected}")
|
| 307 |
|
| 308 |
self.model.to(self.device)
|
| 309 |
self.model.eval()
|