Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import uuid | |
| from datetime import datetime, timedelta | |
| import atexit | |
| import os | |
| import sys | |
| # 頁面配置 | |
| st.set_page_config( | |
| page_title="Bayesian Hierarchical Model - Pokémon Speed Analysis", | |
| page_icon="🎲", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # 自定義 CSS | |
| st.markdown(""" | |
| <style> | |
| .streamlit-expanderHeader { | |
| background-color: #e8f1f8; | |
| border: 1px solid #b0cfe8; | |
| border-radius: 5px; | |
| font-weight: 600; | |
| color: #1b4f72; | |
| } | |
| .streamlit-expanderHeader:hover { | |
| background-color: #d0e7f8; | |
| } | |
| .stMetric { | |
| background-color: #f8fbff; | |
| padding: 10px; | |
| border-radius: 5px; | |
| border: 1px solid #d0e4f5; | |
| } | |
| .stButton > button { | |
| width: 100%; | |
| border-radius: 20px; | |
| font-weight: 600; | |
| transition: all 0.3s ease; | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.2); | |
| } | |
| .success-box { | |
| background-color: #d4edda; | |
| border: 1px solid #c3e6cb; | |
| border-radius: 5px; | |
| padding: 10px; | |
| margin: 10px 0; | |
| } | |
| .warning-box { | |
| background-color: #fff3cd; | |
| border: 1px solid #ffeaa7; | |
| border-radius: 5px; | |
| padding: 10px; | |
| margin: 10px 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # 導入自定義模組 | |
| from bayesian_core import BayesianHierarchicalAnalyzer | |
| # 注意:如果要啟用 DAG 動態生成功能,請將下行改為: | |
| # from bayesian_llm_assistant_enhanced import BayesianLLMAssistant | |
| from bayesian_llm_assistant import BayesianLLMAssistant | |
| from bayesian_utils import ( | |
| plot_trace, | |
| plot_posterior, | |
| plot_forest, | |
| plot_model_dag, | |
| create_summary_table, | |
| create_trial_results_table, | |
| export_results_to_text, | |
| plot_odds_ratio_comparison | |
| ) | |
| # 清理函數 | |
| def cleanup_old_sessions(): | |
| """清理超過 1 小時的 session""" | |
| current_time = datetime.now() | |
| for session_id in list(BayesianHierarchicalAnalyzer._session_results.keys()): | |
| result = BayesianHierarchicalAnalyzer._session_results.get(session_id) | |
| if result: | |
| result_time = datetime.fromisoformat(result['timestamp']) | |
| if current_time - result_time > timedelta(hours=1): | |
| BayesianHierarchicalAnalyzer.clear_session_results(session_id) | |
| # 註冊清理函數 | |
| atexit.register(cleanup_old_sessions) | |
| # 初始化 session state | |
| if 'session_id' not in st.session_state: | |
| st.session_state.session_id = str(uuid.uuid4()) | |
| if 'analysis_results' not in st.session_state: | |
| st.session_state.analysis_results = None | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if 'analyzer' not in st.session_state: | |
| st.session_state.analyzer = None | |
| if 'trace_img' not in st.session_state: | |
| st.session_state.trace_img = None | |
| if 'posterior_img' not in st.session_state: | |
| st.session_state.posterior_img = None | |
| if 'forest_img' not in st.session_state: | |
| st.session_state.forest_img = None | |
| if 'dag_img' not in st.session_state: | |
| st.session_state.dag_img = None | |
| # 標題 | |
| st.title("🎲 Bayesian Hierarchical Model Analysis") | |
| st.markdown("### 貝氏統合分析模型") | |
| st.markdown("---") | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("⚙️ 配置設定") | |
| # API 選擇 | |
| api_choice = st.radio( | |
| "選擇 LLM API", | |
| options=["Google Gemini", "Anthropic Claude"], | |
| index=0, | |
| help="選擇要使用的 AI 助手" | |
| ) | |
| # API Key 輸入 | |
| if api_choice == "Google Gemini": | |
| api_key = st.text_input( | |
| "Google Gemini API Key", | |
| type="password", | |
| help="輸入您的 Google Gemini API Key" | |
| ) | |
| else: # Claude | |
| api_key = st.text_input( | |
| "Anthropic Claude API Key", | |
| type="password", | |
| help="輸入您的 Anthropic API Key (https://console.anthropic.com)" | |
| ) | |
| if api_key: | |
| st.session_state.api_key = api_key | |
| st.session_state.api_choice = api_choice # 新增:儲存 API 選擇 | |
| st.success(f"✅ {api_choice} API Key 已載入") | |
| st.markdown("---") | |
| # MCMC 參數設定 | |
| st.subheader("🔬 蒙地卡羅馬可夫鏈 參數") | |
| n_samples = st.number_input( | |
| "抽樣數 (Samples)", | |
| min_value=500, | |
| max_value=10000, | |
| value=2000, | |
| step=500, | |
| help="每條鏈的抽樣數量" | |
| ) | |
| n_tune = st.number_input( | |
| "Burn-in", | |
| min_value=200, | |
| max_value=5000, | |
| value=1000, | |
| step=200, | |
| help="Burn-in 的樣本數" | |
| ) | |
| n_chains = st.selectbox( | |
| "鏈數 (Chains)", | |
| options=[1, 2, 4], | |
| index=1, | |
| help="平行運行的鏈數" | |
| ) | |
| target_accept = st.slider( | |
| "目標接受率", | |
| min_value=0.80, | |
| max_value=0.99, | |
| value=0.95, | |
| step=0.01, | |
| help="NUTS 採樣器的目標接受率" | |
| ) | |
| st.markdown("---") | |
| # 清理按鈕 | |
| if st.button("🧹 清理過期資料"): | |
| cleanup_old_sessions() | |
| st.success("✅ 清理完成") | |
| st.rerun() | |
| st.markdown("---") | |
| # 資料來源選擇 | |
| st.subheader("📊 資料來源") | |
| data_source = st.radio( | |
| "選擇資料來源:", | |
| ["使用預設資料集", "上傳您的資料"] | |
| ) | |
| uploaded_file = None | |
| if data_source == "上傳您的資料": | |
| uploaded_file = st.file_uploader( | |
| "上傳 CSV 檔案", | |
| type=['csv'], | |
| help="上傳寶可夢速度對戰資料" | |
| ) | |
| with st.expander("📖 資料格式說明"): | |
| st.markdown(""" | |
| **必要欄位格式:** | |
| - `Trial_Type`: 配對名稱(例如:Pair_1, Pair_2) | |
| - `rt`: 火系(治療組)的勝場數 | |
| - `nt`: 火系的總場數 | |
| - `rc`: 水系(對照組)的勝場數 | |
| - `nc`: 水系的總場數 | |
| **範例:** | |
| ``` | |
| Trial_Type,rt,nt,rc,nc | |
| Pair_1,122,133,22,145 | |
| Pair_2,85,132,17,135 | |
| Pair_3,52,129,41,134 | |
| ``` | |
| """) | |
| st.markdown("---") | |
| # 關於系統 | |
| with st.expander("ℹ️ 關於此系統"): | |
| st.markdown(""" | |
| **貝氏階層模型分析系統** | |
| 本系統使用貝氏階層模型來分析速度對寶可夢勝率的影響, | |
| 並考慮不同屬性之間的異質性。 | |
| **主要功能:** | |
| - 🎲 貝氏推論與後驗分佈 | |
| - 📊 階層模型(借用資訊) | |
| - 📈 4 種視覺化圖表 | |
| - 💬 AI 助手解釋 | |
| - 🎮 屬性對抗策略建議 | |
| **適用場景:** | |
| - 分析火系對水系的配對勝率 | |
| - 理解不同配對間的異質性 | |
| - 評估屬性優劣勢 | |
| """) | |
| # 主要內容區 - 雙 Tab | |
| tab1, tab2 = st.tabs(["📊 貝氏分析", "💬 AI 助手"]) | |
| # Tab 1: 貝氏分析 | |
| with tab1: | |
| st.header("📊 貝氏統合分析模型") | |
| # 載入資料 | |
| if data_source == "使用預設資料集": | |
| # 檢查預設資料是否存在 | |
| default_data_path = "fire_water_converted.csv" | |
| if os.path.exists(default_data_path): | |
| df = pd.read_csv(default_data_path) | |
| st.success(f"✅ 已載入預設資料集({len(df)} 組配對)") | |
| else: | |
| st.warning("⚠️ 找不到預設資料集,請上傳您的資料") | |
| df = None | |
| else: | |
| if uploaded_file is not None: | |
| df = pd.read_csv(uploaded_file) | |
| st.success(f"✅ 已載入資料({len(df)} 組配對)") | |
| else: | |
| df = None | |
| st.info("📁 請在左側上傳 CSV 檔案") | |
| if df is not None: | |
| # 顯示資料預覽 | |
| with st.expander("👀 資料預覽"): | |
| st.dataframe(df, use_container_width=True) | |
| st.markdown("---") | |
| # 分析按鈕 | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| analyze_button = st.button( | |
| "🔬 開始貝氏分析", | |
| type="primary", | |
| use_container_width=True | |
| ) | |
| # 執行分析 | |
| if analyze_button: | |
| with st.spinner(f"正在執行貝氏分析... (抽樣 {n_samples} × {n_chains} 條鏈)"): | |
| try: | |
| # 初始化分析器 | |
| if st.session_state.analyzer is None: | |
| st.session_state.analyzer = BayesianHierarchicalAnalyzer(st.session_state.session_id) | |
| # 載入資料 | |
| st.session_state.analyzer.load_data(df) | |
| # 執行分析 | |
| results = st.session_state.analyzer.run_analysis( | |
| n_samples=n_samples, | |
| n_tune=n_tune, | |
| n_chains=n_chains, | |
| target_accept=target_accept | |
| ) | |
| st.session_state.analysis_results = results | |
| # 生成圖表 | |
| with st.spinner("生成視覺化圖表..."): | |
| st.session_state.trace_img = plot_trace(st.session_state.analyzer.trace) | |
| st.session_state.posterior_img = plot_posterior(st.session_state.analyzer.trace) | |
| st.session_state.forest_img = plot_forest( | |
| st.session_state.analyzer.trace, | |
| results['trial_labels'] | |
| ) | |
| st.session_state.dag_img = plot_model_dag(st.session_state.analyzer) | |
| st.success("✅ 分析完成!") | |
| st.balloons() | |
| except Exception as e: | |
| st.error(f"❌ 分析失敗: {str(e)}") | |
| # 顯示結果 | |
| if st.session_state.analysis_results is not None: | |
| results = st.session_state.analysis_results | |
| st.markdown("---") | |
| st.subheader("📊 分析結果") | |
| # 創建 4 個子頁面 | |
| result_tabs = st.tabs([ | |
| "📊 概覽", | |
| "📈 Trace & Posterior", | |
| "🌲 Forest Plot", | |
| "🔍 DAG 模型圖", | |
| "📋 詳細報告" | |
| ]) | |
| # Tab: 概覽 | |
| with result_tabs[0]: | |
| st.markdown("### 🎯 整體效應摘要") | |
| overall = results['overall'] | |
| interp = results['interpretation'] | |
| # 關鍵指標 | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric( | |
| "d (整體效應)", | |
| f"{overall['d_mean']:.4f}", | |
| delta=f"HDI: [{overall['d_hdi_low']:.3f}, {overall['d_hdi_high']:.3f}]" | |
| ) | |
| with col2: | |
| st.metric( | |
| "勝算比 (OR)", | |
| f"{overall['or_mean']:.3f}", | |
| delta=f"HDI: [{overall['or_hdi_low']:.3f}, {overall['or_hdi_high']:.3f}]" | |
| ) | |
| with col3: | |
| st.metric( | |
| "sigma (異質性)", | |
| f"{overall['sigma_mean']:.4f}", | |
| delta=f"HDI: [{overall['sigma_hdi_low']:.3f}, {overall['sigma_hdi_high']:.3f}]" | |
| ) | |
| st.markdown("---") | |
| # 結果解釋 | |
| st.markdown("### 📖 結果解釋") | |
| st.info(f""" | |
| **整體效應**: {interp['overall_effect']} | |
| **顯著性**: {interp['overall_significance']} | |
| **效果大小**: {interp['effect_size']} | |
| **異質性**: {interp['heterogeneity']} | |
| """) | |
| st.markdown("---") | |
| # 收斂診斷 | |
| st.markdown("### 🔍 模型收斂診斷") | |
| diag = results['diagnostics'] | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("**R-hat 診斷** (應 < 1.1):") | |
| if diag['rhat_d']: | |
| st.metric("R-hat (d)", f"{diag['rhat_d']:.4f}", | |
| delta="✓ 良好" if diag['rhat_d'] < 1.1 else "✗ 需改善") | |
| if diag['rhat_sigma']: | |
| st.metric("R-hat (sigma)", f"{diag['rhat_sigma']:.4f}", | |
| delta="✓ 良好" if diag['rhat_sigma'] < 1.1 else "✗ 需改善") | |
| with col2: | |
| st.markdown("**有效樣本數 (ESS)**:") | |
| if diag['ess_d']: | |
| st.metric("ESS (d)", f"{int(diag['ess_d'])}") | |
| if diag['ess_sigma']: | |
| st.metric("ESS (sigma)", f"{int(diag['ess_sigma'])}") | |
| if diag['converged']: | |
| st.success("✅ 模型已收斂,結果可信") | |
| else: | |
| st.warning("⚠️ 模型可能未完全收斂,建議增加抽樣數或鏈數") | |
| st.markdown("---") | |
| # 摘要表格 | |
| st.markdown("### 📊 統計摘要表") | |
| summary_df = create_summary_table(results) | |
| st.dataframe(summary_df, use_container_width=True) | |
| st.markdown("---") | |
| # 各屬性結果 | |
| st.markdown("### 🎮 各屬性詳細結果") | |
| trial_df = create_trial_results_table(results) | |
| st.dataframe(trial_df, use_container_width=True) | |
| st.markdown("---") | |
| # 勝算比比較圖 | |
| st.markdown("### 📊 各屬性速度效應比較") | |
| or_fig = plot_odds_ratio_comparison(results) | |
| st.plotly_chart(or_fig, use_container_width=True) | |
| # Tab: Trace & Posterior | |
| with result_tabs[1]: | |
| st.markdown("### 📈 Trace Plot(抽樣軌跡)") | |
| st.markdown(""" | |
| **Trace Plot 用途**: | |
| - 檢查 MCMC 抽樣是否收斂 | |
| - 左圖:後驗分佈密度 | |
| - 右圖:抽樣軌跡 | |
| """) | |
| if st.session_state.trace_img: | |
| st.image(st.session_state.trace_img, use_column_width=True) | |
| else: | |
| st.info("請先執行分析以生成 Trace Plot") | |
| st.markdown("---") | |
| st.markdown("### 📊 Posterior Plot(後驗分佈)") | |
| st.markdown(""" | |
| **Posterior Plot 用途**: | |
| - 顯示參數的後驗分佈 | |
| - 包含 95% HDI(最高密度區間) | |
| - 顯示平均值 | |
| """) | |
| if st.session_state.posterior_img: | |
| st.image(st.session_state.posterior_img, use_column_width=True) | |
| else: | |
| st.info("請先執行分析以生成 Posterior Plot") | |
| # Tab: Forest Plot | |
| with result_tabs[2]: | |
| st.markdown("### 🌲 Forest Plot(各屬性效應)") | |
| st.markdown(""" | |
| **Forest Plot 用途**: | |
| - 顯示每個屬性的速度效應(delta) | |
| - 點:平均效應 | |
| - 線:95% HDI | |
| - ★ 標記:顯著正效應(HDI 不包含 0) | |
| - ☆ 標記:顯著負效應 | |
| """) | |
| if st.session_state.forest_img: | |
| st.image(st.session_state.forest_img, use_column_width=True) | |
| else: | |
| st.info("請先執行分析以生成 Forest Plot") | |
| # Tab: DAG 模型圖 | |
| with result_tabs[3]: | |
| st.markdown("### 🔍 模型結構圖 (DAG)") | |
| st.markdown(""" | |
| **DAG(有向無環圖)用途**: | |
| - 視覺化模型的階層結構 | |
| - 顯示變數之間的依賴關係 | |
| - 圓形/橢圓:隨機變數 | |
| - 矩形:觀測資料 | |
| - 菱形:推導變數 | |
| """) | |
| if st.session_state.dag_img: | |
| st.image(st.session_state.dag_img, use_column_width=True) | |
| else: | |
| st.warning("⚠️ 無法生成 DAG 圖(可能需要安裝 Graphviz)") | |
| st.markdown(""" | |
| **安裝 Graphviz:** | |
| - Windows: `choco install graphviz` | |
| - Mac: `brew install graphviz` | |
| - Ubuntu: `sudo apt-get install graphviz` | |
| """) | |
| # Tab: 詳細報告 | |
| with result_tabs[4]: | |
| st.markdown("### 📋 完整分析報告") | |
| # 生成文字報告 | |
| text_report = export_results_to_text(results) | |
| st.text_area( | |
| "報告內容", | |
| text_report, | |
| height=500 | |
| ) | |
| # 下載按鈕 | |
| st.download_button( | |
| label="📥 下載完整報告 (.txt)", | |
| data=text_report, | |
| file_name=f"bayesian_report_{results['timestamp'][:10]}.txt", | |
| mime="text/plain" | |
| ) | |
| # Tab 2: AI 助手 | |
| with tab2: | |
| st.header("💬 AI 分析助手") | |
| if not st.session_state.get('api_key'): | |
| st.warning("⚠️ 請在左側輸入您的 Google Gemini API Key 以使用 AI 助手") | |
| elif st.session_state.analysis_results is None: | |
| st.info("ℹ️ 請先在「貝氏分析」頁面執行分析") | |
| else: | |
| # 初始化 LLM 助手 | |
| if 'llm_assistant' not in st.session_state: | |
| api_choice = st.session_state.get('api_choice', 'Google Gemini') | |
| st.session_state.llm_assistant = BayesianLLMAssistant( | |
| api_key=st.session_state.api_key, | |
| session_id=st.session_state.session_id, | |
| api_provider=api_choice # 新增:傳遞 API 選擇 | |
| ) | |
| # 聊天容器 | |
| chat_container = st.container() | |
| with chat_container: | |
| for message in st.session_state.chat_history: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # 如果訊息包含 DAG 圖,顯示圖片 | |
| if message.get("has_dag", False) and message.get("dag_image") is not None: | |
| st.image(message["dag_image"], caption="🎨 生成的 DAG 圖", use_column_width=True) | |
| # 使用者輸入 | |
| if prompt := st.chat_input("詢問關於分析結果的任何問題..."): | |
| # 添加使用者訊息 | |
| st.session_state.chat_history.append({ | |
| "role": "user", | |
| "content": prompt | |
| }) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # AI 回應 | |
| with st.chat_message("assistant"): | |
| with st.spinner("思考中..."): | |
| try: | |
| # 修改:接收回應和可能的 DAG 圖片 | |
| response, dag_image = st.session_state.llm_assistant.get_response( | |
| user_message=prompt, | |
| analysis_results=st.session_state.analysis_results | |
| ) | |
| st.markdown(response) | |
| # 如果有生成 DAG 圖,顯示它 | |
| if dag_image is not None: | |
| st.image(dag_image, caption="🎨 AI 生成的 DAG 圖", use_column_width=True) | |
| st.success("✨ DAG 圖已生成!你可以繼續詢問圖表相關問題。") | |
| except Exception as e: | |
| error_msg = f"❌ 錯誤: {str(e)}\n\n請檢查 API key 或重新表達問題。" | |
| st.error(error_msg) | |
| response = error_msg | |
| dag_image = None | |
| # 添加助手回應(包含 DAG 標記) | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": response, | |
| "has_dag": dag_image is not None, | |
| "dag_image": dag_image # 新增:保存圖片 | |
| }) | |
| st.markdown("---") | |
| # 快速問題按鈕 | |
| st.subheader("💡 快速問題") | |
| # 添加使用提示 | |
| st.info("💡 提示:你可以要求助手「畫一個 DAG 圖」來視覺化模型結構!") | |
| quick_questions = [ | |
| "📊 給我這次分析的總結", | |
| "🎯 解釋 d 和勝算比", | |
| "🔍 解釋 sigma(異質性)", | |
| "❓ 什麼是階層模型?", | |
| "🎨 畫一個模型結構圖", # 新增 DAG 生成按鈕 | |
| "🆚 貝氏 vs 頻率論", | |
| "⚔️ 對戰策略建議", | |
| "🎮 比較不同屬性" | |
| ] | |
| cols = st.columns(4) | |
| for idx, question in enumerate(quick_questions): | |
| col_idx = idx % 4 | |
| if cols[col_idx].button(question, key=f"quick_{idx}"): | |
| # 根據問題選擇對應的方法 | |
| if "總結" in question: | |
| response = st.session_state.llm_assistant.generate_summary( | |
| st.session_state.analysis_results | |
| ) | |
| dag_image = None # 這些方法不返回圖片 | |
| elif "d 和勝算比" in question: | |
| response = st.session_state.llm_assistant.explain_metric( | |
| 'd', | |
| st.session_state.analysis_results | |
| ) | |
| dag_image = None | |
| elif "sigma" in question or "異質性" in question: | |
| response = st.session_state.llm_assistant.explain_metric( | |
| 'sigma', | |
| st.session_state.analysis_results | |
| ) | |
| dag_image = None | |
| elif "階層模型" in question: | |
| response = st.session_state.llm_assistant.explain_hierarchical_model() | |
| dag_image = None | |
| elif "畫一個" in question or "結構圖" in question: | |
| # DAG 生成請求 | |
| response, dag_image = st.session_state.llm_assistant.get_response( | |
| "請畫一個貝氏階層模型的 DAG 圖,並用繁體中文解釋每個節點的意義", | |
| st.session_state.analysis_results | |
| ) | |
| elif "貝氏" in question and "頻率論" in question: | |
| response = st.session_state.llm_assistant.explain_bayesian_vs_frequentist() | |
| dag_image = None | |
| elif "策略" in question: | |
| response = st.session_state.llm_assistant.battle_strategy_advice( | |
| st.session_state.analysis_results | |
| ) | |
| dag_image = None | |
| elif "比較" in question: | |
| response = st.session_state.llm_assistant.compare_types( | |
| st.session_state.analysis_results | |
| ) | |
| dag_image = None | |
| else: | |
| response, dag_image = st.session_state.llm_assistant.get_response( | |
| question, | |
| st.session_state.analysis_results | |
| ) | |
| # 添加到聊天歷史 | |
| st.session_state.chat_history.append({ | |
| "role": "user", | |
| "content": question | |
| }) | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": response, | |
| "has_dag": dag_image is not None if 'dag_image' in locals() else False, | |
| "dag_image": dag_image if 'dag_image' in locals() else None | |
| }) | |
| st.rerun() | |
| # 重置對話按鈕 | |
| st.markdown("---") | |
| if st.button("🔄 重置對話"): | |
| st.session_state.llm_assistant.reset_conversation() | |
| st.session_state.chat_history = [] | |
| st.success("✅ 對話已重置") | |
| st.rerun() | |
| # Footer | |
| st.markdown("---") | |
| st.markdown( | |
| f""" | |
| <div style='text-align: center'> | |
| <p>🎲 Bayesian Hierarchical Model Analysis for Pokémon Speed | Built with Streamlit & PyMC</p> | |
| <p>Session ID: {st.session_state.session_id[:8]} | Powered by Google Gemini 2.0 Flash</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) |