Spaces:
Sleeping
Sleeping
Upload 8 files
Browse files- .streamlitconfig.toml.txt +14 -0
- BC_imputed_micerf_period13_fid_course_D4.csv +0 -0
- README.md +238 -10
- app.py +462 -0
- bn_core.py +410 -0
- llm_assistant.py +265 -0
- requirements.txt +9 -0
- utils.py +387 -0
.streamlitconfig.toml.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[theme]
|
| 2 |
+
primaryColor = "#2d6ca2"
|
| 3 |
+
backgroundColor = "#e8f1f8"
|
| 4 |
+
secondaryBackgroundColor = "#f0f7fc"
|
| 5 |
+
textColor = "#2b3a67"
|
| 6 |
+
font = "sans serif"
|
| 7 |
+
|
| 8 |
+
[server]
|
| 9 |
+
maxUploadSize = 200
|
| 10 |
+
enableCORS = false
|
| 11 |
+
enableXsrfProtection = true
|
| 12 |
+
|
| 13 |
+
[browser]
|
| 14 |
+
gatherUsageStats = false
|
BC_imputed_micerf_period13_fid_course_D4.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
CHANGED
|
@@ -1,12 +1,240 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
| 1 |
+
# 🔬 Bayesian Network Analysis System
|
| 2 |
+
|
| 3 |
+
一個完整的貝葉斯網路分析系統,整合 AI 助手協助解讀分析結果。
|
| 4 |
+
|
| 5 |
+
## ✨ 主要功能
|
| 6 |
+
|
| 7 |
+
### 1. 貝葉斯網路分析
|
| 8 |
+
- ✅ 多種結構學習演算法: NB, TAN, CL, Hill Climbing, PC
|
| 9 |
+
- ✅ 自動特徵識別與處理(分類/連續變數)
|
| 10 |
+
- ✅ 完整的模型評估指標
|
| 11 |
+
- ✅ 互動式網路結構視覺化
|
| 12 |
+
- ✅ 條件機率表查詢
|
| 13 |
+
- ✅ ROC 曲線與混淆矩陣
|
| 14 |
+
|
| 15 |
+
### 2. AI 問答助手
|
| 16 |
+
- ✅ 自動生成分析總結
|
| 17 |
+
- ✅ 解釋模型指標
|
| 18 |
+
- ✅ 提供改進建議
|
| 19 |
+
- ✅ 解析網路結構
|
| 20 |
+
- ✅ 支援多輪對話
|
| 21 |
+
|
| 22 |
+
### 3. 多用戶支援
|
| 23 |
+
- ✅ Session 隔離
|
| 24 |
+
- ✅ 線程安全
|
| 25 |
+
- ✅ 獨立的分析結果儲存
|
| 26 |
+
|
| 27 |
+
## 🚀 快速開始
|
| 28 |
+
|
| 29 |
+
### 本地運行
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
# 1. 克隆專案
|
| 33 |
+
git clone <your-repo-url>
|
| 34 |
+
cd bayesian-network-app
|
| 35 |
+
|
| 36 |
+
# 2. 安裝依賴
|
| 37 |
+
pip install -r requirements.txt
|
| 38 |
+
|
| 39 |
+
# 3. 放置預設資料集
|
| 40 |
+
# 將 BC_imputed_micerf_period13_fid_course_D4.csv 放在根目錄
|
| 41 |
+
|
| 42 |
+
# 4. 運行應用
|
| 43 |
+
streamlit run app.py
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### 部署到 Hugging Face Spaces
|
| 47 |
+
|
| 48 |
+
1. **創建新的 Space**
|
| 49 |
+
- 前往 https://huggingface.co/spaces
|
| 50 |
+
- 點擊 "Create new Space"
|
| 51 |
+
- 選擇 Streamlit SDK
|
| 52 |
+
- 設定 Space 名稱
|
| 53 |
+
|
| 54 |
+
2. **上傳檔案**
|
| 55 |
+
```
|
| 56 |
+
your-space/
|
| 57 |
+
├── app.py
|
| 58 |
+
├── bn_core.py
|
| 59 |
+
├── llm_assistant.py
|
| 60 |
+
├── utils.py
|
| 61 |
+
├── requirements.txt
|
| 62 |
+
├── BC_imputed_micerf_period13_fid_course_D4.csv (optional)
|
| 63 |
+
└── README.md
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
3. **配置 Space Settings**
|
| 67 |
+
- SDK: Streamlit
|
| 68 |
+
- Python version: 3.10
|
| 69 |
+
- Hardware: CPU Basic (免費) 或 升級硬體以獲得更好效能
|
| 70 |
+
|
| 71 |
+
4. **推送到 Hugging Face**
|
| 72 |
+
```bash
|
| 73 |
+
git add .
|
| 74 |
+
git commit -m "Initial commit"
|
| 75 |
+
git push
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## 📋 檔案結構
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
├── app.py # 主應用程式(Streamlit 介面)
|
| 82 |
+
├── bn_core.py # 貝葉斯網路核心邏輯
|
| 83 |
+
├── llm_assistant.py # LLM 問答助手
|
| 84 |
+
├── utils.py # 工具函數(視覺化、資料處理)
|
| 85 |
+
├── requirements.txt # Python 套件依賴
|
| 86 |
+
├── README.md # 說明文件
|
| 87 |
+
└── BC_imputed_micerf_period13_fid_course_D4.csv # 預設資料集(optional)
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## 🎯 使用說明
|
| 91 |
+
|
| 92 |
+
### 1. 設定 API Key
|
| 93 |
+
- 在側邊欄輸入您的 OpenAI API Key
|
| 94 |
+
- API Key 僅在當前 session 有效,不會被儲存
|
| 95 |
+
|
| 96 |
+
### 2. 選擇資料來源
|
| 97 |
+
- **使用預設資料集**: 使用內建的乳癌資料集
|
| 98 |
+
- **上傳自己的資料**: 支援 CSV 格式
|
| 99 |
+
|
| 100 |
+
### 3. 配置模型
|
| 101 |
+
- 選擇分類特徵和連續特徵
|
| 102 |
+
- 選擇目標變數(必須是二元分類)
|
| 103 |
+
- 設定模型參數:
|
| 104 |
+
- 測試集比例
|
| 105 |
+
- 網路結構學習演算法
|
| 106 |
+
- 參數估計方法
|
| 107 |
+
- 其他超參數
|
| 108 |
+
|
| 109 |
+
### 4. 執行分析
|
| 110 |
+
- 點擊 "Run Analysis" 開始訓練
|
| 111 |
+
- 等待分析完成(通常 10-60 秒)
|
| 112 |
+
- 查看結果:
|
| 113 |
+
- 網路結構圖
|
| 114 |
+
- 效能指標
|
| 115 |
+
- 混淆矩陣
|
| 116 |
+
- ROC 曲線
|
| 117 |
+
- 條件機率表
|
| 118 |
+
|
| 119 |
+
### 5. 使用 AI 助手
|
| 120 |
+
- 切換到 "AI Assistant" 標籤
|
| 121 |
+
- 詢問關於分析結果的問題
|
| 122 |
+
- 使用快速問題按鈕獲取常見資訊:
|
| 123 |
+
- 📊 分析總結
|
| 124 |
+
- 🎯 效能評估
|
| 125 |
+
- 🔍 結構解釋
|
| 126 |
+
- ⚠️ 限制說明
|
| 127 |
+
- 💡 改進建議
|
| 128 |
+
|
| 129 |
+
## 🔧 技術架構
|
| 130 |
+
|
| 131 |
+
### 後端
|
| 132 |
+
- **pgmpy**: 貝葉斯網路建模與推論
|
| 133 |
+
- **scikit-learn**: 資料分割與評估指標
|
| 134 |
+
- **pandas/numpy**: 資料處理
|
| 135 |
+
|
| 136 |
+
### 前端
|
| 137 |
+
- **Streamlit**: Web 應用框架
|
| 138 |
+
- **Plotly**: 互動式視覺化
|
| 139 |
+
|
| 140 |
+
### AI 整合
|
| 141 |
+
- **OpenAI GPT-4**: 問答助手
|
| 142 |
+
- 自定義提示詞工程
|
| 143 |
+
- 上下文管理
|
| 144 |
+
|
| 145 |
+
### 多用戶支援
|
| 146 |
+
- Session 隔離機制
|
| 147 |
+
- 線程鎖確保資料一致性
|
| 148 |
+
- 獨立的結果儲存空間
|
| 149 |
+
|
| 150 |
+
## 📊 支援的演算法
|
| 151 |
+
|
| 152 |
+
| 演算法 | 說明 | 適用場景 |
|
| 153 |
+
|--------|------|----------|
|
| 154 |
+
| **NB** | Naive Bayes | 快速、簡單,適合初步分析 |
|
| 155 |
+
| **TAN** | Tree-Augmented Naive Bayes | 比 NB 更靈活,保留樹狀結構 |
|
| 156 |
+
| **CL** | Chow-Liu Tree | 學習最佳樹狀結構 |
|
| 157 |
+
| **HC** | Hill Climbing | 探索更複雜的結構,需選擇評分方法 |
|
| 158 |
+
| **PC** | PC Algorithm | 基於條件獨立性測試,需設定顯著性水準 |
|
| 159 |
+
|
| 160 |
+
## 📈 評估指標說明
|
| 161 |
+
|
| 162 |
+
- **Accuracy**: 整體準確率
|
| 163 |
+
- **Precision**: 精確率(預測為正的樣本中實際為正的比例)
|
| 164 |
+
- **Recall**: 召回率(實際為正的樣本中被正確預測的比例)
|
| 165 |
+
- **F1-Score**: Precision 和 Recall 的調和平均
|
| 166 |
+
- **AUC**: ROC 曲線下面積
|
| 167 |
+
- **G-mean**: 幾何平均數(適合不平衡資料)
|
| 168 |
+
- **P-mean**: 另一種平衡指標
|
| 169 |
+
- **Specificity**: 特異性(實際為負的樣本中被正確預測的比例)
|
| 170 |
+
|
| 171 |
+
## ⚠️ 注意事項
|
| 172 |
+
|
| 173 |
+
### 資料要求
|
| 174 |
+
1. CSV 格式
|
| 175 |
+
2. 目標變數必須是二元分類(0/1 或類似)
|
| 176 |
+
3. 分類特徵不可有超過 40 個唯一值
|
| 177 |
+
4. 避免過多缺失值
|
| 178 |
+
|
| 179 |
+
### API Key 安全
|
| 180 |
+
- API Key 僅儲存在 session state
|
| 181 |
+
- 不會被記錄或上傳
|
| 182 |
+
- 每個用戶使用自己的 API Key
|
| 183 |
+
|
| 184 |
+
### 效能考量
|
| 185 |
+
- 大型資料集(>10000 rows)可能需要較長時間
|
| 186 |
+
- PC 演算法比其他演算法慢
|
| 187 |
+
- 建議先用小樣本測試
|
| 188 |
+
|
| 189 |
+
## 🐛 常見問題
|
| 190 |
+
|
| 191 |
+
### Q1: 訓練失敗怎麼辦?
|
| 192 |
+
**A**: 檢查:
|
| 193 |
+
- 是否有過多缺失值
|
| 194 |
+
- 分類特徵是否有過多唯一值
|
| 195 |
+
- 目標變數是否為二元分類
|
| 196 |
+
- 嘗試更換演算法
|
| 197 |
+
|
| 198 |
+
### Q2: AI 助手無法回應?
|
| 199 |
+
**A**: 確認:
|
| 200 |
+
- OpenAI API Key 是否正確
|
| 201 |
+
- 是否有網路連線
|
| 202 |
+
- API Key 是否有額度
|
| 203 |
+
|
| 204 |
+
### Q3: 如何改善模型效能?
|
| 205 |
+
**A**: 嘗試:
|
| 206 |
+
- 特徵工程(創造新特徵)
|
| 207 |
+
- 調整連續變數的分箱數量
|
| 208 |
+
- 嘗試不同的演算法
|
| 209 |
+
- 使用貝葉斯估計器並調整 equivalent_sample_size
|
| 210 |
+
|
| 211 |
+
### Q4: 多用戶同時使用會衝突嗎?
|
| 212 |
+
**A**: 不會,系統使用:
|
| 213 |
+
- 唯一的 session_id 區分用戶
|
| 214 |
+
- 線程鎖保護共享資源
|
| 215 |
+
- 獨立的結果儲存空間
|
| 216 |
+
|
| 217 |
+
## 🔄 更新日誌
|
| 218 |
+
|
| 219 |
+
### v1.0.0 (2025-01)
|
| 220 |
+
- ✅ 初始版本發布
|
| 221 |
+
- ✅ 支援 5 種結構學習演算法
|
| 222 |
+
- ✅ 整合 OpenAI GPT-4 助手
|
| 223 |
+
- ✅ 完整的視覺化功能
|
| 224 |
+
- ✅ 多用戶支援
|
| 225 |
+
|
| 226 |
+
## 📝 授權
|
| 227 |
+
|
| 228 |
+
此專案基於原 Django 專案改寫,保留原有功能並新增 AI 助手功能。
|
| 229 |
+
|
| 230 |
+
## 🤝 貢獻
|
| 231 |
+
|
| 232 |
+
歡迎提交 Issue 和 Pull Request!
|
| 233 |
+
|
| 234 |
+
## 📧 聯絡
|
| 235 |
+
|
| 236 |
+
如有問題,請透過 GitHub Issues 聯繫。
|
| 237 |
+
|
| 238 |
---
|
| 239 |
|
| 240 |
+
**祝您使用愉快!🎉**
|
app.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import plotly.graph_objects as go
|
| 5 |
+
import plotly.express as px
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
import base64
|
| 8 |
+
import json
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import uuid
|
| 11 |
+
|
| 12 |
+
# 頁面配置
|
| 13 |
+
st.set_page_config(
|
| 14 |
+
page_title="Bayesian Network Analysis System",
|
| 15 |
+
page_icon="🔬",
|
| 16 |
+
layout="wide",
|
| 17 |
+
initial_sidebar_state="expanded"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# 導入自定義模組
|
| 21 |
+
from bn_core import BayesianNetworkAnalyzer
|
| 22 |
+
from llm_assistant import LLMAssistant
|
| 23 |
+
from utils import (
|
| 24 |
+
plot_roc_curve,
|
| 25 |
+
plot_confusion_matrix,
|
| 26 |
+
plot_probability_distribution,
|
| 27 |
+
generate_network_graph,
|
| 28 |
+
create_cpd_table
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# 初始化 session state
|
| 32 |
+
if 'session_id' not in st.session_state:
|
| 33 |
+
st.session_state.session_id = str(uuid.uuid4())
|
| 34 |
+
if 'analysis_results' not in st.session_state:
|
| 35 |
+
st.session_state.analysis_results = None
|
| 36 |
+
if 'chat_history' not in st.session_state:
|
| 37 |
+
st.session_state.chat_history = []
|
| 38 |
+
if 'model_trained' not in st.session_state:
|
| 39 |
+
st.session_state.model_trained = False
|
| 40 |
+
|
| 41 |
+
# 標題
|
| 42 |
+
st.title("🔬 Bayesian Network Analysis System")
|
| 43 |
+
st.markdown("---")
|
| 44 |
+
|
| 45 |
+
# Sidebar - OpenAI API Key
|
| 46 |
+
with st.sidebar:
|
| 47 |
+
st.header("⚙️ Configuration")
|
| 48 |
+
|
| 49 |
+
api_key = st.text_input(
|
| 50 |
+
"OpenAI API Key",
|
| 51 |
+
type="password",
|
| 52 |
+
help="Enter your OpenAI API key to use the AI assistant"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if api_key:
|
| 56 |
+
st.session_state.api_key = api_key
|
| 57 |
+
st.success("✅ API Key loaded")
|
| 58 |
+
|
| 59 |
+
st.markdown("---")
|
| 60 |
+
|
| 61 |
+
# 資料來源選擇
|
| 62 |
+
st.subheader("📊 Data Source")
|
| 63 |
+
data_source = st.radio(
|
| 64 |
+
"Select data source:",
|
| 65 |
+
["Use Default Dataset", "Upload Your Data"]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
uploaded_file = None
|
| 69 |
+
if data_source == "Upload Your Data":
|
| 70 |
+
uploaded_file = st.file_uploader(
|
| 71 |
+
"Upload CSV file",
|
| 72 |
+
type=['csv'],
|
| 73 |
+
help="Upload your dataset in CSV format"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# 主要內容區
|
| 77 |
+
tab1, tab2 = st.tabs(["📈 Analysis", "💬 AI Assistant"])
|
| 78 |
+
|
| 79 |
+
# Tab 1: 分析介面
|
| 80 |
+
with tab1:
|
| 81 |
+
col1, col2 = st.columns([2, 1])
|
| 82 |
+
|
| 83 |
+
with col1:
|
| 84 |
+
st.header("Model Configuration")
|
| 85 |
+
|
| 86 |
+
# 載入資料
|
| 87 |
+
if data_source == "Use Default Dataset":
|
| 88 |
+
# 使用預設資料集
|
| 89 |
+
@st.cache_data
|
| 90 |
+
def load_default_data():
|
| 91 |
+
# 這裡放入預設資料集的路徑
|
| 92 |
+
df = pd.read_csv("BC_imputed_micerf_period13_fid_course_D4.csv")
|
| 93 |
+
return df
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
df = load_default_data()
|
| 97 |
+
st.success(f"✅ Default dataset loaded: {df.shape[0]} rows, {df.shape[1]} columns")
|
| 98 |
+
except:
|
| 99 |
+
st.error("❌ Default dataset not found. Please upload your own data.")
|
| 100 |
+
df = None
|
| 101 |
+
else:
|
| 102 |
+
if uploaded_file:
|
| 103 |
+
df = pd.read_csv(uploaded_file)
|
| 104 |
+
st.success(f"✅ Data loaded: {df.shape[0]} rows, {df.shape[1]} columns")
|
| 105 |
+
else:
|
| 106 |
+
st.info("👆 Please upload a CSV file to begin")
|
| 107 |
+
df = None
|
| 108 |
+
|
| 109 |
+
if df is not None:
|
| 110 |
+
# 特徵選擇
|
| 111 |
+
st.subheader("🎯 Feature Selection")
|
| 112 |
+
|
| 113 |
+
# 自動識別特徵類型
|
| 114 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 115 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 116 |
+
|
| 117 |
+
# 二元分類變數(用於目標變數)
|
| 118 |
+
binary_cols = [col for col in df.columns if df[col].nunique() == 2]
|
| 119 |
+
|
| 120 |
+
col_feat1, col_feat2 = st.columns(2)
|
| 121 |
+
|
| 122 |
+
with col_feat1:
|
| 123 |
+
st.markdown("**Categorical Features**")
|
| 124 |
+
cat_features = st.multiselect(
|
| 125 |
+
"Select categorical features:",
|
| 126 |
+
options=categorical_cols,
|
| 127 |
+
default=categorical_cols[:5] if len(categorical_cols) > 0 else []
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
with col_feat2:
|
| 131 |
+
st.markdown("**Continuous Features**")
|
| 132 |
+
con_features = st.multiselect(
|
| 133 |
+
"Select continuous features:",
|
| 134 |
+
options=numeric_cols,
|
| 135 |
+
default=numeric_cols[:3] if len(numeric_cols) > 0 else []
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# 目標變數
|
| 139 |
+
target_variable = st.selectbox(
|
| 140 |
+
"🎯 Target Variable (Y):",
|
| 141 |
+
options=binary_cols,
|
| 142 |
+
help="Must be a binary classification variable"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# 驗證選擇
|
| 146 |
+
selected_features = cat_features + con_features
|
| 147 |
+
if target_variable in selected_features:
|
| 148 |
+
st.error("❌ Target variable cannot be in feature list!")
|
| 149 |
+
st.stop()
|
| 150 |
+
|
| 151 |
+
st.markdown("---")
|
| 152 |
+
|
| 153 |
+
# 模型參數
|
| 154 |
+
st.subheader("⚙️ Model Parameters")
|
| 155 |
+
|
| 156 |
+
col_param1, col_param2, col_param3 = st.columns(3)
|
| 157 |
+
|
| 158 |
+
with col_param1:
|
| 159 |
+
test_fraction = st.slider(
|
| 160 |
+
"Test Dataset Proportion:",
|
| 161 |
+
min_value=0.1,
|
| 162 |
+
max_value=0.5,
|
| 163 |
+
value=0.25,
|
| 164 |
+
step=0.05
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
algorithm = st.selectbox(
|
| 168 |
+
"Network Structure:",
|
| 169 |
+
options=['NB', 'TAN', 'CL', 'HC', 'PC'],
|
| 170 |
+
format_func=lambda x: {
|
| 171 |
+
'NB': 'Naive Bayes',
|
| 172 |
+
'TAN': 'Tree-Augmented Naive Bayes',
|
| 173 |
+
'CL': 'Chow-Liu',
|
| 174 |
+
'HC': 'Hill Climbing',
|
| 175 |
+
'PC': 'PC Algorithm'
|
| 176 |
+
}[x]
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
with col_param2:
|
| 180 |
+
estimator = st.selectbox(
|
| 181 |
+
"Parameter Estimator:",
|
| 182 |
+
options=['ml', 'bn'],
|
| 183 |
+
format_func=lambda x: {
|
| 184 |
+
'ml': 'Maximum Likelihood',
|
| 185 |
+
'bn': 'Bayesian Estimator'
|
| 186 |
+
}[x]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if estimator == 'bn':
|
| 190 |
+
equivalent_sample_size = st.number_input(
|
| 191 |
+
"Equivalent Sample Size:",
|
| 192 |
+
min_value=1,
|
| 193 |
+
value=3,
|
| 194 |
+
step=1
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
equivalent_sample_size = 3
|
| 198 |
+
|
| 199 |
+
# 條件性參數
|
| 200 |
+
if algorithm == 'HC':
|
| 201 |
+
score_method = st.selectbox(
|
| 202 |
+
"Scoring Method:",
|
| 203 |
+
options=['BIC', 'AIC', 'K2', 'BDeu', 'BDs']
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
score_method = 'BIC'
|
| 207 |
+
|
| 208 |
+
with col_param3:
|
| 209 |
+
if algorithm == 'PC':
|
| 210 |
+
sig_level = st.number_input(
|
| 211 |
+
"Significance Level:",
|
| 212 |
+
min_value=0.01,
|
| 213 |
+
max_value=1.0,
|
| 214 |
+
value=0.05,
|
| 215 |
+
step=0.01
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
sig_level = 0.05
|
| 219 |
+
|
| 220 |
+
n_bins = st.number_input(
|
| 221 |
+
"Number of Bins (for continuous):",
|
| 222 |
+
min_value=3,
|
| 223 |
+
max_value=20,
|
| 224 |
+
value=10,
|
| 225 |
+
step=1
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# 執行分析按鈕
|
| 229 |
+
st.markdown("---")
|
| 230 |
+
|
| 231 |
+
if st.button("🚀 Run Analysis", type="primary", use_container_width=True):
|
| 232 |
+
with st.spinner("🔄 Training Bayesian Network..."):
|
| 233 |
+
try:
|
| 234 |
+
# 初始化分析器
|
| 235 |
+
analyzer = BayesianNetworkAnalyzer(
|
| 236 |
+
session_id=st.session_state.session_id
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# 執行分析
|
| 240 |
+
results = analyzer.run_analysis(
|
| 241 |
+
df=df,
|
| 242 |
+
cat_features=cat_features,
|
| 243 |
+
con_features=con_features,
|
| 244 |
+
target_variable=target_variable,
|
| 245 |
+
test_fraction=test_fraction,
|
| 246 |
+
algorithm=algorithm,
|
| 247 |
+
estimator=estimator,
|
| 248 |
+
equivalent_sample_size=equivalent_sample_size,
|
| 249 |
+
score_method=score_method,
|
| 250 |
+
sig_level=sig_level,
|
| 251 |
+
n_bins=n_bins
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# 儲存結果
|
| 255 |
+
st.session_state.analysis_results = results
|
| 256 |
+
st.session_state.model_trained = True
|
| 257 |
+
|
| 258 |
+
st.success("✅ Analysis completed!")
|
| 259 |
+
st.rerun()
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
st.error(f"❌ Error during analysis: {str(e)}")
|
| 263 |
+
st.exception(e)
|
| 264 |
+
|
| 265 |
+
with col2:
|
| 266 |
+
st.header("Quick Stats")
|
| 267 |
+
|
| 268 |
+
if df is not None:
|
| 269 |
+
st.metric("Total Samples", df.shape[0])
|
| 270 |
+
st.metric("Total Features", df.shape[1])
|
| 271 |
+
st.metric("Selected Features", len(selected_features) if 'selected_features' in locals() else 0)
|
| 272 |
+
|
| 273 |
+
if st.session_state.model_trained:
|
| 274 |
+
st.success("✅ Model Trained")
|
| 275 |
+
else:
|
| 276 |
+
st.info("⏳ Awaiting Training")
|
| 277 |
+
|
| 278 |
+
# 顯示結果
|
| 279 |
+
if st.session_state.analysis_results:
|
| 280 |
+
st.markdown("---")
|
| 281 |
+
st.header("📊 Analysis Results")
|
| 282 |
+
|
| 283 |
+
results = st.session_state.analysis_results
|
| 284 |
+
|
| 285 |
+
# 網路結構
|
| 286 |
+
st.subheader("🕸️ Bayesian Network Structure")
|
| 287 |
+
network_fig = generate_network_graph(results['model'])
|
| 288 |
+
st.plotly_chart(network_fig, use_container_width=True)
|
| 289 |
+
|
| 290 |
+
# 效能指標
|
| 291 |
+
st.subheader("📈 Performance Metrics")
|
| 292 |
+
|
| 293 |
+
col_m1, col_m2 = st.columns(2)
|
| 294 |
+
|
| 295 |
+
with col_m1:
|
| 296 |
+
st.markdown("**Training Set**")
|
| 297 |
+
train_metrics = results['train_metrics']
|
| 298 |
+
|
| 299 |
+
metric_cols = st.columns(4)
|
| 300 |
+
metric_cols[0].metric("Accuracy", f"{train_metrics['accuracy']:.2f}%")
|
| 301 |
+
metric_cols[1].metric("Precision", f"{train_metrics['precision']:.2f}%")
|
| 302 |
+
metric_cols[2].metric("Recall", f"{train_metrics['recall']:.2f}%")
|
| 303 |
+
metric_cols[3].metric("F1-Score", f"{train_metrics['f1']:.2f}%")
|
| 304 |
+
|
| 305 |
+
# 混淆矩陣
|
| 306 |
+
conf_fig_train = plot_confusion_matrix(
|
| 307 |
+
train_metrics['confusion_matrix'],
|
| 308 |
+
title="Training Set Confusion Matrix"
|
| 309 |
+
)
|
| 310 |
+
st.plotly_chart(conf_fig_train, use_container_width=True)
|
| 311 |
+
|
| 312 |
+
# ROC Curve
|
| 313 |
+
roc_fig_train = plot_roc_curve(
|
| 314 |
+
train_metrics['fpr'],
|
| 315 |
+
train_metrics['tpr'],
|
| 316 |
+
train_metrics['auc'],
|
| 317 |
+
title="Training Set ROC Curve"
|
| 318 |
+
)
|
| 319 |
+
st.plotly_chart(roc_fig_train, use_container_width=True)
|
| 320 |
+
|
| 321 |
+
with col_m2:
|
| 322 |
+
st.markdown("**Test Set**")
|
| 323 |
+
test_metrics = results['test_metrics']
|
| 324 |
+
|
| 325 |
+
metric_cols = st.columns(4)
|
| 326 |
+
metric_cols[0].metric("Accuracy", f"{test_metrics['accuracy']:.2f}%")
|
| 327 |
+
metric_cols[1].metric("Precision", f"{test_metrics['precision']:.2f}%")
|
| 328 |
+
metric_cols[2].metric("Recall", f"{test_metrics['recall']:.2f}%")
|
| 329 |
+
metric_cols[3].metric("F1-Score", f"{test_metrics['f1']:.2f}%")
|
| 330 |
+
|
| 331 |
+
# 混淆矩陣
|
| 332 |
+
conf_fig_test = plot_confusion_matrix(
|
| 333 |
+
test_metrics['confusion_matrix'],
|
| 334 |
+
title="Test Set Confusion Matrix"
|
| 335 |
+
)
|
| 336 |
+
st.plotly_chart(conf_fig_test, use_container_width=True)
|
| 337 |
+
|
| 338 |
+
# ROC Curve
|
| 339 |
+
roc_fig_test = plot_roc_curve(
|
| 340 |
+
test_metrics['fpr'],
|
| 341 |
+
test_metrics['tpr'],
|
| 342 |
+
test_metrics['auc'],
|
| 343 |
+
title="Test Set ROC Curve"
|
| 344 |
+
)
|
| 345 |
+
st.plotly_chart(roc_fig_test, use_container_width=True)
|
| 346 |
+
|
| 347 |
+
# 條件機率表
|
| 348 |
+
st.subheader("📋 Conditional Probability Tables")
|
| 349 |
+
|
| 350 |
+
selected_node = st.selectbox(
|
| 351 |
+
"Select a node to view its CPD:",
|
| 352 |
+
options=list(results['cpds'].keys())
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if selected_node:
|
| 356 |
+
cpd_df = create_cpd_table(results['cpds'][selected_node])
|
| 357 |
+
st.dataframe(cpd_df, use_container_width=True)
|
| 358 |
+
|
| 359 |
+
# 評分指標
|
| 360 |
+
st.subheader("📊 Model Scores")
|
| 361 |
+
|
| 362 |
+
score_cols = st.columns(5)
|
| 363 |
+
scores = results['scores']
|
| 364 |
+
score_cols[0].metric("Log-Likelihood", f"{scores['log_likelihood']:.2f}")
|
| 365 |
+
score_cols[1].metric("BIC Score", f"{scores['bic']:.2f}")
|
| 366 |
+
score_cols[2].metric("K2 Score", f"{scores['k2']:.2f}")
|
| 367 |
+
score_cols[3].metric("BDeu Score", f"{scores['bdeu']:.2f}")
|
| 368 |
+
score_cols[4].metric("BDs Score", f"{scores['bds']:.2f}")
|
| 369 |
+
|
| 370 |
+
# Tab 2: AI 助手
|
| 371 |
+
with tab2:
|
| 372 |
+
st.header("💬 AI Analysis Assistant")
|
| 373 |
+
|
| 374 |
+
if not st.session_state.get('api_key'):
|
| 375 |
+
st.warning("⚠️ Please enter your OpenAI API Key in the sidebar to use the AI assistant.")
|
| 376 |
+
elif not st.session_state.model_trained:
|
| 377 |
+
st.info("ℹ️ Please train a model first in the Analysis tab to use the AI assistant.")
|
| 378 |
+
else:
|
| 379 |
+
# 初始化 LLM 助手
|
| 380 |
+
if 'llm_assistant' not in st.session_state:
|
| 381 |
+
st.session_state.llm_assistant = LLMAssistant(
|
| 382 |
+
api_key=st.session_state.api_key,
|
| 383 |
+
session_id=st.session_state.session_id
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# 顯示聊天歷史
|
| 387 |
+
chat_container = st.container()
|
| 388 |
+
|
| 389 |
+
with chat_container:
|
| 390 |
+
for message in st.session_state.chat_history:
|
| 391 |
+
with st.chat_message(message["role"]):
|
| 392 |
+
st.markdown(message["content"])
|
| 393 |
+
|
| 394 |
+
# 聊天輸入
|
| 395 |
+
if prompt := st.chat_input("Ask me anything about your analysis results..."):
|
| 396 |
+
# 添加用戶訊息
|
| 397 |
+
st.session_state.chat_history.append({
|
| 398 |
+
"role": "user",
|
| 399 |
+
"content": prompt
|
| 400 |
+
})
|
| 401 |
+
|
| 402 |
+
with st.chat_message("user"):
|
| 403 |
+
st.markdown(prompt)
|
| 404 |
+
|
| 405 |
+
# 獲取 AI 回應
|
| 406 |
+
with st.chat_message("assistant"):
|
| 407 |
+
with st.spinner("Thinking..."):
|
| 408 |
+
response = st.session_state.llm_assistant.get_response(
|
| 409 |
+
user_message=prompt,
|
| 410 |
+
analysis_results=st.session_state.analysis_results
|
| 411 |
+
)
|
| 412 |
+
st.markdown(response)
|
| 413 |
+
|
| 414 |
+
# 添加助手訊息
|
| 415 |
+
st.session_state.chat_history.append({
|
| 416 |
+
"role": "assistant",
|
| 417 |
+
"content": response
|
| 418 |
+
})
|
| 419 |
+
|
| 420 |
+
# 快速問題按鈕
|
| 421 |
+
st.markdown("---")
|
| 422 |
+
st.subheader("💡 Quick Questions")
|
| 423 |
+
|
| 424 |
+
quick_questions = [
|
| 425 |
+
"📊 Give me a summary of the analysis results",
|
| 426 |
+
"🎯 What is the model's performance?",
|
| 427 |
+
"🔍 Explain the Bayesian Network structure",
|
| 428 |
+
"⚠️ What are the limitations of this model?",
|
| 429 |
+
"💡 How can I improve the model?"
|
| 430 |
+
]
|
| 431 |
+
|
| 432 |
+
cols = st.columns(len(quick_questions))
|
| 433 |
+
for idx, (col, question) in enumerate(zip(cols, quick_questions)):
|
| 434 |
+
if col.button(question, key=f"quick_{idx}"):
|
| 435 |
+
st.session_state.chat_history.append({
|
| 436 |
+
"role": "user",
|
| 437 |
+
"content": question
|
| 438 |
+
})
|
| 439 |
+
|
| 440 |
+
response = st.session_state.llm_assistant.get_response(
|
| 441 |
+
user_message=question,
|
| 442 |
+
analysis_results=st.session_state.analysis_results
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
st.session_state.chat_history.append({
|
| 446 |
+
"role": "assistant",
|
| 447 |
+
"content": response
|
| 448 |
+
})
|
| 449 |
+
|
| 450 |
+
st.rerun()
|
| 451 |
+
|
| 452 |
+
# Footer
|
| 453 |
+
st.markdown("---")
|
| 454 |
+
st.markdown(
|
| 455 |
+
"""
|
| 456 |
+
<div style='text-align: center'>
|
| 457 |
+
<p>🔬 Bayesian Network Analysis System | Built with Streamlit</p>
|
| 458 |
+
<p>Powered by OpenAI GPT-4 | Session ID: {}</p>
|
| 459 |
+
</div>
|
| 460 |
+
""".format(st.session_state.session_id[:8]),
|
| 461 |
+
unsafe_allow_html=True
|
| 462 |
+
)
|
bn_core.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pgmpy.models import BayesianNetwork
|
| 4 |
+
from pgmpy.estimators import (
|
| 5 |
+
TreeSearch, HillClimbSearch, PC,
|
| 6 |
+
MaximumLikelihoodEstimator, BayesianEstimator,
|
| 7 |
+
BicScore, AICScore, K2Score, BDeuScore, BDsScore
|
| 8 |
+
)
|
| 9 |
+
from pgmpy.inference import VariableElimination
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
from sklearn.metrics import (
|
| 12 |
+
confusion_matrix, accuracy_score, precision_score,
|
| 13 |
+
recall_score, f1_score, roc_curve, roc_auc_score
|
| 14 |
+
)
|
| 15 |
+
from pgmpy.metrics import log_likelihood_score, structure_score
|
| 16 |
+
import threading
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
|
| 19 |
+
class BayesianNetworkAnalyzer:
|
| 20 |
+
"""
|
| 21 |
+
貝葉斯網路分析器
|
| 22 |
+
支援多用戶同時使用,每個 session 獨立處理
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# 類級別的鎖,用於線程安全
|
| 26 |
+
_lock = threading.Lock()
|
| 27 |
+
|
| 28 |
+
# 儲存各 session 的分析結果
|
| 29 |
+
_session_results = {}
|
| 30 |
+
|
| 31 |
+
def __init__(self, session_id):
|
| 32 |
+
"""
|
| 33 |
+
初始化分析器
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
session_id: 唯一的 session 識別碼
|
| 37 |
+
"""
|
| 38 |
+
self.session_id = session_id
|
| 39 |
+
self.model = None
|
| 40 |
+
self.inference = None
|
| 41 |
+
self.train_data = None
|
| 42 |
+
self.test_data = None
|
| 43 |
+
self.bins_dict = {}
|
| 44 |
+
|
| 45 |
+
def run_analysis(self, df, cat_features, con_features, target_variable,
|
| 46 |
+
test_fraction=0.25, algorithm='NB', estimator='ml',
|
| 47 |
+
equivalent_sample_size=3, score_method='BIC',
|
| 48 |
+
sig_level=0.05, n_bins=10):
|
| 49 |
+
"""
|
| 50 |
+
執行完整的貝葉斯網路分析
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
df: 原始資料框
|
| 54 |
+
cat_features: 分類特徵列表
|
| 55 |
+
con_features: 連續特徵列表
|
| 56 |
+
target_variable: 目標變數名稱
|
| 57 |
+
test_fraction: 測試集比例
|
| 58 |
+
algorithm: 結構學習演算法
|
| 59 |
+
estimator: 參數估計方法
|
| 60 |
+
equivalent_sample_size: 等效樣本大小(用於貝葉斯估計)
|
| 61 |
+
score_method: 評分方法(用於 Hill Climbing)
|
| 62 |
+
sig_level: 顯著性水準(用於 PC 演算法)
|
| 63 |
+
n_bins: 連續變數分箱數量
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
dict: 包含所有分析結果的字典
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
with self._lock:
|
| 70 |
+
try:
|
| 71 |
+
# 1. 資料預處理
|
| 72 |
+
processed_df = self._preprocess_data(
|
| 73 |
+
df, cat_features, con_features, target_variable, n_bins
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# 2. 分割訓練/測試集
|
| 77 |
+
self.train_data, self.test_data = train_test_split(
|
| 78 |
+
processed_df,
|
| 79 |
+
test_size=test_fraction,
|
| 80 |
+
random_state=42,
|
| 81 |
+
stratify=processed_df[target_variable]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# 3. 學習網路結構
|
| 85 |
+
self.model = self._learn_structure(
|
| 86 |
+
algorithm, score_method, sig_level, target_variable
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# 4. 參數估計
|
| 90 |
+
self._fit_parameters(estimator, equivalent_sample_size)
|
| 91 |
+
|
| 92 |
+
# 5. 初始化推論引擎
|
| 93 |
+
self.inference = VariableElimination(self.model)
|
| 94 |
+
|
| 95 |
+
# 6. 評估模型
|
| 96 |
+
train_metrics = self._evaluate_model(
|
| 97 |
+
self.train_data, target_variable, "train"
|
| 98 |
+
)
|
| 99 |
+
test_metrics = self._evaluate_model(
|
| 100 |
+
self.test_data, target_variable, "test"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# 7. 獲取 CPD
|
| 104 |
+
cpds = self._get_all_cpds()
|
| 105 |
+
|
| 106 |
+
# 8. 計算模型評分
|
| 107 |
+
scores = self._calculate_scores()
|
| 108 |
+
|
| 109 |
+
# 9. 整理結果
|
| 110 |
+
results = {
|
| 111 |
+
'model': self.model,
|
| 112 |
+
'inference': self.inference,
|
| 113 |
+
'train_metrics': train_metrics,
|
| 114 |
+
'test_metrics': test_metrics,
|
| 115 |
+
'cpds': cpds,
|
| 116 |
+
'scores': scores,
|
| 117 |
+
'parameters': {
|
| 118 |
+
'algorithm': algorithm,
|
| 119 |
+
'estimator': estimator,
|
| 120 |
+
'test_fraction': test_fraction,
|
| 121 |
+
'n_features': len(cat_features) + len(con_features),
|
| 122 |
+
'cat_features': cat_features,
|
| 123 |
+
'con_features': con_features,
|
| 124 |
+
'target_variable': target_variable,
|
| 125 |
+
'n_bins': n_bins
|
| 126 |
+
},
|
| 127 |
+
'timestamp': datetime.now().isoformat()
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# 儲存到 session results
|
| 131 |
+
self._session_results[self.session_id] = results
|
| 132 |
+
|
| 133 |
+
return results
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
raise Exception(f"Analysis failed: {str(e)}")
|
| 137 |
+
|
| 138 |
+
def _preprocess_data(self, df, cat_features, con_features,
|
| 139 |
+
target_variable, n_bins):
|
| 140 |
+
"""資料預處理"""
|
| 141 |
+
# 選擇需要的欄位
|
| 142 |
+
selected_columns = cat_features + con_features + [target_variable]
|
| 143 |
+
processed_df = df[selected_columns].copy()
|
| 144 |
+
|
| 145 |
+
# 處理缺失值
|
| 146 |
+
processed_df = processed_df.dropna()
|
| 147 |
+
|
| 148 |
+
# 處理連續變數 - 分箱
|
| 149 |
+
for col in con_features:
|
| 150 |
+
if col in processed_df.columns:
|
| 151 |
+
# 記錄分箱邊界
|
| 152 |
+
bin_edges = pd.cut(
|
| 153 |
+
processed_df[col],
|
| 154 |
+
bins=n_bins,
|
| 155 |
+
retbins=True,
|
| 156 |
+
duplicates='drop'
|
| 157 |
+
)[1]
|
| 158 |
+
|
| 159 |
+
self.bins_dict[col] = bin_edges
|
| 160 |
+
|
| 161 |
+
# 創建分箱標籤
|
| 162 |
+
bin_labels = [
|
| 163 |
+
f"{round(bin_edges[i], 2)}-{round(bin_edges[i+1], 2)}"
|
| 164 |
+
for i in range(len(bin_edges) - 1)
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
# 應用分箱
|
| 168 |
+
processed_df[col] = pd.cut(
|
| 169 |
+
processed_df[col],
|
| 170 |
+
bins=bin_edges,
|
| 171 |
+
labels=bin_labels,
|
| 172 |
+
include_lowest=True
|
| 173 |
+
).astype(str)
|
| 174 |
+
|
| 175 |
+
# 確保分類變數為字串類型
|
| 176 |
+
for col in cat_features:
|
| 177 |
+
if col in processed_df.columns:
|
| 178 |
+
processed_df[col] = processed_df[col].astype(str)
|
| 179 |
+
|
| 180 |
+
# 確保目標變數為整數
|
| 181 |
+
if target_variable in processed_df.columns:
|
| 182 |
+
processed_df[target_variable] = processed_df[target_variable].astype(int)
|
| 183 |
+
|
| 184 |
+
return processed_df
|
| 185 |
+
|
| 186 |
+
def _learn_structure(self, algorithm, score_method, sig_level, target_variable):
|
| 187 |
+
"""學習網路結構"""
|
| 188 |
+
|
| 189 |
+
if algorithm == 'NB':
|
| 190 |
+
# Naive Bayes
|
| 191 |
+
edges = [
|
| 192 |
+
(target_variable, feature)
|
| 193 |
+
for feature in self.train_data.columns
|
| 194 |
+
if feature != target_variable
|
| 195 |
+
]
|
| 196 |
+
model = BayesianNetwork(edges)
|
| 197 |
+
|
| 198 |
+
elif algorithm == 'TAN':
|
| 199 |
+
# Tree-Augmented Naive Bayes
|
| 200 |
+
tan_search = TreeSearch(self.train_data)
|
| 201 |
+
structure = tan_search.estimate(
|
| 202 |
+
estimator_type='tan',
|
| 203 |
+
class_node=target_variable
|
| 204 |
+
)
|
| 205 |
+
model = BayesianNetwork(structure.edges())
|
| 206 |
+
|
| 207 |
+
elif algorithm == 'CL':
|
| 208 |
+
# Chow-Liu
|
| 209 |
+
tan_search = TreeSearch(self.train_data)
|
| 210 |
+
structure = tan_search.estimate(
|
| 211 |
+
estimator_type='chow-liu',
|
| 212 |
+
class_node=target_variable
|
| 213 |
+
)
|
| 214 |
+
model = BayesianNetwork(structure.edges())
|
| 215 |
+
|
| 216 |
+
elif algorithm == 'HC':
|
| 217 |
+
# Hill Climbing
|
| 218 |
+
hc = HillClimbSearch(self.train_data)
|
| 219 |
+
|
| 220 |
+
# 選擇評分方法
|
| 221 |
+
scoring_methods = {
|
| 222 |
+
'BIC': BicScore(self.train_data),
|
| 223 |
+
'AIC': AICScore(self.train_data),
|
| 224 |
+
'K2': K2Score(self.train_data),
|
| 225 |
+
'BDeu': BDeuScore(self.train_data),
|
| 226 |
+
'BDs': BDsScore(self.train_data)
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
structure = hc.estimate(
|
| 230 |
+
scoring_method=scoring_methods[score_method]
|
| 231 |
+
)
|
| 232 |
+
model = BayesianNetwork(structure.edges())
|
| 233 |
+
|
| 234 |
+
elif algorithm == 'PC':
|
| 235 |
+
# PC Algorithm
|
| 236 |
+
pc = PC(self.train_data)
|
| 237 |
+
|
| 238 |
+
# 嘗試不同的 max_cond_vars 直到成功
|
| 239 |
+
for max_cond in [5, 4, 3, 2, 1]:
|
| 240 |
+
try:
|
| 241 |
+
structure = pc.estimate(
|
| 242 |
+
significance_level=sig_level,
|
| 243 |
+
max_cond_vars=max_cond,
|
| 244 |
+
ci_test='chi_square',
|
| 245 |
+
variant='stable'
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# 檢查是否有效
|
| 249 |
+
if structure.edges():
|
| 250 |
+
model = BayesianNetwork(structure.edges())
|
| 251 |
+
break
|
| 252 |
+
except:
|
| 253 |
+
continue
|
| 254 |
+
else:
|
| 255 |
+
# 如果都失敗,使用 Naive Bayes
|
| 256 |
+
edges = [
|
| 257 |
+
(target_variable, feature)
|
| 258 |
+
for feature in self.train_data.columns
|
| 259 |
+
if feature != target_variable
|
| 260 |
+
]
|
| 261 |
+
model = BayesianNetwork(edges)
|
| 262 |
+
|
| 263 |
+
else:
|
| 264 |
+
raise ValueError(f"Unknown algorithm: {algorithm}")
|
| 265 |
+
|
| 266 |
+
return model
|
| 267 |
+
|
| 268 |
+
def _fit_parameters(self, estimator, equivalent_sample_size):
|
| 269 |
+
"""參數估計"""
|
| 270 |
+
if estimator == 'bn':
|
| 271 |
+
self.model.fit(
|
| 272 |
+
self.train_data,
|
| 273 |
+
estimator=BayesianEstimator,
|
| 274 |
+
equivalent_sample_size=equivalent_sample_size
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
self.model.fit(
|
| 278 |
+
self.train_data,
|
| 279 |
+
estimator=MaximumLikelihoodEstimator
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def _predict_probabilities(self, data, target_variable):
|
| 283 |
+
"""預測機率"""
|
| 284 |
+
true_labels = []
|
| 285 |
+
predicted_probs = []
|
| 286 |
+
|
| 287 |
+
model_nodes = set(self.model.nodes())
|
| 288 |
+
|
| 289 |
+
for idx, row in data.iterrows():
|
| 290 |
+
# 準備 evidence
|
| 291 |
+
evidence = {
|
| 292 |
+
k: v for k, v in row.drop(target_variable).to_dict().items()
|
| 293 |
+
if k in model_nodes
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
true_label = row[target_variable]
|
| 297 |
+
true_labels.append(true_label)
|
| 298 |
+
|
| 299 |
+
try:
|
| 300 |
+
result = self.inference.query(
|
| 301 |
+
variables=[target_variable],
|
| 302 |
+
evidence=evidence
|
| 303 |
+
)
|
| 304 |
+
probs = result.values
|
| 305 |
+
predicted_probs.append(probs)
|
| 306 |
+
except:
|
| 307 |
+
# 如果推論失敗,使用邊際機率
|
| 308 |
+
predicted_probs.append(None)
|
| 309 |
+
|
| 310 |
+
# 過濾有效的結果
|
| 311 |
+
valid_data = [
|
| 312 |
+
(label, prob)
|
| 313 |
+
for label, prob in zip(true_labels, predicted_probs)
|
| 314 |
+
if prob is not None and len(prob) > 1
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
if not valid_data:
|
| 318 |
+
return [], []
|
| 319 |
+
|
| 320 |
+
valid_labels, valid_probs = zip(*valid_data)
|
| 321 |
+
prob_array = np.array([prob[1] for prob in valid_probs])
|
| 322 |
+
|
| 323 |
+
return list(valid_labels), prob_array
|
| 324 |
+
|
| 325 |
+
def _evaluate_model(self, data, target_variable, dataset_name):
|
| 326 |
+
"""評估模型效能"""
|
| 327 |
+
# 預測
|
| 328 |
+
true_labels, pred_probs = self._predict_probabilities(
|
| 329 |
+
data, target_variable
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if len(true_labels) == 0:
|
| 333 |
+
return {
|
| 334 |
+
'accuracy': 0,
|
| 335 |
+
'precision': 0,
|
| 336 |
+
'recall': 0,
|
| 337 |
+
'f1': 0,
|
| 338 |
+
'auc': 0,
|
| 339 |
+
'confusion_matrix': [[0, 0], [0, 0]],
|
| 340 |
+
'fpr': [0],
|
| 341 |
+
'tpr': [0]
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
# 二元預測
|
| 345 |
+
pred_labels = (pred_probs >= 0.5).astype(int)
|
| 346 |
+
|
| 347 |
+
# 計算指標
|
| 348 |
+
accuracy = accuracy_score(true_labels, pred_labels) * 100
|
| 349 |
+
precision = precision_score(true_labels, pred_labels, zero_division=0) * 100
|
| 350 |
+
recall = recall_score(true_labels, pred_labels, zero_division=0) * 100
|
| 351 |
+
f1 = f1_score(true_labels, pred_labels, zero_division=0) * 100
|
| 352 |
+
|
| 353 |
+
# ROC 曲線
|
| 354 |
+
fpr, tpr, _ = roc_curve(true_labels, pred_probs)
|
| 355 |
+
auc = roc_auc_score(true_labels, pred_probs)
|
| 356 |
+
|
| 357 |
+
# 混淆矩陣
|
| 358 |
+
cm = confusion_matrix(true_labels, pred_labels).tolist()
|
| 359 |
+
|
| 360 |
+
# G-mean 和 P-mean
|
| 361 |
+
tn, fp, fn, tp = confusion_matrix(true_labels, pred_labels).ravel()
|
| 362 |
+
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 363 |
+
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 364 |
+
g_mean = np.sqrt(sensitivity * precision / 100) * 100
|
| 365 |
+
p_mean = np.sqrt(specificity * sensitivity) * 100
|
| 366 |
+
|
| 367 |
+
return {
|
| 368 |
+
'accuracy': accuracy,
|
| 369 |
+
'precision': precision,
|
| 370 |
+
'recall': recall,
|
| 371 |
+
'f1': f1,
|
| 372 |
+
'auc': auc,
|
| 373 |
+
'g_mean': g_mean,
|
| 374 |
+
'p_mean': p_mean,
|
| 375 |
+
'specificity': specificity * 100,
|
| 376 |
+
'confusion_matrix': cm,
|
| 377 |
+
'fpr': fpr.tolist(),
|
| 378 |
+
'tpr': tpr.tolist(),
|
| 379 |
+
'predicted_probs': pred_probs.tolist()
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
def _get_all_cpds(self):
|
| 383 |
+
"""獲取所有條件機率表"""
|
| 384 |
+
cpds = {}
|
| 385 |
+
for node in self.model.nodes():
|
| 386 |
+
cpd = self.model.get_cpds(node)
|
| 387 |
+
cpds[node] = cpd
|
| 388 |
+
return cpds
|
| 389 |
+
|
| 390 |
+
def _calculate_scores(self):
|
| 391 |
+
"""計算模型評分"""
|
| 392 |
+
scores = {
|
| 393 |
+
'log_likelihood': log_likelihood_score(self.model, self.train_data),
|
| 394 |
+
'bic': structure_score(self.model, self.train_data, scoring_method='bic'),
|
| 395 |
+
'k2': structure_score(self.model, self.train_data, scoring_method='k2'),
|
| 396 |
+
'bdeu': structure_score(self.model, self.train_data, scoring_method='bdeu'),
|
| 397 |
+
'bds': structure_score(self.model, self.train_data, scoring_method='bds')
|
| 398 |
+
}
|
| 399 |
+
return scores
|
| 400 |
+
|
| 401 |
+
@classmethod
|
| 402 |
+
def get_session_results(cls, session_id):
|
| 403 |
+
"""獲取特定 session 的結果"""
|
| 404 |
+
return cls._session_results.get(session_id)
|
| 405 |
+
|
| 406 |
+
@classmethod
|
| 407 |
+
def clear_session_results(cls, session_id):
|
| 408 |
+
"""清除特定 session 的結果"""
|
| 409 |
+
if session_id in cls._session_results:
|
| 410 |
+
del cls._session_results[session_id]
|
llm_assistant.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
class LLMAssistant:
|
| 6 |
+
"""
|
| 7 |
+
LLM 問答助手
|
| 8 |
+
協助用戶理解貝葉斯網路分析結果
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, api_key, session_id):
|
| 12 |
+
"""
|
| 13 |
+
初始化 LLM 助手
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
api_key: OpenAI API key
|
| 17 |
+
session_id: 唯一的 session 識別碼
|
| 18 |
+
"""
|
| 19 |
+
self.client = OpenAI(api_key=api_key)
|
| 20 |
+
self.session_id = session_id
|
| 21 |
+
self.conversation_history = []
|
| 22 |
+
|
| 23 |
+
# 系統提示詞
|
| 24 |
+
self.system_prompt = """You are an expert data scientist specializing in Bayesian Networks and machine learning.
|
| 25 |
+
Your role is to help users understand their Bayesian Network analysis results.
|
| 26 |
+
|
| 27 |
+
You should:
|
| 28 |
+
1. Explain complex statistical concepts in simple terms
|
| 29 |
+
2. Provide insights about model performance metrics
|
| 30 |
+
3. Suggest improvements when asked
|
| 31 |
+
4. Explain the structure and relationships in the Bayesian Network
|
| 32 |
+
5. Help interpret conditional probability tables (CPTs)
|
| 33 |
+
6. Discuss limitations and assumptions of the model
|
| 34 |
+
|
| 35 |
+
Always be clear, concise, and educational. Use examples when helpful.
|
| 36 |
+
Format your responses with proper markdown for better readability."""
|
| 37 |
+
|
| 38 |
+
def get_response(self, user_message, analysis_results):
|
| 39 |
+
"""
|
| 40 |
+
獲取 AI 回應
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
user_message: 用戶訊息
|
| 44 |
+
analysis_results: 分析結果字典
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
str: AI 回應
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
# 準備上下文資訊
|
| 51 |
+
context = self._prepare_context(analysis_results)
|
| 52 |
+
|
| 53 |
+
# 添加用戶訊息到歷史
|
| 54 |
+
self.conversation_history.append({
|
| 55 |
+
"role": "user",
|
| 56 |
+
"content": user_message
|
| 57 |
+
})
|
| 58 |
+
|
| 59 |
+
# 構建訊息列表
|
| 60 |
+
messages = [
|
| 61 |
+
{"role": "system", "content": self.system_prompt},
|
| 62 |
+
{"role": "system", "content": f"Analysis Context:\n{context}"}
|
| 63 |
+
] + self.conversation_history
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
# 調用 OpenAI API
|
| 67 |
+
response = self.client.chat.completions.create(
|
| 68 |
+
model="gpt-4o-mini",
|
| 69 |
+
messages=messages,
|
| 70 |
+
temperature=0.7,
|
| 71 |
+
max_tokens=1500
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
assistant_message = response.choices[0].message.content
|
| 75 |
+
|
| 76 |
+
# 添加助手回應到歷史
|
| 77 |
+
self.conversation_history.append({
|
| 78 |
+
"role": "assistant",
|
| 79 |
+
"content": assistant_message
|
| 80 |
+
})
|
| 81 |
+
|
| 82 |
+
return assistant_message
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
return f"❌ Error: {str(e)}\n\nPlease check your API key and try again."
|
| 86 |
+
|
| 87 |
+
def _prepare_context(self, results):
|
| 88 |
+
"""準備分析結果的上下文資訊"""
|
| 89 |
+
|
| 90 |
+
if not results:
|
| 91 |
+
return "No analysis results available yet."
|
| 92 |
+
|
| 93 |
+
# 提取關鍵資訊
|
| 94 |
+
params = results['parameters']
|
| 95 |
+
train_metrics = results['train_metrics']
|
| 96 |
+
test_metrics = results['test_metrics']
|
| 97 |
+
scores = results['scores']
|
| 98 |
+
|
| 99 |
+
# 構建上下文字串
|
| 100 |
+
context = f"""
|
| 101 |
+
## Model Configuration
|
| 102 |
+
- Algorithm: {params['algorithm']}
|
| 103 |
+
- Estimator: {params['estimator']}
|
| 104 |
+
- Number of Features: {params['n_features']}
|
| 105 |
+
- Categorical: {len(params['cat_features'])}
|
| 106 |
+
- Continuous: {len(params['con_features'])}
|
| 107 |
+
- Target Variable: {params['target_variable']}
|
| 108 |
+
- Test Set Proportion: {params['test_fraction']:.0%}
|
| 109 |
+
|
| 110 |
+
## Training Set Performance
|
| 111 |
+
- Accuracy: {train_metrics['accuracy']:.2f}%
|
| 112 |
+
- Precision: {train_metrics['precision']:.2f}%
|
| 113 |
+
- Recall: {train_metrics['recall']:.2f}%
|
| 114 |
+
- F1-Score: {train_metrics['f1']:.2f}%
|
| 115 |
+
- AUC: {train_metrics['auc']:.4f}
|
| 116 |
+
- G-mean: {train_metrics['g_mean']:.2f}%
|
| 117 |
+
- P-mean: {train_metrics['p_mean']:.2f}%
|
| 118 |
+
- Specificity: {train_metrics['specificity']:.2f}%
|
| 119 |
+
|
| 120 |
+
## Test Set Performance
|
| 121 |
+
- Accuracy: {test_metrics['accuracy']:.2f}%
|
| 122 |
+
- Precision: {test_metrics['precision']:.2f}%
|
| 123 |
+
- Recall: {test_metrics['recall']:.2f}%
|
| 124 |
+
- F1-Score: {test_metrics['f1']:.2f}%
|
| 125 |
+
- AUC: {test_metrics['auc']:.4f}
|
| 126 |
+
- G-mean: {test_metrics['g_mean']:.2f}%
|
| 127 |
+
- P-mean: {test_metrics['p_mean']:.2f}%
|
| 128 |
+
- Specificity: {test_metrics['specificity']:.2f}%
|
| 129 |
+
|
| 130 |
+
## Model Scores
|
| 131 |
+
- Log-Likelihood: {scores['log_likelihood']:.2f}
|
| 132 |
+
- BIC Score: {scores['bic']:.2f}
|
| 133 |
+
- K2 Score: {scores['k2']:.2f}
|
| 134 |
+
- BDeu Score: {scores['bdeu']:.2f}
|
| 135 |
+
- BDs Score: {scores['bds']:.2f}
|
| 136 |
+
|
| 137 |
+
## Network Structure
|
| 138 |
+
- Total Nodes: {len(results['model'].nodes())}
|
| 139 |
+
- Total Edges: {len(results['model'].edges())}
|
| 140 |
+
- Network Edges: {list(results['model'].edges())[:10]}... (showing first 10)
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
return context
|
| 144 |
+
|
| 145 |
+
def generate_summary(self, analysis_results):
|
| 146 |
+
"""
|
| 147 |
+
自動生成分析結果總結
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
analysis_results: 分析結果字典
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
str: 總結文字
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
summary_prompt = """Based on the analysis results provided in the context, please generate a comprehensive summary that includes:
|
| 157 |
+
|
| 158 |
+
1. **Model Overview**: Brief description of the model type and configuration
|
| 159 |
+
2. **Performance Analysis**:
|
| 160 |
+
- Overall model performance on both training and test sets
|
| 161 |
+
- Comparison between training and test performance (overfitting/underfitting)
|
| 162 |
+
- Key strengths and weaknesses
|
| 163 |
+
3. **Network Structure Insights**: What the learned structure tells us about variable relationships
|
| 164 |
+
4. **Recommendations**: Specific suggestions for improvement
|
| 165 |
+
5. **Limitations**: Important caveats and limitations to consider
|
| 166 |
+
|
| 167 |
+
Format the summary in clear markdown with appropriate sections and bullet points."""
|
| 168 |
+
|
| 169 |
+
return self.get_response(summary_prompt, analysis_results)
|
| 170 |
+
|
| 171 |
+
def explain_metric(self, metric_name, analysis_results):
|
| 172 |
+
"""
|
| 173 |
+
解釋特定指標
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
metric_name: 指標名稱
|
| 177 |
+
analysis_results: 分析結果字典
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
str: 指標解釋
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
explain_prompt = f"""Please explain the following metric in the context of this analysis:
|
| 184 |
+
|
| 185 |
+
Metric: {metric_name}
|
| 186 |
+
|
| 187 |
+
Include:
|
| 188 |
+
1. What this metric measures
|
| 189 |
+
2. The value obtained in this analysis (training and test)
|
| 190 |
+
3. How to interpret this value
|
| 191 |
+
4. What it tells us about model performance
|
| 192 |
+
5. How it relates to other metrics in the analysis"""
|
| 193 |
+
|
| 194 |
+
return self.get_response(explain_prompt, analysis_results)
|
| 195 |
+
|
| 196 |
+
def suggest_improvements(self, analysis_results):
|
| 197 |
+
"""
|
| 198 |
+
提供改進建議
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
analysis_results: 分析結果字典
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
str: 改進建議
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
improve_prompt = """Based on the current model performance and configuration, please provide specific, actionable recommendations for improvement.
|
| 208 |
+
|
| 209 |
+
Consider:
|
| 210 |
+
1. Feature engineering opportunities
|
| 211 |
+
2. Algorithm selection
|
| 212 |
+
3. Hyperparameter tuning
|
| 213 |
+
4. Data quality issues
|
| 214 |
+
5. Model complexity trade-offs
|
| 215 |
+
|
| 216 |
+
Prioritize recommendations by potential impact."""
|
| 217 |
+
|
| 218 |
+
return self.get_response(improve_prompt, analysis_results)
|
| 219 |
+
|
| 220 |
+
def explain_network_structure(self, analysis_results):
|
| 221 |
+
"""
|
| 222 |
+
解釋網路結構
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
analysis_results: 分析結果字典
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
str: 網路結構解釋
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
structure_prompt = """Please explain the learned Bayesian Network structure:
|
| 232 |
+
|
| 233 |
+
1. What are the key relationships (edges) discovered?
|
| 234 |
+
2. What do these relationships tell us about the domain?
|
| 235 |
+
3. Are there any surprising or interesting patterns?
|
| 236 |
+
4. How does the structure relate to the target variable?
|
| 237 |
+
5. What are the implications for prediction and inference?"""
|
| 238 |
+
|
| 239 |
+
return self.get_response(structure_prompt, analysis_results)
|
| 240 |
+
|
| 241 |
+
def compare_algorithms(self, analysis_results):
|
| 242 |
+
"""
|
| 243 |
+
比較不同演算法
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
analysis_results: 分析結果字典
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
str: 演算法比較
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
compare_prompt = f"""The current model uses the {analysis_results['parameters']['algorithm']} algorithm.
|
| 253 |
+
|
| 254 |
+
Please:
|
| 255 |
+
1. Explain the characteristics of this algorithm
|
| 256 |
+
2. Compare it with other available algorithms (NB, TAN, CL, HC, PC)
|
| 257 |
+
3. Discuss when this algorithm is most appropriate
|
| 258 |
+
4. Suggest if a different algorithm might be better for this dataset
|
| 259 |
+
5. Explain the trade-offs involved"""
|
| 260 |
+
|
| 261 |
+
return self.get_response(compare_prompt, analysis_results)
|
| 262 |
+
|
| 263 |
+
def reset_conversation(self):
|
| 264 |
+
"""重置對話歷史"""
|
| 265 |
+
self.conversation_history = []
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit==1.31.0
|
| 2 |
+
pandas==2.1.4
|
| 3 |
+
numpy==1.26.3
|
| 4 |
+
plotly==5.18.0
|
| 5 |
+
scikit-learn==1.4.0
|
| 6 |
+
pgmpy==0.1.25
|
| 7 |
+
networkx==3.2.1
|
| 8 |
+
openai==1.12.0
|
| 9 |
+
graphviz==0.20.1
|
utils.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import plotly.graph_objects as go
|
| 2 |
+
import plotly.express as px
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import networkx as nx
|
| 6 |
+
from plotly.subplots import make_subplots
|
| 7 |
+
|
| 8 |
+
def plot_roc_curve(fpr, tpr, auc, title="ROC Curve"):
|
| 9 |
+
"""
|
| 10 |
+
繪製 ROC 曲線
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
fpr: False positive rate
|
| 14 |
+
tpr: True positive rate
|
| 15 |
+
auc: Area under curve
|
| 16 |
+
title: 圖表標題
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
plotly figure
|
| 20 |
+
"""
|
| 21 |
+
fig = go.Figure()
|
| 22 |
+
|
| 23 |
+
# ROC 曲線
|
| 24 |
+
fig.add_trace(go.Scatter(
|
| 25 |
+
x=fpr,
|
| 26 |
+
y=tpr,
|
| 27 |
+
mode='lines',
|
| 28 |
+
name=f'ROC Curve (AUC = {auc:.4f})',
|
| 29 |
+
line=dict(color='#2d6ca2', width=2)
|
| 30 |
+
))
|
| 31 |
+
|
| 32 |
+
# 對角線(隨機分類器)
|
| 33 |
+
fig.add_trace(go.Scatter(
|
| 34 |
+
x=[0, 1],
|
| 35 |
+
y=[0, 1],
|
| 36 |
+
mode='lines',
|
| 37 |
+
name='Random Classifier',
|
| 38 |
+
line=dict(color='gray', width=1, dash='dash')
|
| 39 |
+
))
|
| 40 |
+
|
| 41 |
+
fig.update_layout(
|
| 42 |
+
title=title,
|
| 43 |
+
xaxis_title='False Positive Rate',
|
| 44 |
+
yaxis_title='True Positive Rate',
|
| 45 |
+
width=600,
|
| 46 |
+
height=500,
|
| 47 |
+
template='plotly_white',
|
| 48 |
+
legend=dict(x=0.6, y=0.1)
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return fig
|
| 52 |
+
|
| 53 |
+
def plot_confusion_matrix(cm, title="Confusion Matrix"):
|
| 54 |
+
"""
|
| 55 |
+
繪製混淆矩陣
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
cm: 混淆矩陣 (2x2 list)
|
| 59 |
+
title: 圖表標題
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
plotly figure
|
| 63 |
+
"""
|
| 64 |
+
# 轉換為 numpy array
|
| 65 |
+
cm_array = np.array(cm)
|
| 66 |
+
|
| 67 |
+
# 計算百分比
|
| 68 |
+
cm_percent = cm_array / cm_array.sum() * 100
|
| 69 |
+
|
| 70 |
+
# 創建標籤
|
| 71 |
+
labels = [
|
| 72 |
+
[f'{cm_array[i][j]}<br>({cm_percent[i][j]:.1f}%)'
|
| 73 |
+
for j in range(2)]
|
| 74 |
+
for i in range(2)
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
fig = go.Figure(data=go.Heatmap(
|
| 78 |
+
z=cm_array,
|
| 79 |
+
x=['Predicted: 0', 'Predicted: 1'],
|
| 80 |
+
y=['Actual: 0', 'Actual: 1'],
|
| 81 |
+
text=labels,
|
| 82 |
+
texttemplate='%{text}',
|
| 83 |
+
textfont={"size": 14},
|
| 84 |
+
colorscale='Blues',
|
| 85 |
+
showscale=True
|
| 86 |
+
))
|
| 87 |
+
|
| 88 |
+
fig.update_layout(
|
| 89 |
+
title=title,
|
| 90 |
+
width=500,
|
| 91 |
+
height=450,
|
| 92 |
+
template='plotly_white'
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return fig
|
| 96 |
+
|
| 97 |
+
def plot_probability_distribution(probs, title="Probability Distribution"):
|
| 98 |
+
"""
|
| 99 |
+
繪製機率分佈圖
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
probs: 預測機率列表
|
| 103 |
+
title: 圖表標題
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
plotly figure
|
| 107 |
+
"""
|
| 108 |
+
fig = go.Figure()
|
| 109 |
+
|
| 110 |
+
fig.add_trace(go.Histogram(
|
| 111 |
+
x=probs,
|
| 112 |
+
nbinsx=20,
|
| 113 |
+
name='Predicted Probabilities',
|
| 114 |
+
marker=dict(
|
| 115 |
+
color='#2d6ca2',
|
| 116 |
+
line=dict(color='white', width=1)
|
| 117 |
+
)
|
| 118 |
+
))
|
| 119 |
+
|
| 120 |
+
fig.update_layout(
|
| 121 |
+
title=title,
|
| 122 |
+
xaxis_title='Predicted Probability for Class 1',
|
| 123 |
+
yaxis_title='Frequency',
|
| 124 |
+
width=700,
|
| 125 |
+
height=400,
|
| 126 |
+
template='plotly_white',
|
| 127 |
+
showlegend=False
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
fig.update_xaxes(range=[0, 1])
|
| 131 |
+
|
| 132 |
+
return fig
|
| 133 |
+
|
| 134 |
+
def generate_network_graph(model):
|
| 135 |
+
"""
|
| 136 |
+
生成貝葉斯網路結構圖
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
model: BayesianNetwork 模型
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
plotly figure
|
| 143 |
+
"""
|
| 144 |
+
# 創建 NetworkX 圖
|
| 145 |
+
G = nx.DiGraph()
|
| 146 |
+
G.add_edges_from(model.edges())
|
| 147 |
+
|
| 148 |
+
# 使用層次佈局
|
| 149 |
+
try:
|
| 150 |
+
pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
|
| 151 |
+
except:
|
| 152 |
+
pos = nx.circular_layout(G)
|
| 153 |
+
|
| 154 |
+
# 提取節點和邊的座標
|
| 155 |
+
edge_x = []
|
| 156 |
+
edge_y = []
|
| 157 |
+
for edge in G.edges():
|
| 158 |
+
x0, y0 = pos[edge[0]]
|
| 159 |
+
x1, y1 = pos[edge[1]]
|
| 160 |
+
edge_x.extend([x0, x1, None])
|
| 161 |
+
edge_y.extend([y0, y1, None])
|
| 162 |
+
|
| 163 |
+
edge_trace = go.Scatter(
|
| 164 |
+
x=edge_x, y=edge_y,
|
| 165 |
+
line=dict(width=2, color='#888'),
|
| 166 |
+
hoverinfo='none',
|
| 167 |
+
mode='lines',
|
| 168 |
+
showlegend=False
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
node_x = []
|
| 172 |
+
node_y = []
|
| 173 |
+
node_text = []
|
| 174 |
+
for node in G.nodes():
|
| 175 |
+
x, y = pos[node]
|
| 176 |
+
node_x.append(x)
|
| 177 |
+
node_y.append(y)
|
| 178 |
+
node_text.append(node)
|
| 179 |
+
|
| 180 |
+
node_trace = go.Scatter(
|
| 181 |
+
x=node_x, y=node_y,
|
| 182 |
+
mode='markers+text',
|
| 183 |
+
hoverinfo='text',
|
| 184 |
+
text=node_text,
|
| 185 |
+
textposition="top center",
|
| 186 |
+
showlegend=False,
|
| 187 |
+
marker=dict(
|
| 188 |
+
size=30,
|
| 189 |
+
color='#2d6ca2',
|
| 190 |
+
line=dict(width=2, color='white')
|
| 191 |
+
)
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# 添加箭頭
|
| 195 |
+
annotations = []
|
| 196 |
+
for edge in G.edges():
|
| 197 |
+
x0, y0 = pos[edge[0]]
|
| 198 |
+
x1, y1 = pos[edge[1]]
|
| 199 |
+
|
| 200 |
+
# 計算箭頭位置(在邊的中點)
|
| 201 |
+
mid_x = (x0 + x1) / 2
|
| 202 |
+
mid_y = (y0 + y1) / 2
|
| 203 |
+
|
| 204 |
+
annotations.append(
|
| 205 |
+
dict(
|
| 206 |
+
ax=x0, ay=y0,
|
| 207 |
+
axref='x', ayref='y',
|
| 208 |
+
x=x1, y=y1,
|
| 209 |
+
xref='x', yref='y',
|
| 210 |
+
showarrow=True,
|
| 211 |
+
arrowhead=2,
|
| 212 |
+
arrowsize=1,
|
| 213 |
+
arrowwidth=2,
|
| 214 |
+
arrowcolor='#888'
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
fig = go.Figure(data=[edge_trace, node_trace])
|
| 219 |
+
|
| 220 |
+
fig.update_layout(
|
| 221 |
+
title='Bayesian Network Structure',
|
| 222 |
+
titlefont_size=16,
|
| 223 |
+
showlegend=False,
|
| 224 |
+
hovermode='closest',
|
| 225 |
+
margin=dict(b=20, l=5, r=5, t=40),
|
| 226 |
+
annotations=annotations,
|
| 227 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 228 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 229 |
+
width=900,
|
| 230 |
+
height=700,
|
| 231 |
+
template='plotly_white'
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return fig
|
| 235 |
+
|
| 236 |
+
def create_cpd_table(cpd):
|
| 237 |
+
"""
|
| 238 |
+
創建條件機率表的 DataFrame
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
cpd: CPD 物件
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
pandas DataFrame
|
| 245 |
+
"""
|
| 246 |
+
if cpd is None:
|
| 247 |
+
return pd.DataFrame()
|
| 248 |
+
|
| 249 |
+
# 獲取變數資訊
|
| 250 |
+
variable = cpd.variable
|
| 251 |
+
evidence_vars = cpd.variables[1:] if len(cpd.variables) > 1 else []
|
| 252 |
+
|
| 253 |
+
# 如果是根節點(沒有父節點)
|
| 254 |
+
if not evidence_vars:
|
| 255 |
+
values = np.round(cpd.values.flatten(), 4)
|
| 256 |
+
df = pd.DataFrame(
|
| 257 |
+
{variable: values},
|
| 258 |
+
index=[f"{variable}({i})" for i in range(len(values))]
|
| 259 |
+
)
|
| 260 |
+
return df
|
| 261 |
+
|
| 262 |
+
# 有父節點的情況
|
| 263 |
+
evidence_card = cpd.cardinality[1:]
|
| 264 |
+
|
| 265 |
+
# 生成多層索引欄位
|
| 266 |
+
from itertools import product
|
| 267 |
+
column_values = list(product(*[range(card) for card in evidence_card]))
|
| 268 |
+
|
| 269 |
+
# 創建欄位名稱
|
| 270 |
+
columns = pd.MultiIndex.from_tuples(
|
| 271 |
+
[tuple(f"{var}({val})" for var, val in zip(evidence_vars, vals))
|
| 272 |
+
for vals in column_values],
|
| 273 |
+
names=evidence_vars
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# 重塑 CPD 值
|
| 277 |
+
reshaped_values = cpd.values.reshape(len(cpd.values), -1)
|
| 278 |
+
reshaped_values = np.round(reshaped_values, 4)
|
| 279 |
+
|
| 280 |
+
# 創建 DataFrame
|
| 281 |
+
df = pd.DataFrame(
|
| 282 |
+
reshaped_values,
|
| 283 |
+
index=[f"{variable}({i})" for i in range(len(cpd.values))],
|
| 284 |
+
columns=columns
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
return df
|
| 288 |
+
|
| 289 |
+
def create_metrics_comparison_table(train_metrics, test_metrics):
|
| 290 |
+
"""
|
| 291 |
+
創建訓練集和測試集指標比較表
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
train_metrics: 訓練集指標字典
|
| 295 |
+
test_metrics: 測試集指標字典
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
pandas DataFrame
|
| 299 |
+
"""
|
| 300 |
+
metrics_data = {
|
| 301 |
+
'Metric': [
|
| 302 |
+
'Accuracy', 'Precision', 'Recall', 'F1-Score',
|
| 303 |
+
'AUC', 'G-mean', 'P-mean', 'Specificity'
|
| 304 |
+
],
|
| 305 |
+
'Training Set': [
|
| 306 |
+
f"{train_metrics['accuracy']:.2f}%",
|
| 307 |
+
f"{train_metrics['precision']:.2f}%",
|
| 308 |
+
f"{train_metrics['recall']:.2f}%",
|
| 309 |
+
f"{train_metrics['f1']:.2f}%",
|
| 310 |
+
f"{train_metrics['auc']:.4f}",
|
| 311 |
+
f"{train_metrics['g_mean']:.2f}%",
|
| 312 |
+
f"{train_metrics['p_mean']:.2f}%",
|
| 313 |
+
f"{train_metrics['specificity']:.2f}%"
|
| 314 |
+
],
|
| 315 |
+
'Test Set': [
|
| 316 |
+
f"{test_metrics['accuracy']:.2f}%",
|
| 317 |
+
f"{test_metrics['precision']:.2f}%",
|
| 318 |
+
f"{test_metrics['recall']:.2f}%",
|
| 319 |
+
f"{test_metrics['f1']:.2f}%",
|
| 320 |
+
f"{test_metrics['auc']:.4f}",
|
| 321 |
+
f"{test_metrics['g_mean']:.2f}%",
|
| 322 |
+
f"{test_metrics['p_mean']:.2f}%",
|
| 323 |
+
f"{test_metrics['specificity']:.2f}%"
|
| 324 |
+
]
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
df = pd.DataFrame(metrics_data)
|
| 328 |
+
return df
|
| 329 |
+
|
| 330 |
+
def export_results_to_json(results, filename="analysis_results.json"):
|
| 331 |
+
"""
|
| 332 |
+
將結果匯出為 JSON 格式
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
results: 分析結果字典
|
| 336 |
+
filename: 檔案名稱
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
JSON 字串
|
| 340 |
+
"""
|
| 341 |
+
import json
|
| 342 |
+
|
| 343 |
+
# 移除無法序列化的物件
|
| 344 |
+
exportable_results = {
|
| 345 |
+
'parameters': results['parameters'],
|
| 346 |
+
'train_metrics': {
|
| 347 |
+
k: v for k, v in results['train_metrics'].items()
|
| 348 |
+
if k not in ['fpr', 'tpr', 'predicted_probs']
|
| 349 |
+
},
|
| 350 |
+
'test_metrics': {
|
| 351 |
+
k: v for k, v in results['test_metrics'].items()
|
| 352 |
+
if k not in ['fpr', 'tpr', 'predicted_probs']
|
| 353 |
+
},
|
| 354 |
+
'scores': results['scores'],
|
| 355 |
+
'network_edges': list(results['model'].edges()),
|
| 356 |
+
'timestamp': results['timestamp']
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
return json.dumps(exportable_results, indent=2)
|
| 360 |
+
|
| 361 |
+
def calculate_performance_gap(train_metrics, test_metrics):
|
| 362 |
+
"""
|
| 363 |
+
計算訓練集和測試集之間的效能差距
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
train_metrics: 訓練集指標
|
| 367 |
+
test_metrics: 測試集指標
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
dict: 效能差距字典
|
| 371 |
+
"""
|
| 372 |
+
gaps = {
|
| 373 |
+
'accuracy_gap': train_metrics['accuracy'] - test_metrics['accuracy'],
|
| 374 |
+
'precision_gap': train_metrics['precision'] - test_metrics['precision'],
|
| 375 |
+
'recall_gap': train_metrics['recall'] - test_metrics['recall'],
|
| 376 |
+
'f1_gap': train_metrics['f1'] - test_metrics['f1'],
|
| 377 |
+
'auc_gap': train_metrics['auc'] - test_metrics['auc']
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
# 判斷是否有過擬合
|
| 381 |
+
avg_gap = np.mean([abs(v) for v in gaps.values()])
|
| 382 |
+
overfitting_status = "High" if avg_gap > 10 else "Moderate" if avg_gap > 5 else "Low"
|
| 383 |
+
|
| 384 |
+
gaps['average_gap'] = avg_gap
|
| 385 |
+
gaps['overfitting_risk'] = overfitting_status
|
| 386 |
+
|
| 387 |
+
return gaps
|