BayesianPyMc / app_bayesian.py
Wen1201's picture
Upload 9 files
0ba59a3 verified
import streamlit as st
import pandas as pd
import uuid
from datetime import datetime, timedelta
import atexit
import os
import sys
# 頁面配置
st.set_page_config(
page_title="Bayesian Hierarchical Model - Pokémon Speed Analysis",
page_icon="🎲",
layout="wide",
initial_sidebar_state="expanded"
)
# 自定義 CSS
st.markdown("""
<style>
.streamlit-expanderHeader {
background-color: #e8f1f8;
border: 1px solid #b0cfe8;
border-radius: 5px;
font-weight: 600;
color: #1b4f72;
}
.streamlit-expanderHeader:hover {
background-color: #d0e7f8;
}
.stMetric {
background-color: #f8fbff;
padding: 10px;
border-radius: 5px;
border: 1px solid #d0e4f5;
}
.stButton > button {
width: 100%;
border-radius: 20px;
font-weight: 600;
transition: all 0.3s ease;
}
.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.2);
}
.success-box {
background-color: #d4edda;
border: 1px solid #c3e6cb;
border-radius: 5px;
padding: 10px;
margin: 10px 0;
}
.warning-box {
background-color: #fff3cd;
border: 1px solid #ffeaa7;
border-radius: 5px;
padding: 10px;
margin: 10px 0;
}
</style>
""", unsafe_allow_html=True)
# 導入自定義模組
from bayesian_core import BayesianHierarchicalAnalyzer
# 注意:如果要啟用 DAG 動態生成功能,請將下行改為:
# from bayesian_llm_assistant_enhanced import BayesianLLMAssistant
from bayesian_llm_assistant import BayesianLLMAssistant
from bayesian_utils import (
plot_trace,
plot_posterior,
plot_forest,
plot_model_dag,
create_summary_table,
create_trial_results_table,
export_results_to_text,
plot_odds_ratio_comparison
)
# 清理函數
def cleanup_old_sessions():
"""清理超過 1 小時的 session"""
current_time = datetime.now()
for session_id in list(BayesianHierarchicalAnalyzer._session_results.keys()):
result = BayesianHierarchicalAnalyzer._session_results.get(session_id)
if result:
result_time = datetime.fromisoformat(result['timestamp'])
if current_time - result_time > timedelta(hours=1):
BayesianHierarchicalAnalyzer.clear_session_results(session_id)
# 註冊清理函數
atexit.register(cleanup_old_sessions)
# 初始化 session state
if 'session_id' not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if 'analysis_results' not in st.session_state:
st.session_state.analysis_results = None
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'analyzer' not in st.session_state:
st.session_state.analyzer = None
if 'trace_img' not in st.session_state:
st.session_state.trace_img = None
if 'posterior_img' not in st.session_state:
st.session_state.posterior_img = None
if 'forest_img' not in st.session_state:
st.session_state.forest_img = None
if 'dag_img' not in st.session_state:
st.session_state.dag_img = None
# 標題
st.title("🎲 Bayesian Hierarchical Model Analysis")
st.markdown("### 火系 vs 水系寶可夢配對勝率的貝氏階層分析")
st.markdown("---")
# Sidebar
with st.sidebar:
st.header("⚙️ 配置設定")
# API 選擇
api_choice = st.radio(
"選擇 LLM API",
options=["Google Gemini", "Anthropic Claude"],
index=0,
help="選擇要使用的 AI 助手"
)
# API Key 輸入
if api_choice == "Google Gemini":
api_key = st.text_input(
"Google Gemini API Key",
type="password",
help="輸入您的 Google Gemini API Key"
)
else: # Claude
api_key = st.text_input(
"Anthropic Claude API Key",
type="password",
help="輸入您的 Anthropic API Key (https://console.anthropic.com)"
)
if api_key:
st.session_state.api_key = api_key
st.session_state.api_choice = api_choice # 新增:儲存 API 選擇
st.success(f"✅ {api_choice} API Key 已載入")
st.markdown("---")
# MCMC 參數設定
st.subheader("🔬 MCMC 參數")
n_samples = st.number_input(
"抽樣數 (Samples)",
min_value=500,
max_value=10000,
value=2000,
step=500,
help="每條鏈的抽樣數量"
)
n_tune = st.number_input(
"調整期 (Tune)",
min_value=200,
max_value=5000,
value=1000,
step=200,
help="調整期的樣本數"
)
n_chains = st.selectbox(
"鏈數 (Chains)",
options=[1, 2, 4],
index=1,
help="平行運行的鏈數"
)
target_accept = st.slider(
"目標接受率",
min_value=0.80,
max_value=0.99,
value=0.95,
step=0.01,
help="NUTS 採樣器的目標接受率"
)
st.markdown("---")
# 清理按鈕
if st.button("🧹 清理過期資料"):
cleanup_old_sessions()
st.success("✅ 清理完成")
st.rerun()
st.markdown("---")
# 資料來源選擇
st.subheader("📊 資料來源")
data_source = st.radio(
"選擇資料來源:",
["使用預設資料集", "上傳您的資料"]
)
uploaded_file = None
if data_source == "上傳您的資料":
uploaded_file = st.file_uploader(
"上傳 CSV 檔案",
type=['csv'],
help="上傳寶可夢速度對戰資料"
)
with st.expander("📖 資料格式說明"):
st.markdown("""
**必要欄位格式:**
- `Trial_Type`: 配對名稱(例如:Pair_1, Pair_2)
- `rt`: 火系(治療組)的勝場數
- `nt`: 火系的總場數
- `rc`: 水系(對照組)的勝場數
- `nc`: 水系的總場數
**範例:**
```
Trial_Type,rt,nt,rc,nc
Pair_1,122,133,22,145
Pair_2,85,132,17,135
Pair_3,52,129,41,134
```
""")
st.markdown("---")
# 關於系統
with st.expander("ℹ️ 關於此系統"):
st.markdown("""
**貝氏階層模型分析系統**
本系統使用貝氏階層模型來分析速度對寶可夢勝率的影響,
並考慮不同屬性之間的異質性。
**主要功能:**
- 🎲 貝氏推論與後驗分佈
- 📊 階層模型(借用資訊)
- 📈 4 種視覺化圖表
- 💬 AI 助手解釋
- 🎮 屬性對抗策略建議
**適用場景:**
- 分析火系對水系的配對勝率
- 理解不同配對間的異質性
- 評估屬性優劣勢
""")
# 主要內容區 - 雙 Tab
tab1, tab2 = st.tabs(["📊 貝氏分析", "💬 AI 助手"])
# Tab 1: 貝氏分析
with tab1:
st.header("📊 貝氏階層模型分析")
# 載入資料
if data_source == "使用預設資料集":
# 檢查預設資料是否存在
default_data_path = "fire_water_converted.csv"
if os.path.exists(default_data_path):
df = pd.read_csv(default_data_path)
st.success(f"✅ 已載入預設資料集({len(df)} 組配對)")
else:
st.warning("⚠️ 找不到預設資料集,請上傳您的資料")
df = None
else:
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
st.success(f"✅ 已載入資料({len(df)} 組配對)")
else:
df = None
st.info("📁 請在左側上傳 CSV 檔案")
if df is not None:
# 顯示資料預覽
with st.expander("👀 資料預覽"):
st.dataframe(df, use_container_width=True)
st.markdown("---")
# 分析按鈕
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
analyze_button = st.button(
"🔬 開始貝氏分析",
type="primary",
use_container_width=True
)
# 執行分析
if analyze_button:
with st.spinner(f"正在執行貝氏分析... (抽樣 {n_samples} × {n_chains} 條鏈)"):
try:
# 初始化分析器
if st.session_state.analyzer is None:
st.session_state.analyzer = BayesianHierarchicalAnalyzer(st.session_state.session_id)
# 載入資料
st.session_state.analyzer.load_data(df)
# 執行分析
results = st.session_state.analyzer.run_analysis(
n_samples=n_samples,
n_tune=n_tune,
n_chains=n_chains,
target_accept=target_accept
)
st.session_state.analysis_results = results
# 生成圖表
with st.spinner("生成視覺化圖表..."):
st.session_state.trace_img = plot_trace(st.session_state.analyzer.trace)
st.session_state.posterior_img = plot_posterior(st.session_state.analyzer.trace)
st.session_state.forest_img = plot_forest(
st.session_state.analyzer.trace,
results['trial_labels']
)
st.session_state.dag_img = plot_model_dag(st.session_state.analyzer)
st.success("✅ 分析完成!")
st.balloons()
except Exception as e:
st.error(f"❌ 分析失敗: {str(e)}")
# 顯示結果
if st.session_state.analysis_results is not None:
results = st.session_state.analysis_results
st.markdown("---")
st.subheader("📊 分析結果")
# 創建 4 個子頁面
result_tabs = st.tabs([
"📊 概覽",
"📈 Trace & Posterior",
"🌲 Forest Plot",
"🔍 DAG 模型圖",
"📋 詳細報告"
])
# Tab: 概覽
with result_tabs[0]:
st.markdown("### 🎯 整體效應摘要")
overall = results['overall']
interp = results['interpretation']
# 關鍵指標
col1, col2, col3 = st.columns(3)
with col1:
st.metric(
"d (整體效應)",
f"{overall['d_mean']:.4f}",
delta=f"HDI: [{overall['d_hdi_low']:.3f}, {overall['d_hdi_high']:.3f}]"
)
with col2:
st.metric(
"勝算比 (OR)",
f"{overall['or_mean']:.3f}",
delta=f"HDI: [{overall['or_hdi_low']:.3f}, {overall['or_hdi_high']:.3f}]"
)
with col3:
st.metric(
"sigma (異質性)",
f"{overall['sigma_mean']:.4f}",
delta=f"HDI: [{overall['sigma_hdi_low']:.3f}, {overall['sigma_hdi_high']:.3f}]"
)
st.markdown("---")
# 結果解釋
st.markdown("### 📖 結果解釋")
st.info(f"""
**整體效應**: {interp['overall_effect']}
**顯著性**: {interp['overall_significance']}
**效果大小**: {interp['effect_size']}
**異質性**: {interp['heterogeneity']}
""")
st.markdown("---")
# 收斂診斷
st.markdown("### 🔍 模型收斂診斷")
diag = results['diagnostics']
col1, col2 = st.columns(2)
with col1:
st.markdown("**R-hat 診斷** (應 < 1.1):")
if diag['rhat_d']:
st.metric("R-hat (d)", f"{diag['rhat_d']:.4f}",
delta="✓ 良好" if diag['rhat_d'] < 1.1 else "✗ 需改善")
if diag['rhat_sigma']:
st.metric("R-hat (sigma)", f"{diag['rhat_sigma']:.4f}",
delta="✓ 良好" if diag['rhat_sigma'] < 1.1 else "✗ 需改善")
with col2:
st.markdown("**有效樣本數 (ESS)**:")
if diag['ess_d']:
st.metric("ESS (d)", f"{int(diag['ess_d'])}")
if diag['ess_sigma']:
st.metric("ESS (sigma)", f"{int(diag['ess_sigma'])}")
if diag['converged']:
st.success("✅ 模型已收斂,結果可信")
else:
st.warning("⚠️ 模型可能未完全收斂,建議增加抽樣數或鏈數")
st.markdown("---")
# 摘要表格
st.markdown("### 📊 統計摘要表")
summary_df = create_summary_table(results)
st.dataframe(summary_df, use_container_width=True)
st.markdown("---")
# 各屬性結果
st.markdown("### 🎮 各屬性詳細結果")
trial_df = create_trial_results_table(results)
st.dataframe(trial_df, use_container_width=True)
st.markdown("---")
# 勝算比比較圖
st.markdown("### 📊 各屬性速度效應比較")
or_fig = plot_odds_ratio_comparison(results)
st.plotly_chart(or_fig, use_container_width=True)
# Tab: Trace & Posterior
with result_tabs[1]:
st.markdown("### 📈 Trace Plot(收斂診斷)")
st.markdown("""
**Trace Plot 用途**:
- 檢查 MCMC 抽樣是否收斂
- 左圖:抽樣軌跡(應該像「毛毛蟲」)
- 右圖:後驗分佈密度
""")
if st.session_state.trace_img:
st.image(st.session_state.trace_img, use_column_width=True)
else:
st.info("請先執行分析以生成 Trace Plot")
st.markdown("---")
st.markdown("### 📊 Posterior Plot(後驗分佈)")
st.markdown("""
**Posterior Plot 用途**:
- 顯示參數的後驗分佈
- 包含 95% HDI(最高密度區間)
- 顯示平均值
""")
if st.session_state.posterior_img:
st.image(st.session_state.posterior_img, use_column_width=True)
else:
st.info("請先執行分析以生成 Posterior Plot")
# Tab: Forest Plot
with result_tabs[2]:
st.markdown("### 🌲 Forest Plot(各屬性效應)")
st.markdown("""
**Forest Plot 用途**:
- 顯示每個屬性的速度效應(delta)
- 點:平均效應
- 線:95% HDI
- ★ 標記:顯著正效應(HDI 不包含 0)
- ☆ 標記:顯著負效應
""")
if st.session_state.forest_img:
st.image(st.session_state.forest_img, use_column_width=True)
else:
st.info("請先執行分析以生成 Forest Plot")
# Tab: DAG 模型圖
with result_tabs[3]:
st.markdown("### 🔍 模型結構圖 (DAG)")
st.markdown("""
**DAG(有向無環圖)用途**:
- 視覺化模型的階層結構
- 顯示變數之間的依賴關係
- 圓形/橢圓:隨機變數
- 矩形:觀測資料
- 菱形:推導變數
""")
if st.session_state.dag_img:
st.image(st.session_state.dag_img, use_column_width=True)
else:
st.warning("⚠️ 無法生成 DAG 圖(可能需要安裝 Graphviz)")
st.markdown("""
**安裝 Graphviz:**
- Windows: `choco install graphviz`
- Mac: `brew install graphviz`
- Ubuntu: `sudo apt-get install graphviz`
""")
# Tab: 詳細報告
with result_tabs[4]:
st.markdown("### 📋 完整分析報告")
# 生成文字報告
text_report = export_results_to_text(results)
st.text_area(
"報告內容",
text_report,
height=500
)
# 下載按鈕
st.download_button(
label="📥 下載完整報告 (.txt)",
data=text_report,
file_name=f"bayesian_report_{results['timestamp'][:10]}.txt",
mime="text/plain"
)
# Tab 2: AI 助手
with tab2:
st.header("💬 AI 分析助手")
if not st.session_state.get('api_key'):
st.warning("⚠️ 請在左側輸入您的 Google Gemini API Key 以使用 AI 助手")
elif st.session_state.analysis_results is None:
st.info("ℹ️ 請先在「貝氏分析」頁面執行分析")
else:
# 初始化 LLM 助手
if 'llm_assistant' not in st.session_state:
api_choice = st.session_state.get('api_choice', 'Google Gemini')
st.session_state.llm_assistant = BayesianLLMAssistant(
api_key=st.session_state.api_key,
session_id=st.session_state.session_id,
api_provider=api_choice # 新增:傳遞 API 選擇
)
# 聊天容器
chat_container = st.container()
with chat_container:
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# 如果訊息包含 DAG 圖,顯示圖片
if message.get("has_dag", False) and message.get("dag_image") is not None:
st.image(message["dag_image"], caption="🎨 生成的 DAG 圖", use_column_width=True)
# 使用者輸入
if prompt := st.chat_input("詢問關於分析結果的任何問題..."):
# 添加使用者訊息
st.session_state.chat_history.append({
"role": "user",
"content": prompt
})
with st.chat_message("user"):
st.markdown(prompt)
# AI 回應
with st.chat_message("assistant"):
with st.spinner("思考中..."):
try:
# 修改:接收回應和可能的 DAG 圖片
response, dag_image = st.session_state.llm_assistant.get_response(
user_message=prompt,
analysis_results=st.session_state.analysis_results
)
st.markdown(response)
# 如果有生成 DAG 圖,顯示它
if dag_image is not None:
st.image(dag_image, caption="🎨 AI 生成的 DAG 圖", use_column_width=True)
st.success("✨ DAG 圖已生成!你可以繼續詢問圖表相關問題。")
except Exception as e:
error_msg = f"❌ 錯誤: {str(e)}\n\n請檢查 API key 或重新表達問題。"
st.error(error_msg)
response = error_msg
dag_image = None
# 添加助手回應(包含 DAG 標記)
st.session_state.chat_history.append({
"role": "assistant",
"content": response,
"has_dag": dag_image is not None,
"dag_image": dag_image # 新增:保存圖片
})
st.markdown("---")
# 快速問題按鈕
st.subheader("💡 快速問題")
# 添加使用提示
st.info("💡 提示:你可以要求助手「畫一個 DAG 圖」來視覺化模型結構!")
quick_questions = [
"📊 給我這次分析的總結",
"🎯 解釋 d 和勝算比",
"🔍 解釋 sigma(異質性)",
"❓ 什麼是階層模型?",
"🎨 畫一個模型結構圖", # 新增 DAG 生成按鈕
"🆚 貝氏 vs 頻率論",
"⚔️ 對戰策略建議",
"🎮 比較不同屬性"
]
cols = st.columns(4)
for idx, question in enumerate(quick_questions):
col_idx = idx % 4
if cols[col_idx].button(question, key=f"quick_{idx}"):
# 根據問題選擇對應的方法
if "總結" in question:
response = st.session_state.llm_assistant.generate_summary(
st.session_state.analysis_results
)
dag_image = None # 這些方法不返回圖片
elif "d 和勝算比" in question:
response = st.session_state.llm_assistant.explain_metric(
'd',
st.session_state.analysis_results
)
dag_image = None
elif "sigma" in question or "異質性" in question:
response = st.session_state.llm_assistant.explain_metric(
'sigma',
st.session_state.analysis_results
)
dag_image = None
elif "階層模型" in question:
response = st.session_state.llm_assistant.explain_hierarchical_model()
dag_image = None
elif "畫一個" in question or "結構圖" in question:
# DAG 生成請求
response, dag_image = st.session_state.llm_assistant.get_response(
"請畫一個貝氏階層模型的 DAG 圖,並用繁體中文解釋每個節點的意義",
st.session_state.analysis_results
)
elif "貝氏" in question and "頻率論" in question:
response = st.session_state.llm_assistant.explain_bayesian_vs_frequentist()
dag_image = None
elif "策略" in question:
response = st.session_state.llm_assistant.battle_strategy_advice(
st.session_state.analysis_results
)
dag_image = None
elif "比較" in question:
response = st.session_state.llm_assistant.compare_types(
st.session_state.analysis_results
)
dag_image = None
else:
response, dag_image = st.session_state.llm_assistant.get_response(
question,
st.session_state.analysis_results
)
# 添加到聊天歷史
st.session_state.chat_history.append({
"role": "user",
"content": question
})
st.session_state.chat_history.append({
"role": "assistant",
"content": response,
"has_dag": dag_image is not None if 'dag_image' in locals() else False,
"dag_image": dag_image if 'dag_image' in locals() else None
})
st.rerun()
# 重置對話按鈕
st.markdown("---")
if st.button("🔄 重置對話"):
st.session_state.llm_assistant.reset_conversation()
st.session_state.chat_history = []
st.success("✅ 對話已重置")
st.rerun()
# Footer
st.markdown("---")
st.markdown(
f"""
<div style='text-align: center'>
<p>🎲 Bayesian Hierarchical Model Analysis for Pokémon Speed | Built with Streamlit & PyMC</p>
<p>Session ID: {st.session_state.session_id[:8]} | Powered by Google Gemini 2.0 Flash</p>
</div>
""",
unsafe_allow_html=True
)