Wen1201 commited on
Commit
bca7bb0
·
verified ·
1 Parent(s): 284de4f

Upload bayesian_utils.py

Browse files
Files changed (1) hide show
  1. bayesian_utils.py +66 -14
bayesian_utils.py CHANGED
@@ -13,6 +13,7 @@ from PIL import Image
13
  def plot_trace(trace, var_names=['d', 'sigma']):
14
  """
15
  繪製 Trace Plot(MCMC 收斂診斷)
 
16
 
17
  Args:
18
  trace: ArviZ InferenceData 物件
@@ -24,21 +25,65 @@ def plot_trace(trace, var_names=['d', 'sigma']):
24
  fig, axes = plt.subplots(len(var_names), 2, figsize=(14, 4 * len(var_names)))
25
  if len(var_names) == 1:
26
  axes = axes.reshape(1, -1)
27
-
28
- az.plot_trace(trace, var_names=var_names, axes=axes,
29
- combined=False,
30
- compact=False)
31
-
32
- # 加上 Tune 結束的垂直線
33
- for i in range(len(var_names)):
34
- if i == 0:
35
- axes[i, 1].axvline(x=1000, color='red', linestyle='--',
36
- linewidth=2, alpha=0.7, label='Tune結束')
37
- axes[i, 1].legend(loc='upper right', fontsize=9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  else:
39
- axes[i, 1].axvline(x=1000, color='red', linestyle='--',
40
- linewidth=2, alpha=0.7)
41
-
 
 
 
 
 
 
 
 
 
 
 
42
  plt.tight_layout()
43
 
44
  # 轉換為圖片
@@ -50,6 +95,13 @@ def plot_trace(trace, var_names=['d', 'sigma']):
50
 
51
  return img
52
 
 
 
 
 
 
 
 
53
  def plot_posterior(trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95):
54
  """
55
  繪製後驗分佈圖
 
13
  def plot_trace(trace, var_names=['d', 'sigma']):
14
  """
15
  繪製 Trace Plot(MCMC 收斂診斷)
16
+ 包含完整的 warmup + posterior
17
 
18
  Args:
19
  trace: ArviZ InferenceData 物件
 
25
  fig, axes = plt.subplots(len(var_names), 2, figsize=(14, 4 * len(var_names)))
26
  if len(var_names) == 1:
27
  axes = axes.reshape(1, -1)
28
+
29
+ # 檢查是否有 warmup_posterior
30
+ has_warmup = hasattr(trace, 'warmup_posterior') and trace.warmup_posterior is not None
31
+
32
+ for idx, var_name in enumerate(var_names):
33
+ # 左圖: KDE 密度圖(只用 posterior, 不用 warmup)
34
+ post_data = trace.posterior[var_name].values
35
+ for chain_idx in range(post_data.shape[0]):
36
+ from scipy import stats
37
+ data = post_data[chain_idx].flatten()
38
+ density = stats.gaussian_kde(data)
39
+ xs = np.linspace(data.min(), data.max(), 200)
40
+ axes[idx, 0].plot(xs, density(xs), alpha=0.8, label=f'Chain {chain_idx+1}')
41
+ axes[idx, 0].set_xlabel(var_name, fontsize=12)
42
+ axes[idx, 0].set_ylabel('Density', fontsize=12)
43
+ axes[idx, 0].set_title(f'{var_name}', fontsize=13, fontweight='bold')
44
+ if idx == 0:
45
+ axes[idx, 0].legend()
46
+
47
+ # 右圖: Trace 圖(完整 warmup + posterior)
48
+ if has_warmup:
49
+ # 有 warmup: 合併繪製
50
+ warmup_data = trace.warmup_posterior[var_name].values
51
+ post_data = trace.posterior[var_name].values
52
+
53
+ n_warmup = warmup_data.shape[1]
54
+ n_post = post_data.shape[1]
55
+
56
+ for chain_idx in range(warmup_data.shape[0]):
57
+ # 繪 warmup 部分
58
+ x_warmup = np.arange(n_warmup)
59
+ axes[idx, 1].plot(x_warmup, warmup_data[chain_idx].flatten(),
60
+ alpha=0.7, linewidth=0.5,
61
+ label=f'Chain {chain_idx+1}' if idx == 0 else '')
62
+
63
+ # 繪 posterior 部分
64
+ x_post = np.arange(n_warmup, n_warmup + n_post)
65
+ axes[idx, 1].plot(x_post, post_data[chain_idx].flatten(),
66
+ alpha=0.7, linewidth=0.5)
67
+
68
+ # 加 Tune 結束的紅線
69
+ axes[idx, 1].axvline(x=n_warmup, color='red', linestyle='--',
70
+ linewidth=2, alpha=0.7,
71
+ label='Tune結束' if idx == 0 else '')
72
  else:
73
+ # 沒有 warmup: 只用 posterior
74
+ post_data = trace.posterior[var_name].values
75
+ for chain_idx in range(post_data.shape[0]):
76
+ axes[idx, 1].plot(post_data[chain_idx].flatten(),
77
+ alpha=0.7, linewidth=0.5,
78
+ label=f'Chain {chain_idx+1}' if idx == 0 else '')
79
+
80
+ axes[idx, 1].set_xlabel('Iteration', fontsize=12)
81
+ axes[idx, 1].set_ylabel(var_name, fontsize=12)
82
+ axes[idx, 1].set_title(f'{var_name} trace', fontsize=13, fontweight='bold')
83
+ if idx == 0:
84
+ axes[idx, 1].legend(loc='upper right', fontsize=9)
85
+ axes[idx, 1].grid(alpha=0.3)
86
+
87
  plt.tight_layout()
88
 
89
  # 轉換為圖片
 
95
 
96
  return img
97
 
98
+
99
+ # ============================================
100
+ # 替換說明:
101
+ # 在 bayesian_utils.py 中,把第 13-51 行的整個 plot_trace 函數
102
+ # 替換成上面這個版本
103
+ # ============================================
104
+
105
  def plot_posterior(trace, var_names=['d', 'sigma', 'or_speed'], hdi_prob=0.95):
106
  """
107
  繪製後驗分佈圖