Wen1201 commited on
Commit
e8792d0
·
verified ·
1 Parent(s): bca7bb0

Upload bayesian_utils.py

Browse files
Files changed (1) hide show
  1. bayesian_utils.py +12 -1
bayesian_utils.py CHANGED
@@ -53,22 +53,33 @@ def plot_trace(trace, var_names=['d', 'sigma']):
53
  n_warmup = warmup_data.shape[1]
54
  n_post = post_data.shape[1]
55
 
 
 
 
56
  for chain_idx in range(warmup_data.shape[0]):
 
 
57
  # 繪 warmup 部分
58
  x_warmup = np.arange(n_warmup)
59
  axes[idx, 1].plot(x_warmup, warmup_data[chain_idx].flatten(),
 
60
  alpha=0.7, linewidth=0.5,
61
  label=f'Chain {chain_idx+1}' if idx == 0 else '')
62
 
63
- # 繪 posterior 部分
64
  x_post = np.arange(n_warmup, n_warmup + n_post)
65
  axes[idx, 1].plot(x_post, post_data[chain_idx].flatten(),
 
66
  alpha=0.7, linewidth=0.5)
67
 
68
  # 加 Tune 結束的紅線
69
  axes[idx, 1].axvline(x=n_warmup, color='red', linestyle='--',
70
  linewidth=2, alpha=0.7,
71
  label='Tune結束' if idx == 0 else '')
 
 
 
 
72
  else:
73
  # 沒有 warmup: 只用 posterior
74
  post_data = trace.posterior[var_name].values
 
53
  n_warmup = warmup_data.shape[1]
54
  n_post = post_data.shape[1]
55
 
56
+ # 定義顏色,讓每條鏈用固定顏色
57
+ colors = plt.cm.tab10.colors # 使用 matplotlib 的顏色循環
58
+
59
  for chain_idx in range(warmup_data.shape[0]):
60
+ chain_color = colors[chain_idx % len(colors)] # 每條鏈一個固定顏色
61
+
62
  # 繪 warmup 部分
63
  x_warmup = np.arange(n_warmup)
64
  axes[idx, 1].plot(x_warmup, warmup_data[chain_idx].flatten(),
65
+ color=chain_color, # 👈 指定顏色
66
  alpha=0.7, linewidth=0.5,
67
  label=f'Chain {chain_idx+1}' if idx == 0 else '')
68
 
69
+ # 繪 posterior 部分 (用同樣的顏色!)
70
  x_post = np.arange(n_warmup, n_warmup + n_post)
71
  axes[idx, 1].plot(x_post, post_data[chain_idx].flatten(),
72
+ color=chain_color, # 👈 同一個顏色
73
  alpha=0.7, linewidth=0.5)
74
 
75
  # 加 Tune 結束的紅線
76
  axes[idx, 1].axvline(x=n_warmup, color='red', linestyle='--',
77
  linewidth=2, alpha=0.7,
78
  label='Tune結束' if idx == 0 else '')
79
+
80
+
81
+
82
+
83
  else:
84
  # 沒有 warmup: 只用 posterior
85
  post_data = trace.posterior[var_name].values