Donlagon007 commited on
Commit
c64919d
·
verified ·
1 Parent(s): 9bb6b99

Upload 8 files

Browse files
Files changed (8) hide show
  1. BC_imputed_micerf_period13_fid_course_D4.csv +0 -0
  2. README.md +8 -18
  3. app.py +1060 -0
  4. bn_core.py +536 -0
  5. llm_assistant.py +360 -0
  6. packages.txt +1 -0
  7. requirements.txt +9 -3
  8. utils.py +313 -0
BC_imputed_micerf_period13_fid_course_D4.csv ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,20 +1,10 @@
1
  ---
2
- title: BN Upload
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
  pinned: false
11
- short_description: Streamlit template space
12
- license: mit
13
- ---
14
-
15
- # Welcome to Streamlit!
16
-
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
1
  ---
2
+ title: Bayesian Network Analysis
3
+ emoji: 🔬
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.31.0
8
+ app_file: app.py
 
9
  pinned: false
10
+ ---
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,1060 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ import plotly.express as px
6
+ from io import BytesIO
7
+ import base64
8
+ import json
9
+ from datetime import datetime
10
+ import uuid
11
+
12
+ # 頁面配置
13
+ st.set_page_config(
14
+ page_title="Bayesian Network Analysis System",
15
+ page_icon="🔬",
16
+ layout="wide",
17
+ initial_sidebar_state="expanded"
18
+ )
19
+
20
+ # 自定義 CSS - 讓介面更像 Django
21
+ st.markdown("""
22
+ <style>
23
+ /* Expander 樣式 - 類似 Django 的摺疊區域 */
24
+ .streamlit-expanderHeader {
25
+ background-color: #e8f1f8;
26
+ border: 1px solid #b0cfe8;
27
+ border-radius: 5px;
28
+ font-weight: 600;
29
+ color: #1b4f72;
30
+ }
31
+
32
+ .streamlit-expanderHeader:hover {
33
+ background-color: #d0e7f8;
34
+ }
35
+
36
+ /* Checkbox 樣式 */
37
+ .stCheckbox {
38
+ padding: 2px 0;
39
+ }
40
+
41
+ /* Radio button 樣式 */
42
+ .stRadio > label {
43
+ font-weight: 600;
44
+ color: #1b4f72;
45
+ }
46
+
47
+ /* 選擇框樣式 */
48
+ .stSelectbox > label, .stNumberInput > label {
49
+ font-weight: 600;
50
+ color: #1b4f72;
51
+ }
52
+
53
+ /* 分隔線 */
54
+ hr {
55
+ margin: 1rem 0;
56
+ border-top: 2px solid #b0cfe8;
57
+ }
58
+
59
+ /* 表單容器 */
60
+ .element-container {
61
+ margin-bottom: 0.5rem;
62
+ }
63
+
64
+ /* 摺疊內容區域 */
65
+ .streamlit-expanderContent {
66
+ background-color: #f8fbff;
67
+ border: 1px solid #d0e4f5;
68
+ border-top: none;
69
+ padding: 1rem;
70
+ }
71
+
72
+ /* 按鈕樣式 */
73
+ .stButton > button {
74
+ width: 100%;
75
+ border-radius: 20px;
76
+ font-weight: 600;
77
+ transition: all 0.3s ease;
78
+ }
79
+
80
+ .stButton > button:hover {
81
+ transform: translateY(-2px);
82
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2);
83
+ }
84
+ </style>
85
+ """, unsafe_allow_html=True)
86
+
87
+ # 導入自定義模組
88
+ from bn_core import BayesianNetworkAnalyzer
89
+ from llm_assistant import LLMAssistant
90
+ from utils import (
91
+ plot_roc_curve,
92
+ plot_confusion_matrix,
93
+ plot_probability_distribution,
94
+ generate_network_graph,
95
+ create_cpd_table,
96
+ export_results_to_json
97
+ )
98
+
99
+ # 初始化 session state
100
+ if 'session_id' not in st.session_state:
101
+ st.session_state.session_id = str(uuid.uuid4())
102
+ if 'analysis_results' not in st.session_state:
103
+ st.session_state.analysis_results = None
104
+ if 'trained_model_results' not in st.session_state:
105
+ st.session_state.trained_model_results = None
106
+ if 'loaded_model_results' not in st.session_state:
107
+ st.session_state.loaded_model_results = None
108
+ if 'loaded_models' not in st.session_state:
109
+ st.session_state.loaded_models = [] # List to store multiple loaded models
110
+ if 'chat_history' not in st.session_state:
111
+ st.session_state.chat_history = []
112
+ if 'model_trained' not in st.session_state:
113
+ st.session_state.model_trained = False
114
+
115
+ # 標題
116
+ st.title("🔬 Bayesian Network Analysis System")
117
+ st.markdown("---")
118
+
119
+ # Sidebar - OpenAI API Key
120
+ with st.sidebar:
121
+ st.header("⚙️ Configuration")
122
+
123
+ api_key = st.text_input(
124
+ "OpenAI API Key",
125
+ type="password",
126
+ help="Enter your OpenAI API key to use the AI assistant"
127
+ )
128
+
129
+ if api_key:
130
+ st.session_state.api_key = api_key
131
+ st.success("✅ API Key loaded")
132
+
133
+ st.markdown("---")
134
+
135
+ # 資料來源選擇
136
+ st.subheader("📊 Data Source")
137
+ data_source = st.radio(
138
+ "Select data source:",
139
+ ["Use Default Dataset", "Upload Your Data"]
140
+ )
141
+
142
+ uploaded_file = None
143
+ if data_source == "Upload Your Data":
144
+ uploaded_file = st.file_uploader(
145
+ "Upload CSV file",
146
+ type=['csv'],
147
+ help="Upload your dataset in CSV format"
148
+ )
149
+
150
+ # 主要內容區
151
+ tab1, tab2, tab3 = st.tabs(["📈 Analysis", "💬 AI Assistant", "📂 Load Model"])
152
+
153
+ # Tab 1: 分析介面
154
+ with tab1:
155
+ col1, col2 = st.columns([2, 1])
156
+
157
+ with col1:
158
+ st.header("Model Configuration")
159
+
160
+ # 載入資料
161
+ if data_source == "Use Default Dataset":
162
+ # 使用預設資料集
163
+ @st.cache_data
164
+ def load_default_data():
165
+ # 這裡放入預設資料集的路徑
166
+ df = pd.read_csv("BC_imputed_micerf_period13_fid_course_D4.csv")
167
+ return df
168
+
169
+ try:
170
+ df = load_default_data()
171
+ st.success(f"✅ Default dataset loaded: {df.shape[0]} rows, {df.shape[1]} columns")
172
+ except:
173
+ st.error("❌ Default dataset not found. Please upload your own data.")
174
+ df = None
175
+ else:
176
+ if uploaded_file:
177
+ df = pd.read_csv(uploaded_file)
178
+ st.success(f"✅ Data loaded: {df.shape[0]} rows, {df.shape[1]} columns")
179
+ else:
180
+ st.info("👆 Please upload a CSV file to begin")
181
+ df = None
182
+
183
+ if df is not None:
184
+ # 特��選擇 - 使用 expander (可摺疊)
185
+ st.subheader("🎯 Input Features")
186
+
187
+ # 手動指定特徵類型 (針對預設乳癌資料集)
188
+ if data_source == "Use Default Dataset":
189
+ # 預設資料集的固定分類
190
+ numeric_cols = ['size', 'stime'] # 只有這兩個是連續變數
191
+ categorical_cols = [col for col in df.columns if col not in numeric_cols]
192
+ else:
193
+ # 上傳資料集才自動判斷
194
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
195
+ categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
196
+
197
+ # 二元分類變數(用於目標變數)
198
+ binary_cols = [col for col in df.columns if df[col].nunique() == 2]
199
+
200
+ col_feat1, col_feat2 = st.columns(2)
201
+
202
+
203
+ with col_feat1:
204
+ with st.expander("**Continuous**", expanded=False):
205
+ st.caption("Select continuous features:")
206
+ con_features = []
207
+ for col in numeric_cols:
208
+ if st.checkbox(col, value=False, key=f"con_{col}"):
209
+ con_features.append(col)
210
+
211
+ with col_feat2:
212
+ with st.expander("**Categorical**", expanded=True):
213
+ st.caption("Select categorical features:")
214
+ cat_features = []
215
+ for col in categorical_cols:
216
+ # 預設勾選前幾個
217
+ default_checked = categorical_cols.index(col) < 5 if len(categorical_cols) > 5 else True
218
+ if st.checkbox(col, value=default_checked, key=f"cat_{col}"):
219
+ cat_features.append(col)
220
+
221
+ # 目標變數 - 放在特徵選擇下方
222
+ st.markdown("---")
223
+
224
+ col_target1, col_target2 = st.columns([1, 2])
225
+ with col_target1:
226
+ target_variable = st.selectbox(
227
+ "Target Variable (Y):",
228
+ options=binary_cols,
229
+ help="Must be a binary classification variable"
230
+ )
231
+
232
+ with col_target2:
233
+ test_fraction = st.number_input(
234
+ "Test Dataset Proportion:",
235
+ min_value=0.10,
236
+ max_value=0.50,
237
+ value=0.25,
238
+ step=0.05,
239
+ format="%.2f"
240
+ )
241
+
242
+ # 驗證選擇
243
+ selected_features = cat_features + con_features
244
+ if target_variable in selected_features:
245
+ st.error("❌ Target variable cannot be in feature list!")
246
+ st.stop()
247
+
248
+ st.markdown("---")
249
+
250
+ # 模型參數 - 使用更緊湊的佈局
251
+ st.subheader("⚙️ Model Configuration")
252
+
253
+ col_param1, col_param2 = st.columns(2)
254
+
255
+ with col_param1:
256
+ algorithm = st.radio(
257
+ "Network Structure:",
258
+ options=['NB', 'TAN', 'CL', 'HC', 'PC'],
259
+ format_func=lambda x: {
260
+ 'NB': 'Naive Bayes (NB)',
261
+ 'TAN': 'Tree-Augmented Naive Bayes (TAN)',
262
+ 'CL': 'Chow-Liu',
263
+ 'HC': 'Hill Climbing',
264
+ 'PC': 'PC'
265
+ }[x],
266
+ help="Select structure learning algorithm"
267
+ )
268
+
269
+ # 條件性參數 - HC
270
+ if algorithm == 'HC':
271
+ score_method = st.selectbox(
272
+ "Scoring Method:",
273
+ options=['BIC', 'AIC', 'K2', 'BDeu', 'BDs'],
274
+ help="Select scoring method for Hill Climbing"
275
+ )
276
+ else:
277
+ score_method = 'BIC'
278
+
279
+ # 條件性參數 - PC
280
+ if algorithm == 'PC':
281
+ sig_level = st.number_input(
282
+ "Significance Level:",
283
+ min_value=0.01,
284
+ max_value=1.0,
285
+ value=0.05,
286
+ step=0.01,
287
+ help="Significance level for PC algorithm"
288
+ )
289
+ else:
290
+ sig_level = 0.05
291
+
292
+ with col_param2:
293
+ estimator = st.radio(
294
+ "Parameter Estimator:",
295
+ options=['ml', 'bn'],
296
+ format_func=lambda x: {
297
+ 'ml': 'MaximumLikelihoodEstimator',
298
+ 'bn': 'BayesianEstimator'
299
+ }[x],
300
+ help="Select parameter estimation method"
301
+ )
302
+
303
+ if estimator == 'bn':
304
+ equivalent_sample_size = st.number_input(
305
+ "Equivalent Sample Size:",
306
+ min_value=1,
307
+ value=3,
308
+ step=1,
309
+ help="Prior strength for Bayesian estimation"
310
+ )
311
+ else:
312
+ equivalent_sample_size = 3
313
+
314
+ # Decision (如果是預設資料集才顯示)
315
+ if data_source == "Use Default Dataset":
316
+ decision = st.selectbox(
317
+ "Decision:",
318
+ options=['OverAll', 'Exposed', 'Unexposed'],
319
+ index=0,
320
+ help="Analysis subset selection"
321
+ )
322
+ else:
323
+ decision = 'OverAll'
324
+
325
+ # Provide Evidence - 可摺疊區域
326
+ st.markdown("---")
327
+ with st.expander("**Provide Evidence**", expanded=False):
328
+ st.caption("Enter evidence values for inference (optional):")
329
+
330
+ evidence_cols = st.columns(2)
331
+ evidence_dict = {}
332
+
333
+ # 為每個非目標變數創建輸入框
334
+ all_vars = [v for v in selected_features if v != target_variable]
335
+
336
+ for idx, var in enumerate(all_vars):
337
+ with evidence_cols[idx % 2]:
338
+ val = st.text_input(
339
+ f"{var}:",
340
+ value="",
341
+ key=f"evidence_{var}",
342
+ help=f"Enter value for {var} (leave empty to ignore)"
343
+ )
344
+ if val.strip():
345
+ evidence_dict[var] = val.strip()
346
+
347
+ # 進階參數 - 摺疊區域
348
+ with st.expander("**Advanced Parameters**", expanded=False):
349
+ n_bins = st.slider(
350
+ "Number of Bins (for continuous variables):",
351
+ min_value=3,
352
+ max_value=20,
353
+ value=10,
354
+ step=1,
355
+ help="Number of bins for discretizing continuous features"
356
+ )
357
+
358
+
359
+ # 執行分析按鈕
360
+ st.markdown("---")
361
+
362
+ col_btn1, col_btn2 = st.columns([3, 1])
363
+
364
+ with col_btn1:
365
+ run_button = st.button("🚀 Run Analysis", type="primary", width='stretch')
366
+
367
+ with col_btn2:
368
+ if st.button("🔄 Reset", width='stretch'):
369
+ st.session_state.analysis_results = None
370
+ st.session_state.trained_model_results = None
371
+ st.session_state.model_trained = False
372
+ st.session_state.chat_history = []
373
+ st.rerun()
374
+
375
+ # 分析步驟說明
376
+ with st.expander("ℹ️ Analysis Steps", expanded=False):
377
+ st.markdown("""
378
+ **Process:**
379
+ 1. Split data (train/test)
380
+ 2. Learn network structure
381
+ 3. Process features (bins from train)
382
+ 4. Estimate parameters
383
+ 5. Evaluate performance
384
+
385
+ **Note:** Test set bins are derived from training set to prevent data leakage.
386
+ """)
387
+
388
+ if run_button:
389
+ # 驗證
390
+ if not selected_features:
391
+ st.error("❌ Please select at least one feature!")
392
+ st.stop()
393
+
394
+ if target_variable in selected_features:
395
+ st.error("❌ Target variable cannot be in feature list!")
396
+ st.stop()
397
+
398
+ with st.spinner("🔄 Training Bayesian Network..."):
399
+ progress_bar = st.progress(0)
400
+ status_text = st.empty()
401
+
402
+ try:
403
+ # 初始化分析器
404
+ status_text.text("📊 Initializing analyzer...")
405
+ progress_bar.progress(10)
406
+
407
+ analyzer = BayesianNetworkAnalyzer(
408
+ session_id=st.session_state.session_id
409
+ )
410
+
411
+ status_text.text(f"📐 Learning {algorithm} structure...")
412
+ progress_bar.progress(30)
413
+
414
+ # 執行分析
415
+ results = analyzer.run_analysis(
416
+ df=df,
417
+ cat_features=cat_features,
418
+ con_features=con_features,
419
+ target_variable=target_variable,
420
+ test_fraction=test_fraction,
421
+ algorithm=algorithm,
422
+ estimator=estimator,
423
+ equivalent_sample_size=equivalent_sample_size,
424
+ score_method=score_method,
425
+ sig_level=sig_level,
426
+ n_bins=n_bins
427
+ )
428
+
429
+ status_text.text("✅ Analysis completed!")
430
+ progress_bar.progress(100)
431
+
432
+ # 儲存結果
433
+ st.session_state.trained_model_results = results # For Tab 1 display
434
+ st.session_state.analysis_results = results # For AI Assistant
435
+ st.session_state.model_trained = True
436
+ # 🆕 儲存 analyzer 到 session_state(用於個人化預測)
437
+ st.session_state.analyzer = analyzer
438
+
439
+ st.success("✅ Analysis completed successfully!")
440
+ st.balloons()
441
+
442
+
443
+ # 清空進度
444
+ import time
445
+ time.sleep(1)
446
+ progress_bar.empty()
447
+ status_text.empty()
448
+
449
+ st.rerun()
450
+
451
+ except Exception as e:
452
+ st.error(f"❌ Error during analysis: {str(e)}")
453
+ st.exception(e)
454
+ progress_bar.empty()
455
+ status_text.empty()
456
+
457
+ with col2:
458
+ st.header("Quick Stats")
459
+
460
+ if df is not None:
461
+ st.metric("Total Samples", df.shape[0])
462
+ st.metric("Total Features", df.shape[1])
463
+ st.metric("Selected Features", len(selected_features) if 'selected_features' in locals() else 0)
464
+
465
+ if st.session_state.model_trained:
466
+ st.success("✅ Model Trained")
467
+ else:
468
+ st.info("⏳ Awaiting Training")
469
+
470
+ # 顯示結果
471
+ if st.session_state.trained_model_results:
472
+ st.markdown("---")
473
+ st.header("📊 Analysis Results")
474
+
475
+ results = st.session_state.trained_model_results
476
+
477
+ # 使用 tabs 來組織結果
478
+ result_tabs = st.tabs([
479
+ "🕸️ Network Structure",
480
+ "📈 Performance Metrics",
481
+ "📋 CPD Tables",
482
+ "📊 Model Scores"
483
+ ])
484
+
485
+ # Tab 1: 網路結構
486
+ with result_tabs[0]:
487
+ network_base64 = generate_network_graph(results['model'])# Pi
488
+ st.image(f"data:image/png;base64,{network_base64}", width='stretch')# Pi
489
+
490
+ # 顯示邊的列表
491
+ with st.expander("View Network Edges", expanded=False):
492
+ edges = list(results['model'].edges())
493
+ st.write(f"Total edges: {len(edges)}")
494
+
495
+ # 每行顯示 3 個邊
496
+ for i in range(0, len(edges), 3):
497
+ cols = st.columns(3)
498
+ for j, col in enumerate(cols):
499
+ if i + j < len(edges):
500
+ edge = edges[i + j]
501
+ col.markdown(f"**{edge[0]}** → {edge[1]}")
502
+
503
+ # Tab 2: 效能指標
504
+ with result_tabs[1]:
505
+ # Check if metrics are available
506
+ if 'train_metrics' in results and 'test_metrics' in results:
507
+ col_m1, col_m2 = st.columns(2)
508
+
509
+ with col_m1:
510
+ st.markdown("### Training Set")
511
+ train_metrics = results['train_metrics']
512
+
513
+ # 使用 metrics 卡片
514
+ metric_cols = st.columns(4)
515
+ metric_cols[0].metric("Accuracy", f"{train_metrics['accuracy']:.2f}%")
516
+ metric_cols[1].metric("Precision", f"{train_metrics['precision']:.2f}%")
517
+ metric_cols[2].metric("Recall", f"{train_metrics['recall']:.2f}%")
518
+ metric_cols[3].metric("F1-Score", f"{train_metrics['f1']:.2f}%")
519
+
520
+ metric_cols2 = st.columns(4)
521
+ metric_cols2[0].metric("AUC", f"{train_metrics['auc']:.4f}")
522
+ metric_cols2[1].metric("G-mean", f"{train_metrics['g_mean']:.2f}%")
523
+ metric_cols2[2].metric("P-mean", f"{train_metrics['p_mean']:.2f}%")
524
+ metric_cols2[3].metric("Specificity", f"{train_metrics['specificity']:.2f}%")
525
+
526
+ # 混淆矩陣
527
+ with st.expander("Confusion Matrix", expanded=True):
528
+ conf_fig_train = plot_confusion_matrix(
529
+ train_metrics['confusion_matrix'],
530
+ title="Training Set"
531
+ )
532
+ st.plotly_chart(conf_fig_train, width='stretch')
533
+
534
+ # ROC Curve
535
+ with st.expander("ROC Curve", expanded=False):
536
+ roc_fig_train = plot_roc_curve(
537
+ train_metrics['fpr'],
538
+ train_metrics['tpr'],
539
+ train_metrics['auc'],
540
+ title="Training Set"
541
+ )
542
+ st.plotly_chart(roc_fig_train, width='stretch')
543
+
544
+ with col_m2:
545
+ st.markdown("### Test Set")
546
+ test_metrics = results['test_metrics']
547
+
548
+ metric_cols = st.columns(4)
549
+ metric_cols[0].metric("Accuracy", f"{test_metrics['accuracy']:.2f}%")
550
+ metric_cols[1].metric("Precision", f"{test_metrics['precision']:.2f}%")
551
+ metric_cols[2].metric("Recall", f"{test_metrics['recall']:.2f}%")
552
+ metric_cols[3].metric("F1-Score", f"{test_metrics['f1']:.2f}%")
553
+
554
+ metric_cols2 = st.columns(4)
555
+ metric_cols2[0].metric("AUC", f"{test_metrics['auc']:.4f}")
556
+ metric_cols2[1].metric("G-mean", f"{test_metrics['g_mean']:.2f}%")
557
+ metric_cols2[2].metric("P-mean", f"{test_metrics['p_mean']:.2f}%")
558
+ metric_cols2[3].metric("Specificity", f"{test_metrics['specificity']:.2f}%")
559
+
560
+ # 混淆矩陣
561
+ with st.expander("Confusion Matrix", expanded=True):
562
+ conf_fig_test = plot_confusion_matrix(
563
+ test_metrics['confusion_matrix'],
564
+ title="Test Set"
565
+ )
566
+ st.plotly_chart(conf_fig_test, width='stretch')
567
+
568
+ # ROC Curve
569
+ with st.expander("ROC Curve", expanded=False):
570
+ roc_fig_test = plot_roc_curve(
571
+ test_metrics['fpr'],
572
+ test_metrics['tpr'],
573
+ test_metrics['auc'],
574
+ title="Test Set"
575
+ )
576
+ st.plotly_chart(roc_fig_test, width='stretch')
577
+
578
+ # Tab 3: 條件機率表
579
+ with result_tabs[2]:
580
+ selected_node = st.selectbox(
581
+ "Select a node to view its CPD:",
582
+ options=list(results['cpds'].keys())
583
+ )
584
+
585
+ if selected_node:
586
+ cpd_df = create_cpd_table(results['cpds'][selected_node])
587
+ st.dataframe(cpd_df, width='stretch')
588
+
589
+ # 下載按鈕
590
+ csv = cpd_df.to_csv()
591
+ st.download_button(
592
+ label="📥 Download CPD as CSV",
593
+ data=csv,
594
+ file_name=f"cpd_{selected_node}.csv",
595
+ mime="text/csv"
596
+ )
597
+
598
+ # Tab 4: 模型評分
599
+ with result_tabs[3]:
600
+ scores = results['scores']
601
+
602
+ score_cols = st.columns(5)
603
+ score_cols[0].metric("Log-Likelihood", f"{scores['log_likelihood']:.2f}")
604
+ score_cols[1].metric("BIC Score", f"{scores['bic']:.2f}")
605
+ score_cols[2].metric("K2 Score", f"{scores['k2']:.2f}")
606
+ score_cols[3].metric("BDeu Score", f"{scores['bdeu']:.2f}")
607
+ score_cols[4].metric("BDs Score", f"{scores['bds']:.2f}")
608
+
609
+ # 參數摘要
610
+ with st.expander("Analysis Parameters", expanded=True):
611
+ params = results['parameters']
612
+
613
+ col1, col2, col3 = st.columns(3)
614
+
615
+ with col1:
616
+ st.markdown("**Algorithm Settings**")
617
+ st.write(f"- Algorithm: {params['algorithm']}")
618
+ st.write(f"- Estimator: {params['estimator']}")
619
+ st.write(f"- Test Fraction: {params['test_fraction']:.2%}")
620
+
621
+ with col2:
622
+ st.markdown("**Feature Information**")
623
+ st.write(f"- Total Features: {params['n_features']}")
624
+ st.write(f"- Categorical: {len(params['cat_features'])}")
625
+ st.write(f"- Continuous: {len(params['con_features'])}")
626
+ st.write(f"- Target: {params['target_variable']}")
627
+
628
+ with col3:
629
+ st.markdown("**Other Parameters**")
630
+ st.write(f"- Bins: {params['n_bins']}")
631
+ st.write(f"- Score Method: {params['score_method']}")
632
+ st.write(f"- Significance Level: {params['sig_level']}")
633
+ st.write(f"- Equivalent Sample Size: {params['equivalent_sample_size']}")
634
+
635
+ # 匯出結果
636
+ with st.expander("Export Results", expanded=False):
637
+ col1, col2 = st.columns(2)
638
+
639
+ with col1:
640
+ # 原本的 JSON 下載
641
+ result_json = export_results_to_json(results)
642
+ st.download_button(
643
+ label="📥 Download Full Results (JSON)",
644
+ data=result_json,
645
+ file_name=f"bn_analysis_{results['timestamp'][:10]}.json",
646
+ mime="application/json"
647
+ )
648
+
649
+ with col2:
650
+ # 🆕 新增:下載模型
651
+ if st.button("💾 Save Trained Model"):
652
+ if 'analyzer' in st.session_state:
653
+ import tempfile
654
+ import os
655
+
656
+ # 創建臨時文件
657
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmp_file:
658
+ model_path = tmp_file.name
659
+ st.session_state.analyzer.save_model(model_path)
660
+
661
+ # 讀取並提供下載
662
+ with open(model_path, 'rb') as f:
663
+ st.download_button(
664
+ label="📥 Download Model File (.pkl)",
665
+ data=f,
666
+ file_name=f"bn_model_{results['timestamp'][:10]}.pkl",
667
+ mime="application/octet-stream",
668
+ key="download_model_btn"
669
+ )
670
+
671
+ # 清理臨時文件
672
+ os.unlink(model_path)
673
+ else:
674
+ st.error("❌ Analyzer not found in session state")
675
+
676
+
677
+ # Tab 2: AI 助手
678
+ with tab2:
679
+ st.header("💬 AI Analysis Assistant")
680
+
681
+ if not st.session_state.get('api_key'):
682
+ st.warning("⚠️ Please enter your OpenAI API Key in the sidebar to use the AI assistant.")
683
+ elif not st.session_state.model_trained:
684
+ st.info("ℹ️ Please train a model first in the Analysis tab to use the AI assistant.")
685
+ else:
686
+ # 初始化 LLM 助手
687
+ if 'llm_assistant' not in st.session_state:
688
+ st.session_state.llm_assistant = LLMAssistant(
689
+ api_key=st.session_state.api_key,
690
+ session_id=st.session_state.session_id
691
+ )
692
+
693
+ # 顯示聊天歷史
694
+ chat_container = st.container()
695
+
696
+ with chat_container:
697
+ for message in st.session_state.chat_history:
698
+ with st.chat_message(message["role"]):
699
+ st.markdown(message["content"])
700
+
701
+ # 聊天輸入
702
+ if prompt := st.chat_input("Ask me anything about your analysis results..."):
703
+ # 添加用戶訊息
704
+ st.session_state.chat_history.append({
705
+ "role": "user",
706
+ "content": prompt
707
+ })
708
+
709
+ with st.chat_message("user"):
710
+ st.markdown(prompt)
711
+
712
+ # 🆕 檢測是否為個人化預測請求
713
+ prediction_keywords = ['predict', 'risk', 'patient', 'case', 'my risk', 'calculate', 'probability', 'chance']
714
+ is_prediction_request = any(keyword in prompt.lower() for keyword in prediction_keywords)
715
+
716
+ # 獲取 AI 回應
717
+ with st.chat_message("assistant"):
718
+ with st.spinner("Analyzing..." if is_prediction_request else "Thinking..."):
719
+ try:
720
+ if is_prediction_request:
721
+ # 🆕 執行個人化預測
722
+ # 從 session_state 取得必要資訊
723
+ results = st.session_state.analysis_results
724
+
725
+ # 重建 analyzer(需要載入模型狀態)
726
+ # ⚠️ 這裡需要先把 analyzer 存在 session_state 中
727
+ if 'analyzer' not in st.session_state:
728
+ st.error("❌ Model not found. Please train a model first in the Analysis tab.")
729
+ response = "I cannot perform predictions because the model is not available. Please train a model first."
730
+ else:
731
+ response = st.session_state.llm_assistant.predict_from_text(
732
+ user_description=prompt,
733
+ analyzer=st.session_state.analyzer,
734
+ target_variable=results['parameters']['target_variable'],
735
+ feature_list=results['parameters']['cat_features'] + results['parameters']['con_features']
736
+ )
737
+ else:
738
+ # 原本的一般對話
739
+ response = st.session_state.llm_assistant.get_response(
740
+ user_message=prompt,
741
+ analysis_results=st.session_state.analysis_results
742
+ )
743
+
744
+ st.markdown(response)
745
+
746
+ except Exception as e:
747
+ error_msg = f"❌ Error: {str(e)}\n\nPlease try rephrasing your question or check the model status."
748
+ st.error(error_msg)
749
+ response = error_msg
750
+
751
+ # 添加助手訊息
752
+ st.session_state.chat_history.append({
753
+ "role": "assistant",
754
+ "content": response
755
+ })
756
+
757
+ # 快速問題按鈕
758
+ st.markdown("---")
759
+ st.subheader("💡 Quick Questions")
760
+
761
+ quick_questions = [
762
+ "📊 Give me a summary of the analysis results",
763
+ "🎯 What is the model's performance?",
764
+ "🔍 Explain the Bayesian Network structure",
765
+ "⚠️ What are the limitations of this model?",
766
+ "💡 How can I improve the model?"
767
+ ]
768
+
769
+ cols = st.columns(len(quick_questions))
770
+ for idx, (col, question) in enumerate(zip(cols, quick_questions)):
771
+ if col.button(question, key=f"quick_{idx}"):
772
+ st.session_state.chat_history.append({
773
+ "role": "user",
774
+ "content": question
775
+ })
776
+
777
+ response = st.session_state.llm_assistant.get_response(
778
+ user_message=question,
779
+ analysis_results=st.session_state.analysis_results
780
+ )
781
+
782
+ st.session_state.chat_history.append({
783
+ "role": "assistant",
784
+ "content": response
785
+ })
786
+
787
+ st.rerun()
788
+
789
+ # Tab 3: Load Model
790
+ with tab3:
791
+ st.header("📂 Load Pre-trained Models")
792
+
793
+ st.markdown("""
794
+ Load previously trained Bayesian Network models to view and compare their structures.
795
+
796
+ **Maximum: 2 models**
797
+
798
+ **Supported formats:**
799
+ - 📦 `.pkl` - Full model with all parameters
800
+ """)
801
+
802
+ st.markdown("---")
803
+
804
+ # Check if already loaded 2 models
805
+ if len(st.session_state.loaded_models) >= 2:
806
+ st.warning("⚠️ Maximum 2 models can be loaded. Please remove a model before loading another.")
807
+ uploaded_model = None
808
+ else:
809
+ # File uploader
810
+ uploaded_model = st.file_uploader(
811
+ "Upload model file",
812
+ type=['pkl', 'bif'],
813
+ help="Upload a .pkl file containing a Bayesian Network model"
814
+ )
815
+
816
+ if uploaded_model:
817
+ file_extension = uploaded_model.name.split('.')[-1].lower()
818
+
819
+ col_load1, col_load2 = st.columns([3, 1])
820
+
821
+ with col_load1:
822
+ st.info(f"📄 File: **{uploaded_model.name}** ({file_extension.upper()} format)")
823
+
824
+ with col_load2:
825
+ load_button = st.button("🔄 Load Model", type="primary", width='stretch')
826
+
827
+ if load_button:
828
+ with st.spinner(f"Loading {file_extension.upper()} model..."):
829
+ try:
830
+ if file_extension == 'pkl':
831
+ # Load .pkl file
832
+ import pickle
833
+ import tempfile
834
+ import os
835
+
836
+ # Save uploaded file to temp location
837
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmp_file:
838
+ tmp_file.write(uploaded_model.read())
839
+ tmp_path = tmp_file.name
840
+
841
+ # Load model data
842
+ with open(tmp_path, 'rb') as f:
843
+ model_data = pickle.load(f)
844
+
845
+ # Clean up temp file
846
+ os.unlink(tmp_path)
847
+
848
+ # Extract model info - handle multiple formats
849
+ from pgmpy.models import BayesianNetwork
850
+
851
+ if isinstance(model_data, BayesianNetwork):
852
+ # Case 1: Direct BayesianNetwork object
853
+ model = model_data
854
+ bins_dict = None
855
+ train_columns = list(model.nodes())
856
+ timestamp = 'Unknown'
857
+ st.info("ℹ️ Loaded raw BayesianNetwork object (no metadata)")
858
+
859
+ elif isinstance(model_data, dict):
860
+ # Case 2: Dictionary format
861
+ if 'model' in model_data:
862
+ # Case 2a: Our format or similar
863
+ model = model_data['model']
864
+ bins_dict = model_data.get('bins_dict', None)
865
+ train_columns = model_data.get('train_columns', list(model.nodes()))
866
+ timestamp = model_data.get('timestamp', 'Unknown')
867
+ else:
868
+ # Case 2b: Try to find model in other common keys
869
+ possible_keys = ['network', 'bn', 'bayesian_network', 'graph']
870
+ model = None
871
+ found_key = None
872
+ for key in possible_keys:
873
+ if key in model_data and isinstance(model_data[key], BayesianNetwork):
874
+ model = model_data[key]
875
+ found_key = key
876
+ break
877
+
878
+ if model is None:
879
+ raise ValueError(f"Cannot find BayesianNetwork in pickle file. Available keys: {list(model_data.keys())}. Expected one of: {['model'] + possible_keys}")
880
+
881
+ bins_dict = model_data.get('bins_dict', None)
882
+ train_columns = list(model.nodes())
883
+ timestamp = 'Unknown'
884
+ st.info(f"ℹ️ Loaded model from key: '{found_key}'")
885
+ else:
886
+ raise TypeError(f"Unsupported pickle format. Expected BayesianNetwork or dict, got {type(model_data).__name__}")
887
+
888
+ # Store in session state - append to list (max 2)
889
+ if len(st.session_state.loaded_models) < 2:
890
+ model_info = {
891
+ 'model': model,
892
+ 'source': 'pkl',
893
+ 'bins_dict': bins_dict,
894
+ 'train_columns': train_columns,
895
+ 'timestamp': timestamp,
896
+ 'file_name': uploaded_model.name
897
+ }
898
+ st.session_state.loaded_models.append(model_info)
899
+
900
+ st.success(f"✅ Model #{len(st.session_state.loaded_models)} loaded successfully from .pkl file!")
901
+ st.info("ℹ️ This loaded model is displayed below. To use AI Assistant, please train a model in the Analysis tab.")
902
+ st.balloons()
903
+ else:
904
+ st.error("❌ Cannot load more than 2 models. Please remove a model first.")
905
+
906
+ elif file_extension == 'bif':
907
+ # Load .bif file
908
+ from pgmpy.readwrite import BIFReader
909
+ import tempfile
910
+ import os
911
+
912
+ # Save uploaded file to temp location
913
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.bif', mode='w') as tmp_file:
914
+ tmp_file.write(uploaded_model.read().decode('utf-8'))
915
+ tmp_path = tmp_file.name
916
+
917
+ # Load model
918
+ reader = BIFReader(tmp_path)
919
+ model = reader.get_model()
920
+
921
+ # Clean up temp file
922
+ os.unlink(tmp_path)
923
+
924
+ # Store in session state - append to list (max 2)
925
+ if len(st.session_state.loaded_models) < 2:
926
+ model_info = {
927
+ 'model': model,
928
+ 'source': 'bif',
929
+ 'bins_dict': None,
930
+ 'train_columns': list(model.nodes()),
931
+ 'timestamp': 'Unknown',
932
+ 'file_name': uploaded_model.name
933
+ }
934
+ st.session_state.loaded_models.append(model_info)
935
+
936
+ st.success(f"✅ Model #{len(st.session_state.loaded_models)} loaded successfully from .bif file!")
937
+ st.warning("⚠️ Note: .bif files do not contain bins_dict.")
938
+ st.info("ℹ️ This loaded model is displayed below. To use AI Assistant, please train a model in the Analysis tab.")
939
+ st.balloons()
940
+ else:
941
+ st.error("❌ Cannot load more than 2 models. Please remove a model first.")
942
+
943
+ except Exception as e:
944
+ st.error(f"❌ Error loading model: {str(e)}")
945
+ st.exception(e)
946
+
947
+ # Display loaded models information
948
+ if st.session_state.loaded_models:
949
+ st.markdown("---")
950
+
951
+ # Header with Clear All button
952
+ col_header, col_clear = st.columns([3, 1])
953
+ with col_header:
954
+ st.header(f"📊 Loaded Models ({len(st.session_state.loaded_models)})")
955
+ with col_clear:
956
+ if st.button("🗑️ Clear All", type="secondary", width='stretch'):
957
+ st.session_state.loaded_models = []
958
+ st.rerun()
959
+
960
+ # Loop through all loaded models
961
+ for idx, loaded_model in enumerate(st.session_state.loaded_models):
962
+ model = loaded_model['model']
963
+
964
+ # Model separator
965
+ st.markdown("---")
966
+
967
+ # Model header with Remove button
968
+ col_title, col_remove = st.columns([4, 1])
969
+ with col_title:
970
+ st.subheader(f"Model #{idx + 1}: {loaded_model['file_name']}")
971
+ with col_remove:
972
+ if st.button(f"❌ Remove", key=f"remove_model_{idx}", width='stretch'):
973
+ st.session_state.loaded_models.pop(idx)
974
+ st.rerun()
975
+
976
+ # Display network graph and basic info
977
+ col_graph, col_info = st.columns([2, 1])
978
+
979
+ with col_graph:
980
+ st.markdown("**🕸️ Network Structure**")
981
+ try:
982
+ network_base64 = generate_network_graph(model)
983
+ st.image(f"data:image/png;base64,{network_base64}", width='stretch')
984
+ except Exception as e:
985
+ st.error(f"Error generating network graph: {str(e)}")
986
+ st.info("Network structure visualization is not available.")
987
+
988
+ with col_info:
989
+ st.markdown("**ℹ️ Basic Information**")
990
+ st.metric("File Name", loaded_model['file_name'])
991
+ st.metric("Format", loaded_model['source'].upper())
992
+ st.metric("Total Nodes", len(model.nodes()))
993
+ st.metric("Total Edges", len(model.edges()))
994
+
995
+ if loaded_model['timestamp'] != 'Unknown':
996
+ st.metric("Timestamp", loaded_model['timestamp'][:19])
997
+
998
+ if loaded_model['bins_dict']:
999
+ st.metric("Bins Available", "✅ Yes")
1000
+ else:
1001
+ st.metric("Bins Available", "❌ No")
1002
+
1003
+ # Network structure details
1004
+ col_nodes, col_edges = st.columns(2)
1005
+
1006
+ with col_nodes:
1007
+ with st.expander("📋 Node List", expanded=False):
1008
+ nodes = list(model.nodes())
1009
+ st.write(f"**Total nodes:** {len(nodes)}")
1010
+ for i, node in enumerate(nodes, 1):
1011
+ st.write(f"{i}. {node}")
1012
+
1013
+ with col_edges:
1014
+ with st.expander("🔗 Edge List", expanded=False):
1015
+ edges = list(model.edges())
1016
+ st.write(f"**Total edges:** {len(edges)}")
1017
+ for i, edge in enumerate(edges, 1):
1018
+ st.write(f"{i}. **{edge[0]}** → {edge[1]}")
1019
+
1020
+ # CPD Tables
1021
+ st.markdown("**📋 Conditional Probability Distributions (CPDs)**")
1022
+
1023
+ selected_node = st.selectbox(
1024
+ "Select a node to view its CPD:",
1025
+ options=list(model.nodes()),
1026
+ key=f"load_model_cpd_select_{idx}"
1027
+ )
1028
+
1029
+ if selected_node:
1030
+ cpd = model.get_cpds(selected_node)
1031
+ cpd_df = create_cpd_table(cpd)
1032
+ st.dataframe(cpd_df, width='stretch')
1033
+
1034
+ # Download button
1035
+ csv = cpd_df.to_csv()
1036
+ st.download_button(
1037
+ label="📥 Download CPD as CSV",
1038
+ data=csv,
1039
+ file_name=f"cpd_{selected_node}_model{idx+1}.csv",
1040
+ mime="text/csv",
1041
+ key=f"load_model_cpd_download_{idx}"
1042
+ )
1043
+
1044
+ # Additional information for .pkl files
1045
+ if loaded_model['source'] == 'pkl' and loaded_model['bins_dict']:
1046
+ with st.expander("🔢 Binning Information", expanded=False):
1047
+ st.write("**Bins dictionary available for continuous variables:**")
1048
+ st.json(loaded_model['bins_dict'])
1049
+
1050
+ # Footer
1051
+ st.markdown("---")
1052
+ st.markdown(
1053
+ """
1054
+ <div style='text-align: center'>
1055
+ <p>🔬 Bayesian Network Analysis System | Built with Streamlit</p>
1056
+ <p>Powered by OpenAI GPT-4 | Session ID: {}</p>
1057
+ </div>
1058
+ """.format(st.session_state.session_id[:8]),
1059
+ unsafe_allow_html=True
1060
+ )
bn_core.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from pgmpy.models import BayesianNetwork
4
+ from pgmpy.estimators import (
5
+ TreeSearch, HillClimbSearch, PC,
6
+ MaximumLikelihoodEstimator, BayesianEstimator,
7
+ BicScore, AICScore, K2Score, BDeuScore, BDsScore
8
+ )
9
+ from pgmpy.inference import VariableElimination
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.metrics import (
12
+ confusion_matrix, accuracy_score, precision_score,
13
+ recall_score, f1_score, roc_curve, roc_auc_score
14
+ )
15
+ from pgmpy.metrics import log_likelihood_score, structure_score
16
+ import threading
17
+ from datetime import datetime
18
+ from networkx import is_directed_acyclic_graph, DiGraph
19
+
20
+ class BayesianNetworkAnalyzer:
21
+ """
22
+ 貝葉斯網路分析器
23
+ 支持多用戶同時使用,每個 session 獨立處理
24
+ """
25
+
26
+ # 類別級的鎖,用於線程安全
27
+ _lock = threading.Lock()
28
+
29
+ # 儲存各 session 的分析結果
30
+ _session_results = {}
31
+
32
+ def __init__(self, session_id):
33
+ """
34
+ 初始化分析器
35
+
36
+ Args:
37
+ session_id: 唯一的 session 識別碼
38
+ """
39
+ self.session_id = session_id
40
+ self.model = None
41
+ self.inference = None
42
+ self.train_data = None
43
+ self.test_data = None
44
+ self.bins_dict = {}
45
+
46
+ def run_analysis(self, df, cat_features, con_features, target_variable,
47
+ test_fraction=0.25, algorithm='NB', estimator='ml',
48
+ equivalent_sample_size=3, score_method='BIC',
49
+ sig_level=0.05, n_bins=10):
50
+ """
51
+ 執行完整的貝葉斯網路分析 - 完全對齊 Django 版本的順序
52
+
53
+ Args:
54
+ df: 原始資料框
55
+ cat_features: 分類特徵列表
56
+ con_features: 連續特徵列表
57
+ target_variable: 目標變數名稱
58
+ test_fraction: 測試集比例
59
+ algorithm: 結構學習演算法
60
+ estimator: 參數估計方法
61
+ equivalent_sample_size: 等效樣本大小(用於貝葉斯估計)
62
+ score_method: 評分方法(用於 Hill Climbing)
63
+ sig_level: 顯著性水準(用於 PC 演算法)
64
+ n_bins: 連續變數分箱數量
65
+
66
+ Returns:
67
+ dict: 包含所有分析結果的字典
68
+ """
69
+
70
+ with self._lock:
71
+ try:
72
+ # 1. 資料預處理 (只選擇欄位和處理缺失值)
73
+ processed_df = self._preprocess_data(
74
+ df, cat_features, con_features, target_variable
75
+ )
76
+
77
+ # 2. 分割訓練/測試集 (✅ random_state=526)
78
+ self.train_data, self.test_data = train_test_split(
79
+ processed_df,
80
+ test_size=test_fraction,
81
+ random_state=526,
82
+ stratify=processed_df[target_variable] if target_variable in processed_df.columns else None
83
+ )
84
+
85
+ # 3. ✅ 學習網路結構 (在分箱和編碼之前!)
86
+ self.model = self._learn_structure(
87
+ algorithm, score_method, sig_level, target_variable
88
+ )
89
+
90
+ # 4. ✅ 對分類變數編碼 (在學習結構之後,分箱之前)
91
+ self._encode_categorical_features(cat_features)
92
+
93
+ # 5. ✅ 對連續變數分箱 (在編碼之後)
94
+ self._bin_continuous_features(con_features, n_bins)
95
+
96
+ # 6. 參數估計
97
+ self._fit_parameters(estimator, equivalent_sample_size)
98
+
99
+ # 7. 初始化推論引擎
100
+ self.inference = VariableElimination(self.model)
101
+
102
+ # 8. 評估模型
103
+ train_metrics = self._evaluate_model(
104
+ self.train_data, target_variable, "train"
105
+ )
106
+ test_metrics = self._evaluate_model(
107
+ self.test_data, target_variable, "test"
108
+ )
109
+
110
+ # 9. 獲取 CPD
111
+ cpds = self._get_all_cpds()
112
+
113
+ # 10. 計算模型評分
114
+ scores = self._calculate_scores()
115
+
116
+ # 11. 整理結果
117
+ results = {
118
+ 'model': self.model,
119
+ 'inference': self.inference,
120
+ 'train_metrics': train_metrics,
121
+ 'test_metrics': test_metrics,
122
+ 'cpds': cpds,
123
+ 'scores': scores,
124
+ 'parameters': {
125
+ 'algorithm': algorithm,
126
+ 'estimator': estimator,
127
+ 'test_fraction': test_fraction,
128
+ 'n_features': len(cat_features) + len(con_features),
129
+ 'cat_features': cat_features,
130
+ 'con_features': con_features,
131
+ 'target_variable': target_variable,
132
+ 'n_bins': n_bins,
133
+ 'score_method': score_method,
134
+ 'sig_level': sig_level,
135
+ 'equivalent_sample_size': equivalent_sample_size
136
+ },
137
+ 'timestamp': datetime.now().isoformat()
138
+ }
139
+
140
+ # 儲存到 session results
141
+ self._session_results[self.session_id] = results
142
+
143
+ return results
144
+
145
+ except Exception as e:
146
+ raise Exception(f"Analysis failed: {str(e)}")
147
+
148
+ def _preprocess_data(self, df, cat_features, con_features, target_variable):
149
+ """資料預處理 - 只選擇欄位和刪除缺失值"""
150
+ # 選擇需要的欄位
151
+ selected_columns = cat_features + con_features + [target_variable]
152
+ processed_df = df[selected_columns].copy()
153
+
154
+ # 處理缺失值
155
+ processed_df = processed_df.dropna()
156
+
157
+ return processed_df
158
+
159
+ def _encode_categorical_features(self, cat_features):
160
+ """
161
+ ✅ 將分類變數轉為 category codes - 完全對齊 Django
162
+ 注意:只對 cat_features 編碼,不對分箱後的連續變數編碼
163
+ Django 只對 train_data 編碼,但我們為了一致性也對 test_data 編碼
164
+ """
165
+ for col in cat_features:
166
+ if col in self.train_data.columns:
167
+ if self.train_data[col].dtype == 'object':
168
+ self.train_data[col] = self.train_data[col].astype('category').cat.codes
169
+ # Django 沒有對 test_data 編碼,但為了預測時一致性,我們也編碼
170
+ if col in self.test_data.columns:
171
+ if self.test_data[col].dtype == 'object':
172
+ self.test_data[col] = self.test_data[col].astype('category').cat.codes
173
+
174
+ def _bin_continuous_features(self, con_features, n_bins):
175
+ """
176
+ ✅ 對連續變數分箱 - 完全對齊 Django 版本
177
+ 先用訓練集計算邊界,再套用到測試集
178
+ """
179
+ self.bins_dict = {}
180
+
181
+ for col in con_features:
182
+ if col in self.train_data.columns and self.train_data[col].notna().sum() > 0:
183
+ # 使用訓練集計算分箱邊界
184
+ bin_edges = pd.cut(
185
+ self.train_data[col],
186
+ bins=n_bins,
187
+ retbins=True,
188
+ duplicates='drop'
189
+ )[1]
190
+
191
+ self.bins_dict[col] = bin_edges
192
+
193
+ # 創建分箱標籤 (✅ 使用 – 而不是 -)
194
+ bin_labels = [
195
+ f"{round(bin_edges[i], 2)}–{round(bin_edges[i+1], 2)}"
196
+ for i in range(len(bin_edges) - 1)
197
+ ]
198
+
199
+ # 對訓練集分箱
200
+ self.train_data[col] = pd.cut(
201
+ self.train_data[col],
202
+ bins=bin_edges,
203
+ labels=bin_labels,
204
+ include_lowest=True
205
+ ).astype(object).fillna("Missing")
206
+
207
+ # 對測試集使用相同邊界分箱
208
+ if col in self.test_data.columns:
209
+ self.test_data[col] = pd.cut(
210
+ self.test_data[col],
211
+ bins=bin_edges,
212
+ labels=bin_labels,
213
+ include_lowest=True
214
+ ).astype(object).fillna("Missing")
215
+ else:
216
+ print(f"⚠️ Skipped binning column '{col}' – missing or all NaN")
217
+
218
+ def _learn_structure(self, algorithm, score_method, sig_level, target_variable):
219
+ """學習網路結構 - 完全對齊 Django 版本"""
220
+
221
+ if algorithm == 'NB':
222
+ # Naive Bayes
223
+ edges = [
224
+ (target_variable, feature)
225
+ for feature in self.train_data.columns
226
+ if feature != target_variable
227
+ ]
228
+ model = BayesianNetwork(edges)
229
+
230
+ elif algorithm == 'TAN':
231
+ # Tree-Augmented Naive Bayes
232
+ # ✅ 特殊情況處理: 如果同時存在'asia'和'either'列,特別指定'asia'作為根節點
233
+ if 'asia' in self.train_data.columns and 'either' in self.train_data.columns and target_variable == 'either':
234
+ tan_search = TreeSearch(self.train_data, root_node='asia')
235
+ else:
236
+ tan_search = TreeSearch(self.train_data)
237
+
238
+ structure = tan_search.estimate(
239
+ estimator_type='tan',
240
+ class_node=target_variable
241
+ )
242
+ model = BayesianNetwork(structure.edges())
243
+
244
+ elif algorithm == 'CL':
245
+ # Chow-Liu
246
+ tan_search = TreeSearch(self.train_data)
247
+ structure = tan_search.estimate(
248
+ estimator_type='chow-liu',
249
+ class_node=target_variable
250
+ )
251
+ model = BayesianNetwork(structure.edges())
252
+
253
+ elif algorithm == 'HC':
254
+ # Hill Climbing
255
+ hc = HillClimbSearch(self.train_data)
256
+
257
+ # 選擇評分方法
258
+ scoring_methods = {
259
+ 'BIC': BicScore(self.train_data),
260
+ 'AIC': AICScore(self.train_data),
261
+ 'K2': K2Score(self.train_data),
262
+ 'BDeu': BDeuScore(self.train_data),
263
+ 'BDs': BDsScore(self.train_data)
264
+ }
265
+
266
+ structure = hc.estimate(
267
+ scoring_method=scoring_methods[score_method]
268
+ )
269
+ model = BayesianNetwork(structure.edges())
270
+
271
+ elif algorithm == 'PC':
272
+ # PC Algorithm - ✅ 與 Django 完全一致的降級策略
273
+ pc = PC(self.train_data)
274
+
275
+ # 嘗試不同的 max_cond_vars 直到成功
276
+ for max_cond in [5, 4, 3, 2, 1]:
277
+ try:
278
+ structure = pc.estimate(
279
+ significance_level=sig_level,
280
+ max_cond_vars=max_cond,
281
+ ci_test='chi_square',
282
+ variant='stable',
283
+ n_jobs=1 # ✅ Django 第一次用 1
284
+ )
285
+
286
+ # 檢查是否有效 (✅ 與 Django 一致)
287
+ edges = structure.edges()
288
+ if is_directed_acyclic_graph(DiGraph(edges)) and any(target_variable in edge for edge in edges):
289
+ model = BayesianNetwork(structure.edges())
290
+ break
291
+ except:
292
+ continue
293
+ else:
294
+ # 如果都失敗,使用 Naive Bayes (✅ 與 Django 一致)
295
+ edges = [
296
+ (target_variable, feature)
297
+ for feature in self.train_data.columns
298
+ if feature != target_variable
299
+ ]
300
+ model = BayesianNetwork(edges)
301
+
302
+ else:
303
+ raise ValueError(f"Unknown algorithm: {algorithm}")
304
+
305
+ return model
306
+
307
+ def _fit_parameters(self, estimator, equivalent_sample_size):
308
+ """參數估計"""
309
+ if estimator == 'bn':
310
+ self.model.fit(
311
+ self.train_data,
312
+ estimator=BayesianEstimator,
313
+ equivalent_sample_size=equivalent_sample_size
314
+ )
315
+ else:
316
+ self.model.fit(
317
+ self.train_data,
318
+ estimator=MaximumLikelihoodEstimator
319
+ )
320
+
321
+ def _predict_probabilities(self, data, target_variable):
322
+ """
323
+ 預測機率 - ✅ 與 Django 版本完全一致
324
+ """
325
+ true_labels = []
326
+ predicted_probs = []
327
+
328
+ model_nodes = set(self.model.nodes())
329
+
330
+ for idx, row in data.iterrows():
331
+ # 準備 evidence (✅ 過濾只在模型中的變數)
332
+ raw_evidence = row.drop(target_variable).to_dict()
333
+ filtered_evidence = {k: v for k, v in raw_evidence.items() if k in model_nodes}
334
+
335
+ true_label = row[target_variable]
336
+ true_labels.append(true_label)
337
+
338
+ try:
339
+ result = self.inference.query(
340
+ variables=[target_variable],
341
+ evidence=filtered_evidence
342
+ )
343
+ probs = result.values
344
+ predicted_probs.append(probs)
345
+ except Exception as e:
346
+ print(f"⚠️ Inference failed at row {idx} | evidence keys: {list(filtered_evidence.keys())} | error: {e}")
347
+ predicted_probs.append(None)
348
+
349
+ # ✅ 過濾有效結果 (與 Django 一致)
350
+ valid_data = [
351
+ (label, prob)
352
+ for label, prob in zip(true_labels, predicted_probs)
353
+ if prob is not None and len(prob) > 1
354
+ ]
355
+
356
+ if not valid_data:
357
+ return [], []
358
+
359
+ valid_labels, valid_probs = zip(*valid_data)
360
+ prob_array = np.round(np.array([prob[1] for prob in valid_probs]), 4)
361
+
362
+ return list(valid_labels), prob_array
363
+
364
+ def _evaluate_model(self, data, target_variable, dataset_name):
365
+ """評估模型效能 - ✅ 與 Django 完全一致"""
366
+ # 預測
367
+ true_labels, pred_probs = self._predict_probabilities(
368
+ data, target_variable
369
+ )
370
+
371
+ if len(true_labels) == 0:
372
+ return {
373
+ 'accuracy': 0,
374
+ 'precision': 0,
375
+ 'recall': 0,
376
+ 'f1': 0,
377
+ 'auc': 0,
378
+ 'g_mean': 0,
379
+ 'p_mean': 0,
380
+ 'specificity': 0,
381
+ 'confusion_matrix': [[0, 0], [0, 0]],
382
+ 'fpr': [0],
383
+ 'tpr': [0]
384
+ }
385
+
386
+ # 二元預測 (threshold = 0.1, ✅ 與 Django 一致)
387
+ threshold = 0.1
388
+ pred_labels = (pred_probs >= threshold).astype(int)
389
+
390
+ # 計算指標
391
+ accuracy = accuracy_score(true_labels, pred_labels) * 100
392
+ precision = precision_score(true_labels, pred_labels, zero_division=0) * 100
393
+ recall = recall_score(true_labels, pred_labels, zero_division=0) * 100
394
+ f1 = f1_score(true_labels, pred_labels, zero_division=0) * 100
395
+
396
+ # ROC 曲線
397
+ pred_probs_clean = np.nan_to_num(pred_probs, nan=0.0)
398
+ fpr, tpr, _ = roc_curve(true_labels, pred_probs_clean)
399
+ auc = roc_auc_score(true_labels, pred_probs_clean)
400
+
401
+ # 混淆矩陣
402
+ cm = confusion_matrix(true_labels, pred_labels).tolist()
403
+
404
+ # G-mean 和 P-mean (✅ 與 Django 計算方式一致)
405
+ tn, fp, fn, tp = confusion_matrix(true_labels, pred_labels).ravel()
406
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
407
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
408
+ g_mean = np.sqrt(sensitivity * precision / 100) * 100
409
+ p_mean = np.sqrt(specificity * sensitivity) * 100
410
+
411
+ return {
412
+ 'accuracy': accuracy,
413
+ 'precision': precision,
414
+ 'recall': recall,
415
+ 'f1': f1,
416
+ 'auc': auc,
417
+ 'g_mean': g_mean,
418
+ 'p_mean': p_mean,
419
+ 'specificity': specificity * 100,
420
+ 'confusion_matrix': cm,
421
+ 'fpr': fpr.tolist(),
422
+ 'tpr': tpr.tolist(),
423
+ 'predicted_probs': pred_probs.tolist()
424
+ }
425
+
426
+ def _get_all_cpds(self):
427
+ """獲取所有條件機率表"""
428
+ cpds = {}
429
+ for node in self.model.nodes():
430
+ cpd = self.model.get_cpds(node)
431
+ cpds[node] = cpd
432
+ return cpds
433
+
434
+ def _calculate_scores(self):
435
+ """計算模型評分"""
436
+ scores = {
437
+ 'log_likelihood': log_likelihood_score(self.model, self.train_data),
438
+ 'bic': structure_score(self.model, self.train_data, scoring_method='bic'),
439
+ 'k2': structure_score(self.model, self.train_data, scoring_method='k2'),
440
+ 'bdeu': structure_score(self.model, self.train_data, scoring_method='bdeu'),
441
+ 'bds': structure_score(self.model, self.train_data, scoring_method='bds')
442
+ }
443
+ return scores
444
+
445
+
446
+ def save_model(self, filepath):
447
+ """
448
+ 儲存訓練好的模型
449
+ 包含: model, bins_dict, train_data columns 等資訊
450
+ """
451
+ import pickle
452
+ model_data = {
453
+ 'model': self.model,
454
+ 'bins_dict': self.bins_dict,
455
+ 'train_columns': list(self.train_data.columns),
456
+ 'timestamp': datetime.now().isoformat()
457
+ }
458
+ with open(filepath, 'wb') as f:
459
+ pickle.dump(model_data, f)
460
+
461
+ def load_model(self, filepath):
462
+ """
463
+ 載入已訓練的模型
464
+ """
465
+ import pickle
466
+ with open(filepath, 'rb') as f:
467
+ model_data = pickle.load(f)
468
+ self.model = model_data['model']
469
+ self.bins_dict = model_data['bins_dict']
470
+ self.inference = VariableElimination(self.model)
471
+ return model_data
472
+
473
+
474
+ def predict_single_instance(self, evidence_dict, target_variable):
475
+ """
476
+ 對單一個案進行預測
477
+ """
478
+ processed_evidence = {}
479
+ for key, value in evidence_dict.items():
480
+ if key in self.bins_dict:
481
+ # 連續變數需要分箱
482
+ bins = self.bins_dict[key]
483
+
484
+ # 🆕 處理超出範圍的值
485
+ if value < bins[0]:
486
+ # 小於最小值,使用第一個 bin
487
+ processed_evidence[key] = f"{round(bins[0], 2)}–{round(bins[1], 2)}"
488
+ elif value > bins[-1]:
489
+ # 大於最大值,使用最後一個 bin
490
+ processed_evidence[key] = f"{round(bins[-2], 2)}–{round(bins[-1], 2)}"
491
+ else:
492
+ # 正常範圍內,找到對應的 bin
493
+ for i in range(len(bins)-1):
494
+ if bins[i] <= value <= bins[i+1]:
495
+ processed_evidence[key] = f"{round(bins[i], 2)}–{round(bins[i+1], 2)}"
496
+ break
497
+ else:
498
+ # 分類變數直接使用
499
+ processed_evidence[key] = value
500
+
501
+ # 2. 進行推論
502
+ result = self.inference.query(
503
+ variables=[target_variable],
504
+ evidence=processed_evidence
505
+ )
506
+
507
+ # 3. 整理結果
508
+ probs = result.values
509
+ death_prob = probs[1] if len(probs) > 1 else probs[0]
510
+
511
+ # 判斷風險等級
512
+ if death_prob >= 0.7:
513
+ risk_level = "High"
514
+ elif death_prob >= 0.3:
515
+ risk_level = "Moderate"
516
+ else:
517
+ risk_level = "Low"
518
+
519
+ return {
520
+ 'probability': float(death_prob),
521
+ 'risk_level': risk_level,
522
+ 'all_probs': {i: float(p) for i, p in enumerate(probs)},
523
+ 'processed_evidence': processed_evidence
524
+ }
525
+
526
+
527
+ @classmethod
528
+ def get_session_results(cls, session_id):
529
+ """獲取特定 session 的結果"""
530
+ return cls._session_results.get(session_id)
531
+
532
+ @classmethod
533
+ def clear_session_results(cls, session_id):
534
+ """清除特定 session 的結果"""
535
+ if session_id in cls._session_results:
536
+ del cls._session_results[session_id]
llm_assistant.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import json
3
+ import numpy as np
4
+
5
+
6
+ class LLMAssistant:
7
+ """
8
+ LLM 問答助手
9
+ 協助用戶理解貝葉斯網路分析結果
10
+ """
11
+
12
+ def __init__(self, api_key, session_id):
13
+ """
14
+ 初始化 LLM 助手
15
+
16
+ Args:
17
+ api_key: OpenAI API key
18
+ session_id: 唯一的 session 識別碼
19
+ """
20
+ self.client = OpenAI(api_key=api_key)
21
+ self.session_id = session_id
22
+ self.conversation_history = []
23
+
24
+ # 系統提示詞
25
+ self.system_prompt = """You are an expert data scientist specializing in Bayesian Networks and machine learning.
26
+ Your role is to help users understand their Bayesian Network analysis results.
27
+
28
+ You should:
29
+ 1. Explain complex statistical concepts in simple terms
30
+ 2. Provide insights about model performance metrics
31
+ 3. Suggest improvements when asked
32
+ 4. Explain the structure and relationships in the Bayesian Network
33
+ 5. Help interpret conditional probability tables (CPTs)
34
+ 6. Discuss limitations and assumptions of the model
35
+ 7. Perform personalized risk predictions from patient descriptions**
36
+ 8. Provide empathetic, evidence-based interpretations of risk levels**
37
+
38
+ When performing predictions:
39
+ - Extract relevant medical features from natural language descriptions
40
+ - Clearly communicate risk levels (High/Moderate/Low) with probabilities
41
+ - Explain key risk factors in understandable terms
42
+ - Always emphasize limitations and the need for professional medical consultation
43
+
44
+ Always be clear, concise, and educational. Use examples when helpful.
45
+ Format your responses with proper markdown for better readability."""
46
+
47
+ def get_response(self, user_message, analysis_results):
48
+ """
49
+ 獲取 AI 回應
50
+
51
+ Args:
52
+ user_message: 用戶訊息
53
+ analysis_results: 分析結果字典
54
+
55
+ Returns:
56
+ str: AI 回應
57
+ """
58
+
59
+ # 準備上下文資訊
60
+ context = self._prepare_context(analysis_results)
61
+
62
+ # 添加用戶訊息到歷史
63
+ self.conversation_history.append({
64
+ "role": "user",
65
+ "content": user_message
66
+ })
67
+
68
+ # 構建訊息列表
69
+ messages = [
70
+ {"role": "system", "content": self.system_prompt},
71
+ {"role": "system", "content": f"Analysis Context:\n{context}"}
72
+ ] + self.conversation_history
73
+
74
+ try:
75
+ # 調用 OpenAI API
76
+ response = self.client.chat.completions.create(
77
+ model="gpt-4o-mini",
78
+ messages=messages,
79
+ temperature=0.7,
80
+ max_tokens=1500
81
+ )
82
+
83
+ assistant_message = response.choices[0].message.content
84
+
85
+ # 添加助手回應到歷史
86
+ self.conversation_history.append({
87
+ "role": "assistant",
88
+ "content": assistant_message
89
+ })
90
+
91
+ return assistant_message
92
+
93
+ except Exception as e:
94
+ return f"❌ Error: {str(e)}\n\nPlease check your API key and try again."
95
+
96
+ def _prepare_context(self, results):
97
+ """準備分析結果的上下文資訊"""
98
+
99
+ if not results:
100
+ return "No analysis results available yet."
101
+
102
+ # 提取關鍵資訊
103
+ params = results['parameters']
104
+ train_metrics = results['train_metrics']
105
+ test_metrics = results['test_metrics']
106
+ scores = results['scores']
107
+
108
+ # 構建上下文字串
109
+ context = f"""
110
+ ## Model Configuration
111
+ - Algorithm: {params['algorithm']}
112
+ - Estimator: {params['estimator']}
113
+ - Number of Features: {params['n_features']}
114
+ - Categorical: {len(params['cat_features'])}
115
+ - Continuous: {len(params['con_features'])}
116
+ - Target Variable: {params['target_variable']}
117
+ - Test Set Proportion: {params['test_fraction']:.0%}
118
+
119
+ ## Training Set Performance
120
+ - Accuracy: {train_metrics['accuracy']:.2f}%
121
+ - Precision: {train_metrics['precision']:.2f}%
122
+ - Recall: {train_metrics['recall']:.2f}%
123
+ - F1-Score: {train_metrics['f1']:.2f}%
124
+ - AUC: {train_metrics['auc']:.4f}
125
+ - G-mean: {train_metrics['g_mean']:.2f}%
126
+ - P-mean: {train_metrics['p_mean']:.2f}%
127
+ - Specificity: {train_metrics['specificity']:.2f}%
128
+
129
+ ## Test Set Performance
130
+ - Accuracy: {test_metrics['accuracy']:.2f}%
131
+ - Precision: {test_metrics['precision']:.2f}%
132
+ - Recall: {test_metrics['recall']:.2f}%
133
+ - F1-Score: {test_metrics['f1']:.2f}%
134
+ - AUC: {test_metrics['auc']:.4f}
135
+ - G-mean: {test_metrics['g_mean']:.2f}%
136
+ - P-mean: {test_metrics['p_mean']:.2f}%
137
+ - Specificity: {test_metrics['specificity']:.2f}%
138
+
139
+ ## Model Scores
140
+ - Log-Likelihood: {scores['log_likelihood']:.2f}
141
+ - BIC Score: {scores['bic']:.2f}
142
+ - K2 Score: {scores['k2']:.2f}
143
+ - BDeu Score: {scores['bdeu']:.2f}
144
+ - BDs Score: {scores['bds']:.2f}
145
+
146
+ ## Network Structure
147
+ - Total Nodes: {len(results['model'].nodes())}
148
+ - Total Edges: {len(results['model'].edges())}
149
+ - Network Edges: {list(results['model'].edges())[:10]}... (showing first 10)
150
+ """
151
+
152
+ return context
153
+
154
+ def generate_summary(self, analysis_results):
155
+ """
156
+ 自動生成分析結果總結
157
+
158
+ Args:
159
+ analysis_results: 分析結果字典
160
+
161
+ Returns:
162
+ str: 總結文字
163
+ """
164
+
165
+ summary_prompt = """Based on the analysis results provided in the context, please generate a comprehensive summary that includes:
166
+
167
+ 1. **Model Overview**: Brief description of the model type and configuration
168
+ 2. **Performance Analysis**:
169
+ - Overall model performance on both training and test sets
170
+ - Comparison between training and test performance (overfitting/underfitting)
171
+ - Key strengths and weaknesses
172
+ 3. **Network Structure Insights**: What the learned structure tells us about variable relationships
173
+ 4. **Recommendations**: Specific suggestions for improvement
174
+ 5. **Limitations**: Important caveats and limitations to consider
175
+
176
+ Format the summary in clear markdown with appropriate sections and bullet points."""
177
+
178
+ return self.get_response(summary_prompt, analysis_results)
179
+
180
+ def explain_metric(self, metric_name, analysis_results):
181
+ """
182
+ 解釋特定指標
183
+
184
+ Args:
185
+ metric_name: 指標名稱
186
+ analysis_results: 分析結果字典
187
+
188
+ Returns:
189
+ str: 指標解釋
190
+ """
191
+
192
+ explain_prompt = f"""Please explain the following metric in the context of this analysis:
193
+
194
+ Metric: {metric_name}
195
+
196
+ Include:
197
+ 1. What this metric measures
198
+ 2. The value obtained in this analysis (training and test)
199
+ 3. How to interpret this value
200
+ 4. What it tells us about model performance
201
+ 5. How it relates to other metrics in the analysis"""
202
+
203
+ return self.get_response(explain_prompt, analysis_results)
204
+
205
+ def suggest_improvements(self, analysis_results):
206
+ """
207
+ 提供改進建議
208
+
209
+ Args:
210
+ analysis_results: 分析結果字典
211
+
212
+ Returns:
213
+ str: 改進建議
214
+ """
215
+
216
+ improve_prompt = """Based on the current model performance and configuration, please provide specific, actionable recommendations for improvement.
217
+
218
+ Consider:
219
+ 1. Feature engineering opportunities
220
+ 2. Algorithm selection
221
+ 3. Hyperparameter tuning
222
+ 4. Data quality issues
223
+ 5. Model complexity trade-offs
224
+
225
+ Prioritize recommendations by potential impact."""
226
+
227
+ return self.get_response(improve_prompt, analysis_results)
228
+
229
+ def explain_network_structure(self, analysis_results):
230
+ """
231
+ 解釋網路結構
232
+
233
+ Args:
234
+ analysis_results: 分析結果字典
235
+
236
+ Returns:
237
+ str: 網路結構解釋
238
+ """
239
+
240
+ structure_prompt = """Please explain the learned Bayesian Network structure:
241
+
242
+ 1. What are the key relationships (edges) discovered?
243
+ 2. What do these relationships tell us about the domain?
244
+ 3. Are there any surprising or interesting patterns?
245
+ 4. How does the structure relate to the target variable?
246
+ 5. What are the implications for prediction and inference?"""
247
+
248
+ return self.get_response(structure_prompt, analysis_results)
249
+
250
+ def compare_algorithms(self, analysis_results):
251
+ """
252
+ 比較不同演算法
253
+
254
+ Args:
255
+ analysis_results: 分析結果字典
256
+
257
+ Returns:
258
+ str: 演算法比較
259
+ """
260
+
261
+ compare_prompt = f"""The current model uses the {analysis_results['parameters']['algorithm']} algorithm.
262
+
263
+ Please:
264
+ 1. Explain the characteristics of this algorithm
265
+ 2. Compare it with other available algorithms (NB, TAN, CL, HC, PC)
266
+ 3. Discuss when this algorithm is most appropriate
267
+ 4. Suggest if a different algorithm might be better for this dataset
268
+ 5. Explain the trade-offs involved"""
269
+
270
+ return self.get_response(compare_prompt, analysis_results)
271
+
272
+
273
+ def predict_from_text(self, user_description, analyzer, target_variable, feature_list):
274
+ """
275
+ 從文字描述中提取特徵並進行預測
276
+
277
+ Args:
278
+ user_description: 用戶的文字描述
279
+ analyzer: BayesianNetworkAnalyzer 實例
280
+ target_variable: 目標變數
281
+ feature_list: 模型使用的特徵列表
282
+
283
+ Returns:
284
+ str: AI 回應包含預測結果
285
+ """
286
+
287
+ # Step 1: 使用 LLM 從文字中提取結構化特徵
288
+ extraction_prompt = f"""
289
+ You are a medical data analyst. Extract the following patient features from the description:
290
+
291
+ Required features: {', '.join(feature_list)}
292
+
293
+ User description: "{user_description}"
294
+
295
+ Please extract the values in JSON format. If a feature is not mentioned, use "unknown".
296
+ Return ONLY the JSON object, no other text.
297
+
298
+ Example format:
299
+ {{
300
+ "age": 65,
301
+ "size": 25,
302
+ "grade": 2,
303
+ "nodes": 1,
304
+ ...
305
+ }}
306
+ """
307
+
308
+ # 呼叫 OpenAI API 提取特徵
309
+ response = self.client.chat.completions.create(
310
+ model="gpt-4o-mini",
311
+ messages=[
312
+ {"role": "system", "content": "You are a precise medical data extractor. Return only valid JSON."},
313
+ {"role": "user", "content": extraction_prompt}
314
+ ],
315
+ temperature=0.1
316
+ )
317
+
318
+ # 解析 JSON
319
+ extracted_features = json.loads(response.choices[0].message.content)
320
+
321
+ # Step 2: 移除 unknown 值
322
+ evidence_dict = {k: v for k, v in extracted_features.items()
323
+ if v != "unknown" and k != target_variable}
324
+
325
+ # Step 3: 使用模型進行預測
326
+ prediction = analyzer.predict_single_instance(evidence_dict, target_variable)
327
+
328
+ # Step 4: 讓 LLM 生成易懂的回應
329
+ interpretation_prompt = f"""
330
+ Based on the Bayesian Network model analysis:
331
+
332
+ Patient features: {evidence_dict}
333
+ Predicted death probability: {prediction['probability']:.2%}
334
+ Risk level: {prediction['risk_level']}
335
+
336
+ Please provide a clear, empathetic explanation including:
337
+ 1. A summary of the patient's key risk factors
338
+ 2. The predicted risk level and what it means
339
+ 3. Important considerations and limitations
340
+ 4. Recommendations for next steps
341
+
342
+ Be professional but accessible. Use markdown formatting.
343
+ """
344
+
345
+ final_response = self.client.chat.completions.create(
346
+ model="gpt-4o-mini",
347
+ messages=[
348
+ {"role": "system", "content": "You are a compassionate medical AI assistant."},
349
+ {"role": "user", "content": interpretation_prompt}
350
+ ],
351
+ temperature=0.7
352
+ )
353
+
354
+ return final_response.choices[0].message.content
355
+
356
+
357
+
358
+ def reset_conversation(self):
359
+ """重置對話歷史"""
360
+ self.conversation_history = []
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ graphviz
requirements.txt CHANGED
@@ -1,3 +1,9 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
1
+ streamlit>=1.37.0
2
+ pandas>=2.2.0
3
+ plotly>=5.20.0
4
+ scikit-learn>=1.5.0
5
+ networkx>=3.3
6
+ openai>=1.30.0
7
+ graphviz>=0.20.3
8
+ pgmpy==0.1.26
9
+ numpy>=2.1.0,<3.0.0
utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import plotly.express as px
3
+ import pandas as pd
4
+ import numpy as np
5
+ import networkx as nx
6
+ from plotly.subplots import make_subplots
7
+ from graphviz import Digraph
8
+ import base64
9
+
10
+ def plot_roc_curve(fpr, tpr, auc, title="ROC Curve"):
11
+ """
12
+ 繪製 ROC 曲線
13
+
14
+ Args:
15
+ fpr: False positive rate
16
+ tpr: True positive rate
17
+ auc: Area under curve
18
+ title: 圖表標題
19
+
20
+ Returns:
21
+ plotly figure
22
+ """
23
+ fig = go.Figure()
24
+
25
+ # ROC 曲線
26
+ fig.add_trace(go.Scatter(
27
+ x=fpr,
28
+ y=tpr,
29
+ mode='lines',
30
+ name=f'ROC Curve (AUC = {auc:.4f})',
31
+ line=dict(color='#2d6ca2', width=2)
32
+ ))
33
+
34
+ # 對角線(隨機分類器)
35
+ fig.add_trace(go.Scatter(
36
+ x=[0, 1],
37
+ y=[0, 1],
38
+ mode='lines',
39
+ name='Random Classifier',
40
+ line=dict(color='gray', width=1, dash='dash')
41
+ ))
42
+
43
+ fig.update_layout(
44
+ title=title,
45
+ xaxis_title='False Positive Rate',
46
+ yaxis_title='True Positive Rate',
47
+ width=600,
48
+ height=500,
49
+ template='plotly_white',
50
+ legend=dict(x=0.6, y=0.1)
51
+ )
52
+
53
+ return fig
54
+
55
+ def plot_confusion_matrix(cm, title="Confusion Matrix"):
56
+ """
57
+ 繪製混淆矩陣
58
+
59
+ Args:
60
+ cm: 混淆矩陣 (2x2 list)
61
+ title: 圖表標題
62
+
63
+ Returns:
64
+ plotly figure
65
+ """
66
+ # 轉換為 numpy array
67
+ cm_array = np.array(cm)
68
+
69
+ # 計算百分比
70
+ cm_percent = cm_array / cm_array.sum() * 100
71
+
72
+ # 創建標籤
73
+ labels = [
74
+ [f'{cm_array[i][j]}<br>({cm_percent[i][j]:.1f}%)'
75
+ for j in range(2)]
76
+ for i in range(2)
77
+ ]
78
+
79
+ fig = go.Figure(data=go.Heatmap(
80
+ z=cm_array,
81
+ x=['Predicted: 0', 'Predicted: 1'],
82
+ y=['Actual: 0', 'Actual: 1'],
83
+ text=labels,
84
+ texttemplate='%{text}',
85
+ textfont={"size": 14},
86
+ colorscale='Blues',
87
+ showscale=True
88
+ ))
89
+
90
+ fig.update_layout(
91
+ title=title,
92
+ width=500,
93
+ height=450,
94
+ template='plotly_white'
95
+ )
96
+
97
+ return fig
98
+
99
+ def plot_probability_distribution(probs, title="Probability Distribution"):
100
+ """
101
+ 繪製機率分佈圖
102
+
103
+ Args:
104
+ probs: 預測機率列表
105
+ title: 圖表標題
106
+
107
+ Returns:
108
+ plotly figure
109
+ """
110
+ fig = go.Figure()
111
+
112
+ fig.add_trace(go.Histogram(
113
+ x=probs,
114
+ nbinsx=20,
115
+ name='Predicted Probabilities',
116
+ marker=dict(
117
+ color='#2d6ca2',
118
+ line=dict(color='white', width=1)
119
+ )
120
+ ))
121
+
122
+ fig.update_layout(
123
+ title=title,
124
+ xaxis_title='Predicted Probability for Class 1',
125
+ yaxis_title='Frequency',
126
+ width=700,
127
+ height=400,
128
+ template='plotly_white',
129
+ showlegend=False
130
+ )
131
+
132
+ fig.update_xaxes(range=[0, 1])
133
+
134
+ return fig
135
+
136
+ def generate_network_graph(model): # Pi
137
+ """
138
+ Generate a Graphviz tree from a BayesianNetwork model and return it as a Base64-encoded string.
139
+
140
+ Args:
141
+ model: BayesianNetwork 模型
142
+
143
+ Returns:
144
+ Base64-encoded PNG string
145
+ """
146
+ dot = Digraph(format='png', engine='dot')
147
+ dot.attr('node', style='filled', color='lightblue', shape='ellipse')
148
+ dot.attr(dpi='300')
149
+
150
+ # Add nodes and edges from the BayesianNetwork model
151
+ for node in model.nodes():
152
+ dot.node(node)
153
+ for edge in model.edges():
154
+ dot.edge(edge[1], edge[0])
155
+
156
+ # Render directly to binary and encode in Base64
157
+ png_data = dot.pipe(format='png')
158
+ tree_base64 = base64.b64encode(png_data).decode('utf-8')
159
+
160
+ return tree_base64
161
+
162
+ def create_cpd_table(cpd):
163
+ """
164
+ 創建條件機率表的 DataFrame
165
+
166
+ Args:
167
+ cpd: CPD 物件
168
+
169
+ Returns:
170
+ pandas DataFrame
171
+ """
172
+ if cpd is None:
173
+ return pd.DataFrame()
174
+
175
+ # 獲取變數資訊
176
+ variable = cpd.variable
177
+ evidence_vars = cpd.variables[1:] if len(cpd.variables) > 1 else []
178
+
179
+ # 如果是根節點(沒有父節點)
180
+ if not evidence_vars:
181
+ values = np.round(cpd.values.flatten(), 4)
182
+ df = pd.DataFrame(
183
+ {variable: values},
184
+ index=[f"{variable}({i})" for i in range(len(values))]
185
+ )
186
+ return df
187
+
188
+ # 有父節點的情況
189
+ evidence_card = cpd.cardinality[1:]
190
+
191
+ # 生成多層索引欄位
192
+ from itertools import product
193
+ column_values = list(product(*[range(card) for card in evidence_card]))
194
+
195
+ # 創建欄位名稱
196
+ columns = pd.MultiIndex.from_tuples(
197
+ [tuple(f"{var}({val})" for var, val in zip(evidence_vars, vals))
198
+ for vals in column_values],
199
+ names=evidence_vars
200
+ )
201
+
202
+ # 重塑 CPD 值
203
+ reshaped_values = cpd.values.reshape(len(cpd.values), -1)
204
+ reshaped_values = np.round(reshaped_values, 4)
205
+
206
+ # 創建 DataFrame
207
+ df = pd.DataFrame(
208
+ reshaped_values,
209
+ index=[f"{variable}({i})" for i in range(len(cpd.values))],
210
+ columns=columns
211
+ )
212
+
213
+ return df
214
+
215
+ def create_metrics_comparison_table(train_metrics, test_metrics):
216
+ """
217
+ 創建訓練集和測試集指標比較表
218
+
219
+ Args:
220
+ train_metrics: 訓練集指標字典
221
+ test_metrics: 測試集指標字典
222
+
223
+ Returns:
224
+ pandas DataFrame
225
+ """
226
+ metrics_data = {
227
+ 'Metric': [
228
+ 'Accuracy', 'Precision', 'Recall', 'F1-Score',
229
+ 'AUC', 'G-mean', 'P-mean', 'Specificity'
230
+ ],
231
+ 'Training Set': [
232
+ f"{train_metrics['accuracy']:.2f}%",
233
+ f"{train_metrics['precision']:.2f}%",
234
+ f"{train_metrics['recall']:.2f}%",
235
+ f"{train_metrics['f1']:.2f}%",
236
+ f"{train_metrics['auc']:.4f}",
237
+ f"{train_metrics['g_mean']:.2f}%",
238
+ f"{train_metrics['p_mean']:.2f}%",
239
+ f"{train_metrics['specificity']:.2f}%"
240
+ ],
241
+ 'Test Set': [
242
+ f"{test_metrics['accuracy']:.2f}%",
243
+ f"{test_metrics['precision']:.2f}%",
244
+ f"{test_metrics['recall']:.2f}%",
245
+ f"{test_metrics['f1']:.2f}%",
246
+ f"{test_metrics['auc']:.4f}",
247
+ f"{test_metrics['g_mean']:.2f}%",
248
+ f"{test_metrics['p_mean']:.2f}%",
249
+ f"{test_metrics['specificity']:.2f}%"
250
+ ]
251
+ }
252
+
253
+ df = pd.DataFrame(metrics_data)
254
+ return df
255
+
256
+ def export_results_to_json(results, filename="analysis_results.json"):
257
+ """
258
+ 將結果匯出為 JSON 格式
259
+
260
+ Args:
261
+ results: 分析結果字典
262
+ filename: 檔案名稱
263
+
264
+ Returns:
265
+ JSON 字串
266
+ """
267
+ import json
268
+
269
+ # 移除無法序列化的物件
270
+ exportable_results = {
271
+ 'parameters': results['parameters'],
272
+ 'train_metrics': {
273
+ k: v for k, v in results['train_metrics'].items()
274
+ if k not in ['fpr', 'tpr', 'predicted_probs']
275
+ },
276
+ 'test_metrics': {
277
+ k: v for k, v in results['test_metrics'].items()
278
+ if k not in ['fpr', 'tpr', 'predicted_probs']
279
+ },
280
+ 'scores': results['scores'],
281
+ 'network_edges': list(results['model'].edges()),
282
+ 'timestamp': results['timestamp']
283
+ }
284
+
285
+ return json.dumps(exportable_results, indent=2)
286
+
287
+ def calculate_performance_gap(train_metrics, test_metrics):
288
+ """
289
+ 計算訓練集和測試集之間的效能差距
290
+
291
+ Args:
292
+ train_metrics: 訓練集指標
293
+ test_metrics: 測試集指標
294
+
295
+ Returns:
296
+ dict: 效能差距字典
297
+ """
298
+ gaps = {
299
+ 'accuracy_gap': train_metrics['accuracy'] - test_metrics['accuracy'],
300
+ 'precision_gap': train_metrics['precision'] - test_metrics['precision'],
301
+ 'recall_gap': train_metrics['recall'] - test_metrics['recall'],
302
+ 'f1_gap': train_metrics['f1'] - test_metrics['f1'],
303
+ 'auc_gap': train_metrics['auc'] - test_metrics['auc']
304
+ }
305
+
306
+ # 判斷是否有過擬合
307
+ avg_gap = np.mean([abs(v) for v in gaps.values()])
308
+ overfitting_status = "High" if avg_gap > 10 else "Moderate" if avg_gap > 5 else "Low"
309
+
310
+ gaps['average_gap'] = avg_gap
311
+ gaps['overfitting_risk'] = overfitting_status
312
+
313
+ return gaps