PinHsuan commited on
Commit
562c23b
·
verified ·
1 Parent(s): 95626bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -18
app.py CHANGED
@@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
5
  import seaborn as sns
6
  import gradio as gr
7
  import os
8
-
9
  from model import DualStreamTransformer, ArcMarginProduct
10
 
11
  css = """
@@ -38,35 +38,34 @@ css = """
38
  }
39
  """
40
 
41
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- MODEL_PATH = "./best_model_fold_5.pt"
43
 
 
44
  model = DualStreamTransformer(n_feat1=25, n_feat2=12, d_model=32).to(DEVICE)
45
  metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
 
 
 
46
 
47
- if os.path.exists(MODEL_PATH):
48
- checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
49
- if isinstance(checkpoint, dict) and 'model' in checkpoint:
50
- model.load_state_dict(checkpoint['model'])
51
- metric_fc.load_state_dict(checkpoint['fc'])
52
- else:
53
- model.load_state_dict(checkpoint)
54
- model.eval()
55
- print("模型載入成功!")
56
 
 
 
57
 
58
  def analyze_and_predict(*all_answers):
59
- if any(a is None for a in all_answers):
60
- raise gr.Error("請完整填寫所有問卷題目!")
61
-
62
  ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
63
  osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
64
 
65
- x1 = torch.tensor([[ccmq_map[a] for a in all_answers[:25]]], dtype=torch.float32).to(DEVICE)
66
- x2 = torch.tensor([[osdi_map[a] for a in all_answers[25:]]], dtype=torch.float32).to(DEVICE)
67
 
 
 
 
 
 
 
68
  with torch.no_grad():
69
- feats = model(x1, x2)
70
  logits = metric_fc.predict(feats)
71
  probs = torch.softmax(logits, dim=1)
72
  pred_idx = torch.argmax(probs, dim=1).item()
 
5
  import seaborn as sns
6
  import gradio as gr
7
  import os
8
+ import joblib
9
  from model import DualStreamTransformer, ArcMarginProduct
10
 
11
  css = """
 
38
  }
39
  """
40
 
 
 
41
 
42
+ FOLD = 5
43
  model = DualStreamTransformer(n_feat1=25, n_feat2=12, d_model=32).to(DEVICE)
44
  metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
45
+ checkpoint = torch.load(f"best_model_fold_{FOLD}.pt", map_location=DEVICE)
46
+ model.load_state_dict(checkpoint['model'])
47
+ metric_fc.load_state_dict(checkpoint['fc'])
48
 
 
 
 
 
 
 
 
 
 
49
 
50
+ scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
51
+ scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")
52
 
53
  def analyze_and_predict(*all_answers):
54
+ # 1. 數值映射 (與訓練時的編碼一致)
 
 
55
  ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
56
  osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
57
 
58
+ x1_raw = np.array([[ccmq_map[a] for a in all_answers[:25]]])
59
+ x2_raw = np.array([[osdi_map[a] for a in all_answers[25:]]])
60
 
61
+ x1_scaled = scaler_ccmq.transform(x1_raw)
62
+ x2_scaled = scaler_osdi.transform(x2_raw)
63
+
64
+ sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
65
+ sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)
66
+
67
  with torch.no_grad():
68
+ feats = model(sx1, sx2)
69
  logits = metric_fc.predict(feats)
70
  probs = torch.softmax(logits, dim=1)
71
  pred_idx = torch.argmax(probs, dim=1).item()