Spaces:
Sleeping
Sleeping
File size: 17,212 Bytes
662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 9d95a80 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be 8d13eaf 662a0be de34a92 52fc51b 662a0be 52fc51b 328bfd2 5f41fef de34a92 52fc51b 662a0be 52fc51b 662a0be 328bfd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 | import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import threading
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
class BayesianHierarchicalAnalyzer:
"""
貝氏階層模型分析器
用於分析寶可夢速度對勝率的影響(跨屬性)
"""
# 類別級的鎖,用於執行緒安全
_lock = threading.Lock()
# 儲存各 session 的分析結果
_session_results = {}
def __init__(self, session_id):
"""
初始化分析器
Args:
session_id: 唯一的 session 識別碼
"""
self.session_id = session_id
self.df = None
self.model = None
self.trace = None
# 👇 加入這些屬性
self.col_trial_type = None # 配對名稱欄位
self.col_control_win = None # 控制組勝場欄位
self.col_control_total = None # 控制組總場欄位
self.col_treatment_win = None # 實驗組勝場欄位
self.col_treatment_total = None # 實驗組總場欄位
def load_data(self, csv_path_or_df):
"""
載入資料 (自動識別欄位名稱)
Args:
csv_path_or_df: CSV 檔案路徑或 DataFrame
Expected format:
第 1 欄: 配對名稱 (Trial_Type)
第 2 欄: 控制組勝場 (例如 water_win)
第 3 欄: 控制組總場 (例如 water_battles)
第 4 欄: 實驗組勝場 (例如 fire_win)
第 5 欄: 實驗組總場 (例如 fire_battles)
"""
if isinstance(csv_path_or_df, str):
self.df = pd.read_csv(csv_path_or_df)
else:
self.df = csv_path_or_df.copy()
# 檢查欄位數量
if len(self.df.columns) < 5:
raise ValueError(f"資料至少需要 5 欄,目前只有 {len(self.df.columns)} 欄")
# 自動識別欄位名稱 (假設前 5 欄按照固定順序)
cols = self.df.columns.tolist()
self.col_trial_type = cols[0]
self.col_control_win = cols[1]
self.col_control_total = cols[2]
self.col_treatment_win = cols[3]
self.col_treatment_total = cols[4]
print(f"✓ 自動識別欄位:")
print(f" - 配對名稱: {self.col_trial_type}")
print(f" - 控制組: {self.col_control_win}/{self.col_control_total}")
print(f" - 實驗組: {self.col_treatment_win}/{self.col_treatment_total}")
return True
def validate_data(self):
"""驗證資料有效性"""
if self.df is None:
raise ValueError("請先載入資料")
# 檢查數值欄位
numeric_cols = [
self.col_control_win,
self.col_control_total,
self.col_treatment_win,
self.col_treatment_total
]
for col in numeric_cols:
if not pd.api.types.is_numeric_dtype(self.df[col]):
raise ValueError(f"欄位 {col} 必須是數值類型")
# 檢查邏輯約束
if (self.df[self.col_control_win] > self.df[self.col_control_total]).any():
raise ValueError(f"{self.col_control_win} (勝場數) 不能大於 {self.col_control_total} (總場數)")
if (self.df[self.col_treatment_win] > self.df[self.col_treatment_total]).any():
raise ValueError(f"{self.col_treatment_win} (勝場數) 不能大於 {self.col_treatment_total} (總場數)")
return True
def run_analysis(self, n_samples=2000, n_tune=1000, n_chains=2, target_accept=0.95):
"""
執行貝氏階層模型分析
Args:
n_samples: MCMC 抽樣數
n_tune: 調整期樣本數
n_chains: 鏈數
target_accept: 目標接受率
Returns:
dict: 包含所有分析結果的字典
"""
with self._lock:
try:
self.validate_data()
# 準備資料
trial_labels = self.df[self.col_trial_type].values
num_trials = len(self.df)
# 提取欄位名稱用於模型
control_win_name = self.col_control_win
control_total_name = self.col_control_total
treatment_win_name = self.col_treatment_win
treatment_total_name = self.col_treatment_total
# 提取前綴用於變數命名 (例如 "water_win" → "water")
control_prefix = control_win_name.replace('_win', '').replace('_battles', '').replace('_total', '')
treatment_prefix = treatment_win_name.replace('_win', '').replace('_battles', '').replace('_total', '')
# 建立模型
with pm.Model() as self.model:
# --- 先驗分佈 (Priors) ---
d = pm.Normal('d', mu=0, sigma=10)
tau = pm.Gamma('tau', alpha=0.001, beta=0.001)
sigma = pm.Deterministic('sigma', 1 / pm.math.sqrt(tau))
# --- 各配對特定效應 (Pair-specific effects) ---
mu = pm.Normal('mu', mu=0, sigma=10, shape=num_trials)
delta = pm.Normal('delta', mu=d, sigma=1 / pm.math.sqrt(tau), shape=num_trials)
# --- 轉換與似然函數 (Logit Link & Likelihood) ---
# 使用動態命名
p_control = pm.Deterministic(f'p_{control_prefix}', pm.math.invlogit(mu))
p_treatment = pm.Deterministic(f'p_{treatment_prefix}', pm.math.invlogit(mu + delta))
# 使用動態欄位名稱創建觀測值
control_obs = pm.Binomial(
f'{control_win_name}_obs',
n=self.df[control_total_name].values,
p=p_control,
observed=self.df[control_win_name].values
)
treatment_obs = pm.Binomial(
f'{treatment_win_name}_obs',
n=self.df[treatment_total_name].values,
p=p_treatment,
observed=self.df[treatment_win_name].values
)
# --- 其他統計量 ---
delta_new = pm.Normal('delta_new', mu=d, sigma=1 / pm.math.sqrt(tau))
or_speed = pm.Deterministic('or_speed', pm.math.exp(d))
# 執行 MCMC 抽樣
self.trace = pm.sample(
draws=n_samples,
tune=n_tune,
chains=n_chains,
target_accept=target_accept,
return_inferencedata=True,
progressbar=False, # 在 Streamlit 中關閉進度條
discard_tuned_samples=False # 保留 tune 樣本
)
# 生成摘要統計
summary = az.summary(self.trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95)
# 計算各配對的 delta 統計量
delta_posterior = self.trace.posterior['delta'].values.reshape(-1, num_trials)
delta_mean = delta_posterior.mean(axis=0)
delta_std = delta_posterior.std(axis=0)
delta_hdi = az.hdi(self.trace, var_names=['delta'], hdi_prob=0.95)['delta'].values
# 判斷顯著性(HDI 不包含 0)
delta_significant = (delta_hdi[:, 0] > 0) | (delta_hdi[:, 1] < 0)
# 計算控制組和實驗組的勝率 (使用動態變數名稱)
p_control_posterior = self.trace.posterior[f'p_{control_prefix}'].values.reshape(-1, num_trials)
p_treatment_posterior = self.trace.posterior[f'p_{treatment_prefix}'].values.reshape(-1, num_trials)
p_control_mean = p_control_posterior.mean(axis=0)
p_treatment_mean = p_treatment_posterior.mean(axis=0)
# 整理結果
results = {
'timestamp': datetime.now().isoformat(),
'n_trials': num_trials,
'trial_labels': trial_labels.tolist(),
# 欄位名稱資訊
'column_names': {
'trial_type': self.col_trial_type,
'control_win': control_win_name,
'control_total': control_total_name,
'treatment_win': treatment_win_name,
'treatment_total': treatment_total_name,
'control_prefix': control_prefix,
'treatment_prefix': treatment_prefix
},
# 整體效應
'overall': {
'd_mean': float(summary.loc['d', 'mean']),
'd_sd': float(summary.loc['d', 'sd']),
'd_hdi_low': float(summary.loc['d', 'hdi_2.5%']),
'd_hdi_high': float(summary.loc['d', 'hdi_97.5%']),
'sigma_mean': float(summary.loc['sigma', 'mean']),
'sigma_sd': float(summary.loc['sigma', 'sd']),
'sigma_hdi_low': float(summary.loc['sigma', 'hdi_2.5%']),
'sigma_hdi_high': float(summary.loc['sigma', 'hdi_97.5%']),
'or_mean': float(summary.loc['or_speed', 'mean']),
'or_sd': float(summary.loc['or_speed', 'sd']),
'or_hdi_low': float(summary.loc['or_speed', 'hdi_2.5%']),
'or_hdi_high': float(summary.loc['or_speed', 'hdi_97.5%']),
},
# 各配對的效應 (使用動態鍵名)
'by_trial': {
'delta_mean': delta_mean.tolist(),
'delta_std': delta_std.tolist(),
'delta_hdi_low': delta_hdi[:, 0].tolist(),
'delta_hdi_high': delta_hdi[:, 1].tolist(),
'delta_significant': delta_significant.tolist(),
f'p_{control_prefix}_mean': p_control_mean.tolist(),
f'p_{treatment_prefix}_mean': p_treatment_mean.tolist(),
},
# 原始資料
'data': self.df.to_dict('records'),
# 模型參數
'model_params': {
'n_samples': n_samples,
'n_tune': n_tune,
'n_chains': n_chains,
'target_accept': target_accept
},
# 收斂診斷
'diagnostics': self._compute_diagnostics(summary),
# 解釋
'interpretation': self._interpret_results(
summary.loc['or_speed', 'mean'],
summary.loc['or_speed', 'hdi_2.5%'],
summary.loc['or_speed', 'hdi_97.5%'],
summary.loc['sigma', 'mean']
)
}
# 儲存到 session results
self._session_results[self.session_id] = results
return results
except Exception as e:
raise Exception(f"分析失敗: {str(e)}")
def _compute_diagnostics(self, summary):
"""計算收斂診斷指標"""
try:
# R-hat (應該接近 1.0)
rhat_d = float(summary.loc['d', 'r_hat']) if 'r_hat' in summary.columns else None
rhat_sigma = float(summary.loc['sigma', 'r_hat']) if 'r_hat' in summary.columns else None
# ESS (有效樣本數)
ess_d = float(summary.loc['d', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
ess_sigma = float(summary.loc['sigma', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
return {
'rhat_d': rhat_d,
'rhat_sigma': rhat_sigma,
'ess_d': ess_d,
'ess_sigma': ess_sigma,
'converged': (rhat_d is None or rhat_d < 1.1) and (rhat_sigma is None or rhat_sigma < 1.1)
}
except:
return {
'converged': None,
'rhat_d': None,
'rhat_sigma': None,
'ess_d': None,
'ess_sigma': None
}
def _interpret_results(self, or_mean, or_low, or_high, sigma_mean):
"""解釋分析結果"""
# 整體效應顯著性
if or_low > 1:
overall_effect = "火系寶可夢相對於水系顯著更容易獲勝"
overall_significance = "顯著正效應"
elif or_high < 1:
overall_effect = "水系寶可夢相對於火系顯著更容易獲勝"
overall_significance = "顯著負效應"
else:
overall_effect = "火系與水系勝率無顯著差異"
overall_significance = "不顯著"
# 效果大小
if or_mean > 2:
effect_size = "大效果 (OR > 2) - 火系有明顯優勢"
elif or_mean > 1.5:
effect_size = "中等效果 (OR > 1.5) - 火系有一定優勢"
elif or_mean > 1:
effect_size = "小效果 (OR > 1) - 火系略有優勢"
elif or_mean == 1:
effect_size = "無差異 (OR = 1) - 火系與水系勢均力敵"
elif or_mean > 0.67:
effect_size = "小效果 (OR < 1) - 水系略有優勢"
elif or_mean > 0.5:
effect_size = "中等效果 (OR < 0.67) - 水系有一定優勢"
else:
effect_size = "大效果 (OR < 0.5) - 水系有明顯優勢"
# 異質性評估
if sigma_mean > 0.5:
heterogeneity = "高異質性 - 不同配對的勝率差異很大"
elif sigma_mean > 0.3:
heterogeneity = "中等異質性 - 不同配對的勝率有一定差異"
else:
heterogeneity = "低異質性 - 不同配對的勝率相對一致"
return {
'overall_effect': overall_effect,
'overall_significance': overall_significance,
'effect_size': effect_size,
'heterogeneity': heterogeneity
}
def get_model_graph(self):
"""生成模型 DAG 圖(返回 graphviz 物件)"""
if self.model is None:
raise ValueError("請先執行分析")
try:
gv = pm.model_to_graphviz(self.model)
# 嘗試美化標籤 (如果失敗就用原圖)
try:
# 獲取欄位前綴
control_prefix = self.col_control_win.replace('_win', '').replace('_battles', '').replace('_total', '')
treatment_prefix = self.col_treatment_win.replace('_win', '').replace('_battles', '').replace('_total', '')
# 簡單替換標籤
src = gv.source
src = src.replace('label="d"', f'label="d\\n({treatment_prefix} vs {control_prefix})"')
src = src.replace(f'label="p_{control_prefix}"', f'label="p_{control_prefix}[i]"')
src = src.replace(f'label="p_{treatment_prefix}"', f'label="p_{treatment_prefix}[i]"')
src = src.replace(f'label="{self.col_control_win}_obs"', f'label="{self.col_control_win}_obs[i]"')
src = src.replace(f'label="{self.col_treatment_win}_obs"', f'label="{self.col_treatment_win}_obs[i]"')
src = src.replace('label="mu"', 'label="mu[i]"')
src = src.replace('label="delta"', 'label="delta[i]"')
gv.source = src
except:
pass
return gv
except Exception as e:
raise Exception(f"無法生成 DAG 圖: {str(e)}")
@classmethod
def get_session_results(cls, session_id):
"""獲取特定 session 的結果"""
return cls._session_results.get(session_id)
@classmethod
def clear_session_results(cls, session_id):
"""清除特定 session 的結果"""
if session_id in cls._session_results:
del cls._session_results[session_id] |