Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import pymc as pm | |
| import arviz as az | |
| import threading | |
| from datetime import datetime | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| class BayesianHierarchicalAnalyzer: | |
| """ | |
| 貝氏階層模型分析器 | |
| 用於分析寶可夢速度對勝率的影響(跨屬性) | |
| """ | |
| # 類別級的鎖,用於執行緒安全 | |
| _lock = threading.Lock() | |
| # 儲存各 session 的分析結果 | |
| _session_results = {} | |
| def __init__(self, session_id): | |
| """ | |
| 初始化分析器 | |
| Args: | |
| session_id: 唯一的 session 識別碼 | |
| """ | |
| self.session_id = session_id | |
| self.df = None | |
| self.model = None | |
| self.trace = None | |
| # 👇 加入這些屬性 | |
| self.col_trial_type = None # 配對名稱欄位 | |
| self.col_control_win = None # 控制組勝場欄位 | |
| self.col_control_total = None # 控制組總場欄位 | |
| self.col_treatment_win = None # 實驗組勝場欄位 | |
| self.col_treatment_total = None # 實驗組總場欄位 | |
| def load_data(self, csv_path_or_df): | |
| """ | |
| 載入資料 (自動識別欄位名稱) | |
| Args: | |
| csv_path_or_df: CSV 檔案路徑或 DataFrame | |
| Expected format: | |
| 第 1 欄: 配對名稱 (Trial_Type) | |
| 第 2 欄: 控制組勝場 (例如 water_win) | |
| 第 3 欄: 控制組總場 (例如 water_battles) | |
| 第 4 欄: 實驗組勝場 (例如 fire_win) | |
| 第 5 欄: 實驗組總場 (例如 fire_battles) | |
| """ | |
| if isinstance(csv_path_or_df, str): | |
| self.df = pd.read_csv(csv_path_or_df) | |
| else: | |
| self.df = csv_path_or_df.copy() | |
| # 檢查欄位數量 | |
| if len(self.df.columns) < 5: | |
| raise ValueError(f"資料至少需要 5 欄,目前只有 {len(self.df.columns)} 欄") | |
| # 自動識別欄位名稱 (假設前 5 欄按照固定順序) | |
| cols = self.df.columns.tolist() | |
| self.col_trial_type = cols[0] | |
| self.col_control_win = cols[1] | |
| self.col_control_total = cols[2] | |
| self.col_treatment_win = cols[3] | |
| self.col_treatment_total = cols[4] | |
| print(f"✓ 自動識別欄位:") | |
| print(f" - 配對名稱: {self.col_trial_type}") | |
| print(f" - 控制組: {self.col_control_win}/{self.col_control_total}") | |
| print(f" - 實驗組: {self.col_treatment_win}/{self.col_treatment_total}") | |
| return True | |
| def validate_data(self): | |
| """驗證資料有效性""" | |
| if self.df is None: | |
| raise ValueError("請先載入資料") | |
| # 檢查數值欄位 | |
| numeric_cols = [ | |
| self.col_control_win, | |
| self.col_control_total, | |
| self.col_treatment_win, | |
| self.col_treatment_total | |
| ] | |
| for col in numeric_cols: | |
| if not pd.api.types.is_numeric_dtype(self.df[col]): | |
| raise ValueError(f"欄位 {col} 必須是數值類型") | |
| # 檢查邏輯約束 | |
| if (self.df[self.col_control_win] > self.df[self.col_control_total]).any(): | |
| raise ValueError(f"{self.col_control_win} (勝場數) 不能大於 {self.col_control_total} (總場數)") | |
| if (self.df[self.col_treatment_win] > self.df[self.col_treatment_total]).any(): | |
| raise ValueError(f"{self.col_treatment_win} (勝場數) 不能大於 {self.col_treatment_total} (總場數)") | |
| return True | |
| def run_analysis(self, n_samples=2000, n_tune=1000, n_chains=2, target_accept=0.95): | |
| """ | |
| 執行貝氏階層模型分析 | |
| Args: | |
| n_samples: MCMC 抽樣數 | |
| n_tune: 調整期樣本數 | |
| n_chains: 鏈數 | |
| target_accept: 目標接受率 | |
| Returns: | |
| dict: 包含所有分析結果的字典 | |
| """ | |
| with self._lock: | |
| try: | |
| self.validate_data() | |
| # 準備資料 | |
| trial_labels = self.df[self.col_trial_type].values | |
| num_trials = len(self.df) | |
| # 提取欄位名稱用於模型 | |
| control_win_name = self.col_control_win | |
| control_total_name = self.col_control_total | |
| treatment_win_name = self.col_treatment_win | |
| treatment_total_name = self.col_treatment_total | |
| # 提取前綴用於變數命名 (例如 "water_win" → "water") | |
| control_prefix = control_win_name.replace('_win', '').replace('_battles', '').replace('_total', '') | |
| treatment_prefix = treatment_win_name.replace('_win', '').replace('_battles', '').replace('_total', '') | |
| # 建立模型 | |
| with pm.Model() as self.model: | |
| # --- 先驗分佈 (Priors) --- | |
| d = pm.Normal('d', mu=0, sigma=10) | |
| tau = pm.Gamma('tau', alpha=0.001, beta=0.001) | |
| sigma = pm.Deterministic('sigma', 1 / pm.math.sqrt(tau)) | |
| # --- 各配對特定效應 (Pair-specific effects) --- | |
| mu = pm.Normal('mu', mu=0, sigma=10, shape=num_trials) | |
| delta = pm.Normal('delta', mu=d, sigma=1 / pm.math.sqrt(tau), shape=num_trials) | |
| # --- 轉換與似然函數 (Logit Link & Likelihood) --- | |
| # 使用動態命名 | |
| p_control = pm.Deterministic(f'p_{control_prefix}', pm.math.invlogit(mu)) | |
| p_treatment = pm.Deterministic(f'p_{treatment_prefix}', pm.math.invlogit(mu + delta)) | |
| # 使用動態欄位名稱創建觀測值 | |
| control_obs = pm.Binomial( | |
| f'{control_win_name}_obs', | |
| n=self.df[control_total_name].values, | |
| p=p_control, | |
| observed=self.df[control_win_name].values | |
| ) | |
| treatment_obs = pm.Binomial( | |
| f'{treatment_win_name}_obs', | |
| n=self.df[treatment_total_name].values, | |
| p=p_treatment, | |
| observed=self.df[treatment_win_name].values | |
| ) | |
| # --- 其他統計量 --- | |
| delta_new = pm.Normal('delta_new', mu=d, sigma=1 / pm.math.sqrt(tau)) | |
| or_speed = pm.Deterministic('or_speed', pm.math.exp(d)) | |
| # 執行 MCMC 抽樣 | |
| self.trace = pm.sample( | |
| draws=n_samples, | |
| tune=n_tune, | |
| chains=n_chains, | |
| target_accept=target_accept, | |
| return_inferencedata=True, | |
| progressbar=False, # 在 Streamlit 中關閉進度條 | |
| discard_tuned_samples=False # 保留 tune 樣本 | |
| ) | |
| # 生成摘要統計 | |
| summary = az.summary(self.trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95) | |
| # 計算各配對的 delta 統計量 | |
| delta_posterior = self.trace.posterior['delta'].values.reshape(-1, num_trials) | |
| delta_mean = delta_posterior.mean(axis=0) | |
| delta_std = delta_posterior.std(axis=0) | |
| delta_hdi = az.hdi(self.trace, var_names=['delta'], hdi_prob=0.95)['delta'].values | |
| # 判斷顯著性(HDI 不包含 0) | |
| delta_significant = (delta_hdi[:, 0] > 0) | (delta_hdi[:, 1] < 0) | |
| # 計算控制組和實驗組的勝率 (使用動態變數名稱) | |
| p_control_posterior = self.trace.posterior[f'p_{control_prefix}'].values.reshape(-1, num_trials) | |
| p_treatment_posterior = self.trace.posterior[f'p_{treatment_prefix}'].values.reshape(-1, num_trials) | |
| p_control_mean = p_control_posterior.mean(axis=0) | |
| p_treatment_mean = p_treatment_posterior.mean(axis=0) | |
| # 整理結果 | |
| results = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'n_trials': num_trials, | |
| 'trial_labels': trial_labels.tolist(), | |
| # 欄位名稱資訊 | |
| 'column_names': { | |
| 'trial_type': self.col_trial_type, | |
| 'control_win': control_win_name, | |
| 'control_total': control_total_name, | |
| 'treatment_win': treatment_win_name, | |
| 'treatment_total': treatment_total_name, | |
| 'control_prefix': control_prefix, | |
| 'treatment_prefix': treatment_prefix | |
| }, | |
| # 整體效應 | |
| 'overall': { | |
| 'd_mean': float(summary.loc['d', 'mean']), | |
| 'd_sd': float(summary.loc['d', 'sd']), | |
| 'd_hdi_low': float(summary.loc['d', 'hdi_2.5%']), | |
| 'd_hdi_high': float(summary.loc['d', 'hdi_97.5%']), | |
| 'sigma_mean': float(summary.loc['sigma', 'mean']), | |
| 'sigma_sd': float(summary.loc['sigma', 'sd']), | |
| 'sigma_hdi_low': float(summary.loc['sigma', 'hdi_2.5%']), | |
| 'sigma_hdi_high': float(summary.loc['sigma', 'hdi_97.5%']), | |
| 'or_mean': float(summary.loc['or_speed', 'mean']), | |
| 'or_sd': float(summary.loc['or_speed', 'sd']), | |
| 'or_hdi_low': float(summary.loc['or_speed', 'hdi_2.5%']), | |
| 'or_hdi_high': float(summary.loc['or_speed', 'hdi_97.5%']), | |
| }, | |
| # 各配對的效應 (使用動態鍵名) | |
| 'by_trial': { | |
| 'delta_mean': delta_mean.tolist(), | |
| 'delta_std': delta_std.tolist(), | |
| 'delta_hdi_low': delta_hdi[:, 0].tolist(), | |
| 'delta_hdi_high': delta_hdi[:, 1].tolist(), | |
| 'delta_significant': delta_significant.tolist(), | |
| f'p_{control_prefix}_mean': p_control_mean.tolist(), | |
| f'p_{treatment_prefix}_mean': p_treatment_mean.tolist(), | |
| }, | |
| # 原始資料 | |
| 'data': self.df.to_dict('records'), | |
| # 模型參數 | |
| 'model_params': { | |
| 'n_samples': n_samples, | |
| 'n_tune': n_tune, | |
| 'n_chains': n_chains, | |
| 'target_accept': target_accept | |
| }, | |
| # 收斂診斷 | |
| 'diagnostics': self._compute_diagnostics(summary), | |
| # 解釋 | |
| 'interpretation': self._interpret_results( | |
| summary.loc['or_speed', 'mean'], | |
| summary.loc['or_speed', 'hdi_2.5%'], | |
| summary.loc['or_speed', 'hdi_97.5%'], | |
| summary.loc['sigma', 'mean'] | |
| ) | |
| } | |
| # 儲存到 session results | |
| self._session_results[self.session_id] = results | |
| return results | |
| except Exception as e: | |
| raise Exception(f"分析失敗: {str(e)}") | |
| def _compute_diagnostics(self, summary): | |
| """計算收斂診斷指標""" | |
| try: | |
| # R-hat (應該接近 1.0) | |
| rhat_d = float(summary.loc['d', 'r_hat']) if 'r_hat' in summary.columns else None | |
| rhat_sigma = float(summary.loc['sigma', 'r_hat']) if 'r_hat' in summary.columns else None | |
| # ESS (有效樣本數) | |
| ess_d = float(summary.loc['d', 'ess_bulk']) if 'ess_bulk' in summary.columns else None | |
| ess_sigma = float(summary.loc['sigma', 'ess_bulk']) if 'ess_bulk' in summary.columns else None | |
| return { | |
| 'rhat_d': rhat_d, | |
| 'rhat_sigma': rhat_sigma, | |
| 'ess_d': ess_d, | |
| 'ess_sigma': ess_sigma, | |
| 'converged': (rhat_d is None or rhat_d < 1.1) and (rhat_sigma is None or rhat_sigma < 1.1) | |
| } | |
| except: | |
| return { | |
| 'converged': None, | |
| 'rhat_d': None, | |
| 'rhat_sigma': None, | |
| 'ess_d': None, | |
| 'ess_sigma': None | |
| } | |
| def _interpret_results(self, or_mean, or_low, or_high, sigma_mean): | |
| """解釋分析結果""" | |
| # 整體效應顯著性 | |
| if or_low > 1: | |
| overall_effect = "火系寶可夢相對於水系顯著更容易獲勝" | |
| overall_significance = "顯著正效應" | |
| elif or_high < 1: | |
| overall_effect = "水系寶可夢相對於火系顯著更容易獲勝" | |
| overall_significance = "顯著負效應" | |
| else: | |
| overall_effect = "火系與水系勝率無顯著差異" | |
| overall_significance = "不顯著" | |
| # 效果大小 | |
| if or_mean > 2: | |
| effect_size = "大效果 (OR > 2) - 火系有明顯優勢" | |
| elif or_mean > 1.5: | |
| effect_size = "中等效果 (OR > 1.5) - 火系有一定優勢" | |
| elif or_mean > 1: | |
| effect_size = "小效果 (OR > 1) - 火系略有優勢" | |
| elif or_mean == 1: | |
| effect_size = "無差異 (OR = 1) - 火系與水系勢均力敵" | |
| elif or_mean > 0.67: | |
| effect_size = "小效果 (OR < 1) - 水系略有優勢" | |
| elif or_mean > 0.5: | |
| effect_size = "中等效果 (OR < 0.67) - 水系有一定優勢" | |
| else: | |
| effect_size = "大效果 (OR < 0.5) - 水系有明顯優勢" | |
| # 異質性評估 | |
| if sigma_mean > 0.5: | |
| heterogeneity = "高異質性 - 不同配對的勝率差異很大" | |
| elif sigma_mean > 0.3: | |
| heterogeneity = "中等異質性 - 不同配對的勝率有一定差異" | |
| else: | |
| heterogeneity = "低異質性 - 不同配對的勝率相對一致" | |
| return { | |
| 'overall_effect': overall_effect, | |
| 'overall_significance': overall_significance, | |
| 'effect_size': effect_size, | |
| 'heterogeneity': heterogeneity | |
| } | |
| def get_model_graph(self): | |
| """生成模型 DAG 圖(返回 graphviz 物件)""" | |
| if self.model is None: | |
| raise ValueError("請先執行分析") | |
| try: | |
| gv = pm.model_to_graphviz(self.model) | |
| # 嘗試美化標籤 (如果失敗就用原圖) | |
| try: | |
| # 獲取欄位前綴 | |
| control_prefix = self.col_control_win.replace('_win', '').replace('_battles', '').replace('_total', '') | |
| treatment_prefix = self.col_treatment_win.replace('_win', '').replace('_battles', '').replace('_total', '') | |
| # 簡單替換標籤 | |
| src = gv.source | |
| src = src.replace('label="d"', f'label="d\\n({treatment_prefix} vs {control_prefix})"') | |
| src = src.replace(f'label="p_{control_prefix}"', f'label="p_{control_prefix}[i]"') | |
| src = src.replace(f'label="p_{treatment_prefix}"', f'label="p_{treatment_prefix}[i]"') | |
| src = src.replace(f'label="{self.col_control_win}_obs"', f'label="{self.col_control_win}_obs[i]"') | |
| src = src.replace(f'label="{self.col_treatment_win}_obs"', f'label="{self.col_treatment_win}_obs[i]"') | |
| src = src.replace('label="mu"', 'label="mu[i]"') | |
| src = src.replace('label="delta"', 'label="delta[i]"') | |
| gv.source = src | |
| except: | |
| pass | |
| return gv | |
| except Exception as e: | |
| raise Exception(f"無法生成 DAG 圖: {str(e)}") | |
| def get_session_results(cls, session_id): | |
| """獲取特定 session 的結果""" | |
| return cls._session_results.get(session_id) | |
| def clear_session_results(cls, session_id): | |
| """清除特定 session 的結果""" | |
| if session_id in cls._session_results: | |
| del cls._session_results[session_id] |