feat: 修复多个bug
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import numpy as np
|
| 3 |
from pathlib import Path
|
|
@@ -24,7 +25,7 @@ from kan.utils import create_dataset, ex_round # type: ignore
|
|
| 24 |
# Set torch dtype
|
| 25 |
torch.set_default_dtype(torch.float64)
|
| 26 |
|
| 27 |
-
def show_kan_prediction(model, device, samples, placeholder):
|
| 28 |
"""显示KAN的预测结果"""
|
| 29 |
# 生成网格数据
|
| 30 |
x = np.linspace(-5, 5, 100)
|
|
@@ -84,7 +85,7 @@ def show_kan_prediction(model, device, samples, placeholder):
|
|
| 84 |
|
| 85 |
# 更新布局
|
| 86 |
fig_kan.update_layout(
|
| 87 |
-
title='KAN
|
| 88 |
showlegend=True,
|
| 89 |
width=1200,
|
| 90 |
height=600,
|
|
@@ -100,7 +101,10 @@ def show_kan_prediction(model, device, samples, placeholder):
|
|
| 100 |
fig_kan.update_yaxes(title_text='Y', row=1, col=2)
|
| 101 |
|
| 102 |
# 使用占位符显示图形
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
def create_gmm_plot(dataset, centers, K, samples=None):
|
| 106 |
"""创建GMM分布的可视化图形"""
|
|
@@ -198,6 +202,7 @@ def train_kan(samples, gmm_dataset, device='cuda'):
|
|
| 198 |
device = torch.device('cuda')
|
| 199 |
else:
|
| 200 |
device = torch.device('cpu')
|
|
|
|
| 201 |
|
| 202 |
# 转换采样点为tensor
|
| 203 |
samples = torch.from_numpy(samples).to(device)
|
|
@@ -216,13 +221,17 @@ def train_kan(samples, gmm_dataset, device='cuda'):
|
|
| 216 |
}
|
| 217 |
|
| 218 |
# 创建训练进度显示组件
|
| 219 |
-
st.write("
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
| 223 |
progress_container = st.container()
|
| 224 |
|
| 225 |
-
total_steps = 100
|
|
|
|
| 226 |
steps_per_update = 10
|
| 227 |
|
| 228 |
def calculate_error(model, x, y):
|
|
@@ -260,44 +269,69 @@ def train_kan(samples, gmm_dataset, device='cuda'):
|
|
| 260 |
| 测试误差 | {test_error:.8f} |
|
| 261 |
""")
|
| 262 |
|
| 263 |
-
|
| 264 |
-
show_kan_prediction(model, device, samples, kan_plot_placeholder)
|
| 265 |
-
|
| 266 |
# 更新可视化(每5步更新一次)
|
| 267 |
-
if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps:
|
| 268 |
-
|
| 269 |
-
|
| 270 |
|
| 271 |
# 更新网络结构图(可选)
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
| 283 |
|
|
|
|
|
|
|
|
|
|
| 284 |
with progress_container:
|
| 285 |
st.markdown("#### 训练过程")
|
| 286 |
error_text = st.empty()
|
| 287 |
|
| 288 |
# 第一阶段训练
|
| 289 |
# 第一阶段:初始训练
|
| 290 |
-
with st.spinner("
|
| 291 |
-
train_phase("
|
| 292 |
|
| 293 |
# 剪枝阶段
|
| 294 |
with st.spinner("正在进行网络剪枝优化..."):
|
| 295 |
model = model.prune()
|
| 296 |
progress_container.info("网络剪枝完成")
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
# 显示最终误差
|
| 303 |
train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
|
|
@@ -307,7 +341,7 @@ def train_kan(samples, gmm_dataset, device='cuda'):
|
|
| 307 |
- 训练集误差: {train_error:.6f}
|
| 308 |
- 测试集误差: {test_error:.6f}
|
| 309 |
""")
|
| 310 |
-
|
| 311 |
progress_container.success("🎉 训练完成!")
|
| 312 |
return model
|
| 313 |
|
|
@@ -421,7 +455,7 @@ with st.sidebar:
|
|
| 421 |
|
| 422 |
# 采样设置
|
| 423 |
st.subheader("采样设置")
|
| 424 |
-
n_samples = st.slider("采样点数", 5,
|
| 425 |
if st.button("重新采样"):
|
| 426 |
# 创建GMM数据集进行采样
|
| 427 |
gmm = GeneralizedGaussianMixture(
|
|
@@ -548,34 +582,37 @@ fig.update_xaxes(title_text='X', row=1, col=2)
|
|
| 548 |
fig.update_yaxes(title_text='Y', row=1, col=2)
|
| 549 |
|
| 550 |
# 显示GMM主图
|
| 551 |
-
st.plotly_chart(fig, use_container_width=
|
|
|
|
| 552 |
|
| 553 |
# KAN网络训练和预测部分
|
| 554 |
if st.session_state.sample_points is not None:
|
| 555 |
st.markdown("---")
|
| 556 |
st.subheader("KAN网络训练与预测")
|
|
|
|
|
|
|
| 557 |
|
| 558 |
# 训练控制按钮
|
| 559 |
col1, col2, col3 = st.columns([1, 2, 1])
|
| 560 |
with col1:
|
| 561 |
-
if st.button("拟合KAN", use_container_width=
|
| 562 |
with st.spinner('训练KAN网络中...'):
|
| 563 |
st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset)
|
| 564 |
st.balloons()
|
| 565 |
|
| 566 |
with col3:
|
| 567 |
if st.session_state.kan_model is not None:
|
| 568 |
-
if st.button("清除KAN结果", use_container_width=
|
| 569 |
st.session_state.kan_model = None
|
| 570 |
st.rerun()
|
| 571 |
|
| 572 |
# 显示KAN预测结果
|
| 573 |
-
if st.session_state.kan_model is not None:
|
| 574 |
-
st.subheader("KAN预测结果")
|
| 575 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 576 |
-
kan_plot_placeholder = st.empty()
|
| 577 |
-
show_kan_prediction(st.session_state.kan_model, device,
|
| 578 |
-
|
| 579 |
|
| 580 |
st.markdown("---")
|
| 581 |
|
|
@@ -625,4 +662,5 @@ with st.expander("分布参数说明"):
|
|
| 625 |
""")
|
| 626 |
|
| 627 |
# 显示当前参数的数学公式
|
| 628 |
-
st.
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
import streamlit as st
|
| 3 |
import numpy as np
|
| 4 |
from pathlib import Path
|
|
|
|
| 25 |
# Set torch dtype
|
| 26 |
torch.set_default_dtype(torch.float64)
|
| 27 |
|
| 28 |
+
def show_kan_prediction(model, device, samples, placeholder, phase_name):
|
| 29 |
"""显示KAN的预测结果"""
|
| 30 |
# 生成网格数据
|
| 31 |
x = np.linspace(-5, 5, 100)
|
|
|
|
| 85 |
|
| 86 |
# 更新布局
|
| 87 |
fig_kan.update_layout(
|
| 88 |
+
title='KAN预测分布',
|
| 89 |
showlegend=True,
|
| 90 |
width=1200,
|
| 91 |
height=600,
|
|
|
|
| 101 |
fig_kan.update_yaxes(title_text='Y', row=1, col=2)
|
| 102 |
|
| 103 |
# 使用占位符显示图形
|
| 104 |
+
|
| 105 |
+
placeholder.plotly_chart(fig_kan,
|
| 106 |
+
use_container_width=False,
|
| 107 |
+
key=f"kan_plot_{phase_name}_{time.time()}")
|
| 108 |
|
| 109 |
def create_gmm_plot(dataset, centers, K, samples=None):
|
| 110 |
"""创建GMM分布的可视化图形"""
|
|
|
|
| 202 |
device = torch.device('cuda')
|
| 203 |
else:
|
| 204 |
device = torch.device('cpu')
|
| 205 |
+
st.info(f"使用设备: {device} 训练网络")
|
| 206 |
|
| 207 |
# 转换采样点为tensor
|
| 208 |
samples = torch.from_numpy(samples).to(device)
|
|
|
|
| 221 |
}
|
| 222 |
|
| 223 |
# 创建训练进度显示组件
|
| 224 |
+
# st.write("网络预测分布:")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
st.write("网络图形结构:")
|
| 228 |
+
kan_network_arch_placeholder = st.empty()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
progress_container = st.container()
|
| 232 |
|
| 233 |
+
# total_steps = 100
|
| 234 |
+
total_steps = 50
|
| 235 |
steps_per_update = 10
|
| 236 |
|
| 237 |
def calculate_error(model, x, y):
|
|
|
|
| 269 |
| 测试误差 | {test_error:.8f} |
|
| 270 |
""")
|
| 271 |
|
| 272 |
+
|
|
|
|
|
|
|
| 273 |
# 更新可视化(每5步更新一次)
|
| 274 |
+
# if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps:
|
| 275 |
+
# # 更新预测结果
|
| 276 |
+
# show_kan_prediction(model, device, samples, kan_plot_placeholder, phase_name)
|
| 277 |
|
| 278 |
# 更新网络结构图(可选)
|
| 279 |
+
if show_plot:
|
| 280 |
+
try:
|
| 281 |
+
model.plot()
|
| 282 |
+
kan_fig = plt.gcf()
|
| 283 |
+
# if isinstance(kan_fig, tuple):
|
| 284 |
+
# kan_fig = kan_fig[0] # 如果是元组,取第一个元素
|
| 285 |
+
# if kan_fig is not None:
|
| 286 |
+
kan_network_arch_placeholder.pyplot(kan_fig, use_container_width=False)
|
| 287 |
+
# plt.close('all') # 确保关闭所有图形
|
| 288 |
+
except Exception as e:
|
| 289 |
+
if step == 0: # 只在第一次出错时显示警告
|
| 290 |
+
st.warning(f"注意:网络结构图显示失败 ({str(e)})")
|
| 291 |
+
|
| 292 |
|
| 293 |
+
# 更新进度和预测结果
|
| 294 |
+
show_kan_prediction(model, device, samples, kan_distribution_plot_placeholder, phase_name)
|
| 295 |
+
|
| 296 |
with progress_container:
|
| 297 |
st.markdown("#### 训练过程")
|
| 298 |
error_text = st.empty()
|
| 299 |
|
| 300 |
# 第一阶段训练
|
| 301 |
# 第一阶段:初始训练
|
| 302 |
+
with st.spinner("参数调整中..."):
|
| 303 |
+
train_phase("第一阶段: 正则化训练", total_steps, lamb=0.001, show_plot=True)
|
| 304 |
|
| 305 |
# 剪枝阶段
|
| 306 |
with st.spinner("正在进行网络剪枝优化..."):
|
| 307 |
model = model.prune()
|
| 308 |
progress_container.info("网络剪枝完成")
|
| 309 |
|
| 310 |
+
with st.spinner("参数调整中..."):
|
| 311 |
+
train_phase("第二阶段: 剪枝适应性训练", total_steps, show_plot=True)
|
| 312 |
+
|
| 313 |
+
with st.spinner("正在进行网格精细化..."):
|
| 314 |
+
model = model.refine(10)
|
| 315 |
+
progress_container.info("网格精细化完成")
|
| 316 |
+
|
| 317 |
+
with st.spinner("参数调整中..."):
|
| 318 |
+
train_phase("第三阶段: 网格适应性训练", total_steps, show_plot=True)
|
| 319 |
+
|
| 320 |
+
with st.spinner("符号简化中..."):
|
| 321 |
+
# model = model.prune()
|
| 322 |
+
# progress_container.info("网络剪枝完成")
|
| 323 |
+
model.auto_symbolic()
|
| 324 |
+
progress_container.info("符号简化完成")
|
| 325 |
+
|
| 326 |
+
with st.spinner("参数调整中..."):
|
| 327 |
+
train_phase("第四阶段:符号适应性训练", total_steps, show_plot=True)
|
| 328 |
+
|
| 329 |
+
from kan.utils import ex_round
|
| 330 |
+
from sympy import latex
|
| 331 |
+
s= ex_round(model.symbolic_formula()[0][0],4)
|
| 332 |
+
|
| 333 |
+
st.write("网络公式:")
|
| 334 |
+
st.latex(latex(s))
|
| 335 |
|
| 336 |
# 显示最终误差
|
| 337 |
train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
|
|
|
|
| 341 |
- 训练集误差: {train_error:.6f}
|
| 342 |
- 测试集误差: {test_error:.6f}
|
| 343 |
""")
|
| 344 |
+
|
| 345 |
progress_container.success("🎉 训练完成!")
|
| 346 |
return model
|
| 347 |
|
|
|
|
| 455 |
|
| 456 |
# 采样设置
|
| 457 |
st.subheader("采样设置")
|
| 458 |
+
n_samples = st.slider("采样点数", 5, 1000, 100)
|
| 459 |
if st.button("重新采样"):
|
| 460 |
# 创建GMM数据集进行采样
|
| 461 |
gmm = GeneralizedGaussianMixture(
|
|
|
|
| 582 |
fig.update_yaxes(title_text='Y', row=1, col=2)
|
| 583 |
|
| 584 |
# 显示GMM主图
|
| 585 |
+
st.plotly_chart(fig, use_container_width=False)
|
| 586 |
+
|
| 587 |
|
| 588 |
# KAN网络训练和预测部分
|
| 589 |
if st.session_state.sample_points is not None:
|
| 590 |
st.markdown("---")
|
| 591 |
st.subheader("KAN网络训练与预测")
|
| 592 |
+
|
| 593 |
+
kan_distribution_plot_placeholder = st.empty()
|
| 594 |
|
| 595 |
# 训练控制按钮
|
| 596 |
col1, col2, col3 = st.columns([1, 2, 1])
|
| 597 |
with col1:
|
| 598 |
+
if st.button("拟合KAN", use_container_width=False):
|
| 599 |
with st.spinner('训练KAN网络中...'):
|
| 600 |
st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset)
|
| 601 |
st.balloons()
|
| 602 |
|
| 603 |
with col3:
|
| 604 |
if st.session_state.kan_model is not None:
|
| 605 |
+
if st.button("清除KAN结果", use_container_width=False):
|
| 606 |
st.session_state.kan_model = None
|
| 607 |
st.rerun()
|
| 608 |
|
| 609 |
# 显示KAN预测结果
|
| 610 |
+
# if st.session_state.kan_model is not None:
|
| 611 |
+
# st.subheader("KAN预测结果")
|
| 612 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 613 |
+
# kan_plot_placeholder = st.empty()
|
| 614 |
+
# show_kan_prediction(st.session_state.kan_model, device,
|
| 615 |
+
# st.session_state.sample_points, kan_plot_placeholder, "显示结果")
|
| 616 |
|
| 617 |
st.markdown("---")
|
| 618 |
|
|
|
|
| 662 |
""")
|
| 663 |
|
| 664 |
# 显示当前参数的数学公式
|
| 665 |
+
with st.expander("分布概率密度函数公式"):
|
| 666 |
+
st.latex(generate_latex_formula(st.session_state.p, K, centers[:K], scales[:K], weights[:K]))
|