Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import pymc as pm | |
| import arviz as az | |
| import xarray as xr | |
| import threading | |
| from datetime import datetime | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| class BayesianMetaAnalyzer: | |
| """ | |
| 貝氏後設分析器 | |
| 用於分析不同寶可夢屬性在多個道館中的對戰表現 | |
| """ | |
| # 類別級的鎖,用於執行緒安全 | |
| _lock = threading.Lock() | |
| # 儲存各 session 的分析結果 | |
| _session_results = {} | |
| def __init__(self, session_id, treatment_type, control_type): | |
| """ | |
| 初始化分析器 | |
| Args: | |
| session_id: 唯一的 session 識別碼 | |
| treatment_type: 實驗組屬性名稱 | |
| control_type: 對照組屬性名稱 | |
| """ | |
| self.session_id = session_id | |
| self.treatment_type = treatment_type | |
| self.control_type = control_type | |
| self.df = None | |
| self.model = None | |
| self.trace_full = None | |
| self.trace = None # 收斂後的樣本 | |
| def load_data(self, df_full, treatment_type, control_type): | |
| """ | |
| 載入資料 | |
| Args: | |
| df_full: 完整的資料 DataFrame | |
| treatment_type: 實驗組屬性 | |
| control_type: 對照組屬性 | |
| """ | |
| # 建立分析資料 | |
| data = { | |
| 'rt': df_full[f'{treatment_type}_win_count'].values, | |
| 'nt': df_full[f'{treatment_type}_total_battles'].values, | |
| 'rc': df_full[f'{control_type}_win_count'].values, | |
| 'nc': df_full[f'{control_type}_total_battles'].values | |
| } | |
| self.df = pd.DataFrame(data) | |
| # 驗證資料 | |
| if (self.df['rt'] > self.df['nt']).any(): | |
| raise ValueError("實驗組勝場數不能大於總場數") | |
| if (self.df['rc'] > self.df['nc']).any(): | |
| raise ValueError("對照組勝場數不能大於總場數") | |
| return True | |
| def run_analysis(self, n_warmup=1000, n_samples=2000, n_chains=1, target_accept=0.95): | |
| """ | |
| 執行貝氏後設分析 | |
| Args: | |
| n_warmup: Warmup 樣本數(burn-in) | |
| n_samples: Posterior 樣本數 | |
| n_tune: Tune 樣本數 | |
| n_chains: 鏈數 | |
| Returns: | |
| dict: 包含所有分析結果的字典 | |
| """ | |
| with self._lock: | |
| try: | |
| if self.df is None: | |
| raise ValueError("請先載入資料") | |
| Num = len(self.df) | |
| # 建立模型 | |
| with pm.Model() as self.model: | |
| # Priors | |
| d = pm.Normal('d', mu=0, sigma=1000) | |
| tau = pm.Gamma('tau', alpha=0.001, beta=0.001) | |
| sigma = pm.Deterministic('sigma', 1 / pm.math.sqrt(tau)) | |
| # Study-specific effects | |
| mu = pm.Normal('mu', mu=0, sigma=100, shape=Num) | |
| delta = pm.Normal('delta', mu=d, sigma=1 / pm.math.sqrt(tau), shape=Num) | |
| # Logit transformations | |
| pc = pm.Deterministic(f'win_rate_{self.control_type}', pm.math.invlogit(mu)) | |
| pt = pm.Deterministic(f'win_rate_{self.treatment_type}', pm.math.invlogit(mu + delta)) | |
| # Likelihoods | |
| rc_obs = pm.Binomial( | |
| f'{self.control_type}_observed_wins', | |
| n=self.df['nc'].values, | |
| p=pc, | |
| observed=self.df['rc'].values | |
| ) | |
| rt_obs = pm.Binomial( | |
| f'{self.treatment_type}_observed_wins', | |
| n=self.df['nt'].values, | |
| p=pt, | |
| observed=self.df['rt'].values | |
| ) | |
| # Predictive distribution | |
| delta_new = pm.Normal('delta_new', mu=d, sigma=1 / pm.math.sqrt(tau)) | |
| # Odds ratio | |
| or_pooled = pm.Deterministic('or', pm.math.exp(d)) | |
| # ===== MCMC 抽樣(分階段) ===== | |
| # Step 1: 抽樣 warmup (burn-in) | |
| trace_warmup = pm.sample( | |
| n_warmup, | |
| tune=0, | |
| chains=n_chains, | |
| return_inferencedata=True, | |
| progressbar=False, | |
| discard_tuned_samples=False | |
| ) | |
| # Step 2: 抽樣 posterior | |
| trace_posterior = pm.sample( | |
| n_samples, | |
| tune=n_warmup, | |
| chains=n_chains, | |
| target_accept=target_accept, | |
| return_inferencedata=True, | |
| progressbar=False, | |
| discard_tuned_samples=False # ← 加這行!保留 tune 樣本 | |
| ) | |
| # Step 3: 手動合併 | |
| warmup_data = trace_warmup.posterior | |
| posterior_data = trace_posterior.posterior | |
| # 合併 draw 維度 | |
| combined_datasets = {} | |
| for var_name in posterior_data.data_vars: | |
| warmup_var = warmup_data[var_name] | |
| posterior_var = posterior_data[var_name] | |
| # 沿著 draw 維度合併 | |
| combined_var = xr.concat([warmup_var, posterior_var], dim='draw') | |
| combined_datasets[var_name] = combined_var | |
| # 建立完整的 posterior | |
| combined_posterior = xr.Dataset(combined_datasets) | |
| total_draws = n_warmup + n_samples | |
| combined_posterior = combined_posterior.assign_coords(draw=range(total_draws)) | |
| # 建立完整的 InferenceData | |
| self.trace_full = az.InferenceData(posterior=combined_posterior) | |
| # 提取收斂後的樣本(去除 warmup) | |
| self.trace = self.trace_full.sel(draw=slice(n_warmup, None)) | |
| # 生成摘要統計(使用收斂後的樣本) | |
| summary = az.summary( | |
| self.trace, | |
| var_names=['d', 'delta_new', 'sigma', 'or'], | |
| hdi_prob=0.95 | |
| ) | |
| # 提取關鍵統計量 | |
| d_samples = self.trace.posterior['d'].values.flatten() | |
| delta_new_samples = self.trace.posterior['delta_new'].values.flatten() | |
| sigma_samples = self.trace.posterior['sigma'].values.flatten() | |
| or_samples = self.trace.posterior['or'].values.flatten() | |
| # 計算各研究的 delta 統計量 | |
| delta_posterior = self.trace.posterior['delta'].values | |
| if delta_posterior.ndim == 3: # (chain, draw, study) | |
| delta_posterior = delta_posterior.reshape(-1, Num) | |
| delta_mean = delta_posterior.mean(axis=0) | |
| delta_std = delta_posterior.std(axis=0) | |
| delta_hdi = az.hdi(self.trace, var_names=['delta'], hdi_prob=0.95)['delta'].values | |
| # 整理結果 | |
| results = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'treatment_type': self.treatment_type, | |
| 'control_type': self.control_type, | |
| 'n_studies': Num, | |
| # 整體效應 | |
| 'overall': { | |
| 'd_mean': float(summary.loc['d', 'mean']), | |
| 'd_sd': float(summary.loc['d', 'sd']), | |
| 'd_hdi_low': float(summary.loc['d', 'hdi_2.5%']), | |
| 'd_hdi_high': float(summary.loc['d', 'hdi_97.5%']), | |
| 'sigma_mean': float(summary.loc['sigma', 'mean']), | |
| 'sigma_sd': float(summary.loc['sigma', 'sd']), | |
| 'sigma_hdi_low': float(summary.loc['sigma', 'hdi_2.5%']), | |
| 'sigma_hdi_high': float(summary.loc['sigma', 'hdi_97.5%']), | |
| 'or_mean': float(np.exp(summary.loc['d', 'mean'])), | |
| 'or_sd': float(summary.loc['or', 'sd']), | |
| #'or_hdi_low': float(summary.loc['or', 'hdi_2.5%']), | |
| #'or_hdi_high': float(summary.loc['or', 'hdi_97.5%']), | |
| 'or_hdi_low': float(np.exp(summary.loc['d', 'hdi_2.5%'])), | |
| 'or_hdi_high': float(np.exp(summary.loc['d', 'hdi_97.5%'])), | |
| }, | |
| # 預測效應 | |
| 'predictive': { | |
| 'delta_new_mean': float(summary.loc['delta_new', 'mean']), | |
| 'delta_new_sd': float(summary.loc['delta_new', 'sd']), | |
| 'delta_new_hdi_low': float(summary.loc['delta_new', 'hdi_2.5%']), | |
| 'delta_new_hdi_high': float(summary.loc['delta_new', 'hdi_97.5%']), | |
| 'or_new_mean': float(np.exp(delta_new_samples.mean())), | |
| #'or_new_mean': float(np.exp(delta_new_samples).mean()), | |
| 'or_new_hdi_low': float(np.percentile(np.exp(delta_new_samples), 2.5)), | |
| 'or_new_hdi_high': float(np.percentile(np.exp(delta_new_samples), 97.5)), | |
| 'uncertainty_ratio': float(delta_new_samples.std() / d_samples.std()), | |
| }, | |
| # 各研究的效應 | |
| 'by_study': { | |
| 'delta_mean': delta_mean.tolist(), | |
| 'delta_std': delta_std.tolist(), | |
| 'delta_hdi_low': delta_hdi[:, 0].tolist(), | |
| 'delta_hdi_high': delta_hdi[:, 1].tolist(), | |
| }, | |
| # 原始資料 | |
| 'data': self.df.to_dict('records'), | |
| # 模型參數 | |
| 'model_params': { | |
| 'n_warmup': n_warmup, | |
| 'n_samples': n_samples, | |
| 'n_chains': n_chains, | |
| 'total_draws': n_warmup + n_samples, | |
| }, | |
| # 收斂診斷 | |
| 'diagnostics': self._compute_diagnostics(summary), | |
| # Trace 資料(用於繪圖) | |
| 'trace_data': { | |
| 'd': self.trace_full.posterior['d'].values.tolist(), # 保留完整形狀 | |
| 'sigma': self.trace_full.posterior['sigma'].values.tolist(), | |
| 'delta_new': self.trace_full.posterior['delta_new'].values.tolist(), | |
| 'or': self.trace_full.posterior['or'].values.tolist(), | |
| 'warmup_end': n_warmup, | |
| 'n_chains': n_chains, # 加上鏈數資訊 | |
| } | |
| } | |
| # 儲存到 session results | |
| self._session_results[self.session_id] = results | |
| return results | |
| except Exception as e: | |
| raise Exception(f"分析失敗: {str(e)}") | |
| def _compute_diagnostics(self, summary): | |
| """計算收斂診斷指標""" | |
| try: | |
| # R-hat (應該接近 1.0) | |
| rhat_d = float(summary.loc['d', 'r_hat']) if 'r_hat' in summary.columns else None | |
| rhat_sigma = float(summary.loc['sigma', 'r_hat']) if 'r_hat' in summary.columns else None | |
| # ESS (有效樣本數) | |
| ess_d = float(summary.loc['d', 'ess_bulk']) if 'ess_bulk' in summary.columns else None | |
| ess_sigma = float(summary.loc['sigma', 'ess_bulk']) if 'ess_bulk' in summary.columns else None | |
| return { | |
| 'rhat_d': rhat_d, | |
| 'rhat_sigma': rhat_sigma, | |
| 'ess_d': ess_d, | |
| 'ess_sigma': ess_sigma, | |
| 'converged': (rhat_d is None or rhat_d < 1.1) and (rhat_sigma is None or rhat_sigma < 1.1) | |
| } | |
| except: | |
| return { | |
| 'converged': None, | |
| 'rhat_d': None, | |
| 'rhat_sigma': None, | |
| 'ess_d': None, | |
| 'ess_sigma': None | |
| } | |
| def get_session_results(cls, session_id): | |
| """獲取特定 session 的結果""" | |
| return cls._session_results.get(session_id) | |
| def clear_session_results(cls, session_id): | |
| """清除特定 session 的結果""" | |
| if session_id in cls._session_results: | |
| del cls._session_results[session_id] | |