Wen1201's picture
Upload 3 files
0ee744a verified
import streamlit as st
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from io import BytesIO
import base64
import json
from datetime import datetime
import uuid
# 頁面配置
st.set_page_config(
page_title="Bayesian Network Analysis System",
page_icon="🔬",
layout="wide",
initial_sidebar_state="expanded"
)
# 自定義 CSS - 讓介面更像 Django
st.markdown("""
<style>
/* Expander 樣式 - 類似 Django 的摺疊區域 */
.streamlit-expanderHeader {
background-color: #e8f1f8;
border: 1px solid #b0cfe8;
border-radius: 5px;
font-weight: 600;
color: #1b4f72;
}
.streamlit-expanderHeader:hover {
background-color: #d0e7f8;
}
/* Checkbox 樣式 */
.stCheckbox {
padding: 2px 0;
}
/* Radio button 樣式 */
.stRadio > label {
font-weight: 600;
color: #1b4f72;
}
/* 選擇框樣式 */
.stSelectbox > label, .stNumberInput > label {
font-weight: 600;
color: #1b4f72;
}
/* 分隔線 */
hr {
margin: 1rem 0;
border-top: 2px solid #b0cfe8;
}
/* 表單容器 */
.element-container {
margin-bottom: 0.5rem;
}
/* 摺疊內容區域 */
.streamlit-expanderContent {
background-color: #f8fbff;
border: 1px solid #d0e4f5;
border-top: none;
padding: 1rem;
}
/* 按鈕樣式 */
.stButton > button {
width: 100%;
border-radius: 20px;
font-weight: 600;
transition: all 0.3s ease;
}
.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.2);
}
</style>
""", unsafe_allow_html=True)
# 導入自定義模組
from bn_core import BayesianNetworkAnalyzer
from llm_assistant import LLMAssistant
from utils import (
plot_roc_curve,
plot_confusion_matrix,
plot_probability_distribution,
generate_network_graph,
create_cpd_table,
export_results_to_json
)
# 初始化 session state
if 'session_id' not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if 'analysis_results' not in st.session_state:
st.session_state.analysis_results = None
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'model_trained' not in st.session_state:
st.session_state.model_trained = False
# 標題
st.title("🔬 Bayesian Network Analysis System")
st.markdown("---")
# Sidebar - OpenAI API Key
with st.sidebar:
st.header("⚙️ Configuration")
api_key = st.text_input(
"OpenAI API Key",
type="password",
help="Enter your OpenAI API key to use the AI assistant"
)
if api_key:
st.session_state.api_key = api_key
st.success("✅ API Key loaded")
st.markdown("---")
# 資料來源選擇
st.subheader("📊 Data Source")
data_source = st.radio(
"Select data source:",
["Use Default Dataset", "Upload Your Data"]
)
uploaded_file = None
if data_source == "Upload Your Data":
uploaded_file = st.file_uploader(
"Upload CSV file",
type=['csv'],
help="Upload your dataset in CSV format"
)
# 主要內容區
tab1, tab2 = st.tabs(["📈 Analysis", "💬 AI Assistant"])
# Tab 1: 分析介面
with tab1:
col1, col2 = st.columns([2, 1])
with col1:
st.header("Model Configuration")
# 載入資料
if data_source == "Use Default Dataset":
# 使用預設資料集
@st.cache_data
def load_default_data():
# 這裡放入預設資料集的路徑
df = pd.read_csv("BC_imputed_micerf_period13_fid_course_D4.csv")
return df
try:
df = load_default_data()
st.success(f"✅ Default dataset loaded: {df.shape[0]} rows, {df.shape[1]} columns")
except:
st.error("❌ Default dataset not found. Please upload your own data.")
df = None
else:
if uploaded_file:
df = pd.read_csv(uploaded_file)
st.success(f"✅ Data loaded: {df.shape[0]} rows, {df.shape[1]} columns")
else:
st.info("👆 Please upload a CSV file to begin")
df = None
if df is not None:
# 特徵選擇 - 使用 expander (可摺疊)
st.subheader("🎯 Input Features")
# 手動指定特徵類型 (針對預設乳癌資料集)
if data_source == "Use Default Dataset":
# 預設資料集的固定分類
numeric_cols = ['size', 'stime'] # 只有這兩個是連續變數
categorical_cols = [col for col in df.columns if col not in numeric_cols]
else:
# 上傳資料集才自動判斷
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
# 二元分類變數(用於目標變數)
binary_cols = [col for col in df.columns if df[col].nunique() == 2]
col_feat1, col_feat2 = st.columns(2)
with col_feat1:
with st.expander("**Continuous**", expanded=False):
st.caption("Select continuous features:")
con_features = []
for col in numeric_cols:
if st.checkbox(col, value=False, key=f"con_{col}"):
con_features.append(col)
with col_feat2:
with st.expander("**Categorical**", expanded=True):
st.caption("Select categorical features:")
cat_features = []
for col in categorical_cols:
# 預設勾選前幾個
default_checked = categorical_cols.index(col) < 5 if len(categorical_cols) > 5 else True
if st.checkbox(col, value=default_checked, key=f"cat_{col}"):
cat_features.append(col)
# 目標變數 - 放在特徵選擇下方
st.markdown("---")
col_target1, col_target2 = st.columns([1, 2])
with col_target1:
target_variable = st.selectbox(
"Target Variable (Y):",
options=binary_cols,
help="Must be a binary classification variable"
)
with col_target2:
test_fraction = st.number_input(
"Test Dataset Proportion:",
min_value=0.10,
max_value=0.50,
value=0.25,
step=0.05,
format="%.2f"
)
# 驗證選擇
selected_features = cat_features + con_features
if target_variable in selected_features:
st.error("❌ Target variable cannot be in feature list!")
st.stop()
st.markdown("---")
# 模型參數 - 使用更緊湊的佈局
st.subheader("⚙️ Model Configuration")
col_param1, col_param2 = st.columns(2)
with col_param1:
algorithm = st.radio(
"Network Structure:",
options=['NB', 'TAN', 'CL', 'HC', 'PC'],
format_func=lambda x: {
'NB': 'Naive Bayes (NB)',
'TAN': 'Tree-Augmented Naive Bayes (TAN)',
'CL': 'Chow-Liu',
'HC': 'Hill Climbing',
'PC': 'PC'
}[x],
help="Select structure learning algorithm"
)
# 條件性參數 - HC
if algorithm == 'HC':
score_method = st.selectbox(
"Scoring Method:",
options=['BIC', 'AIC', 'K2', 'BDeu', 'BDs'],
help="Select scoring method for Hill Climbing"
)
else:
score_method = 'BIC'
# 條件性參數 - PC
if algorithm == 'PC':
sig_level = st.number_input(
"Significance Level:",
min_value=0.01,
max_value=1.0,
value=0.05,
step=0.01,
help="Significance level for PC algorithm"
)
else:
sig_level = 0.05
with col_param2:
estimator = st.radio(
"Parameter Estimator:",
options=['ml', 'bn'],
format_func=lambda x: {
'ml': 'MaximumLikelihoodEstimator',
'bn': 'BayesianEstimator'
}[x],
help="Select parameter estimation method"
)
if estimator == 'bn':
equivalent_sample_size = st.number_input(
"Equivalent Sample Size:",
min_value=1,
value=3,
step=1,
help="Prior strength for Bayesian estimation"
)
else:
equivalent_sample_size = 3
# Decision (如果是預設資料集才顯示)
if data_source == "Use Default Dataset":
decision = st.selectbox(
"Decision:",
options=['OverAll', 'Exposed', 'Unexposed'],
index=0,
help="Analysis subset selection"
)
else:
decision = 'OverAll'
# Provide Evidence - 可摺疊區域
st.markdown("---")
with st.expander("**Provide Evidence**", expanded=False):
st.caption("Enter evidence values for inference (optional):")
evidence_cols = st.columns(2)
evidence_dict = {}
# 為每個非目標變數創建輸入框
all_vars = [v for v in selected_features if v != target_variable]
for idx, var in enumerate(all_vars):
with evidence_cols[idx % 2]:
val = st.text_input(
f"{var}:",
value="",
key=f"evidence_{var}",
help=f"Enter value for {var} (leave empty to ignore)"
)
if val.strip():
evidence_dict[var] = val.strip()
# 進階參數 - 摺疊區域
with st.expander("**Advanced Parameters**", expanded=False):
n_bins = st.slider(
"Number of Bins (for continuous variables):",
min_value=3,
max_value=20,
value=10,
step=1,
help="Number of bins for discretizing continuous features"
)
# 執行分析按鈕
st.markdown("---")
col_btn1, col_btn2 = st.columns([3, 1])
with col_btn1:
run_button = st.button("🚀 Run Analysis", type="primary", use_container_width=True)
with col_btn2:
if st.button("🔄 Reset", use_container_width=True):
st.session_state.analysis_results = None
st.session_state.model_trained = False
st.session_state.chat_history = []
st.rerun()
# 分析步驟說明
with st.expander("ℹ️ Analysis Steps", expanded=False):
st.markdown("""
**Process:**
1. Split data (train/test)
2. Learn network structure
3. Process features (bins from train)
4. Estimate parameters
5. Evaluate performance
**Note:** Test set bins are derived from training set to prevent data leakage.
""")
if run_button:
# 驗證
if not selected_features:
st.error("❌ Please select at least one feature!")
st.stop()
if target_variable in selected_features:
st.error("❌ Target variable cannot be in feature list!")
st.stop()
with st.spinner("🔄 Training Bayesian Network..."):
progress_bar = st.progress(0)
status_text = st.empty()
try:
# 初始化分析器
status_text.text("📊 Initializing analyzer...")
progress_bar.progress(10)
analyzer = BayesianNetworkAnalyzer(
session_id=st.session_state.session_id
)
status_text.text(f"📐 Learning {algorithm} structure...")
progress_bar.progress(30)
# 執行分析
results = analyzer.run_analysis(
df=df,
cat_features=cat_features,
con_features=con_features,
target_variable=target_variable,
test_fraction=test_fraction,
algorithm=algorithm,
estimator=estimator,
equivalent_sample_size=equivalent_sample_size,
score_method=score_method,
sig_level=sig_level,
n_bins=n_bins
)
status_text.text("✅ Analysis completed!")
progress_bar.progress(100)
# 儲存結果
st.session_state.analysis_results = results
st.session_state.model_trained = True
# 🆕 儲存 analyzer 到 session_state(用於個人化預測)
st.session_state.analyzer = analyzer
st.success("✅ Analysis completed successfully!")
st.balloons()
# 清空進度
import time
time.sleep(1)
progress_bar.empty()
status_text.empty()
st.rerun()
except Exception as e:
st.error(f"❌ Error during analysis: {str(e)}")
st.exception(e)
progress_bar.empty()
status_text.empty()
with col2:
st.header("Quick Stats")
if df is not None:
st.metric("Total Samples", df.shape[0])
st.metric("Total Features", df.shape[1])
st.metric("Selected Features", len(selected_features) if 'selected_features' in locals() else 0)
if st.session_state.model_trained:
st.success("✅ Model Trained")
else:
st.info("⏳ Awaiting Training")
# 顯示結果
if st.session_state.analysis_results:
st.markdown("---")
st.header("📊 Analysis Results")
results = st.session_state.analysis_results
# 使用 tabs 來組織結果
result_tabs = st.tabs([
"🕸️ Network Structure",
"📈 Performance Metrics",
"📋 CPD Tables",
"📊 Model Scores"
])
# Tab 1: 網路結構
with result_tabs[0]:
network_fig = generate_network_graph(results['model'])
st.plotly_chart(network_fig, use_container_width=True)
# 顯示邊的列表
with st.expander("View Network Edges", expanded=False):
edges = list(results['model'].edges())
st.write(f"Total edges: {len(edges)}")
# 每行顯示 3 個邊
for i in range(0, len(edges), 3):
cols = st.columns(3)
for j, col in enumerate(cols):
if i + j < len(edges):
edge = edges[i + j]
col.markdown(f"**{edge[0]}** → {edge[1]}")
# Tab 2: 效能指標
with result_tabs[1]:
col_m1, col_m2 = st.columns(2)
with col_m1:
st.markdown("### Training Set")
train_metrics = results['train_metrics']
# 使用 metrics 卡片
metric_cols = st.columns(4)
metric_cols[0].metric("Accuracy", f"{train_metrics['accuracy']:.2f}%")
metric_cols[1].metric("Precision", f"{train_metrics['precision']:.2f}%")
metric_cols[2].metric("Recall", f"{train_metrics['recall']:.2f}%")
metric_cols[3].metric("F1-Score", f"{train_metrics['f1']:.2f}%")
metric_cols2 = st.columns(4)
metric_cols2[0].metric("AUC", f"{train_metrics['auc']:.4f}")
metric_cols2[1].metric("G-mean", f"{train_metrics['g_mean']:.2f}%")
metric_cols2[2].metric("P-mean", f"{train_metrics['p_mean']:.2f}%")
metric_cols2[3].metric("Specificity", f"{train_metrics['specificity']:.2f}%")
# 混淆矩陣
with st.expander("Confusion Matrix", expanded=True):
conf_fig_train = plot_confusion_matrix(
train_metrics['confusion_matrix'],
title="Training Set"
)
st.plotly_chart(conf_fig_train, use_container_width=True)
# ROC Curve
with st.expander("ROC Curve", expanded=False):
roc_fig_train = plot_roc_curve(
train_metrics['fpr'],
train_metrics['tpr'],
train_metrics['auc'],
title="Training Set"
)
st.plotly_chart(roc_fig_train, use_container_width=True)
with col_m2:
st.markdown("### Test Set")
test_metrics = results['test_metrics']
metric_cols = st.columns(4)
metric_cols[0].metric("Accuracy", f"{test_metrics['accuracy']:.2f}%")
metric_cols[1].metric("Precision", f"{test_metrics['precision']:.2f}%")
metric_cols[2].metric("Recall", f"{test_metrics['recall']:.2f}%")
metric_cols[3].metric("F1-Score", f"{test_metrics['f1']:.2f}%")
metric_cols2 = st.columns(4)
metric_cols2[0].metric("AUC", f"{test_metrics['auc']:.4f}")
metric_cols2[1].metric("G-mean", f"{test_metrics['g_mean']:.2f}%")
metric_cols2[2].metric("P-mean", f"{test_metrics['p_mean']:.2f}%")
metric_cols2[3].metric("Specificity", f"{test_metrics['specificity']:.2f}%")
# 混淆矩陣
with st.expander("Confusion Matrix", expanded=True):
conf_fig_test = plot_confusion_matrix(
test_metrics['confusion_matrix'],
title="Test Set"
)
st.plotly_chart(conf_fig_test, use_container_width=True)
# ROC Curve
with st.expander("ROC Curve", expanded=False):
roc_fig_test = plot_roc_curve(
test_metrics['fpr'],
test_metrics['tpr'],
test_metrics['auc'],
title="Test Set"
)
st.plotly_chart(roc_fig_test, use_container_width=True)
# Tab 3: 條件機率表
with result_tabs[2]:
selected_node = st.selectbox(
"Select a node to view its CPD:",
options=list(results['cpds'].keys())
)
if selected_node:
cpd_df = create_cpd_table(results['cpds'][selected_node])
st.dataframe(cpd_df, use_container_width=True)
# 下載按鈕
csv = cpd_df.to_csv()
st.download_button(
label="📥 Download CPD as CSV",
data=csv,
file_name=f"cpd_{selected_node}.csv",
mime="text/csv"
)
# Tab 4: 模型評分
with result_tabs[3]:
scores = results['scores']
score_cols = st.columns(5)
score_cols[0].metric("Log-Likelihood", f"{scores['log_likelihood']:.2f}")
score_cols[1].metric("BIC Score", f"{scores['bic']:.2f}")
score_cols[2].metric("K2 Score", f"{scores['k2']:.2f}")
score_cols[3].metric("BDeu Score", f"{scores['bdeu']:.2f}")
score_cols[4].metric("BDs Score", f"{scores['bds']:.2f}")
# 參數摘要
with st.expander("Analysis Parameters", expanded=True):
params = results['parameters']
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Algorithm Settings**")
st.write(f"- Algorithm: {params['algorithm']}")
st.write(f"- Estimator: {params['estimator']}")
st.write(f"- Test Fraction: {params['test_fraction']:.2%}")
with col2:
st.markdown("**Feature Information**")
st.write(f"- Total Features: {params['n_features']}")
st.write(f"- Categorical: {len(params['cat_features'])}")
st.write(f"- Continuous: {len(params['con_features'])}")
st.write(f"- Target: {params['target_variable']}")
with col3:
st.markdown("**Other Parameters**")
st.write(f"- Bins: {params['n_bins']}")
st.write(f"- Score Method: {params['score_method']}")
st.write(f"- Significance Level: {params['sig_level']}")
st.write(f"- Equivalent Sample Size: {params['equivalent_sample_size']}")
# 匯出結果
with st.expander("Export Results", expanded=False):
col1, col2 = st.columns(2)
with col1:
# 原本的 JSON 下載
result_json = export_results_to_json(results)
st.download_button(
label="📥 Download Full Results (JSON)",
data=result_json,
file_name=f"bn_analysis_{results['timestamp'][:10]}.json",
mime="application/json"
)
with col2:
# 🆕 新增:下載模型
if st.button("💾 Save Trained Model"):
if 'analyzer' in st.session_state:
import tempfile
import os
# 創建臨時文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmp_file:
model_path = tmp_file.name
st.session_state.analyzer.save_model(model_path)
# 讀取並提供下載
with open(model_path, 'rb') as f:
st.download_button(
label="📥 Download Model File (.pkl)",
data=f,
file_name=f"bn_model_{results['timestamp'][:10]}.pkl",
mime="application/octet-stream",
key="download_model_btn"
)
# 清理臨時文件
os.unlink(model_path)
else:
st.error("❌ Analyzer not found in session state")
# Tab 2: AI 助手
with tab2:
st.header("💬 AI Analysis Assistant")
if not st.session_state.get('api_key'):
st.warning("⚠️ Please enter your OpenAI API Key in the sidebar to use the AI assistant.")
elif not st.session_state.model_trained:
st.info("ℹ️ Please train a model first in the Analysis tab to use the AI assistant.")
else:
# 初始化 LLM 助手
if 'llm_assistant' not in st.session_state:
st.session_state.llm_assistant = LLMAssistant(
api_key=st.session_state.api_key,
session_id=st.session_state.session_id
)
# 顯示聊天歷史
chat_container = st.container()
with chat_container:
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# 聊天輸入
if prompt := st.chat_input("Ask me anything about your analysis results..."):
# 添加用戶訊息
st.session_state.chat_history.append({
"role": "user",
"content": prompt
})
with st.chat_message("user"):
st.markdown(prompt)
# 🆕 檢測是否為個人化預測請求
prediction_keywords = ['predict', 'risk', 'patient', 'case', 'my risk', 'calculate', 'probability', 'chance']
is_prediction_request = any(keyword in prompt.lower() for keyword in prediction_keywords)
# 獲取 AI 回應
with st.chat_message("assistant"):
with st.spinner("Analyzing..." if is_prediction_request else "Thinking..."):
try:
if is_prediction_request:
# 🆕 執行個人化預測
# 從 session_state 取得必要資訊
results = st.session_state.analysis_results
# 重建 analyzer(需要載入模型狀態)
# ⚠️ 這裡需要先把 analyzer 存在 session_state 中
if 'analyzer' not in st.session_state:
st.error("❌ Model not found. Please train a model first in the Analysis tab.")
response = "I cannot perform predictions because the model is not available. Please train a model first."
else:
response = st.session_state.llm_assistant.predict_from_text(
user_description=prompt,
analyzer=st.session_state.analyzer,
target_variable=results['parameters']['target_variable'],
feature_list=results['parameters']['cat_features'] + results['parameters']['con_features']
)
else:
# 原本的一般對話
response = st.session_state.llm_assistant.get_response(
user_message=prompt,
analysis_results=st.session_state.analysis_results
)
st.markdown(response)
except Exception as e:
error_msg = f"❌ Error: {str(e)}\n\nPlease try rephrasing your question or check the model status."
st.error(error_msg)
response = error_msg
# 添加助手訊息
st.session_state.chat_history.append({
"role": "assistant",
"content": response
})
# 快速問題按鈕
st.markdown("---")
st.subheader("💡 Quick Questions")
quick_questions = [
"📊 Give me a summary of the analysis results",
"🎯 What is the model's performance?",
"🔍 Explain the Bayesian Network structure",
"⚠️ What are the limitations of this model?",
"💡 How can I improve the model?"
]
cols = st.columns(len(quick_questions))
for idx, (col, question) in enumerate(zip(cols, quick_questions)):
if col.button(question, key=f"quick_{idx}"):
st.session_state.chat_history.append({
"role": "user",
"content": question
})
response = st.session_state.llm_assistant.get_response(
user_message=question,
analysis_results=st.session_state.analysis_results
)
st.session_state.chat_history.append({
"role": "assistant",
"content": response
})
st.rerun()
# Footer
st.markdown("---")
st.markdown(
"""
<div style='text-align: center'>
<p>🔬 Bayesian Network Analysis System | Built with Streamlit</p>
<p>Powered by OpenAI GPT-4 | Session ID: {}</p>
</div>
""".format(st.session_state.session_id[:8]),
unsafe_allow_html=True
)