Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from io import BytesIO | |
| import base64 | |
| import json | |
| from datetime import datetime | |
| import uuid | |
| # 頁面配置 | |
| st.set_page_config( | |
| page_title="Bayesian Network Analysis System", | |
| page_icon="🔬", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # 自定義 CSS - 讓介面更像 Django | |
| st.markdown(""" | |
| <style> | |
| /* Expander 樣式 - 類似 Django 的摺疊區域 */ | |
| .streamlit-expanderHeader { | |
| background-color: #e8f1f8; | |
| border: 1px solid #b0cfe8; | |
| border-radius: 5px; | |
| font-weight: 600; | |
| color: #1b4f72; | |
| } | |
| .streamlit-expanderHeader:hover { | |
| background-color: #d0e7f8; | |
| } | |
| /* Checkbox 樣式 */ | |
| .stCheckbox { | |
| padding: 2px 0; | |
| } | |
| /* Radio button 樣式 */ | |
| .stRadio > label { | |
| font-weight: 600; | |
| color: #1b4f72; | |
| } | |
| /* 選擇框樣式 */ | |
| .stSelectbox > label, .stNumberInput > label { | |
| font-weight: 600; | |
| color: #1b4f72; | |
| } | |
| /* 分隔線 */ | |
| hr { | |
| margin: 1rem 0; | |
| border-top: 2px solid #b0cfe8; | |
| } | |
| /* 表單容器 */ | |
| .element-container { | |
| margin-bottom: 0.5rem; | |
| } | |
| /* 摺疊內容區域 */ | |
| .streamlit-expanderContent { | |
| background-color: #f8fbff; | |
| border: 1px solid #d0e4f5; | |
| border-top: none; | |
| padding: 1rem; | |
| } | |
| /* 按鈕樣式 */ | |
| .stButton > button { | |
| width: 100%; | |
| border-radius: 20px; | |
| font-weight: 600; | |
| transition: all 0.3s ease; | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.2); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # 導入自定義模組 | |
| from bn_core import BayesianNetworkAnalyzer | |
| from llm_assistant import LLMAssistant | |
| from utils import ( | |
| plot_roc_curve, | |
| plot_confusion_matrix, | |
| plot_probability_distribution, | |
| generate_network_graph, | |
| create_cpd_table, | |
| export_results_to_json | |
| ) | |
| # 初始化 session state | |
| if 'session_id' not in st.session_state: | |
| st.session_state.session_id = str(uuid.uuid4()) | |
| if 'analysis_results' not in st.session_state: | |
| st.session_state.analysis_results = None | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if 'model_trained' not in st.session_state: | |
| st.session_state.model_trained = False | |
| # 標題 | |
| st.title("🔬 Bayesian Network Analysis System") | |
| st.markdown("---") | |
| # Sidebar - OpenAI API Key | |
| with st.sidebar: | |
| st.header("⚙️ Configuration") | |
| api_key = st.text_input( | |
| "OpenAI API Key", | |
| type="password", | |
| help="Enter your OpenAI API key to use the AI assistant" | |
| ) | |
| if api_key: | |
| st.session_state.api_key = api_key | |
| st.success("✅ API Key loaded") | |
| st.markdown("---") | |
| # 資料來源選擇 | |
| st.subheader("📊 Data Source") | |
| data_source = st.radio( | |
| "Select data source:", | |
| ["Use Default Dataset", "Upload Your Data"] | |
| ) | |
| uploaded_file = None | |
| if data_source == "Upload Your Data": | |
| uploaded_file = st.file_uploader( | |
| "Upload CSV file", | |
| type=['csv'], | |
| help="Upload your dataset in CSV format" | |
| ) | |
| # 主要內容區 | |
| tab1, tab2 = st.tabs(["📈 Analysis", "💬 AI Assistant"]) | |
| # Tab 1: 分析介面 | |
| with tab1: | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.header("Model Configuration") | |
| # 載入資料 | |
| if data_source == "Use Default Dataset": | |
| # 使用預設資料集 | |
| def load_default_data(): | |
| # 這裡放入預設資料集的路徑 | |
| df = pd.read_csv("BC_imputed_micerf_period13_fid_course_D4.csv") | |
| return df | |
| try: | |
| df = load_default_data() | |
| st.success(f"✅ Default dataset loaded: {df.shape[0]} rows, {df.shape[1]} columns") | |
| except: | |
| st.error("❌ Default dataset not found. Please upload your own data.") | |
| df = None | |
| else: | |
| if uploaded_file: | |
| df = pd.read_csv(uploaded_file) | |
| st.success(f"✅ Data loaded: {df.shape[0]} rows, {df.shape[1]} columns") | |
| else: | |
| st.info("👆 Please upload a CSV file to begin") | |
| df = None | |
| if df is not None: | |
| # 特徵選擇 - 使用 expander (可摺疊) | |
| st.subheader("🎯 Input Features") | |
| # 手動指定特徵類型 (針對預設乳癌資料集) | |
| if data_source == "Use Default Dataset": | |
| # 預設資料集的固定分類 | |
| numeric_cols = ['size', 'stime'] # 只有這兩個是連續變數 | |
| categorical_cols = [col for col in df.columns if col not in numeric_cols] | |
| else: | |
| # 上傳資料集才自動判斷 | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() | |
| categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist() | |
| # 二元分類變數(用於目標變數) | |
| binary_cols = [col for col in df.columns if df[col].nunique() == 2] | |
| col_feat1, col_feat2 = st.columns(2) | |
| with col_feat1: | |
| with st.expander("**Continuous**", expanded=False): | |
| st.caption("Select continuous features:") | |
| con_features = [] | |
| for col in numeric_cols: | |
| if st.checkbox(col, value=False, key=f"con_{col}"): | |
| con_features.append(col) | |
| with col_feat2: | |
| with st.expander("**Categorical**", expanded=True): | |
| st.caption("Select categorical features:") | |
| cat_features = [] | |
| for col in categorical_cols: | |
| # 預設勾選前幾個 | |
| default_checked = categorical_cols.index(col) < 5 if len(categorical_cols) > 5 else True | |
| if st.checkbox(col, value=default_checked, key=f"cat_{col}"): | |
| cat_features.append(col) | |
| # 目標變數 - 放在特徵選擇下方 | |
| st.markdown("---") | |
| col_target1, col_target2 = st.columns([1, 2]) | |
| with col_target1: | |
| target_variable = st.selectbox( | |
| "Target Variable (Y):", | |
| options=binary_cols, | |
| help="Must be a binary classification variable" | |
| ) | |
| with col_target2: | |
| test_fraction = st.number_input( | |
| "Test Dataset Proportion:", | |
| min_value=0.10, | |
| max_value=0.50, | |
| value=0.25, | |
| step=0.05, | |
| format="%.2f" | |
| ) | |
| # 驗證選擇 | |
| selected_features = cat_features + con_features | |
| if target_variable in selected_features: | |
| st.error("❌ Target variable cannot be in feature list!") | |
| st.stop() | |
| st.markdown("---") | |
| # 模型參數 - 使用更緊湊的佈局 | |
| st.subheader("⚙️ Model Configuration") | |
| col_param1, col_param2 = st.columns(2) | |
| with col_param1: | |
| algorithm = st.radio( | |
| "Network Structure:", | |
| options=['NB', 'TAN', 'CL', 'HC', 'PC'], | |
| format_func=lambda x: { | |
| 'NB': 'Naive Bayes (NB)', | |
| 'TAN': 'Tree-Augmented Naive Bayes (TAN)', | |
| 'CL': 'Chow-Liu', | |
| 'HC': 'Hill Climbing', | |
| 'PC': 'PC' | |
| }[x], | |
| help="Select structure learning algorithm" | |
| ) | |
| # 條件性參數 - HC | |
| if algorithm == 'HC': | |
| score_method = st.selectbox( | |
| "Scoring Method:", | |
| options=['BIC', 'AIC', 'K2', 'BDeu', 'BDs'], | |
| help="Select scoring method for Hill Climbing" | |
| ) | |
| else: | |
| score_method = 'BIC' | |
| # 條件性參數 - PC | |
| if algorithm == 'PC': | |
| sig_level = st.number_input( | |
| "Significance Level:", | |
| min_value=0.01, | |
| max_value=1.0, | |
| value=0.05, | |
| step=0.01, | |
| help="Significance level for PC algorithm" | |
| ) | |
| else: | |
| sig_level = 0.05 | |
| with col_param2: | |
| estimator = st.radio( | |
| "Parameter Estimator:", | |
| options=['ml', 'bn'], | |
| format_func=lambda x: { | |
| 'ml': 'MaximumLikelihoodEstimator', | |
| 'bn': 'BayesianEstimator' | |
| }[x], | |
| help="Select parameter estimation method" | |
| ) | |
| if estimator == 'bn': | |
| equivalent_sample_size = st.number_input( | |
| "Equivalent Sample Size:", | |
| min_value=1, | |
| value=3, | |
| step=1, | |
| help="Prior strength for Bayesian estimation" | |
| ) | |
| else: | |
| equivalent_sample_size = 3 | |
| # Decision (如果是預設資料集才顯示) | |
| if data_source == "Use Default Dataset": | |
| decision = st.selectbox( | |
| "Decision:", | |
| options=['OverAll', 'Exposed', 'Unexposed'], | |
| index=0, | |
| help="Analysis subset selection" | |
| ) | |
| else: | |
| decision = 'OverAll' | |
| # Provide Evidence - 可摺疊區域 | |
| st.markdown("---") | |
| with st.expander("**Provide Evidence**", expanded=False): | |
| st.caption("Enter evidence values for inference (optional):") | |
| evidence_cols = st.columns(2) | |
| evidence_dict = {} | |
| # 為每個非目標變數創建輸入框 | |
| all_vars = [v for v in selected_features if v != target_variable] | |
| for idx, var in enumerate(all_vars): | |
| with evidence_cols[idx % 2]: | |
| val = st.text_input( | |
| f"{var}:", | |
| value="", | |
| key=f"evidence_{var}", | |
| help=f"Enter value for {var} (leave empty to ignore)" | |
| ) | |
| if val.strip(): | |
| evidence_dict[var] = val.strip() | |
| # 進階參數 - 摺疊區域 | |
| with st.expander("**Advanced Parameters**", expanded=False): | |
| n_bins = st.slider( | |
| "Number of Bins (for continuous variables):", | |
| min_value=3, | |
| max_value=20, | |
| value=10, | |
| step=1, | |
| help="Number of bins for discretizing continuous features" | |
| ) | |
| # 執行分析按鈕 | |
| st.markdown("---") | |
| col_btn1, col_btn2 = st.columns([3, 1]) | |
| with col_btn1: | |
| run_button = st.button("🚀 Run Analysis", type="primary", use_container_width=True) | |
| with col_btn2: | |
| if st.button("🔄 Reset", use_container_width=True): | |
| st.session_state.analysis_results = None | |
| st.session_state.model_trained = False | |
| st.session_state.chat_history = [] | |
| st.rerun() | |
| # 分析步驟說明 | |
| with st.expander("ℹ️ Analysis Steps", expanded=False): | |
| st.markdown(""" | |
| **Process:** | |
| 1. Split data (train/test) | |
| 2. Learn network structure | |
| 3. Process features (bins from train) | |
| 4. Estimate parameters | |
| 5. Evaluate performance | |
| **Note:** Test set bins are derived from training set to prevent data leakage. | |
| """) | |
| if run_button: | |
| # 驗證 | |
| if not selected_features: | |
| st.error("❌ Please select at least one feature!") | |
| st.stop() | |
| if target_variable in selected_features: | |
| st.error("❌ Target variable cannot be in feature list!") | |
| st.stop() | |
| with st.spinner("🔄 Training Bayesian Network..."): | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| try: | |
| # 初始化分析器 | |
| status_text.text("📊 Initializing analyzer...") | |
| progress_bar.progress(10) | |
| analyzer = BayesianNetworkAnalyzer( | |
| session_id=st.session_state.session_id | |
| ) | |
| status_text.text(f"📐 Learning {algorithm} structure...") | |
| progress_bar.progress(30) | |
| # 執行分析 | |
| results = analyzer.run_analysis( | |
| df=df, | |
| cat_features=cat_features, | |
| con_features=con_features, | |
| target_variable=target_variable, | |
| test_fraction=test_fraction, | |
| algorithm=algorithm, | |
| estimator=estimator, | |
| equivalent_sample_size=equivalent_sample_size, | |
| score_method=score_method, | |
| sig_level=sig_level, | |
| n_bins=n_bins | |
| ) | |
| status_text.text("✅ Analysis completed!") | |
| progress_bar.progress(100) | |
| # 儲存結果 | |
| st.session_state.analysis_results = results | |
| st.session_state.model_trained = True | |
| # 🆕 儲存 analyzer 到 session_state(用於個人化預測) | |
| st.session_state.analyzer = analyzer | |
| st.success("✅ Analysis completed successfully!") | |
| st.balloons() | |
| # 清空進度 | |
| import time | |
| time.sleep(1) | |
| progress_bar.empty() | |
| status_text.empty() | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"❌ Error during analysis: {str(e)}") | |
| st.exception(e) | |
| progress_bar.empty() | |
| status_text.empty() | |
| with col2: | |
| st.header("Quick Stats") | |
| if df is not None: | |
| st.metric("Total Samples", df.shape[0]) | |
| st.metric("Total Features", df.shape[1]) | |
| st.metric("Selected Features", len(selected_features) if 'selected_features' in locals() else 0) | |
| if st.session_state.model_trained: | |
| st.success("✅ Model Trained") | |
| else: | |
| st.info("⏳ Awaiting Training") | |
| # 顯示結果 | |
| if st.session_state.analysis_results: | |
| st.markdown("---") | |
| st.header("📊 Analysis Results") | |
| results = st.session_state.analysis_results | |
| # 使用 tabs 來組織結果 | |
| result_tabs = st.tabs([ | |
| "🕸️ Network Structure", | |
| "📈 Performance Metrics", | |
| "📋 CPD Tables", | |
| "📊 Model Scores" | |
| ]) | |
| # Tab 1: 網路結構 | |
| with result_tabs[0]: | |
| network_fig = generate_network_graph(results['model']) | |
| st.plotly_chart(network_fig, use_container_width=True) | |
| # 顯示邊的列表 | |
| with st.expander("View Network Edges", expanded=False): | |
| edges = list(results['model'].edges()) | |
| st.write(f"Total edges: {len(edges)}") | |
| # 每行顯示 3 個邊 | |
| for i in range(0, len(edges), 3): | |
| cols = st.columns(3) | |
| for j, col in enumerate(cols): | |
| if i + j < len(edges): | |
| edge = edges[i + j] | |
| col.markdown(f"**{edge[0]}** → {edge[1]}") | |
| # Tab 2: 效能指標 | |
| with result_tabs[1]: | |
| col_m1, col_m2 = st.columns(2) | |
| with col_m1: | |
| st.markdown("### Training Set") | |
| train_metrics = results['train_metrics'] | |
| # 使用 metrics 卡片 | |
| metric_cols = st.columns(4) | |
| metric_cols[0].metric("Accuracy", f"{train_metrics['accuracy']:.2f}%") | |
| metric_cols[1].metric("Precision", f"{train_metrics['precision']:.2f}%") | |
| metric_cols[2].metric("Recall", f"{train_metrics['recall']:.2f}%") | |
| metric_cols[3].metric("F1-Score", f"{train_metrics['f1']:.2f}%") | |
| metric_cols2 = st.columns(4) | |
| metric_cols2[0].metric("AUC", f"{train_metrics['auc']:.4f}") | |
| metric_cols2[1].metric("G-mean", f"{train_metrics['g_mean']:.2f}%") | |
| metric_cols2[2].metric("P-mean", f"{train_metrics['p_mean']:.2f}%") | |
| metric_cols2[3].metric("Specificity", f"{train_metrics['specificity']:.2f}%") | |
| # 混淆矩陣 | |
| with st.expander("Confusion Matrix", expanded=True): | |
| conf_fig_train = plot_confusion_matrix( | |
| train_metrics['confusion_matrix'], | |
| title="Training Set" | |
| ) | |
| st.plotly_chart(conf_fig_train, use_container_width=True) | |
| # ROC Curve | |
| with st.expander("ROC Curve", expanded=False): | |
| roc_fig_train = plot_roc_curve( | |
| train_metrics['fpr'], | |
| train_metrics['tpr'], | |
| train_metrics['auc'], | |
| title="Training Set" | |
| ) | |
| st.plotly_chart(roc_fig_train, use_container_width=True) | |
| with col_m2: | |
| st.markdown("### Test Set") | |
| test_metrics = results['test_metrics'] | |
| metric_cols = st.columns(4) | |
| metric_cols[0].metric("Accuracy", f"{test_metrics['accuracy']:.2f}%") | |
| metric_cols[1].metric("Precision", f"{test_metrics['precision']:.2f}%") | |
| metric_cols[2].metric("Recall", f"{test_metrics['recall']:.2f}%") | |
| metric_cols[3].metric("F1-Score", f"{test_metrics['f1']:.2f}%") | |
| metric_cols2 = st.columns(4) | |
| metric_cols2[0].metric("AUC", f"{test_metrics['auc']:.4f}") | |
| metric_cols2[1].metric("G-mean", f"{test_metrics['g_mean']:.2f}%") | |
| metric_cols2[2].metric("P-mean", f"{test_metrics['p_mean']:.2f}%") | |
| metric_cols2[3].metric("Specificity", f"{test_metrics['specificity']:.2f}%") | |
| # 混淆矩陣 | |
| with st.expander("Confusion Matrix", expanded=True): | |
| conf_fig_test = plot_confusion_matrix( | |
| test_metrics['confusion_matrix'], | |
| title="Test Set" | |
| ) | |
| st.plotly_chart(conf_fig_test, use_container_width=True) | |
| # ROC Curve | |
| with st.expander("ROC Curve", expanded=False): | |
| roc_fig_test = plot_roc_curve( | |
| test_metrics['fpr'], | |
| test_metrics['tpr'], | |
| test_metrics['auc'], | |
| title="Test Set" | |
| ) | |
| st.plotly_chart(roc_fig_test, use_container_width=True) | |
| # Tab 3: 條件機率表 | |
| with result_tabs[2]: | |
| selected_node = st.selectbox( | |
| "Select a node to view its CPD:", | |
| options=list(results['cpds'].keys()) | |
| ) | |
| if selected_node: | |
| cpd_df = create_cpd_table(results['cpds'][selected_node]) | |
| st.dataframe(cpd_df, use_container_width=True) | |
| # 下載按鈕 | |
| csv = cpd_df.to_csv() | |
| st.download_button( | |
| label="📥 Download CPD as CSV", | |
| data=csv, | |
| file_name=f"cpd_{selected_node}.csv", | |
| mime="text/csv" | |
| ) | |
| # Tab 4: 模型評分 | |
| with result_tabs[3]: | |
| scores = results['scores'] | |
| score_cols = st.columns(5) | |
| score_cols[0].metric("Log-Likelihood", f"{scores['log_likelihood']:.2f}") | |
| score_cols[1].metric("BIC Score", f"{scores['bic']:.2f}") | |
| score_cols[2].metric("K2 Score", f"{scores['k2']:.2f}") | |
| score_cols[3].metric("BDeu Score", f"{scores['bdeu']:.2f}") | |
| score_cols[4].metric("BDs Score", f"{scores['bds']:.2f}") | |
| # 參數摘要 | |
| with st.expander("Analysis Parameters", expanded=True): | |
| params = results['parameters'] | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.markdown("**Algorithm Settings**") | |
| st.write(f"- Algorithm: {params['algorithm']}") | |
| st.write(f"- Estimator: {params['estimator']}") | |
| st.write(f"- Test Fraction: {params['test_fraction']:.2%}") | |
| with col2: | |
| st.markdown("**Feature Information**") | |
| st.write(f"- Total Features: {params['n_features']}") | |
| st.write(f"- Categorical: {len(params['cat_features'])}") | |
| st.write(f"- Continuous: {len(params['con_features'])}") | |
| st.write(f"- Target: {params['target_variable']}") | |
| with col3: | |
| st.markdown("**Other Parameters**") | |
| st.write(f"- Bins: {params['n_bins']}") | |
| st.write(f"- Score Method: {params['score_method']}") | |
| st.write(f"- Significance Level: {params['sig_level']}") | |
| st.write(f"- Equivalent Sample Size: {params['equivalent_sample_size']}") | |
| # 匯出結果 | |
| with st.expander("Export Results", expanded=False): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # 原本的 JSON 下載 | |
| result_json = export_results_to_json(results) | |
| st.download_button( | |
| label="📥 Download Full Results (JSON)", | |
| data=result_json, | |
| file_name=f"bn_analysis_{results['timestamp'][:10]}.json", | |
| mime="application/json" | |
| ) | |
| with col2: | |
| # 🆕 新增:下載模型 | |
| if st.button("💾 Save Trained Model"): | |
| if 'analyzer' in st.session_state: | |
| import tempfile | |
| import os | |
| # 創建臨時文件 | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmp_file: | |
| model_path = tmp_file.name | |
| st.session_state.analyzer.save_model(model_path) | |
| # 讀取並提供下載 | |
| with open(model_path, 'rb') as f: | |
| st.download_button( | |
| label="📥 Download Model File (.pkl)", | |
| data=f, | |
| file_name=f"bn_model_{results['timestamp'][:10]}.pkl", | |
| mime="application/octet-stream", | |
| key="download_model_btn" | |
| ) | |
| # 清理臨時文件 | |
| os.unlink(model_path) | |
| else: | |
| st.error("❌ Analyzer not found in session state") | |
| # Tab 2: AI 助手 | |
| with tab2: | |
| st.header("💬 AI Analysis Assistant") | |
| if not st.session_state.get('api_key'): | |
| st.warning("⚠️ Please enter your OpenAI API Key in the sidebar to use the AI assistant.") | |
| elif not st.session_state.model_trained: | |
| st.info("ℹ️ Please train a model first in the Analysis tab to use the AI assistant.") | |
| else: | |
| # 初始化 LLM 助手 | |
| if 'llm_assistant' not in st.session_state: | |
| st.session_state.llm_assistant = LLMAssistant( | |
| api_key=st.session_state.api_key, | |
| session_id=st.session_state.session_id | |
| ) | |
| # 顯示聊天歷史 | |
| chat_container = st.container() | |
| with chat_container: | |
| for message in st.session_state.chat_history: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # 聊天輸入 | |
| if prompt := st.chat_input("Ask me anything about your analysis results..."): | |
| # 添加用戶訊息 | |
| st.session_state.chat_history.append({ | |
| "role": "user", | |
| "content": prompt | |
| }) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # 🆕 檢測是否為個人化預測請求 | |
| prediction_keywords = ['predict', 'risk', 'patient', 'case', 'my risk', 'calculate', 'probability', 'chance'] | |
| is_prediction_request = any(keyword in prompt.lower() for keyword in prediction_keywords) | |
| # 獲取 AI 回應 | |
| with st.chat_message("assistant"): | |
| with st.spinner("Analyzing..." if is_prediction_request else "Thinking..."): | |
| try: | |
| if is_prediction_request: | |
| # 🆕 執行個人化預測 | |
| # 從 session_state 取得必要資訊 | |
| results = st.session_state.analysis_results | |
| # 重建 analyzer(需要載入模型狀態) | |
| # ⚠️ 這裡需要先把 analyzer 存在 session_state 中 | |
| if 'analyzer' not in st.session_state: | |
| st.error("❌ Model not found. Please train a model first in the Analysis tab.") | |
| response = "I cannot perform predictions because the model is not available. Please train a model first." | |
| else: | |
| response = st.session_state.llm_assistant.predict_from_text( | |
| user_description=prompt, | |
| analyzer=st.session_state.analyzer, | |
| target_variable=results['parameters']['target_variable'], | |
| feature_list=results['parameters']['cat_features'] + results['parameters']['con_features'] | |
| ) | |
| else: | |
| # 原本的一般對話 | |
| response = st.session_state.llm_assistant.get_response( | |
| user_message=prompt, | |
| analysis_results=st.session_state.analysis_results | |
| ) | |
| st.markdown(response) | |
| except Exception as e: | |
| error_msg = f"❌ Error: {str(e)}\n\nPlease try rephrasing your question or check the model status." | |
| st.error(error_msg) | |
| response = error_msg | |
| # 添加助手訊息 | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| # 快速問題按鈕 | |
| st.markdown("---") | |
| st.subheader("💡 Quick Questions") | |
| quick_questions = [ | |
| "📊 Give me a summary of the analysis results", | |
| "🎯 What is the model's performance?", | |
| "🔍 Explain the Bayesian Network structure", | |
| "⚠️ What are the limitations of this model?", | |
| "💡 How can I improve the model?" | |
| ] | |
| cols = st.columns(len(quick_questions)) | |
| for idx, (col, question) in enumerate(zip(cols, quick_questions)): | |
| if col.button(question, key=f"quick_{idx}"): | |
| st.session_state.chat_history.append({ | |
| "role": "user", | |
| "content": question | |
| }) | |
| response = st.session_state.llm_assistant.get_response( | |
| user_message=question, | |
| analysis_results=st.session_state.analysis_results | |
| ) | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| st.rerun() | |
| # Footer | |
| st.markdown("---") | |
| st.markdown( | |
| """ | |
| <div style='text-align: center'> | |
| <p>🔬 Bayesian Network Analysis System | Built with Streamlit</p> | |
| <p>Powered by OpenAI GPT-4 | Session ID: {}</p> | |
| </div> | |
| """.format(st.session_state.session_id[:8]), | |
| unsafe_allow_html=True | |
| ) |