Spaces:
Sleeping
Sleeping
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') # 使用非互動式後端 | |
| import arviz as az | |
| import io | |
| import base64 | |
| from PIL import Image | |
| def plot_trace(trace, var_names=['d', 'sigma']): | |
| """ | |
| 繪製 Trace Plot(MCMC 收斂診斷) | |
| 包含完整的 warmup + posterior | |
| Args: | |
| trace: ArviZ InferenceData 物件 | |
| var_names: 要繪製的變數名稱 | |
| Returns: | |
| PIL Image | |
| """ | |
| fig, axes = plt.subplots(len(var_names), 2, figsize=(14, 4 * len(var_names))) | |
| if len(var_names) == 1: | |
| axes = axes.reshape(1, -1) | |
| # 檢查是否有 warmup_posterior | |
| has_warmup = hasattr(trace, 'warmup_posterior') and trace.warmup_posterior is not None | |
| for idx, var_name in enumerate(var_names): | |
| # 左圖: KDE 密度圖(只用 posterior, 不用 warmup) | |
| post_data = trace.posterior[var_name].values | |
| for chain_idx in range(post_data.shape[0]): | |
| from scipy import stats | |
| data = post_data[chain_idx].flatten() | |
| density = stats.gaussian_kde(data) | |
| xs = np.linspace(data.min(), data.max(), 200) | |
| axes[idx, 0].plot(xs, density(xs), alpha=0.8, label=f'Chain {chain_idx+1}') | |
| axes[idx, 0].set_xlabel(var_name, fontsize=12) | |
| axes[idx, 0].set_ylabel('Density', fontsize=12) | |
| axes[idx, 0].set_title(f'{var_name}', fontsize=13, fontweight='bold') | |
| if idx == 0: | |
| axes[idx, 0].legend() | |
| # 右圖: Trace 圖(完整 warmup + posterior) | |
| if has_warmup: | |
| # 有 warmup: 合併繪製 | |
| warmup_data = trace.warmup_posterior[var_name].values | |
| post_data = trace.posterior[var_name].values | |
| n_warmup = warmup_data.shape[1] | |
| n_post = post_data.shape[1] | |
| # 定義顏色,讓每條鏈用固定顏色 | |
| colors = plt.cm.tab10.colors # 使用 matplotlib 的顏色循環 | |
| for chain_idx in range(warmup_data.shape[0]): | |
| chain_color = colors[chain_idx % len(colors)] # 每條鏈一個固定顏色 | |
| # 繪 warmup 部分 | |
| x_warmup = np.arange(n_warmup) | |
| axes[idx, 1].plot(x_warmup, warmup_data[chain_idx].flatten(), | |
| color=chain_color, # 👈 指定顏色 | |
| alpha=0.7, linewidth=0.5, | |
| label=f'Chain {chain_idx+1}' if idx == 0 else '') | |
| # 繪 posterior 部分 (用同樣的顏色!) | |
| x_post = np.arange(n_warmup, n_warmup + n_post) | |
| axes[idx, 1].plot(x_post, post_data[chain_idx].flatten(), | |
| color=chain_color, # 👈 同一個顏色 | |
| alpha=0.7, linewidth=0.5) | |
| # 加 Tune 結束的紅線 | |
| axes[idx, 1].axvline(x=n_warmup, color='red', linestyle='--', | |
| linewidth=2, alpha=0.7, | |
| label='Tune結束' if idx == 0 else '') | |
| else: | |
| # 沒有 warmup: 只用 posterior | |
| post_data = trace.posterior[var_name].values | |
| for chain_idx in range(post_data.shape[0]): | |
| axes[idx, 1].plot(post_data[chain_idx].flatten(), | |
| alpha=0.7, linewidth=0.5, | |
| label=f'Chain {chain_idx+1}' if idx == 0 else '') | |
| axes[idx, 1].set_xlabel('Iteration', fontsize=12) | |
| axes[idx, 1].set_ylabel(var_name, fontsize=12) | |
| axes[idx, 1].set_title(f'{var_name} trace', fontsize=13, fontweight='bold') | |
| if idx == 0: | |
| axes[idx, 1].legend(loc='upper right', fontsize=9) | |
| axes[idx, 1].grid(alpha=0.3) | |
| plt.tight_layout() | |
| # 轉換為圖片 | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=300, bbox_inches='tight') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img | |
| # ============================================ | |
| # 替換說明: | |
| # 在 bayesian_utils.py 中,把第 13-51 行的整個 plot_trace 函數 | |
| # 替換成上面這個版本 | |
| # ============================================ | |
| def plot_posterior(trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95): | |
| """ | |
| 繪製後驗分佈圖 | |
| Args: | |
| trace: ArviZ InferenceData 物件 | |
| var_names: 要繪製的變數名稱 | |
| hdi_prob: HDI 機率 | |
| Returns: | |
| PIL Image | |
| """ | |
| fig = az.plot_posterior(trace, var_names=var_names, hdi_prob=hdi_prob, figsize=(14, 5)) | |
| plt.tight_layout() | |
| # 轉換為圖片 | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=300, bbox_inches='tight') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img | |
| def plot_forest(trace, trial_labels, title='Effect of Speed on Win Rate by Type'): | |
| """ | |
| 繪製 Forest Plot(各屬性效應) | |
| Args: | |
| trace: ArviZ InferenceData 物件 | |
| trial_labels: 屬性標籤列表 | |
| title: 圖表標題 | |
| Returns: | |
| PIL Image | |
| """ | |
| num_trials = len(trial_labels) | |
| # 計算統計量 | |
| delta_posterior = trace.posterior['delta'].values.reshape(-1, num_trials) | |
| delta_mean = delta_posterior.mean(axis=0) | |
| delta_hdi = az.hdi(trace, var_names=['delta'], hdi_prob=0.95)['delta'].values | |
| # 建立圖表 | |
| fig, ax = plt.subplots(figsize=(12, max(10, num_trials * 0.4))) | |
| y_pos = np.arange(num_trials) | |
| # 繪製信賴區間(橫線) | |
| ax.hlines(y_pos, delta_hdi[:, 0], delta_hdi[:, 1], color='steelblue', linewidth=3, label='95% HDI') | |
| # 繪製平均值(點) | |
| ax.scatter(delta_mean, y_pos, color='darkblue', s=120, zorder=3, | |
| edgecolors='white', linewidth=1.5, label='Mean') | |
| # 標註顯著的點 | |
| for i, (mean, hdi) in enumerate(zip(delta_mean, delta_hdi)): | |
| if hdi[0] > 0: # 顯著正效應 | |
| ax.text(mean, i, ' ★', fontsize=15, ha='left', va='center', color='gold') | |
| elif hdi[1] < 0: # 顯著負效應 | |
| ax.text(mean, i, ' ☆', fontsize=15, ha='left', va='center', color='red') | |
| # 設定軸 | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(trial_labels, fontsize=11) | |
| ax.invert_yaxis() | |
| ax.axvline(0, color='red', linestyle='--', linewidth=2, label='No Effect (δ=0)') | |
| ax.set_xlabel('Delta (Log Odds Ratio)', fontsize=13) | |
| ax.set_title(title, fontsize=15, fontweight='bold', pad=20) | |
| ax.legend(loc='lower right') | |
| ax.grid(axis='x', alpha=0.3) | |
| plt.tight_layout() | |
| # 轉換為圖片 | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=300, bbox_inches='tight') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img | |
| def plot_model_dag(analyzer): | |
| """ | |
| 繪製模型 DAG 圖 | |
| Args: | |
| analyzer: BayesianHierarchicalAnalyzer 物件 | |
| Returns: | |
| PIL Image 或 None | |
| """ | |
| try: | |
| gv = analyzer.get_model_graph() | |
| # 轉換為 PNG | |
| png_bytes = gv.pipe(format='png') | |
| # 轉換為 PIL Image | |
| img = Image.open(io.BytesIO(png_bytes)) | |
| return img | |
| except Exception as e: | |
| print(f"無法生成 DAG 圖: {e}") | |
| return None | |
| def create_summary_table(results): | |
| """ | |
| 創建結果摘要表格 | |
| Args: | |
| results: 分析結果字典 | |
| Returns: | |
| pandas DataFrame | |
| """ | |
| overall = results['overall'] | |
| summary_data = { | |
| '參數': ['d (整體效應)', 'sigma (配對間變異)', 'or_speed (勝算比)'], | |
| '平均值': [ | |
| f"{overall['d_mean']:.4f}", | |
| f"{overall['sigma_mean']:.4f}", | |
| f"{overall['or_mean']:.4f}" | |
| ], | |
| '標準差': [ | |
| f"{overall['d_sd']:.4f}", | |
| f"{overall['sigma_sd']:.4f}", | |
| f"{overall['or_sd']:.4f}" | |
| ], | |
| '95% HDI 下界': [ | |
| f"{overall['d_hdi_low']:.4f}", | |
| f"{overall['sigma_hdi_low']:.4f}", | |
| f"{overall['or_hdi_low']:.4f}" | |
| ], | |
| '95% HDI 上界': [ | |
| f"{overall['d_hdi_high']:.4f}", | |
| f"{overall['sigma_hdi_high']:.4f}", | |
| f"{overall['or_hdi_high']:.4f}" | |
| ] | |
| } | |
| return pd.DataFrame(summary_data) | |
| def create_trial_results_table(results): | |
| """ | |
| 創建各配對結果表格 (使用動態欄位名稱) | |
| Args: | |
| results: 分析結果字典 | |
| Returns: | |
| pandas DataFrame | |
| """ | |
| trial_labels = results['trial_labels'] | |
| by_trial = results['by_trial'] | |
| data = results['data'] | |
| col_names = results['column_names'] | |
| # 動態獲取勝率欄位的鍵名 | |
| control_key = f"p_{col_names['control_prefix']}_mean" | |
| treatment_key = f"p_{col_names['treatment_prefix']}_mean" | |
| trial_data = { | |
| '配對': trial_labels, | |
| 'Delta (平均)': [f"{x:.4f}" for x in by_trial['delta_mean']], | |
| 'Delta (標準差)': [f"{x:.4f}" for x in by_trial['delta_std']], | |
| '95% HDI 下界': [f"{x:.4f}" for x in by_trial['delta_hdi_low']], | |
| '95% HDI 上界': [f"{x:.4f}" for x in by_trial['delta_hdi_high']], | |
| '顯著性': ['★ 顯著' if sig else '不顯著' for sig in by_trial['delta_significant']], | |
| f"{col_names['control_prefix']}勝率": [f"{x:.2%}" for x in by_trial[control_key]], | |
| f"{col_names['treatment_prefix']}勝率": [f"{x:.2%}" for x in by_trial[treatment_key]], | |
| f"{col_names['control_prefix']} (勝/總)": [f"{d[col_names['control_win']]}/{d[col_names['control_total']]}" for d in data], | |
| f"{col_names['treatment_prefix']} (勝/總)": [f"{d[col_names['treatment_win']]}/{d[col_names['treatment_total']]}" for d in data] | |
| } | |
| return pd.DataFrame(trial_data) | |
| def export_results_to_text(results): | |
| """ | |
| 匯出結果為純文字格式 | |
| Args: | |
| results: 分析結果字典 | |
| Returns: | |
| str: 格式化的文字報告 | |
| """ | |
| overall = results['overall'] | |
| interp = results['interpretation'] | |
| diag = results['diagnostics'] | |
| col_names = results['column_names'] | |
| report = f""" | |
| ============================================== | |
| 貝氏階層模型分析報告 | |
| ============================================== | |
| 分析時間: {results['timestamp']} | |
| 配對數量: {results['n_trials']} | |
| ---------------------------------------------- | |
| 1. 整體效應摘要 | |
| ---------------------------------------------- | |
| d (整體效應 - Log OR): | |
| - 平均值: {overall['d_mean']:.4f} | |
| - 標準差: {overall['d_sd']:.4f} | |
| - 95% HDI: [{overall['d_hdi_low']:.4f}, {overall['d_hdi_high']:.4f}] | |
| sigma (配對間變異): | |
| - 平均值: {overall['sigma_mean']:.4f} | |
| - 標準差: {overall['sigma_sd']:.4f} | |
| - 95% HDI: [{overall['sigma_hdi_low']:.4f}, {overall['sigma_hdi_high']:.4f}] | |
| or_speed (勝算比): | |
| - 平均值: {overall['or_mean']:.4f} | |
| - 標準差: {overall['or_sd']:.4f} | |
| - 95% HDI: [{overall['or_hdi_low']:.4f}, {overall['or_hdi_high']:.4f}] | |
| ---------------------------------------------- | |
| 2. 模型收斂診斷 | |
| ---------------------------------------------- | |
| R-hat (d): {f"{diag['rhat_d']:.4f}" if diag['rhat_d'] is not None else 'N/A'} | |
| R-hat (sigma): {f"{diag['rhat_sigma']:.4f}" if diag['rhat_sigma'] is not None else 'N/A'} | |
| ESS (d): {int(diag['ess_d']) if diag['ess_d'] is not None else 'N/A'} | |
| ESS (sigma): {int(diag['ess_sigma']) if diag['ess_sigma'] is not None else 'N/A'} | |
| 收斂狀態: {'✓ 已收斂' if diag['converged'] else '✗ 未收斂'} | |
| ---------------------------------------------- | |
| 3. 結果解釋 | |
| ---------------------------------------------- | |
| 整體效應: {interp['overall_effect']} | |
| 顯著性: {interp['overall_significance']} | |
| 效果大小: {interp['effect_size']} | |
| 異質性: {interp['heterogeneity']} | |
| ---------------------------------------------- | |
| 4. 各配對詳細結果 | |
| ---------------------------------------------- | |
| """ | |
| # 添加各配對的詳細資訊 | |
| trial_labels = results['trial_labels'] | |
| by_trial = results['by_trial'] | |
| # 動態獲取鍵名 | |
| control_key = f"p_{col_names['control_prefix']}_mean" | |
| treatment_key = f"p_{col_names['treatment_prefix']}_mean" | |
| control_label = col_names['control_prefix'].capitalize() | |
| treatment_label = col_names['treatment_prefix'].capitalize() | |
| for i, label in enumerate(trial_labels): | |
| sig_marker = "★" if by_trial['delta_significant'][i] else " " | |
| report += f""" | |
| {sig_marker} {label}: | |
| Delta (平均): {by_trial['delta_mean'][i]:.4f} | |
| 95% HDI: [{by_trial['delta_hdi_low'][i]:.4f}, {by_trial['delta_hdi_high'][i]:.4f}] | |
| {control_label}勝率: {by_trial[control_key][i]:.2%} | |
| {treatment_label}勝率: {by_trial[treatment_key][i]:.2%} | |
| 勝率差異: {(by_trial[treatment_key][i] - by_trial[control_key][i]):.2%} | |
| """ | |
| report += """ | |
| ============================================== | |
| """ | |
| return report | |
| def plot_odds_ratio_comparison(results): | |
| """ | |
| 繪製各屬性的勝算比比較圖(Plotly 版本) | |
| Args: | |
| results: 分析結果字典 | |
| Returns: | |
| plotly figure | |
| """ | |
| trial_labels = results['trial_labels'] | |
| delta_mean = results['by_trial']['delta_mean'] | |
| # 轉換為勝算比 | |
| or_values = [np.exp(d) for d in delta_mean] | |
| # 排序 | |
| sorted_indices = np.argsort(or_values)[::-1] | |
| sorted_labels = [trial_labels[i] for i in sorted_indices] | |
| sorted_or = [or_values[i] for i in sorted_indices] | |
| sorted_sig = [results['by_trial']['delta_significant'][i] for i in sorted_indices] | |
| # 顏色標記 | |
| colors = ['#2ecc71' if sig else '#95a5a6' for sig in sorted_sig] | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| x=sorted_or, | |
| y=sorted_labels, | |
| orientation='h', | |
| marker=dict( | |
| color=colors, | |
| line=dict(color='white', width=1) | |
| ), | |
| text=[f'{or_val:.2f}' for or_val in sorted_or], | |
| textposition='outside', | |
| hovertemplate='%{y}<br>OR: %{x:.3f}<extra></extra>' | |
| )) | |
| # 參考線 (OR = 1) | |
| fig.add_vline(x=1, line_dash="dash", line_color="red", line_width=2) | |
| fig.update_layout( | |
| title='各屬性速度效應(勝算比)', | |
| xaxis_title='Odds Ratio', | |
| yaxis_title='', | |
| width=800, | |
| height=max(400, len(trial_labels) * 25), | |
| template='plotly_white', | |
| showlegend=False | |
| ) | |
| return fig |