File size: 17,212 Bytes
662a0be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d95a80
662a0be
 
 
 
 
 
 
 
 
 
 
9d95a80
 
 
 
 
 
 
 
 
662a0be
 
9d95a80
662a0be
 
 
 
9d95a80
 
 
 
 
 
662a0be
 
 
 
 
 
9d95a80
 
 
662a0be
9d95a80
 
 
 
 
 
 
 
 
 
 
 
 
 
662a0be
 
 
 
 
 
 
9d95a80
 
 
 
 
 
 
 
662a0be
 
 
 
9d95a80
 
662a0be
9d95a80
 
662a0be
9d95a80
662a0be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d95a80
662a0be
 
9d95a80
 
 
 
 
 
 
 
 
 
662a0be
 
 
 
 
 
 
9d95a80
662a0be
 
 
 
9d95a80
 
 
 
 
 
 
 
 
 
 
662a0be
9d95a80
 
 
 
 
 
662a0be
 
 
 
 
 
 
 
 
 
 
 
9d95a80
 
662a0be
 
 
 
 
9d95a80
662a0be
 
 
 
 
 
 
 
9d95a80
 
 
662a0be
9d95a80
 
662a0be
 
 
 
 
 
 
9d95a80
 
 
 
 
 
 
 
 
 
 
662a0be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d95a80
662a0be
 
 
 
 
 
9d95a80
 
662a0be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d95a80
662a0be
9d95a80
662a0be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d13eaf
662a0be
8d13eaf
662a0be
 
8d13eaf
662a0be
 
8d13eaf
 
 
 
662a0be
8d13eaf
662a0be
8d13eaf
662a0be
8d13eaf
662a0be
8d13eaf
 
 
 
 
662a0be
8d13eaf
 
662a0be
 
 
8d13eaf
662a0be
8d13eaf
662a0be
8d13eaf
 
662a0be
 
 
 
 
 
 
de34a92
52fc51b
662a0be
 
 
 
52fc51b
328bfd2
5f41fef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de34a92
52fc51b
662a0be
 
52fc51b
 
662a0be
 
 
 
 
 
 
 
328bfd2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import threading
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

