2catycm commited on
Commit
86e568b
·
1 Parent(s): 4a7325a

feat: KAN可视化有bug

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +384 -22
.gitignore CHANGED
@@ -1,3 +1,5 @@
 
 
1
  *.npz
2
  # Byte-compiled / optimized / DLL files
3
  __pycache__/
 
1
+ figures/
2
+ model/
3
  *.npz
4
  # Byte-compiled / optimized / DLL files
5
  __pycache__/
app.py CHANGED
@@ -5,6 +5,311 @@ from experiments.gmm_dataset import GeneralizedGaussianMixture
5
  import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
  from typing import List, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def init_session_state():
10
  """初始化session state"""
@@ -20,6 +325,8 @@ def init_session_state():
20
  st.session_state.weights = np.ones(3, dtype=np.float64) / 3
21
  if 'sample_points' not in st.session_state:
22
  st.session_state.sample_points = None
 
 
23
 
24
  def create_default_parameters(K: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
25
  """创建默认参数"""
@@ -45,7 +352,7 @@ def generate_latex_formula(p: float, K: int, centers: np.ndarray,
45
  c = centers[k]
46
  s = scales[k]
47
  w = weights[k]
48
- component = f"P_{k+1}(x) = \\frac{{{p:.1f}}}{{2\\alpha_{k+1} \\Gamma(1/{p:.1f})}}\\exp(-|\\frac{{x-({c[0]:.1f}, {c[1]:.1f})}}{{{s[0]:.1f}, {s[1]:.1f}}}|^{{{p:.1f}}}) \\\\"
49
  formula += component
50
  formula += f"\\pi_{k+1} = {w:.2f} \\\\"
51
 
@@ -116,15 +423,19 @@ with st.sidebar:
116
  st.subheader("采样设置")
117
  n_samples = st.slider("采样点数", 5, 20, 10)
118
  if st.button("重新采样"):
119
- # 生成随机样本
120
- samples = []
121
- for _ in range(n_samples):
122
- # 选择分量
123
- k = np.random.choice(K, p=weights)
124
- # 从选定的分量生成样本
125
- sample = np.random.normal(centers[k], scales[k], size=2)
126
- samples.append(sample)
127
- st.session_state.sample_points = np.array(samples)
 
 
 
 
128
 
129
  # 创建GMM数据集
130
  dataset = GeneralizedGaussianMixture(
@@ -196,7 +507,7 @@ if st.session_state.sample_points is not None:
196
  posteriors = []
197
  for sample in samples:
198
  component_probs = [
199
- weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
200
  for k in range(K)
201
  ]
202
  total = sum(component_probs)
@@ -218,16 +529,6 @@ if st.session_state.sample_points is not None:
218
  ),
219
  row=1, col=2
220
  )
221
-
222
- # 显示样本点的概率信息
223
- st.subheader("采样点信息")
224
- for i, (sample, prob, post) in enumerate(zip(samples, probs, posteriors)):
225
- st.write(f"样本点 S{i+1} ({sample[0]:.2f}, {sample[1]:.2f}):")
226
- st.write(f"- 概率密度: {prob:.4f}")
227
- st.write("- 后验概率:")
228
- for k in range(K):
229
- st.write(f" - 分量 {k+1}: {post[k]:.4f}")
230
- st.write("---")
231
 
232
  # 更新布局
233
  fig.update_layout(
@@ -246,9 +547,70 @@ fig.update_layout(
246
  fig.update_xaxes(title_text='X', row=1, col=2)
247
  fig.update_yaxes(title_text='Y', row=1, col=2)
248
 
249
- # 显示图形
250
  st.plotly_chart(fig, use_container_width=True)
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  # 添加参数说明
253
  with st.expander("分布参数说明"):
254
  st.markdown("""
 
5
  import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
  from typing import List, Tuple
8
+ import torch
9
+ import os
10
+ import sys
11
+ import matplotlib.pyplot as plt
12
+
13
+ # Set torch path
14
+ torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__ or "")]
15
+
16
+ # Add pykan to path
17
+ pykan_path = Path(__file__).parent.parent / 'third_party' / 'pykan'
18
+ sys.path.append(str(pykan_path))
19
+
20
+ # Import KAN related modules
21
+ from kan import KAN # type: ignore
22
+ from kan.utils import create_dataset, ex_round # type: ignore
23
+
24
+ # Set torch dtype
25
+ torch.set_default_dtype(torch.float64)
26
+
27
+ def show_kan_prediction(model, device, samples, placeholder):
28
+ """显示KAN的预测结果"""
29
+ # 生成网格数据
30
+ x = np.linspace(-5, 5, 100)
31
+ y = np.linspace(-5, 5, 100)
32
+ X, Y = np.meshgrid(x, y)
33
+ xy = np.column_stack((X.ravel(), Y.ravel()))
34
+
35
+ # 使用KAN预测
36
+ grid_points = torch.from_numpy(xy).to(device)
37
+ with torch.no_grad():
38
+ Z_kan = model(grid_points).cpu().numpy().reshape(X.shape)
39
+
40
+ # 创建预测的概率密度图
41
+ fig_kan = make_subplots(
42
+ rows=1, cols=2,
43
+ specs=[[{'type': 'surface'}, {'type': 'contour'}]],
44
+ subplot_titles=('KAN预测的3D概率密度曲面', 'KAN预测的等高线图')
45
+ )
46
+
47
+ # 3D Surface
48
+ surface_kan = go.Surface(
49
+ x=X, y=Y, z=Z_kan,
50
+ colorscale='viridis',
51
+ showscale=True,
52
+ colorbar=dict(x=0.45)
53
+ )
54
+ fig_kan.add_trace(surface_kan, row=1, col=1)
55
+
56
+ # Contour Plot
57
+ contour_kan = go.Contour(
58
+ x=x, y=y, z=Z_kan,
59
+ colorscale='viridis',
60
+ showscale=True,
61
+ colorbar=dict(x=1.0),
62
+ contours=dict(
63
+ showlabels=True,
64
+ labelfont=dict(size=12)
65
+ )
66
+ )
67
+ fig_kan.add_trace(contour_kan, row=1, col=2)
68
+
69
+ # 添加采样点
70
+ if samples is not None:
71
+ fig_kan.add_trace(
72
+ go.Scatter(
73
+ x=samples[:, 0], y=samples[:, 1],
74
+ mode='markers',
75
+ marker=dict(
76
+ size=8,
77
+ color='yellow',
78
+ line=dict(color='black', width=1)
79
+ ),
80
+ name='训练点'
81
+ ),
82
+ row=1, col=2
83
+ )
84
+
85
+ # 更新布局
86
+ fig_kan.update_layout(
87
+ title='KAN预测结果',
88
+ showlegend=True,
89
+ width=1200,
90
+ height=600,
91
+ scene=dict(
92
+ xaxis_title='X',
93
+ yaxis_title='Y',
94
+ zaxis_title='密度'
95
+ )
96
+ )
97
+
98
+ # 更新2D图的坐标轴
99
+ fig_kan.update_xaxes(title_text='X', row=1, col=2)
100
+ fig_kan.update_yaxes(title_text='Y', row=1, col=2)
101
+
102
+ # 使用占位符显示图形
103
+ placeholder.plotly_chart(fig_kan, use_container_width=True)
104
+
105
+ def create_gmm_plot(dataset, centers, K, samples=None):
106
+ """创建GMM分布的可视化图形"""
107
+ # 生成网格数据
108
+ x = np.linspace(-5, 5, 100)
109
+ y = np.linspace(-5, 5, 100)
110
+ X, Y = np.meshgrid(x, y)
111
+ xy = np.column_stack((X.ravel(), Y.ravel()))
112
+
113
+ # 计算概率密度
114
+ Z = dataset.pdf(xy).reshape(X.shape)
115
+
116
+ # 创建2D和3D可视化
117
+ fig = make_subplots(
118
+ rows=1, cols=2,
119
+ specs=[[{'type': 'surface'}, {'type': 'contour'}]],
120
+ subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
121
+ )
122
+
123
+ # 3D Surface
124
+ surface = go.Surface(
125
+ x=X, y=Y, z=Z,
126
+ colorscale='viridis',
127
+ showscale=True,
128
+ colorbar=dict(x=0.45)
129
+ )
130
+ fig.add_trace(surface, row=1, col=1)
131
+
132
+ # Contour Plot
133
+ contour = go.Contour(
134
+ x=x, y=y, z=Z,
135
+ colorscale='viridis',
136
+ showscale=True,
137
+ colorbar=dict(x=1.0),
138
+ contours=dict(
139
+ showlabels=True,
140
+ labelfont=dict(size=12)
141
+ )
142
+ )
143
+ fig.add_trace(contour, row=1, col=2)
144
+
145
+ # 添加分量中心点
146
+ fig.add_trace(
147
+ go.Scatter(
148
+ x=centers[:K, 0], y=centers[:K, 1],
149
+ mode='markers+text',
150
+ marker=dict(size=10, color='red'),
151
+ text=[f'C{i+1}' for i in range(K)],
152
+ textposition="top center",
153
+ name='分量中心'
154
+ ),
155
+ row=1, col=2
156
+ )
157
+
158
+ # 添加采样点(如果有)
159
+ if samples is not None:
160
+ fig.add_trace(
161
+ go.Scatter(
162
+ x=samples[:, 0], y=samples[:, 1],
163
+ mode='markers+text',
164
+ marker=dict(
165
+ size=8,
166
+ color='yellow',
167
+ line=dict(color='black', width=1)
168
+ ),
169
+ text=[f'S{i+1}' for i in range(len(samples))],
170
+ textposition="bottom center",
171
+ name='采样点'
172
+ ),
173
+ row=1, col=2
174
+ )
175
+
176
+ # 更新布局
177
+ fig.update_layout(
178
+ title='广义高斯混合分布',
179
+ showlegend=True,
180
+ width=1200,
181
+ height=600,
182
+ scene=dict(
183
+ xaxis_title='X',
184
+ yaxis_title='Y',
185
+ zaxis_title='密度'
186
+ )
187
+ )
188
+
189
+ # 更新2D图的坐标轴
190
+ fig.update_xaxes(title_text='X', row=1, col=2)
191
+ fig.update_yaxes(title_text='Y', row=1, col=2)
192
+
193
+ return fig
194
+
195
+ def train_kan(samples, gmm_dataset, device='cuda'):
196
+ """训练KAN网络"""
197
+ if torch.cuda.is_available() and device == 'cuda':
198
+ device = torch.device('cuda')
199
+ else:
200
+ device = torch.device('cpu')
201
+
202
+ # 转换采样点为tensor
203
+ samples = torch.from_numpy(samples).to(device)
204
+ # 计算标签(概率密度值)
205
+ labels = torch.from_numpy(gmm_dataset.pdf(samples.cpu().numpy())).reshape(-1, 1).to(device)
206
+
207
+ # 创建KAN模型
208
+ model = KAN(width=[2,5,1], grid=3, k=3, seed=42, device=device)
209
+ # 创建训练和测试数据集
210
+ train_size = int(0.8 * samples.shape[0])
211
+ train_dataset = {
212
+ 'train_input': samples[:train_size],
213
+ 'train_label': labels[:train_size],
214
+ 'test_input': samples[train_size:],
215
+ 'test_label': labels[train_size:]
216
+ }
217
+
218
+ # 创建训练进度显示组件
219
+ st.write("网络结构:")
220
+ kan_fig_placeholder = st.empty()
221
+ st.write("预测结果:")
222
+ kan_plot_placeholder = st.empty()
223
+ progress_container = st.container()
224
+
225
+ total_steps = 100
226
+ steps_per_update = 10
227
+
228
+ def calculate_error(model, x, y):
229
+ """计算预测误差"""
230
+ with torch.no_grad():
231
+ pred = model(x)
232
+ return torch.mean((pred - y) ** 2).item()
233
+
234
+ def train_phase(phase_name, steps, lamb=None, show_plot=True):
235
+ with progress_container:
236
+ progress_bar = st.progress(0)
237
+ status_text = st.empty()
238
+
239
+ for step in range(0, steps, steps_per_update):
240
+ # 训练几步
241
+ if lamb is not None:
242
+ model.fit(train_dataset, opt="LBFGS", steps=steps_per_update, lamb=lamb)
243
+ else:
244
+ model.fit(train_dataset, opt="LBFGS", steps=steps_per_update)
245
+
246
+ # 更新进度和误差
247
+ progress = (step + steps_per_update) / steps
248
+ progress_bar.progress(progress)
249
+
250
+ # 计算当前误差
251
+ train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
252
+ test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label'])
253
+ # 使用表格格式显示进度和误差
254
+ status_text.markdown(f"""
255
+ ### {phase_name}
256
+ | 项目 | 值 |
257
+ |:---|:---|
258
+ | 进度 | {progress:.0%} |
259
+ | 训练误差 | {train_error:.8f} |
260
+ | 测试误差 | {test_error:.8f} |
261
+ """)
262
+
263
+ # 更新进度和预测结果
264
+ show_kan_prediction(model, device, samples, kan_plot_placeholder)
265
+
266
+ # 更新可视化(每5步更新一次)
267
+ if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps:
268
+ # 更新预测结果
269
+ show_kan_prediction(model, device, samples, kan_plot_placeholder)
270
+
271
+ # 更新网络结构图(可选)
272
+ if show_plot:
273
+ try:
274
+ kan_fig = model.plot()
275
+ if isinstance(kan_fig, tuple):
276
+ kan_fig = kan_fig[0] # 如果是元组,取第一个元素
277
+ if kan_fig is not None:
278
+ kan_fig_placeholder.pyplot(kan_fig)
279
+ plt.close('all') # 确保关闭所有图形
280
+ except Exception as e:
281
+ if step == 0: # 只在第一次出错时显示警告
282
+ st.warning(f"注意:网络结构图显示失败 ({str(e)})")
283
+
284
+ with progress_container:
285
+ st.markdown("#### 训练过程")
286
+ error_text = st.empty()
287
+
288
+ # 第一阶段训练
289
+ # 第一阶段:初始训练
290
+ with st.spinner("初始训练阶段..."):
291
+ train_phase("第一阶段", total_steps, lamb=0.001, show_plot=False) # 第一阶段不显示网络图
292
+
293
+ # 剪枝阶段
294
+ with st.spinner("正在进行网络剪枝优化..."):
295
+ model = model.prune()
296
+ progress_container.info("网络剪枝完成")
297
+
298
+ # 第二阶段:精调
299
+ with st.spinner("最终调优阶段..."):
300
+ train_phase("第二阶段", total_steps, show_plot=True) # 第二阶段显示网络图
301
+
302
+ # 显示最终误差
303
+ train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
304
+ test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label'])
305
+ error_text.markdown(f"""
306
+ #### 训练结果
307
+ - 训练集误差: {train_error:.6f}
308
+ - 测试集误差: {test_error:.6f}
309
+ """)
310
+
311
+ progress_container.success("🎉 训练完成!")
312
+ return model
313
 
314
  def init_session_state():
315
  """初始化session state"""
 
325
  st.session_state.weights = np.ones(3, dtype=np.float64) / 3
326
  if 'sample_points' not in st.session_state:
327
  st.session_state.sample_points = None
328
+ if 'kan_model' not in st.session_state:
329
+ st.session_state.kan_model = None
330
 
331
  def create_default_parameters(K: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
332
  """创建默认参数"""
 
352
  c = centers[k]
353
  s = scales[k]
354
  w = weights[k]
355
+ component = f"P_{{\\theta_{k+1}}}(x) = \\frac{{{p:.1f}}}{{2\\alpha_{k+1} \\Gamma(1/{p:.1f})}}\\exp(-|\\frac{{x-({c[0]:.1f}, {c[1]:.1f})}}{{{s[0]:.1f}, {s[1]:.1f}}}|^{{{p:.1f}}}) \\\\"
356
  formula += component
357
  formula += f"\\pi_{k+1} = {w:.2f} \\\\"
358
 
 
423
  st.subheader("采样设置")
424
  n_samples = st.slider("采样点数", 5, 20, 10)
425
  if st.button("重新采样"):
426
+ # 创建GMM数据集进行采样
427
+ gmm = GeneralizedGaussianMixture(
428
+ D=2,
429
+ K=K,
430
+ p=st.session_state.p,
431
+ centers=centers[:K],
432
+ scales=scales[:K],
433
+ weights=weights[:K]
434
+ )
435
+ # 使用GMM生成采样点
436
+ samples, _ = gmm.generate_samples(n_samples)
437
+ st.session_state.sample_points = samples
438
+ st.session_state.kan_model = None # 重置KAN模型
439
 
440
  # 创建GMM数据集
441
  dataset = GeneralizedGaussianMixture(
 
507
  posteriors = []
508
  for sample in samples:
509
  component_probs = [
510
+ weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
511
  for k in range(K)
512
  ]
513
  total = sum(component_probs)
 
529
  ),
530
  row=1, col=2
531
  )
 
 
 
 
 
 
 
 
 
 
532
 
533
  # 更新布局
534
  fig.update_layout(
 
547
  fig.update_xaxes(title_text='X', row=1, col=2)
548
  fig.update_yaxes(title_text='Y', row=1, col=2)
549
 
550
+ # 显示GMM主图
551
  st.plotly_chart(fig, use_container_width=True)
552
 
553
+ # KAN网络训练和预测部分
554
+ if st.session_state.sample_points is not None:
555
+ st.markdown("---")
556
+ st.subheader("KAN网络训练与预测")
557
+
558
+ # 训练控制按钮
559
+ col1, col2, col3 = st.columns([1, 2, 1])
560
+ with col1:
561
+ if st.button("拟合KAN", use_container_width=True):
562
+ with st.spinner('训练KAN网络中...'):
563
+ st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset)
564
+ st.balloons()
565
+
566
+ with col3:
567
+ if st.session_state.kan_model is not None:
568
+ if st.button("清除KAN结果", use_container_width=True):
569
+ st.session_state.kan_model = None
570
+ st.rerun()
571
+
572
+ # 显示KAN预测结果
573
+ if st.session_state.kan_model is not None:
574
+ st.subheader("KAN预测结果")
575
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
576
+ kan_plot_placeholder = st.empty()
577
+ show_kan_prediction(st.session_state.kan_model, device,
578
+ st.session_state.sample_points, kan_plot_placeholder)
579
+
580
+ st.markdown("---")
581
+
582
+ # 显示采样点信息
583
+ if st.session_state.sample_points is not None:
584
+ # 重新计算采样点的概率密度和后验概率
585
+ samples = st.session_state.sample_points
586
+ probs = dataset.pdf(samples)
587
+ posteriors = []
588
+ for sample in samples:
589
+ component_probs = [
590
+ weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
591
+ for k in range(K)
592
+ ]
593
+ total = sum(component_probs)
594
+ posteriors.append([p/total for p in component_probs])
595
+
596
+ with st.expander("采样点信息"):
597
+ # 创建数据列表
598
+ point_data = []
599
+ for i, (sample, prob, post) in enumerate(zip(samples, probs, posteriors)):
600
+ row = {
601
+ '采样点': f'S{i+1}',
602
+ 'X坐标': f'{sample[0]:.2f}',
603
+ 'Y坐标': f'{sample[1]:.2f}',
604
+ '概率密度': f'{prob:.4f}'
605
+ }
606
+ # 添加每个分量的后验概率
607
+ for k in range(K):
608
+ row[f'分量{k+1}后验概率'] = f'{post[k]:.4f}'
609
+ point_data.append(row)
610
+
611
+ # 显示dataframe
612
+ st.dataframe(point_data)
613
+
614
  # 添加参数说明
615
  with st.expander("分布参数说明"):
616
  st.markdown("""