Spaces:
Sleeping
Sleeping
Upload bayesian_utils.py
Browse files- bayesian_utils.py +12 -1
bayesian_utils.py
CHANGED
|
@@ -53,22 +53,33 @@ def plot_trace(trace, var_names=['d', 'sigma']):
|
|
| 53 |
n_warmup = warmup_data.shape[1]
|
| 54 |
n_post = post_data.shape[1]
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
for chain_idx in range(warmup_data.shape[0]):
|
|
|
|
|
|
|
| 57 |
# 繪 warmup 部分
|
| 58 |
x_warmup = np.arange(n_warmup)
|
| 59 |
axes[idx, 1].plot(x_warmup, warmup_data[chain_idx].flatten(),
|
|
|
|
| 60 |
alpha=0.7, linewidth=0.5,
|
| 61 |
label=f'Chain {chain_idx+1}' if idx == 0 else '')
|
| 62 |
|
| 63 |
-
# 繪 posterior 部分
|
| 64 |
x_post = np.arange(n_warmup, n_warmup + n_post)
|
| 65 |
axes[idx, 1].plot(x_post, post_data[chain_idx].flatten(),
|
|
|
|
| 66 |
alpha=0.7, linewidth=0.5)
|
| 67 |
|
| 68 |
# 加 Tune 結束的紅線
|
| 69 |
axes[idx, 1].axvline(x=n_warmup, color='red', linestyle='--',
|
| 70 |
linewidth=2, alpha=0.7,
|
| 71 |
label='Tune結束' if idx == 0 else '')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
else:
|
| 73 |
# 沒有 warmup: 只用 posterior
|
| 74 |
post_data = trace.posterior[var_name].values
|
|
|
|
| 53 |
n_warmup = warmup_data.shape[1]
|
| 54 |
n_post = post_data.shape[1]
|
| 55 |
|
| 56 |
+
# 定義顏色,讓每條鏈用固定顏色
|
| 57 |
+
colors = plt.cm.tab10.colors # 使用 matplotlib 的顏色循環
|
| 58 |
+
|
| 59 |
for chain_idx in range(warmup_data.shape[0]):
|
| 60 |
+
chain_color = colors[chain_idx % len(colors)] # 每條鏈一個固定顏色
|
| 61 |
+
|
| 62 |
# 繪 warmup 部分
|
| 63 |
x_warmup = np.arange(n_warmup)
|
| 64 |
axes[idx, 1].plot(x_warmup, warmup_data[chain_idx].flatten(),
|
| 65 |
+
color=chain_color, # 👈 指定顏色
|
| 66 |
alpha=0.7, linewidth=0.5,
|
| 67 |
label=f'Chain {chain_idx+1}' if idx == 0 else '')
|
| 68 |
|
| 69 |
+
# 繪 posterior 部分 (用同樣的顏色!)
|
| 70 |
x_post = np.arange(n_warmup, n_warmup + n_post)
|
| 71 |
axes[idx, 1].plot(x_post, post_data[chain_idx].flatten(),
|
| 72 |
+
color=chain_color, # 👈 同一個顏色
|
| 73 |
alpha=0.7, linewidth=0.5)
|
| 74 |
|
| 75 |
# 加 Tune 結束的紅線
|
| 76 |
axes[idx, 1].axvline(x=n_warmup, color='red', linestyle='--',
|
| 77 |
linewidth=2, alpha=0.7,
|
| 78 |
label='Tune結束' if idx == 0 else '')
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
else:
|
| 84 |
# 沒有 warmup: 只用 posterior
|
| 85 |
post_data = trace.posterior[var_name].values
|