BayesianPyMc1 / bayesian_core.py
Wen1201's picture
Upload bayesian_core.py
5f41fef verified
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import threading
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
class BayesianHierarchicalAnalyzer:
"""
貝氏階層模型分析器
用於分析寶可夢速度對勝率的影響(跨屬性)
"""
# 類別級的鎖,用於執行緒安全
_lock = threading.Lock()
# 儲存各 session 的分析結果
_session_results = {}
def __init__(self, session_id):
"""
初始化分析器
Args:
session_id: 唯一的 session 識別碼
"""
self.session_id = session_id
self.df = None
self.model = None
self.trace = None
# 👇 加入這些屬性
self.col_trial_type = None # 配對名稱欄位
self.col_control_win = None # 控制組勝場欄位
self.col_control_total = None # 控制組總場欄位
self.col_treatment_win = None # 實驗組勝場欄位
self.col_treatment_total = None # 實驗組總場欄位
def load_data(self, csv_path_or_df):
"""
載入資料 (自動識別欄位名稱)
Args:
csv_path_or_df: CSV 檔案路徑或 DataFrame
Expected format:
第 1 欄: 配對名稱 (Trial_Type)
第 2 欄: 控制組勝場 (例如 water_win)
第 3 欄: 控制組總場 (例如 water_battles)
第 4 欄: 實驗組勝場 (例如 fire_win)
第 5 欄: 實驗組總場 (例如 fire_battles)
"""
if isinstance(csv_path_or_df, str):
self.df = pd.read_csv(csv_path_or_df)
else:
self.df = csv_path_or_df.copy()
# 檢查欄位數量
if len(self.df.columns) < 5:
raise ValueError(f"資料至少需要 5 欄,目前只有 {len(self.df.columns)} 欄")
# 自動識別欄位名稱 (假設前 5 欄按照固定順序)
cols = self.df.columns.tolist()
self.col_trial_type = cols[0]
self.col_control_win = cols[1]
self.col_control_total = cols[2]
self.col_treatment_win = cols[3]
self.col_treatment_total = cols[4]
print(f"✓ 自動識別欄位:")
print(f" - 配對名稱: {self.col_trial_type}")
print(f" - 控制組: {self.col_control_win}/{self.col_control_total}")
print(f" - 實驗組: {self.col_treatment_win}/{self.col_treatment_total}")
return True
def validate_data(self):
"""驗證資料有效性"""
if self.df is None:
raise ValueError("請先載入資料")
# 檢查數值欄位
numeric_cols = [
self.col_control_win,
self.col_control_total,
self.col_treatment_win,
self.col_treatment_total
]
for col in numeric_cols:
if not pd.api.types.is_numeric_dtype(self.df[col]):
raise ValueError(f"欄位 {col} 必須是數值類型")
# 檢查邏輯約束
if (self.df[self.col_control_win] > self.df[self.col_control_total]).any():
raise ValueError(f"{self.col_control_win} (勝場數) 不能大於 {self.col_control_total} (總場數)")
if (self.df[self.col_treatment_win] > self.df[self.col_treatment_total]).any():
raise ValueError(f"{self.col_treatment_win} (勝場數) 不能大於 {self.col_treatment_total} (總場數)")
return True
def run_analysis(self, n_samples=2000, n_tune=1000, n_chains=2, target_accept=0.95):
"""
執行貝氏階層模型分析
Args:
n_samples: MCMC 抽樣數
n_tune: 調整期樣本數
n_chains: 鏈數
target_accept: 目標接受率
Returns:
dict: 包含所有分析結果的字典
"""
with self._lock:
try:
self.validate_data()
# 準備資料
trial_labels = self.df[self.col_trial_type].values
num_trials = len(self.df)
# 提取欄位名稱用於模型
control_win_name = self.col_control_win
control_total_name = self.col_control_total
treatment_win_name = self.col_treatment_win
treatment_total_name = self.col_treatment_total
# 提取前綴用於變數命名 (例如 "water_win" → "water")
control_prefix = control_win_name.replace('_win', '').replace('_battles', '').replace('_total', '')
treatment_prefix = treatment_win_name.replace('_win', '').replace('_battles', '').replace('_total', '')
# 建立模型
with pm.Model() as self.model:
# --- 先驗分佈 (Priors) ---
d = pm.Normal('d', mu=0, sigma=10)
tau = pm.Gamma('tau', alpha=0.001, beta=0.001)
sigma = pm.Deterministic('sigma', 1 / pm.math.sqrt(tau))
# --- 各配對特定效應 (Pair-specific effects) ---
mu = pm.Normal('mu', mu=0, sigma=10, shape=num_trials)
delta = pm.Normal('delta', mu=d, sigma=1 / pm.math.sqrt(tau), shape=num_trials)
# --- 轉換與似然函數 (Logit Link & Likelihood) ---
# 使用動態命名
p_control = pm.Deterministic(f'p_{control_prefix}', pm.math.invlogit(mu))
p_treatment = pm.Deterministic(f'p_{treatment_prefix}', pm.math.invlogit(mu + delta))
# 使用動態欄位名稱創建觀測值
control_obs = pm.Binomial(
f'{control_win_name}_obs',
n=self.df[control_total_name].values,
p=p_control,
observed=self.df[control_win_name].values
)
treatment_obs = pm.Binomial(
f'{treatment_win_name}_obs',
n=self.df[treatment_total_name].values,
p=p_treatment,
observed=self.df[treatment_win_name].values
)
# --- 其他統計量 ---
delta_new = pm.Normal('delta_new', mu=d, sigma=1 / pm.math.sqrt(tau))
or_speed = pm.Deterministic('or_speed', pm.math.exp(d))
# 執行 MCMC 抽樣
self.trace = pm.sample(
draws=n_samples,
tune=n_tune,
chains=n_chains,
target_accept=target_accept,
return_inferencedata=True,
progressbar=False, # 在 Streamlit 中關閉進度條
discard_tuned_samples=False # 保留 tune 樣本
)
# 生成摘要統計
summary = az.summary(self.trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95)
# 計算各配對的 delta 統計量
delta_posterior = self.trace.posterior['delta'].values.reshape(-1, num_trials)
delta_mean = delta_posterior.mean(axis=0)
delta_std = delta_posterior.std(axis=0)
delta_hdi = az.hdi(self.trace, var_names=['delta'], hdi_prob=0.95)['delta'].values
# 判斷顯著性(HDI 不包含 0)
delta_significant = (delta_hdi[:, 0] > 0) | (delta_hdi[:, 1] < 0)
# 計算控制組和實驗組的勝率 (使用動態變數名稱)
p_control_posterior = self.trace.posterior[f'p_{control_prefix}'].values.reshape(-1, num_trials)
p_treatment_posterior = self.trace.posterior[f'p_{treatment_prefix}'].values.reshape(-1, num_trials)
p_control_mean = p_control_posterior.mean(axis=0)
p_treatment_mean = p_treatment_posterior.mean(axis=0)
# 整理結果
results = {
'timestamp': datetime.now().isoformat(),
'n_trials': num_trials,
'trial_labels': trial_labels.tolist(),
# 欄位名稱資訊
'column_names': {
'trial_type': self.col_trial_type,
'control_win': control_win_name,
'control_total': control_total_name,
'treatment_win': treatment_win_name,
'treatment_total': treatment_total_name,
'control_prefix': control_prefix,
'treatment_prefix': treatment_prefix
},
# 整體效應
'overall': {
'd_mean': float(summary.loc['d', 'mean']),
'd_sd': float(summary.loc['d', 'sd']),
'd_hdi_low': float(summary.loc['d', 'hdi_2.5%']),
'd_hdi_high': float(summary.loc['d', 'hdi_97.5%']),
'sigma_mean': float(summary.loc['sigma', 'mean']),
'sigma_sd': float(summary.loc['sigma', 'sd']),
'sigma_hdi_low': float(summary.loc['sigma', 'hdi_2.5%']),
'sigma_hdi_high': float(summary.loc['sigma', 'hdi_97.5%']),
'or_mean': float(summary.loc['or_speed', 'mean']),
'or_sd': float(summary.loc['or_speed', 'sd']),
'or_hdi_low': float(summary.loc['or_speed', 'hdi_2.5%']),
'or_hdi_high': float(summary.loc['or_speed', 'hdi_97.5%']),
},
# 各配對的效應 (使用動態鍵名)
'by_trial': {
'delta_mean': delta_mean.tolist(),
'delta_std': delta_std.tolist(),
'delta_hdi_low': delta_hdi[:, 0].tolist(),
'delta_hdi_high': delta_hdi[:, 1].tolist(),
'delta_significant': delta_significant.tolist(),
f'p_{control_prefix}_mean': p_control_mean.tolist(),
f'p_{treatment_prefix}_mean': p_treatment_mean.tolist(),
},
# 原始資料
'data': self.df.to_dict('records'),
# 模型參數
'model_params': {
'n_samples': n_samples,
'n_tune': n_tune,
'n_chains': n_chains,
'target_accept': target_accept
},
# 收斂診斷
'diagnostics': self._compute_diagnostics(summary),
# 解釋
'interpretation': self._interpret_results(
summary.loc['or_speed', 'mean'],
summary.loc['or_speed', 'hdi_2.5%'],
summary.loc['or_speed', 'hdi_97.5%'],
summary.loc['sigma', 'mean']
)
}
# 儲存到 session results
self._session_results[self.session_id] = results
return results
except Exception as e:
raise Exception(f"分析失敗: {str(e)}")
def _compute_diagnostics(self, summary):
"""計算收斂診斷指標"""
try:
# R-hat (應該接近 1.0)
rhat_d = float(summary.loc['d', 'r_hat']) if 'r_hat' in summary.columns else None
rhat_sigma = float(summary.loc['sigma', 'r_hat']) if 'r_hat' in summary.columns else None
# ESS (有效樣本數)
ess_d = float(summary.loc['d', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
ess_sigma = float(summary.loc['sigma', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
return {
'rhat_d': rhat_d,
'rhat_sigma': rhat_sigma,
'ess_d': ess_d,
'ess_sigma': ess_sigma,
'converged': (rhat_d is None or rhat_d < 1.1) and (rhat_sigma is None or rhat_sigma < 1.1)
}
except:
return {
'converged': None,
'rhat_d': None,
'rhat_sigma': None,
'ess_d': None,
'ess_sigma': None
}
def _interpret_results(self, or_mean, or_low, or_high, sigma_mean):
"""解釋分析結果"""
# 整體效應顯著性
if or_low > 1:
overall_effect = "火系寶可夢相對於水系顯著更容易獲勝"
overall_significance = "顯著正效應"
elif or_high < 1:
overall_effect = "水系寶可夢相對於火系顯著更容易獲勝"
overall_significance = "顯著負效應"
else:
overall_effect = "火系與水系勝率無顯著差異"
overall_significance = "不顯著"
# 效果大小
if or_mean > 2:
effect_size = "大效果 (OR > 2) - 火系有明顯優勢"
elif or_mean > 1.5:
effect_size = "中等效果 (OR > 1.5) - 火系有一定優勢"
elif or_mean > 1:
effect_size = "小效果 (OR > 1) - 火系略有優勢"
elif or_mean == 1:
effect_size = "無差異 (OR = 1) - 火系與水系勢均力敵"
elif or_mean > 0.67:
effect_size = "小效果 (OR < 1) - 水系略有優勢"
elif or_mean > 0.5:
effect_size = "中等效果 (OR < 0.67) - 水系有一定優勢"
else:
effect_size = "大效果 (OR < 0.5) - 水系有明顯優勢"
# 異質性評估
if sigma_mean > 0.5:
heterogeneity = "高異質性 - 不同配對的勝率差異很大"
elif sigma_mean > 0.3:
heterogeneity = "中等異質性 - 不同配對的勝率有一定差異"
else:
heterogeneity = "低異質性 - 不同配對的勝率相對一致"
return {
'overall_effect': overall_effect,
'overall_significance': overall_significance,
'effect_size': effect_size,
'heterogeneity': heterogeneity
}
def get_model_graph(self):
"""生成模型 DAG 圖(返回 graphviz 物件)"""
if self.model is None:
raise ValueError("請先執行分析")
try:
gv = pm.model_to_graphviz(self.model)
# 嘗試美化標籤 (如果失敗就用原圖)
try:
# 獲取欄位前綴
control_prefix = self.col_control_win.replace('_win', '').replace('_battles', '').replace('_total', '')
treatment_prefix = self.col_treatment_win.replace('_win', '').replace('_battles', '').replace('_total', '')
# 簡單替換標籤
src = gv.source
src = src.replace('label="d"', f'label="d\\n({treatment_prefix} vs {control_prefix})"')
src = src.replace(f'label="p_{control_prefix}"', f'label="p_{control_prefix}[i]"')
src = src.replace(f'label="p_{treatment_prefix}"', f'label="p_{treatment_prefix}[i]"')
src = src.replace(f'label="{self.col_control_win}_obs"', f'label="{self.col_control_win}_obs[i]"')
src = src.replace(f'label="{self.col_treatment_win}_obs"', f'label="{self.col_treatment_win}_obs[i]"')
src = src.replace('label="mu"', 'label="mu[i]"')
src = src.replace('label="delta"', 'label="delta[i]"')
gv.source = src
except:
pass
return gv
except Exception as e:
raise Exception(f"無法生成 DAG 圖: {str(e)}")
@classmethod
def get_session_results(cls, session_id):
"""獲取特定 session 的結果"""
return cls._session_results.get(session_id)
@classmethod
def clear_session_results(cls, session_id):
"""清除特定 session 的結果"""
if session_id in cls._session_results:
del cls._session_results[session_id]