meta_analysis / meta_analysis_core.py
Donlagon007's picture
Upload 8 files
a0bc3fc verified
raw
history blame
13.3 kB
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import xarray as xr
import threading
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
class BayesianMetaAnalyzer:
"""
貝氏後設分析器
用於分析不同寶可夢屬性在多個道館中的對戰表現
"""
# 類別級的鎖,用於執行緒安全
_lock = threading.Lock()
# 儲存各 session 的分析結果
_session_results = {}
def __init__(self, session_id, treatment_type, control_type):
"""
初始化分析器
Args:
session_id: 唯一的 session 識別碼
treatment_type: 實驗組屬性名稱
control_type: 對照組屬性名稱
"""
self.session_id = session_id
self.treatment_type = treatment_type
self.control_type = control_type
self.df = None
self.model = None
self.trace_full = None
self.trace = None # 收斂後的樣本
def load_data(self, df_full, treatment_type, control_type):
"""
載入資料
Args:
df_full: 完整的資料 DataFrame
treatment_type: 實驗組屬性
control_type: 對照組屬性
"""
# 建立分析資料
data = {
'rt': df_full[f'{treatment_type}_win_count'].values,
'nt': df_full[f'{treatment_type}_total_battles'].values,
'rc': df_full[f'{control_type}_win_count'].values,
'nc': df_full[f'{control_type}_total_battles'].values
}
self.df = pd.DataFrame(data)
# 驗證資料
if (self.df['rt'] > self.df['nt']).any():
raise ValueError("實驗組勝場數不能大於總場數")
if (self.df['rc'] > self.df['nc']).any():
raise ValueError("對照組勝場數不能大於總場數")
return True
def run_analysis(self, n_warmup=1000, n_samples=2000, n_chains=1, target_accept=0.95):
"""
執行貝氏後設分析
Args:
n_warmup: Warmup 樣本數(burn-in)
n_samples: Posterior 樣本數
n_tune: Tune 樣本數
n_chains: 鏈數
Returns:
dict: 包含所有分析結果的字典
"""
with self._lock:
try:
if self.df is None:
raise ValueError("請先載入資料")
Num = len(self.df)
# 建立模型
with pm.Model() as self.model:
# Priors
d = pm.Normal('d', mu=0, sigma=1000)
tau = pm.Gamma('tau', alpha=0.001, beta=0.001)
sigma = pm.Deterministic('sigma', 1 / pm.math.sqrt(tau))
# Study-specific effects
mu = pm.Normal('mu', mu=0, sigma=100, shape=Num)
delta = pm.Normal('delta', mu=d, sigma=1 / pm.math.sqrt(tau), shape=Num)
# Logit transformations
pc = pm.Deterministic(f'win_rate_{self.control_type}', pm.math.invlogit(mu))
pt = pm.Deterministic(f'win_rate_{self.treatment_type}', pm.math.invlogit(mu + delta))
# Likelihoods
rc_obs = pm.Binomial(
f'{self.control_type}_observed_wins',
n=self.df['nc'].values,
p=pc,
observed=self.df['rc'].values
)
rt_obs = pm.Binomial(
f'{self.treatment_type}_observed_wins',
n=self.df['nt'].values,
p=pt,
observed=self.df['rt'].values
)
# Predictive distribution
delta_new = pm.Normal('delta_new', mu=d, sigma=1 / pm.math.sqrt(tau))
# Odds ratio
or_pooled = pm.Deterministic('or', pm.math.exp(d))
# ===== MCMC 抽樣(分階段) =====
# Step 1: 抽樣 warmup (burn-in)
trace_warmup = pm.sample(
n_warmup,
tune=0,
chains=n_chains,
return_inferencedata=True,
progressbar=False,
discard_tuned_samples=False
)
# Step 2: 抽樣 posterior
trace_posterior = pm.sample(
n_samples,
tune=n_warmup,
chains=n_chains,
target_accept=target_accept,
return_inferencedata=True,
progressbar=False,
discard_tuned_samples=False # ← 加這行!保留 tune 樣本
)
# Step 3: 手動合併
warmup_data = trace_warmup.posterior
posterior_data = trace_posterior.posterior
# 合併 draw 維度
combined_datasets = {}
for var_name in posterior_data.data_vars:
warmup_var = warmup_data[var_name]
posterior_var = posterior_data[var_name]
# 沿著 draw 維度合併
combined_var = xr.concat([warmup_var, posterior_var], dim='draw')
combined_datasets[var_name] = combined_var
# 建立完整的 posterior
combined_posterior = xr.Dataset(combined_datasets)
total_draws = n_warmup + n_samples
combined_posterior = combined_posterior.assign_coords(draw=range(total_draws))
# 建立完整的 InferenceData
self.trace_full = az.InferenceData(posterior=combined_posterior)
# 提取收斂後的樣本(去除 warmup)
self.trace = self.trace_full.sel(draw=slice(n_warmup, None))
# 生成摘要統計(使用收斂後的樣本)
summary = az.summary(
self.trace,
var_names=['d', 'delta_new', 'sigma', 'or'],
hdi_prob=0.95
)
# 提取關鍵統計量
d_samples = self.trace.posterior['d'].values.flatten()
delta_new_samples = self.trace.posterior['delta_new'].values.flatten()
sigma_samples = self.trace.posterior['sigma'].values.flatten()
or_samples = self.trace.posterior['or'].values.flatten()
# 計算各研究的 delta 統計量
delta_posterior = self.trace.posterior['delta'].values
if delta_posterior.ndim == 3: # (chain, draw, study)
delta_posterior = delta_posterior.reshape(-1, Num)
delta_mean = delta_posterior.mean(axis=0)
delta_std = delta_posterior.std(axis=0)
delta_hdi = az.hdi(self.trace, var_names=['delta'], hdi_prob=0.95)['delta'].values
# 整理結果
results = {
'timestamp': datetime.now().isoformat(),
'treatment_type': self.treatment_type,
'control_type': self.control_type,
'n_studies': Num,
# 整體效應
'overall': {
'd_mean': float(summary.loc['d', 'mean']),
'd_sd': float(summary.loc['d', 'sd']),
'd_hdi_low': float(summary.loc['d', 'hdi_2.5%']),
'd_hdi_high': float(summary.loc['d', 'hdi_97.5%']),
'sigma_mean': float(summary.loc['sigma', 'mean']),
'sigma_sd': float(summary.loc['sigma', 'sd']),
'sigma_hdi_low': float(summary.loc['sigma', 'hdi_2.5%']),
'sigma_hdi_high': float(summary.loc['sigma', 'hdi_97.5%']),
'or_mean': float(np.exp(summary.loc['d', 'mean'])),
'or_sd': float(summary.loc['or', 'sd']),
#'or_hdi_low': float(summary.loc['or', 'hdi_2.5%']),
#'or_hdi_high': float(summary.loc['or', 'hdi_97.5%']),
'or_hdi_low': float(np.exp(summary.loc['d', 'hdi_2.5%'])),
'or_hdi_high': float(np.exp(summary.loc['d', 'hdi_97.5%'])),
},
# 預測效應
'predictive': {
'delta_new_mean': float(summary.loc['delta_new', 'mean']),
'delta_new_sd': float(summary.loc['delta_new', 'sd']),
'delta_new_hdi_low': float(summary.loc['delta_new', 'hdi_2.5%']),
'delta_new_hdi_high': float(summary.loc['delta_new', 'hdi_97.5%']),
'or_new_mean': float(np.exp(delta_new_samples.mean())),
#'or_new_mean': float(np.exp(delta_new_samples).mean()),
'or_new_hdi_low': float(np.percentile(np.exp(delta_new_samples), 2.5)),
'or_new_hdi_high': float(np.percentile(np.exp(delta_new_samples), 97.5)),
'uncertainty_ratio': float(delta_new_samples.std() / d_samples.std()),
},
# 各研究的效應
'by_study': {
'delta_mean': delta_mean.tolist(),
'delta_std': delta_std.tolist(),
'delta_hdi_low': delta_hdi[:, 0].tolist(),
'delta_hdi_high': delta_hdi[:, 1].tolist(),
},
# 原始資料
'data': self.df.to_dict('records'),
# 模型參數
'model_params': {
'n_warmup': n_warmup,
'n_samples': n_samples,
'n_chains': n_chains,
'total_draws': n_warmup + n_samples,
},
# 收斂診斷
'diagnostics': self._compute_diagnostics(summary),
# Trace 資料(用於繪圖)
'trace_data': {
'd': self.trace_full.posterior['d'].values.tolist(), # 保留完整形狀
'sigma': self.trace_full.posterior['sigma'].values.tolist(),
'delta_new': self.trace_full.posterior['delta_new'].values.tolist(),
'or': self.trace_full.posterior['or'].values.tolist(),
'warmup_end': n_warmup,
'n_chains': n_chains, # 加上鏈數資訊
}
}
# 儲存到 session results
self._session_results[self.session_id] = results
return results
except Exception as e:
raise Exception(f"分析失敗: {str(e)}")
def _compute_diagnostics(self, summary):
"""計算收斂診斷指標"""
try:
# R-hat (應該接近 1.0)
rhat_d = float(summary.loc['d', 'r_hat']) if 'r_hat' in summary.columns else None
rhat_sigma = float(summary.loc['sigma', 'r_hat']) if 'r_hat' in summary.columns else None
# ESS (有效樣本數)
ess_d = float(summary.loc['d', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
ess_sigma = float(summary.loc['sigma', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
return {
'rhat_d': rhat_d,
'rhat_sigma': rhat_sigma,
'ess_d': ess_d,
'ess_sigma': ess_sigma,
'converged': (rhat_d is None or rhat_d < 1.1) and (rhat_sigma is None or rhat_sigma < 1.1)
}
except:
return {
'converged': None,
'rhat_d': None,
'rhat_sigma': None,
'ess_d': None,
'ess_sigma': None
}
@classmethod
def get_session_results(cls, session_id):
"""獲取特定 session 的結果"""
return cls._session_results.get(session_id)
@classmethod
def clear_session_results(cls, session_id):
"""清除特定 session 的結果"""
if session_id in cls._session_results:
del cls._session_results[session_id]