Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
|
|
| 5 |
import seaborn as sns
|
| 6 |
import gradio as gr
|
| 7 |
import os
|
| 8 |
-
|
| 9 |
from model import DualStreamTransformer, ArcMarginProduct
|
| 10 |
|
| 11 |
css = """
|
|
@@ -38,35 +38,34 @@ css = """
|
|
| 38 |
}
|
| 39 |
"""
|
| 40 |
|
| 41 |
-
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
-
MODEL_PATH = "./best_model_fold_5.pt"
|
| 43 |
|
|
|
|
| 44 |
model = DualStreamTransformer(n_feat1=25, n_feat2=12, d_model=32).to(DEVICE)
|
| 45 |
metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
if os.path.exists(MODEL_PATH):
|
| 48 |
-
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
|
| 49 |
-
if isinstance(checkpoint, dict) and 'model' in checkpoint:
|
| 50 |
-
model.load_state_dict(checkpoint['model'])
|
| 51 |
-
metric_fc.load_state_dict(checkpoint['fc'])
|
| 52 |
-
else:
|
| 53 |
-
model.load_state_dict(checkpoint)
|
| 54 |
-
model.eval()
|
| 55 |
-
print("模型載入成功!")
|
| 56 |
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def analyze_and_predict(*all_answers):
|
| 59 |
-
|
| 60 |
-
raise gr.Error("請完整填寫所有問卷題目!")
|
| 61 |
-
|
| 62 |
ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
|
| 63 |
osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
with torch.no_grad():
|
| 69 |
-
feats = model(
|
| 70 |
logits = metric_fc.predict(feats)
|
| 71 |
probs = torch.softmax(logits, dim=1)
|
| 72 |
pred_idx = torch.argmax(probs, dim=1).item()
|
|
|
|
| 5 |
import seaborn as sns
|
| 6 |
import gradio as gr
|
| 7 |
import os
|
| 8 |
+
import joblib
|
| 9 |
from model import DualStreamTransformer, ArcMarginProduct
|
| 10 |
|
| 11 |
css = """
|
|
|
|
| 38 |
}
|
| 39 |
"""
|
| 40 |
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
FOLD = 5
|
| 43 |
model = DualStreamTransformer(n_feat1=25, n_feat2=12, d_model=32).to(DEVICE)
|
| 44 |
metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
|
| 45 |
+
checkpoint = torch.load(f"best_model_fold_{FOLD}.pt", map_location=DEVICE)
|
| 46 |
+
model.load_state_dict(checkpoint['model'])
|
| 47 |
+
metric_fc.load_state_dict(checkpoint['fc'])
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
|
| 51 |
+
scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")
|
| 52 |
|
| 53 |
def analyze_and_predict(*all_answers):
|
| 54 |
+
# 1. 數值映射 (與訓練時的編碼一致)
|
|
|
|
|
|
|
| 55 |
ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
|
| 56 |
osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
|
| 57 |
|
| 58 |
+
x1_raw = np.array([[ccmq_map[a] for a in all_answers[:25]]])
|
| 59 |
+
x2_raw = np.array([[osdi_map[a] for a in all_answers[25:]]])
|
| 60 |
|
| 61 |
+
x1_scaled = scaler_ccmq.transform(x1_raw)
|
| 62 |
+
x2_scaled = scaler_osdi.transform(x2_raw)
|
| 63 |
+
|
| 64 |
+
sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
|
| 65 |
+
sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)
|
| 66 |
+
|
| 67 |
with torch.no_grad():
|
| 68 |
+
feats = model(sx1, sx2)
|
| 69 |
logits = metric_fc.predict(feats)
|
| 70 |
probs = torch.softmax(logits, dim=1)
|
| 71 |
pred_idx = torch.argmax(probs, dim=1).item()
|