Spaces:
Sleeping
Sleeping
Upload bayesian_utils.py
Browse files- bayesian_utils.py +66 -14
bayesian_utils.py
CHANGED
|
@@ -13,6 +13,7 @@ from PIL import Image
|
|
| 13 |
def plot_trace(trace, var_names=['d', 'sigma']):
|
| 14 |
"""
|
| 15 |
繪製 Trace Plot(MCMC 收斂診斷)
|
|
|
|
| 16 |
|
| 17 |
Args:
|
| 18 |
trace: ArviZ InferenceData 物件
|
|
@@ -24,21 +25,65 @@ def plot_trace(trace, var_names=['d', 'sigma']):
|
|
| 24 |
fig, axes = plt.subplots(len(var_names), 2, figsize=(14, 4 * len(var_names)))
|
| 25 |
if len(var_names) == 1:
|
| 26 |
axes = axes.reshape(1, -1)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
else:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
plt.tight_layout()
|
| 43 |
|
| 44 |
# 轉換為圖片
|
|
@@ -50,6 +95,13 @@ def plot_trace(trace, var_names=['d', 'sigma']):
|
|
| 50 |
|
| 51 |
return img
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def plot_posterior(trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95):
|
| 54 |
"""
|
| 55 |
繪製後驗分佈圖
|
|
|
|
| 13 |
def plot_trace(trace, var_names=['d', 'sigma']):
|
| 14 |
"""
|
| 15 |
繪製 Trace Plot(MCMC 收斂診斷)
|
| 16 |
+
包含完整的 warmup + posterior
|
| 17 |
|
| 18 |
Args:
|
| 19 |
trace: ArviZ InferenceData 物件
|
|
|
|
| 25 |
fig, axes = plt.subplots(len(var_names), 2, figsize=(14, 4 * len(var_names)))
|
| 26 |
if len(var_names) == 1:
|
| 27 |
axes = axes.reshape(1, -1)
|
| 28 |
+
|
| 29 |
+
# 檢查是否有 warmup_posterior
|
| 30 |
+
has_warmup = hasattr(trace, 'warmup_posterior') and trace.warmup_posterior is not None
|
| 31 |
+
|
| 32 |
+
for idx, var_name in enumerate(var_names):
|
| 33 |
+
# 左圖: KDE 密度圖(只用 posterior, 不用 warmup)
|
| 34 |
+
post_data = trace.posterior[var_name].values
|
| 35 |
+
for chain_idx in range(post_data.shape[0]):
|
| 36 |
+
from scipy import stats
|
| 37 |
+
data = post_data[chain_idx].flatten()
|
| 38 |
+
density = stats.gaussian_kde(data)
|
| 39 |
+
xs = np.linspace(data.min(), data.max(), 200)
|
| 40 |
+
axes[idx, 0].plot(xs, density(xs), alpha=0.8, label=f'Chain {chain_idx+1}')
|
| 41 |
+
axes[idx, 0].set_xlabel(var_name, fontsize=12)
|
| 42 |
+
axes[idx, 0].set_ylabel('Density', fontsize=12)
|
| 43 |
+
axes[idx, 0].set_title(f'{var_name}', fontsize=13, fontweight='bold')
|
| 44 |
+
if idx == 0:
|
| 45 |
+
axes[idx, 0].legend()
|
| 46 |
+
|
| 47 |
+
# 右圖: Trace 圖(完整 warmup + posterior)
|
| 48 |
+
if has_warmup:
|
| 49 |
+
# 有 warmup: 合併繪製
|
| 50 |
+
warmup_data = trace.warmup_posterior[var_name].values
|
| 51 |
+
post_data = trace.posterior[var_name].values
|
| 52 |
+
|
| 53 |
+
n_warmup = warmup_data.shape[1]
|
| 54 |
+
n_post = post_data.shape[1]
|
| 55 |
+
|
| 56 |
+
for chain_idx in range(warmup_data.shape[0]):
|
| 57 |
+
# 繪 warmup 部分
|
| 58 |
+
x_warmup = np.arange(n_warmup)
|
| 59 |
+
axes[idx, 1].plot(x_warmup, warmup_data[chain_idx].flatten(),
|
| 60 |
+
alpha=0.7, linewidth=0.5,
|
| 61 |
+
label=f'Chain {chain_idx+1}' if idx == 0 else '')
|
| 62 |
+
|
| 63 |
+
# 繪 posterior 部分
|
| 64 |
+
x_post = np.arange(n_warmup, n_warmup + n_post)
|
| 65 |
+
axes[idx, 1].plot(x_post, post_data[chain_idx].flatten(),
|
| 66 |
+
alpha=0.7, linewidth=0.5)
|
| 67 |
+
|
| 68 |
+
# 加 Tune 結束的紅線
|
| 69 |
+
axes[idx, 1].axvline(x=n_warmup, color='red', linestyle='--',
|
| 70 |
+
linewidth=2, alpha=0.7,
|
| 71 |
+
label='Tune結束' if idx == 0 else '')
|
| 72 |
else:
|
| 73 |
+
# 沒有 warmup: 只用 posterior
|
| 74 |
+
post_data = trace.posterior[var_name].values
|
| 75 |
+
for chain_idx in range(post_data.shape[0]):
|
| 76 |
+
axes[idx, 1].plot(post_data[chain_idx].flatten(),
|
| 77 |
+
alpha=0.7, linewidth=0.5,
|
| 78 |
+
label=f'Chain {chain_idx+1}' if idx == 0 else '')
|
| 79 |
+
|
| 80 |
+
axes[idx, 1].set_xlabel('Iteration', fontsize=12)
|
| 81 |
+
axes[idx, 1].set_ylabel(var_name, fontsize=12)
|
| 82 |
+
axes[idx, 1].set_title(f'{var_name} trace', fontsize=13, fontweight='bold')
|
| 83 |
+
if idx == 0:
|
| 84 |
+
axes[idx, 1].legend(loc='upper right', fontsize=9)
|
| 85 |
+
axes[idx, 1].grid(alpha=0.3)
|
| 86 |
+
|
| 87 |
plt.tight_layout()
|
| 88 |
|
| 89 |
# 轉換為圖片
|
|
|
|
| 95 |
|
| 96 |
return img
|
| 97 |
|
| 98 |
+
|
| 99 |
+
# ============================================
|
| 100 |
+
# 替換說明:
|
| 101 |
+
# 在 bayesian_utils.py 中,把第 13-51 行的整個 plot_trace 函數
|
| 102 |
+
# 替換成上面這個版本
|
| 103 |
+
# ============================================
|
| 104 |
+
|
| 105 |
def plot_posterior(trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95):
|
| 106 |
"""
|
| 107 |
繪製後驗分佈圖
|