class BayesianHierarchicalAnalyzer:
    """
    貝氏階層模型分析器
    用於分析寶可夢速度對勝率的影響(跨屬性)
    """
    
    # 類別級的鎖,用於執行緒安全
    _lock = threading.Lock()
    
    # 儲存各 session 的分析結果
    _session_results = {}
   
    def __init__(self, session_id):
        """
        初始化分析器
        
        Args:
            session_id: 唯一的 session 識別碼
        """
        self.session_id = session_id
        self.df = None
        self.model = None
        self.trace = None
        
        # 👇 加入這些屬性
        self.col_trial_type = None      # 配對名稱欄位
        self.col_control_win = None     # 控制組勝場欄位
        self.col_control_total = None   # 控制組總場欄位  
        self.col_treatment_win = None   # 實驗組勝場欄位
        self.col_treatment_total = None # 實驗組總場欄位   
   
             
    def load_data(self, csv_path_or_df):
        """
        載入資料 (自動識別欄位名稱)
        
        Args:
            csv_path_or_df: CSV 檔案路徑或 DataFrame
            
        Expected format:
            第 1 欄: 配對名稱 (Trial_Type)
            第 2 欄: 控制組勝場 (例如 water_win)
            第 3 欄: 控制組總場 (例如 water_battles)
            第 4 欄: 實驗組勝場 (例如 fire_win)
            第 5 欄: 實驗組總場 (例如 fire_battles)
        """
        if isinstance(csv_path_or_df, str):
            self.df = pd.read_csv(csv_path_or_df)
        else:
            self.df = csv_path_or_df.copy()
        
        # 檢查欄位數量
        if len(self.df.columns) < 5:
            raise ValueError(f"資料至少需要 5 欄,目前只有 {len(self.df.columns)} 欄")
        
        # 自動識別欄位名稱 (假設前 5 欄按照固定順序)
        cols = self.df.columns.tolist()
        self.col_trial_type = cols[0]
        self.col_control_win = cols[1]
        self.col_control_total = cols[2]
        self.col_treatment_win = cols[3]
        self.col_treatment_total = cols[4]
        
        print(f"✓ 自動識別欄位:")
        print(f"  - 配對名稱: {self.col_trial_type}")
        print(f"  - 控制組: {self.col_control_win}/{self.col_control_total}")
        print(f"  - 實驗組: {self.col_treatment_win}/{self.col_treatment_total}")
        
        return True        
        
    def validate_data(self):
        """驗證資料有效性"""
        if self.df is None:
            raise ValueError("請先載入資料")
        
        # 檢查數值欄位
        numeric_cols = [
            self.col_control_win, 
            self.col_control_total,
            self.col_treatment_win, 
            self.col_treatment_total
        ]
        
        for col in numeric_cols:
            if not pd.api.types.is_numeric_dtype(self.df[col]):
                raise ValueError(f"欄位 {col} 必須是數值類型")
        
        # 檢查邏輯約束
        if (self.df[self.col_control_win] > self.df[self.col_control_total]).any():
            raise ValueError(f"{self.col_control_win} (勝場數) 不能大於 {self.col_control_total} (總場數)")
        
        if (self.df[self.col_treatment_win] > self.df[self.col_treatment_total]).any():
            raise ValueError(f"{self.col_treatment_win} (勝場數) 不能大於 {self.col_treatment_total} (總場數)")
        
        return True    
    
    def run_analysis(self, n_samples=2000, n_tune=1000, n_chains=2, target_accept=0.95):
        """
        執行貝氏階層模型分析
        
        Args:
            n_samples: MCMC 抽樣數
            n_tune: 調整期樣本數
            n_chains: 鏈數
            target_accept: 目標接受率
            
        Returns:
            dict: 包含所有分析結果的字典
        """
        with self._lock:
            try:
                self.validate_data()
                
                # 準備資料
                trial_labels = self.df[self.col_trial_type].values
                num_trials = len(self.df)
                
                # 提取欄位名稱用於模型
                control_win_name = self.col_control_win
                control_total_name = self.col_control_total
                treatment_win_name = self.col_treatment_win
                treatment_total_name = self.col_treatment_total
                
                # 提取前綴用於變數命名 (例如 "water_win" → "water")
                control_prefix = control_win_name.replace('_win', '').replace('_battles', '').replace('_total', '')
                treatment_prefix = treatment_win_name.replace('_win', '').replace('_battles', '').replace('_total', '')
                
                # 建立模型
                with pm.Model() as self.model:
                    # --- 先驗分佈 (Priors) ---
                    d = pm.Normal('d', mu=0, sigma=10)
                    tau = pm.Gamma('tau', alpha=0.001, beta=0.001)
                    sigma = pm.Deterministic('sigma', 1 / pm.math.sqrt(tau))
                    
                    # --- 各配對特定效應 (Pair-specific effects) ---
                    mu = pm.Normal('mu', mu=0, sigma=10, shape=num_trials)
                    delta = pm.Normal('delta', mu=d, sigma=1 / pm.math.sqrt(tau), shape=num_trials)
                    
                    # --- 轉換與似然函數 (Logit Link & Likelihood) ---
                    # 使用動態命名
                    p_control = pm.Deterministic(f'p_{control_prefix}', pm.math.invlogit(mu))
                    p_treatment = pm.Deterministic(f'p_{treatment_prefix}', pm.math.invlogit(mu + delta))
                    
                    # 使用動態欄位名稱創建觀測值
                    control_obs = pm.Binomial(
                        f'{control_win_name}_obs',
                        n=self.df[control_total_name].values,
                        p=p_control,
                        observed=self.df[control_win_name].values
                    )
                    
                    treatment_obs = pm.Binomial(
                        f'{treatment_win_name}_obs',
                        n=self.df[treatment_total_name].values,
                        p=p_treatment,
                        observed=self.df[treatment_win_name].values
                    )
                    
                    # --- 其他統計量 ---
                    delta_new = pm.Normal('delta_new', mu=d, sigma=1 / pm.math.sqrt(tau))
                    or_speed = pm.Deterministic('or_speed', pm.math.exp(d))
                    
                    # 執行 MCMC 抽樣
                    self.trace = pm.sample(
                        draws=n_samples,
                        tune=n_tune,
                        chains=n_chains,
                        target_accept=target_accept,
                        return_inferencedata=True,
                        progressbar=False,  # 在 Streamlit 中關閉進度條
                        discard_tuned_samples=False  # 保留 tune 樣本
                    )
                
                # 生成摘要統計
                summary = az.summary(self.trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95)
                
                # 計算各配對的 delta 統計量
                delta_posterior = self.trace.posterior['delta'].values.reshape(-1, num_trials)
                delta_mean = delta_posterior.mean(axis=0)
                delta_std = delta_posterior.std(axis=0)
                delta_hdi = az.hdi(self.trace, var_names=['delta'], hdi_prob=0.95)['delta'].values
                
                # 判斷顯著性(HDI 不包含 0)
                delta_significant = (delta_hdi[:, 0] > 0) | (delta_hdi[:, 1] < 0)
                
                # 計算控制組和實驗組的勝率 (使用動態變數名稱)
                p_control_posterior = self.trace.posterior[f'p_{control_prefix}'].values.reshape(-1, num_trials)
                p_treatment_posterior = self.trace.posterior[f'p_{treatment_prefix}'].values.reshape(-1, num_trials)
                
                p_control_mean = p_control_posterior.mean(axis=0)
                p_treatment_mean = p_treatment_posterior.mean(axis=0)
                
                # 整理結果
                results = {
                    'timestamp': datetime.now().isoformat(),
                    'n_trials': num_trials,
                    'trial_labels': trial_labels.tolist(),
                    
                    # 欄位名稱資訊
                    'column_names': {
                        'trial_type': self.col_trial_type,
                        'control_win': control_win_name,
                        'control_total': control_total_name,
                        'treatment_win': treatment_win_name,
                        'treatment_total': treatment_total_name,
                        'control_prefix': control_prefix,
                        'treatment_prefix': treatment_prefix
                    },
                    
                    # 整體效應
                    'overall': {
                        'd_mean': float(summary.loc['d', 'mean']),
                        'd_sd': float(summary.loc['d', 'sd']),
                        'd_hdi_low': float(summary.loc['d', 'hdi_2.5%']),
                        'd_hdi_high': float(summary.loc['d', 'hdi_97.5%']),
                        
                        'sigma_mean': float(summary.loc['sigma', 'mean']),
                        'sigma_sd': float(summary.loc['sigma', 'sd']),
                        'sigma_hdi_low': float(summary.loc['sigma', 'hdi_2.5%']),
                        'sigma_hdi_high': float(summary.loc['sigma', 'hdi_97.5%']),
                        
                        'or_mean': float(summary.loc['or_speed', 'mean']),
                        'or_sd': float(summary.loc['or_speed', 'sd']),
                        'or_hdi_low': float(summary.loc['or_speed', 'hdi_2.5%']),
                        'or_hdi_high': float(summary.loc['or_speed', 'hdi_97.5%']),
                    },
                    
                    # 各配對的效應 (使用動態鍵名)
                    'by_trial': {
                        'delta_mean': delta_mean.tolist(),
                        'delta_std': delta_std.tolist(),
                        'delta_hdi_low': delta_hdi[:, 0].tolist(),
                        'delta_hdi_high': delta_hdi[:, 1].tolist(),
                        'delta_significant': delta_significant.tolist(),
                        f'p_{control_prefix}_mean': p_control_mean.tolist(),
                        f'p_{treatment_prefix}_mean': p_treatment_mean.tolist(),
                    },
                    
                    # 原始資料
                    'data': self.df.to_dict('records'),
                    
                    # 模型參數
                    'model_params': {
                        'n_samples': n_samples,
                        'n_tune': n_tune,
                        'n_chains': n_chains,
                        'target_accept': target_accept
                    },
                    
                    # 收斂診斷
                    'diagnostics': self._compute_diagnostics(summary),
                    
                    # 解釋
                    'interpretation': self._interpret_results(
                        summary.loc['or_speed', 'mean'],
                        summary.loc['or_speed', 'hdi_2.5%'],
                        summary.loc['or_speed', 'hdi_97.5%'],
                        summary.loc['sigma', 'mean']
                    )
                }
                
                # 儲存到 session results
                self._session_results[self.session_id] = results
                
                return results
                
            except Exception as e:
                raise Exception(f"分析失敗: {str(e)}")    
    

    def _compute_diagnostics(self, summary):
        """計算收斂診斷指標"""
        try:
            # R-hat (應該接近 1.0)
            rhat_d = float(summary.loc['d', 'r_hat']) if 'r_hat' in summary.columns else None
            rhat_sigma = float(summary.loc['sigma', 'r_hat']) if 'r_hat' in summary.columns else None
            
            # ESS (有效樣本數)
            ess_d = float(summary.loc['d', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
            ess_sigma = float(summary.loc['sigma', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
            
            return {
                'rhat_d': rhat_d,
                'rhat_sigma': rhat_sigma,
                'ess_d': ess_d,
                'ess_sigma': ess_sigma,
                'converged': (rhat_d is None or rhat_d < 1.1) and (rhat_sigma is None or rhat_sigma < 1.1)
            }
        except:
            return {
                'converged': None,
                'rhat_d': None,
                'rhat_sigma': None,
                'ess_d': None,
                'ess_sigma': None
            }
    
    def _interpret_results(self, or_mean, or_low, or_high, sigma_mean):
        """解釋分析結果"""
        # 整體效應顯著性       
        if or_low > 1:
            overall_effect = "火系寶可夢相對於水系顯著更容易獲勝"
            overall_significance = "顯著正效應"
        elif or_high < 1:
            overall_effect = "水系寶可夢相對於火系顯著更容易獲勝"
            overall_significance = "顯著負效應"
        else:
            overall_effect = "火系與水系勝率無顯著差異"
            overall_significance = "不顯著"        

        # 效果大小           
        if or_mean > 2:
            effect_size = "大效果 (OR > 2) - 火系有明顯優勢"
        elif or_mean > 1.5:
            effect_size = "中等效果 (OR > 1.5) - 火系有一定優勢"
        elif or_mean > 1:
            effect_size = "小效果 (OR > 1) - 火系略有優勢"
        elif or_mean == 1:
            effect_size = "無差異 (OR = 1) - 火系與水系勢均力敵"
        elif or_mean > 0.67:
            effect_size = "小效果 (OR < 1) - 水系略有優勢"
        elif or_mean > 0.5:
            effect_size = "中等效果 (OR < 0.67) - 水系有一定優勢"
        else:
            effect_size = "大效果 (OR < 0.5) - 水系有明顯優勢"          
                    
        
        # 異質性評估
        if sigma_mean > 0.5:
            heterogeneity = "高異質性 - 不同配對的勝率差異很大"
        elif sigma_mean > 0.3:
            heterogeneity = "中等異質性 - 不同配對的勝率有一定差異"
        else:
            heterogeneity = "低異質性 - 不同配對的勝率相對一致"        
              
        return {
            'overall_effect': overall_effect,
            'overall_significance': overall_significance,
            'effect_size': effect_size,
            'heterogeneity': heterogeneity
        }
    
    def get_model_graph(self):
        """生成模型 DAG 圖(返回 graphviz 物件)"""
        if self.model is None:
            raise ValueError("請先執行分析")
        
        try:
            gv = pm.model_to_graphviz(self.model)
            
            # 嘗試美化標籤 (如果失敗就用原圖)
            try:
                # 獲取欄位前綴
                control_prefix = self.col_control_win.replace('_win', '').replace('_battles', '').replace('_total', '')
                treatment_prefix = self.col_treatment_win.replace('_win', '').replace('_battles', '').replace('_total', '')
                
                # 簡單替換標籤
                src = gv.source
                src = src.replace('label="d"', f'label="d\\n({treatment_prefix} vs {control_prefix})"')
                src = src.replace(f'label="p_{control_prefix}"', f'label="p_{control_prefix}[i]"')
                src = src.replace(f'label="p_{treatment_prefix}"', f'label="p_{treatment_prefix}[i]"')
                src = src.replace(f'label="{self.col_control_win}_obs"', f'label="{self.col_control_win}_obs[i]"')
                src = src.replace(f'label="{self.col_treatment_win}_obs"', f'label="{self.col_treatment_win}_obs[i]"')
                src = src.replace('label="mu"', 'label="mu[i]"')
                src = src.replace('label="delta"', 'label="delta[i]"')
                gv.source = src
            except:
                pass
            
            return gv
        except Exception as e:
            raise Exception(f"無法生成 DAG 圖: {str(e)}")
    
    @classmethod
    def get_session_results(cls, session_id):
        """獲取特定 session 的結果"""
        return cls._session_results.get(session_id)
    
    @classmethod
    def clear_session_results(cls, session_id):
        """清除特定 session 的結果"""
        if session_id in cls._session_results:
            del cls._session_results[session_id]