File size: 12,829 Bytes
89b9684
0ba59a3
 
89b9684
79eea14
0ba59a3
 
 
89b9684
79eea14
89b9684
79eea14
0ba59a3
89b9684
79eea14
 
 
 
 
 
 
 
89b9684
 
79eea14
89b9684
79eea14
89b9684
79eea14
 
89b9684
 
79eea14
 
 
 
89b9684
79eea14
 
0ba59a3
 
 
 
 
 
 
89b9684
79eea14
 
 
 
 
 
 
 
 
 
 
0ba59a3
 
79eea14
0ba59a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79eea14
 
 
89b9684
79eea14
 
 
89b9684
79eea14
89b9684
79eea14
89b9684
79eea14
 
0ba59a3
79eea14
 
 
0ba59a3
79eea14
0ba59a3
 
 
79eea14
 
 
 
0ba59a3
 
 
79eea14
0ba59a3
79eea14
 
0ba59a3
79eea14
 
 
0ba59a3
79eea14
 
 
 
0ba59a3
 
 
 
 
79eea14
0ba59a3
 
79eea14
 
0ba59a3
 
79eea14
0ba59a3
 
 
 
 
79eea14
0ba59a3
 
79eea14
0ba59a3
 
 
79eea14
0ba59a3
 
79eea14
 
 
 
0ba59a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79eea14
 
 
 
0ba59a3
 
 
 
 
 
 
 
 
 
 
 
79eea14
 
 
 
 
 
 
 
 
 
0ba59a3
 
79eea14
0ba59a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79eea14
0ba59a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89b9684
0ba59a3
 
 
 
 
 
 
 
 
 
 
 
 
 
79eea14
0ba59a3
 
 
 
89b9684
0ba59a3
 
 
 
 
79eea14
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import threading
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

class BayesianHierarchicalAnalyzer:
    """
    貝氏階層模型分析器
    用於分析寶可夢速度對勝率的影響(跨屬性)
    """
    
    # 類別級的鎖,用於執行緒安全
    _lock = threading.Lock()
    
    # 儲存各 session 的分析結果
    _session_results = {}
    
    def __init__(self, session_id):
        """
        初始化分析器
        
        Args:
            session_id: 唯一的 session 識別碼
        """
        self.session_id = session_id
        self.df = None
        self.model = None
        self.trace = None
    
    def load_data(self, csv_path_or_df):
        """
        載入資料
        
        Args:
            csv_path_or_df: CSV 檔案路徑或 DataFrame
            
        Expected columns:
            - Trial_Type: 屬性名稱 (e.g., Water, Fire, Grass)
            - rc: 控制組(速度慢)的勝場數
            - nc: 控制組的總場數
            - rt: 實驗組(速度快)的勝場數
            - nt: 實驗組的總場數
        """
        if isinstance(csv_path_or_df, str):
            self.df = pd.read_csv(csv_path_or_df)
        else:
            self.df = csv_path_or_df.copy()
        
        # 驗證必要欄位
        required_cols = ['Trial_Type', 'rc', 'nc', 'rt', 'nt']
        missing_cols = [col for col in required_cols if col not in self.df.columns]
        
        if missing_cols:
            raise ValueError(f"資料缺少必要欄位: {missing_cols}")
        
        return True
    
    def validate_data(self):
        """驗證資料有效性"""
        if self.df is None:
            raise ValueError("請先載入資料")
        
        # 檢查數值欄位
        for col in ['rc', 'nc', 'rt', 'nt']:
            if not pd.api.types.is_numeric_dtype(self.df[col]):
                raise ValueError(f"欄位 {col} 必須是數值類型")
        
        # 檢查邏輯約束
        if (self.df['rc'] > self.df['nc']).any():
            raise ValueError("rc (勝場數) 不能大於 nc (總場數)")
        
        if (self.df['rt'] > self.df['nt']).any():
            raise ValueError("rt (勝場數) 不能大於 nt (總場數)")
        
        return True
    
    def run_analysis(self, n_samples=2000, n_tune=1000, n_chains=2, target_accept=0.95):
        """
        執行貝氏階層模型分析
        
        Args:
            n_samples: MCMC 抽樣數
            n_tune: 調整期樣本數
            n_chains: 鏈數
            target_accept: 目標接受率
            
        Returns:
            dict: 包含所有分析結果的字典
        """
        with self._lock:
            try:
                self.validate_data()
                
                # 準備資料
                trial_labels = self.df['Trial_Type'].values
                num_trials = len(self.df)
                
                # 建立模型
                with pm.Model() as self.model:
                    # --- 先驗分佈 (Priors) ---
                    d = pm.Normal('d', mu=0, sigma=10)
                    tau = pm.Gamma('tau', alpha=0.001, beta=0.001)
                    sigma = pm.Deterministic('sigma', 1 / pm.math.sqrt(tau))
                    
                    # --- 各屬性特定效應 (Trial-specific effects) ---
                    mu = pm.Normal('mu', mu=0, sigma=10, shape=num_trials)
                    delta = pm.Normal('delta', mu=d, sigma=1 / pm.math.sqrt(tau), shape=num_trials)
                    
                    # --- 轉換與似然函數 (Logit Link & Likelihood) ---
                    pc = pm.Deterministic('pc', pm.math.invlogit(mu))
                    pt = pm.Deterministic('pt', pm.math.invlogit(mu + delta))
                    
                    rc_obs = pm.Binomial('rc_obs', n=self.df['nc'].values, p=pc, observed=self.df['rc'].values)
                    rt_obs = pm.Binomial('rt_obs', n=self.df['nt'].values, p=pt, observed=self.df['rt'].values)
                    
                    # --- 其他統計量 ---
                    delta_new = pm.Normal('delta_new', mu=d, sigma=1 / pm.math.sqrt(tau))
                    or_speed = pm.Deterministic('or_speed', pm.math.exp(d))
                    
                    # 執行 MCMC 抽樣
                    self.trace = pm.sample(
                        draws=n_samples,
                        tune=n_tune,
                        chains=n_chains,
                        target_accept=target_accept,
                        return_inferencedata=True,
                        progressbar=False, # 在 Streamlit 中關閉進度條
                        discard_tuned_samples=False  # 👈 加這行!保留 tune 樣本
                    )
                
                # 生成摘要統計
                summary = az.summary(self.trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95)
                
                # 計算各屬性的 delta 統計量
                delta_posterior = self.trace.posterior['delta'].values.reshape(-1, num_trials)
                delta_mean = delta_posterior.mean(axis=0)
                delta_std = delta_posterior.std(axis=0)
                delta_hdi = az.hdi(self.trace, var_names=['delta'], hdi_prob=0.95)['delta'].values
                
                # 判斷顯著性(HDI 不包含 0)
                delta_significant = (delta_hdi[:, 0] > 0) | (delta_hdi[:, 1] < 0)
                
                # 計算控制組和實驗組的勝率
                pc_posterior = self.trace.posterior['pc'].values.reshape(-1, num_trials)
                pt_posterior = self.trace.posterior['pt'].values.reshape(-1, num_trials)
                
                pc_mean = pc_posterior.mean(axis=0)
                pt_mean = pt_posterior.mean(axis=0)
                
                # 整理結果
                results = {
                    'timestamp': datetime.now().isoformat(),
                    'n_trials': num_trials,
                    'trial_labels': trial_labels.tolist(),
                    
                    # 整體效應
                    'overall': {
                        'd_mean': float(summary.loc['d', 'mean']),
                        'd_sd': float(summary.loc['d', 'sd']),
                        'd_hdi_low': float(summary.loc['d', 'hdi_2.5%']),
                        'd_hdi_high': float(summary.loc['d', 'hdi_97.5%']),
                        
                        'sigma_mean': float(summary.loc['sigma', 'mean']),
                        'sigma_sd': float(summary.loc['sigma', 'sd']),
                        'sigma_hdi_low': float(summary.loc['sigma', 'hdi_2.5%']),
                        'sigma_hdi_high': float(summary.loc['sigma', 'hdi_97.5%']),
                        
                        'or_mean': float(summary.loc['or_speed', 'mean']),
                        'or_sd': float(summary.loc['or_speed', 'sd']),
                        'or_hdi_low': float(summary.loc['or_speed', 'hdi_2.5%']),
                        'or_hdi_high': float(summary.loc['or_speed', 'hdi_97.5%']),
                    },
                    
                    # 各屬性的效應
                    'by_trial': {
                        'delta_mean': delta_mean.tolist(),
                        'delta_std': delta_std.tolist(),
                        'delta_hdi_low': delta_hdi[:, 0].tolist(),
                        'delta_hdi_high': delta_hdi[:, 1].tolist(),
                        'delta_significant': delta_significant.tolist(),
                        'pc_mean': pc_mean.tolist(),
                        'pt_mean': pt_mean.tolist(),
                    },
                    
                    # 原始資料
                    'data': self.df.to_dict('records'),
                    
                    # 模型參數
                    'model_params': {
                        'n_samples': n_samples,
                        'n_tune': n_tune,
                        'n_chains': n_chains,
                        'target_accept': target_accept
                    },
                    
                    # 收斂診斷
                    'diagnostics': self._compute_diagnostics(summary),
                    
                    # 解釋
                    'interpretation': self._interpret_results(
                        summary.loc['or_speed', 'mean'],
                        summary.loc['or_speed', 'hdi_2.5%'],
                        summary.loc['or_speed', 'hdi_97.5%'],
                        summary.loc['sigma', 'mean']
                    )
                }
                
                # 儲存到 session results
                self._session_results[self.session_id] = results
                
                return results
                
            except Exception as e:
                raise Exception(f"分析失敗: {str(e)}")
    
    def _compute_diagnostics(self, summary):
        """計算收斂診斷指標"""
        try:
            # R-hat (應該接近 1.0)
            rhat_d = float(summary.loc['d', 'r_hat']) if 'r_hat' in summary.columns else None
            rhat_sigma = float(summary.loc['sigma', 'r_hat']) if 'r_hat' in summary.columns else None
            
            # ESS (有效樣本數)
            ess_d = float(summary.loc['d', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
            ess_sigma = float(summary.loc['sigma', 'ess_bulk']) if 'ess_bulk' in summary.columns else None
            
            return {
                'rhat_d': rhat_d,
                'rhat_sigma': rhat_sigma,
                'ess_d': ess_d,
                'ess_sigma': ess_sigma,
                'converged': (rhat_d is None or rhat_d < 1.1) and (rhat_sigma is None or rhat_sigma < 1.1)
            }
        except:
            return {
                'converged': None,
                'rhat_d': None,
                'rhat_sigma': None,
                'ess_d': None,
                'ess_sigma': None
            }
    
    def _interpret_results(self, or_mean, or_low, or_high, sigma_mean):
        """解釋分析結果"""
        # 整體效應顯著性       
        if or_low > 1:
            overall_effect = "火系寶可夢相對於水系顯著更容易獲勝"
            overall_significance = "顯著正效應"
        elif or_high < 1:
            overall_effect = "水系寶可夢相對於火系顯著更容易獲勝"
            overall_significance = "顯著負效應"
        else:
            overall_effect = "火系與水系勝率無顯著差異"
            overall_significance = "不顯著"        

        # 效果大小           
        if or_mean > 2:
            effect_size = "大效果 (OR > 2) - 火系有明顯優勢"
        elif or_mean > 1.5:
            effect_size = "中等效果 (OR > 1.5) - 火系有一定優勢"
        elif or_mean > 1:
            effect_size = "小效果 (OR > 1) - 火系略有優勢"
        elif or_mean == 1:
            effect_size = "無差異 (OR = 1) - 火系與水系勢均力敵"
        elif or_mean > 0.67:
            effect_size = "小效果 (OR < 1) - 水系略有優勢"
        elif or_mean > 0.5:
            effect_size = "中等效果 (OR < 0.67) - 水系有一定優勢"
        else:
            effect_size = "大效果 (OR < 0.5) - 水系有明顯優勢"          
                    
        
        # 異質性評估
        if sigma_mean > 0.5:
            heterogeneity = "高異質性 - 不同配對的勝率差異很大"
        elif sigma_mean > 0.3:
            heterogeneity = "中等異質性 - 不同配對的勝率有一定差異"
        else:
            heterogeneity = "低異質性 - 不同配對的勝率相對一致"        
              
        return {
            'overall_effect': overall_effect,
            'overall_significance': overall_significance,
            'effect_size': effect_size,
            'heterogeneity': heterogeneity
        }
    
    def get_model_graph(self):
        """生成模型 DAG 圖(返回 graphviz 物件)"""
        if self.model is None:
            raise ValueError("請先執行分析")
        
        try:
            gv = pm.model_to_graphviz(self.model)
            return gv
        except Exception as e:
            raise Exception(f"無法生成 DAG 圖: {str(e)}")
    
    @classmethod
    def get_session_results(cls, session_id):
        """獲取特定 session 的結果"""
        return cls._session_results.get(session_id)
    
    @classmethod
    def clear_session_results(cls, session_id):
        """清除特定 session 的結果"""
        if session_id in cls._session_results:
            del cls._session_results[session_id]