feat: KAN可视化有bug
Browse files- .gitignore +2 -0
- app.py +384 -22
.gitignore
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
*.npz
|
| 2 |
# Byte-compiled / optimized / DLL files
|
| 3 |
__pycache__/
|
|
|
|
| 1 |
+
figures/
|
| 2 |
+
model/
|
| 3 |
*.npz
|
| 4 |
# Byte-compiled / optimized / DLL files
|
| 5 |
__pycache__/
|
app.py
CHANGED
|
@@ -5,6 +5,311 @@ from experiments.gmm_dataset import GeneralizedGaussianMixture
|
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
from plotly.subplots import make_subplots
|
| 7 |
from typing import List, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def init_session_state():
|
| 10 |
"""初始化session state"""
|
|
@@ -20,6 +325,8 @@ def init_session_state():
|
|
| 20 |
st.session_state.weights = np.ones(3, dtype=np.float64) / 3
|
| 21 |
if 'sample_points' not in st.session_state:
|
| 22 |
st.session_state.sample_points = None
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def create_default_parameters(K: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 25 |
"""创建默认参数"""
|
|
@@ -45,7 +352,7 @@ def generate_latex_formula(p: float, K: int, centers: np.ndarray,
|
|
| 45 |
c = centers[k]
|
| 46 |
s = scales[k]
|
| 47 |
w = weights[k]
|
| 48 |
-
component = f"P_{k+1}(x) = \\frac{{{p:.1f}}}{{2\\alpha_{k+1} \\Gamma(1/{p:.1f})}}\\exp(-|\\frac{{x-({c[0]:.1f}, {c[1]:.1f})}}{{{s[0]:.1f}, {s[1]:.1f}}}|^{{{p:.1f}}}) \\\\"
|
| 49 |
formula += component
|
| 50 |
formula += f"\\pi_{k+1} = {w:.2f} \\\\"
|
| 51 |
|
|
@@ -116,15 +423,19 @@ with st.sidebar:
|
|
| 116 |
st.subheader("采样设置")
|
| 117 |
n_samples = st.slider("采样点数", 5, 20, 10)
|
| 118 |
if st.button("重新采样"):
|
| 119 |
-
#
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
# 创建GMM数据集
|
| 130 |
dataset = GeneralizedGaussianMixture(
|
|
@@ -196,7 +507,7 @@ if st.session_state.sample_points is not None:
|
|
| 196 |
posteriors = []
|
| 197 |
for sample in samples:
|
| 198 |
component_probs = [
|
| 199 |
-
weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
|
| 200 |
for k in range(K)
|
| 201 |
]
|
| 202 |
total = sum(component_probs)
|
|
@@ -218,16 +529,6 @@ if st.session_state.sample_points is not None:
|
|
| 218 |
),
|
| 219 |
row=1, col=2
|
| 220 |
)
|
| 221 |
-
|
| 222 |
-
# 显示样本点的概率信息
|
| 223 |
-
st.subheader("采样点信息")
|
| 224 |
-
for i, (sample, prob, post) in enumerate(zip(samples, probs, posteriors)):
|
| 225 |
-
st.write(f"样本点 S{i+1} ({sample[0]:.2f}, {sample[1]:.2f}):")
|
| 226 |
-
st.write(f"- 概率密度: {prob:.4f}")
|
| 227 |
-
st.write("- 后验概率:")
|
| 228 |
-
for k in range(K):
|
| 229 |
-
st.write(f" - 分量 {k+1}: {post[k]:.4f}")
|
| 230 |
-
st.write("---")
|
| 231 |
|
| 232 |
# 更新布局
|
| 233 |
fig.update_layout(
|
|
@@ -246,9 +547,70 @@ fig.update_layout(
|
|
| 246 |
fig.update_xaxes(title_text='X', row=1, col=2)
|
| 247 |
fig.update_yaxes(title_text='Y', row=1, col=2)
|
| 248 |
|
| 249 |
-
#
|
| 250 |
st.plotly_chart(fig, use_container_width=True)
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
# 添加参数说明
|
| 253 |
with st.expander("分布参数说明"):
|
| 254 |
st.markdown("""
|
|
|
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
from plotly.subplots import make_subplots
|
| 7 |
from typing import List, Tuple
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
|
| 13 |
+
# Set torch path
|
| 14 |
+
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__ or "")]
|
| 15 |
+
|
| 16 |
+
# Add pykan to path
|
| 17 |
+
pykan_path = Path(__file__).parent.parent / 'third_party' / 'pykan'
|
| 18 |
+
sys.path.append(str(pykan_path))
|
| 19 |
+
|
| 20 |
+
# Import KAN related modules
|
| 21 |
+
from kan import KAN # type: ignore
|
| 22 |
+
from kan.utils import create_dataset, ex_round # type: ignore
|
| 23 |
+
|
| 24 |
+
# Set torch dtype
|
| 25 |
+
torch.set_default_dtype(torch.float64)
|
| 26 |
+
|
| 27 |
+
def show_kan_prediction(model, device, samples, placeholder):
|
| 28 |
+
"""显示KAN的预测结果"""
|
| 29 |
+
# 生成网格数据
|
| 30 |
+
x = np.linspace(-5, 5, 100)
|
| 31 |
+
y = np.linspace(-5, 5, 100)
|
| 32 |
+
X, Y = np.meshgrid(x, y)
|
| 33 |
+
xy = np.column_stack((X.ravel(), Y.ravel()))
|
| 34 |
+
|
| 35 |
+
# 使用KAN预测
|
| 36 |
+
grid_points = torch.from_numpy(xy).to(device)
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
Z_kan = model(grid_points).cpu().numpy().reshape(X.shape)
|
| 39 |
+
|
| 40 |
+
# 创建预测的概率密度图
|
| 41 |
+
fig_kan = make_subplots(
|
| 42 |
+
rows=1, cols=2,
|
| 43 |
+
specs=[[{'type': 'surface'}, {'type': 'contour'}]],
|
| 44 |
+
subplot_titles=('KAN预测的3D概率密度曲面', 'KAN预测的等高线图')
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# 3D Surface
|
| 48 |
+
surface_kan = go.Surface(
|
| 49 |
+
x=X, y=Y, z=Z_kan,
|
| 50 |
+
colorscale='viridis',
|
| 51 |
+
showscale=True,
|
| 52 |
+
colorbar=dict(x=0.45)
|
| 53 |
+
)
|
| 54 |
+
fig_kan.add_trace(surface_kan, row=1, col=1)
|
| 55 |
+
|
| 56 |
+
# Contour Plot
|
| 57 |
+
contour_kan = go.Contour(
|
| 58 |
+
x=x, y=y, z=Z_kan,
|
| 59 |
+
colorscale='viridis',
|
| 60 |
+
showscale=True,
|
| 61 |
+
colorbar=dict(x=1.0),
|
| 62 |
+
contours=dict(
|
| 63 |
+
showlabels=True,
|
| 64 |
+
labelfont=dict(size=12)
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
fig_kan.add_trace(contour_kan, row=1, col=2)
|
| 68 |
+
|
| 69 |
+
# 添加采样点
|
| 70 |
+
if samples is not None:
|
| 71 |
+
fig_kan.add_trace(
|
| 72 |
+
go.Scatter(
|
| 73 |
+
x=samples[:, 0], y=samples[:, 1],
|
| 74 |
+
mode='markers',
|
| 75 |
+
marker=dict(
|
| 76 |
+
size=8,
|
| 77 |
+
color='yellow',
|
| 78 |
+
line=dict(color='black', width=1)
|
| 79 |
+
),
|
| 80 |
+
name='训练点'
|
| 81 |
+
),
|
| 82 |
+
row=1, col=2
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# 更新布局
|
| 86 |
+
fig_kan.update_layout(
|
| 87 |
+
title='KAN预测结果',
|
| 88 |
+
showlegend=True,
|
| 89 |
+
width=1200,
|
| 90 |
+
height=600,
|
| 91 |
+
scene=dict(
|
| 92 |
+
xaxis_title='X',
|
| 93 |
+
yaxis_title='Y',
|
| 94 |
+
zaxis_title='密度'
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# 更新2D图的坐标轴
|
| 99 |
+
fig_kan.update_xaxes(title_text='X', row=1, col=2)
|
| 100 |
+
fig_kan.update_yaxes(title_text='Y', row=1, col=2)
|
| 101 |
+
|
| 102 |
+
# 使用占位符显示图形
|
| 103 |
+
placeholder.plotly_chart(fig_kan, use_container_width=True)
|
| 104 |
+
|
| 105 |
+
def create_gmm_plot(dataset, centers, K, samples=None):
|
| 106 |
+
"""创建GMM分布的可视化图形"""
|
| 107 |
+
# 生成网格数据
|
| 108 |
+
x = np.linspace(-5, 5, 100)
|
| 109 |
+
y = np.linspace(-5, 5, 100)
|
| 110 |
+
X, Y = np.meshgrid(x, y)
|
| 111 |
+
xy = np.column_stack((X.ravel(), Y.ravel()))
|
| 112 |
+
|
| 113 |
+
# 计算概率密度
|
| 114 |
+
Z = dataset.pdf(xy).reshape(X.shape)
|
| 115 |
+
|
| 116 |
+
# 创建2D和3D可视化
|
| 117 |
+
fig = make_subplots(
|
| 118 |
+
rows=1, cols=2,
|
| 119 |
+
specs=[[{'type': 'surface'}, {'type': 'contour'}]],
|
| 120 |
+
subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# 3D Surface
|
| 124 |
+
surface = go.Surface(
|
| 125 |
+
x=X, y=Y, z=Z,
|
| 126 |
+
colorscale='viridis',
|
| 127 |
+
showscale=True,
|
| 128 |
+
colorbar=dict(x=0.45)
|
| 129 |
+
)
|
| 130 |
+
fig.add_trace(surface, row=1, col=1)
|
| 131 |
+
|
| 132 |
+
# Contour Plot
|
| 133 |
+
contour = go.Contour(
|
| 134 |
+
x=x, y=y, z=Z,
|
| 135 |
+
colorscale='viridis',
|
| 136 |
+
showscale=True,
|
| 137 |
+
colorbar=dict(x=1.0),
|
| 138 |
+
contours=dict(
|
| 139 |
+
showlabels=True,
|
| 140 |
+
labelfont=dict(size=12)
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
fig.add_trace(contour, row=1, col=2)
|
| 144 |
+
|
| 145 |
+
# 添加分量中心点
|
| 146 |
+
fig.add_trace(
|
| 147 |
+
go.Scatter(
|
| 148 |
+
x=centers[:K, 0], y=centers[:K, 1],
|
| 149 |
+
mode='markers+text',
|
| 150 |
+
marker=dict(size=10, color='red'),
|
| 151 |
+
text=[f'C{i+1}' for i in range(K)],
|
| 152 |
+
textposition="top center",
|
| 153 |
+
name='分量中心'
|
| 154 |
+
),
|
| 155 |
+
row=1, col=2
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# 添加采样点(如果有)
|
| 159 |
+
if samples is not None:
|
| 160 |
+
fig.add_trace(
|
| 161 |
+
go.Scatter(
|
| 162 |
+
x=samples[:, 0], y=samples[:, 1],
|
| 163 |
+
mode='markers+text',
|
| 164 |
+
marker=dict(
|
| 165 |
+
size=8,
|
| 166 |
+
color='yellow',
|
| 167 |
+
line=dict(color='black', width=1)
|
| 168 |
+
),
|
| 169 |
+
text=[f'S{i+1}' for i in range(len(samples))],
|
| 170 |
+
textposition="bottom center",
|
| 171 |
+
name='采样点'
|
| 172 |
+
),
|
| 173 |
+
row=1, col=2
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# 更新布局
|
| 177 |
+
fig.update_layout(
|
| 178 |
+
title='广义高斯混合分布',
|
| 179 |
+
showlegend=True,
|
| 180 |
+
width=1200,
|
| 181 |
+
height=600,
|
| 182 |
+
scene=dict(
|
| 183 |
+
xaxis_title='X',
|
| 184 |
+
yaxis_title='Y',
|
| 185 |
+
zaxis_title='密度'
|
| 186 |
+
)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# 更新2D图的坐标轴
|
| 190 |
+
fig.update_xaxes(title_text='X', row=1, col=2)
|
| 191 |
+
fig.update_yaxes(title_text='Y', row=1, col=2)
|
| 192 |
+
|
| 193 |
+
return fig
|
| 194 |
+
|
| 195 |
+
def train_kan(samples, gmm_dataset, device='cuda'):
|
| 196 |
+
"""训练KAN网络"""
|
| 197 |
+
if torch.cuda.is_available() and device == 'cuda':
|
| 198 |
+
device = torch.device('cuda')
|
| 199 |
+
else:
|
| 200 |
+
device = torch.device('cpu')
|
| 201 |
+
|
| 202 |
+
# 转换采样点为tensor
|
| 203 |
+
samples = torch.from_numpy(samples).to(device)
|
| 204 |
+
# 计算标签(概率密度值)
|
| 205 |
+
labels = torch.from_numpy(gmm_dataset.pdf(samples.cpu().numpy())).reshape(-1, 1).to(device)
|
| 206 |
+
|
| 207 |
+
# 创建KAN模型
|
| 208 |
+
model = KAN(width=[2,5,1], grid=3, k=3, seed=42, device=device)
|
| 209 |
+
# 创建训练和测试数据集
|
| 210 |
+
train_size = int(0.8 * samples.shape[0])
|
| 211 |
+
train_dataset = {
|
| 212 |
+
'train_input': samples[:train_size],
|
| 213 |
+
'train_label': labels[:train_size],
|
| 214 |
+
'test_input': samples[train_size:],
|
| 215 |
+
'test_label': labels[train_size:]
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
# 创建训练进度显示组件
|
| 219 |
+
st.write("网络结构:")
|
| 220 |
+
kan_fig_placeholder = st.empty()
|
| 221 |
+
st.write("预测结果:")
|
| 222 |
+
kan_plot_placeholder = st.empty()
|
| 223 |
+
progress_container = st.container()
|
| 224 |
+
|
| 225 |
+
total_steps = 100
|
| 226 |
+
steps_per_update = 10
|
| 227 |
+
|
| 228 |
+
def calculate_error(model, x, y):
|
| 229 |
+
"""计算预测误差"""
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
pred = model(x)
|
| 232 |
+
return torch.mean((pred - y) ** 2).item()
|
| 233 |
+
|
| 234 |
+
def train_phase(phase_name, steps, lamb=None, show_plot=True):
|
| 235 |
+
with progress_container:
|
| 236 |
+
progress_bar = st.progress(0)
|
| 237 |
+
status_text = st.empty()
|
| 238 |
+
|
| 239 |
+
for step in range(0, steps, steps_per_update):
|
| 240 |
+
# 训练几步
|
| 241 |
+
if lamb is not None:
|
| 242 |
+
model.fit(train_dataset, opt="LBFGS", steps=steps_per_update, lamb=lamb)
|
| 243 |
+
else:
|
| 244 |
+
model.fit(train_dataset, opt="LBFGS", steps=steps_per_update)
|
| 245 |
+
|
| 246 |
+
# 更新进度和误差
|
| 247 |
+
progress = (step + steps_per_update) / steps
|
| 248 |
+
progress_bar.progress(progress)
|
| 249 |
+
|
| 250 |
+
# 计算当前误差
|
| 251 |
+
train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
|
| 252 |
+
test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label'])
|
| 253 |
+
# 使用表格格式显示进度和误差
|
| 254 |
+
status_text.markdown(f"""
|
| 255 |
+
### {phase_name}
|
| 256 |
+
| 项目 | 值 |
|
| 257 |
+
|:---|:---|
|
| 258 |
+
| 进度 | {progress:.0%} |
|
| 259 |
+
| 训练误差 | {train_error:.8f} |
|
| 260 |
+
| 测试误差 | {test_error:.8f} |
|
| 261 |
+
""")
|
| 262 |
+
|
| 263 |
+
# 更新进度和预测结果
|
| 264 |
+
show_kan_prediction(model, device, samples, kan_plot_placeholder)
|
| 265 |
+
|
| 266 |
+
# 更新可视化(每5步更新一次)
|
| 267 |
+
if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps:
|
| 268 |
+
# 更新预测结果
|
| 269 |
+
show_kan_prediction(model, device, samples, kan_plot_placeholder)
|
| 270 |
+
|
| 271 |
+
# 更新网络结构图(可选)
|
| 272 |
+
if show_plot:
|
| 273 |
+
try:
|
| 274 |
+
kan_fig = model.plot()
|
| 275 |
+
if isinstance(kan_fig, tuple):
|
| 276 |
+
kan_fig = kan_fig[0] # 如果是元组,取第一个元素
|
| 277 |
+
if kan_fig is not None:
|
| 278 |
+
kan_fig_placeholder.pyplot(kan_fig)
|
| 279 |
+
plt.close('all') # 确保关闭所有图形
|
| 280 |
+
except Exception as e:
|
| 281 |
+
if step == 0: # 只在第一次出错时显示警告
|
| 282 |
+
st.warning(f"注意:网络结构图显示失败 ({str(e)})")
|
| 283 |
+
|
| 284 |
+
with progress_container:
|
| 285 |
+
st.markdown("#### 训练过程")
|
| 286 |
+
error_text = st.empty()
|
| 287 |
+
|
| 288 |
+
# 第一阶段训练
|
| 289 |
+
# 第一阶段:初始训练
|
| 290 |
+
with st.spinner("初始训练阶段..."):
|
| 291 |
+
train_phase("第一阶段", total_steps, lamb=0.001, show_plot=False) # 第一阶段不显示网络图
|
| 292 |
+
|
| 293 |
+
# 剪枝阶段
|
| 294 |
+
with st.spinner("正在进行网络剪枝优化..."):
|
| 295 |
+
model = model.prune()
|
| 296 |
+
progress_container.info("网络剪枝完成")
|
| 297 |
+
|
| 298 |
+
# 第二阶段:精调
|
| 299 |
+
with st.spinner("最终调优阶段..."):
|
| 300 |
+
train_phase("第二阶段", total_steps, show_plot=True) # 第二阶段显示网络图
|
| 301 |
+
|
| 302 |
+
# 显示最终误差
|
| 303 |
+
train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
|
| 304 |
+
test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label'])
|
| 305 |
+
error_text.markdown(f"""
|
| 306 |
+
#### 训练结果
|
| 307 |
+
- 训练集误差: {train_error:.6f}
|
| 308 |
+
- 测试集误差: {test_error:.6f}
|
| 309 |
+
""")
|
| 310 |
+
|
| 311 |
+
progress_container.success("🎉 训练完成!")
|
| 312 |
+
return model
|
| 313 |
|
| 314 |
def init_session_state():
|
| 315 |
"""初始化session state"""
|
|
|
|
| 325 |
st.session_state.weights = np.ones(3, dtype=np.float64) / 3
|
| 326 |
if 'sample_points' not in st.session_state:
|
| 327 |
st.session_state.sample_points = None
|
| 328 |
+
if 'kan_model' not in st.session_state:
|
| 329 |
+
st.session_state.kan_model = None
|
| 330 |
|
| 331 |
def create_default_parameters(K: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 332 |
"""创建默认参数"""
|
|
|
|
| 352 |
c = centers[k]
|
| 353 |
s = scales[k]
|
| 354 |
w = weights[k]
|
| 355 |
+
component = f"P_{{\\theta_{k+1}}}(x) = \\frac{{{p:.1f}}}{{2\\alpha_{k+1} \\Gamma(1/{p:.1f})}}\\exp(-|\\frac{{x-({c[0]:.1f}, {c[1]:.1f})}}{{{s[0]:.1f}, {s[1]:.1f}}}|^{{{p:.1f}}}) \\\\"
|
| 356 |
formula += component
|
| 357 |
formula += f"\\pi_{k+1} = {w:.2f} \\\\"
|
| 358 |
|
|
|
|
| 423 |
st.subheader("采样设置")
|
| 424 |
n_samples = st.slider("采样点数", 5, 20, 10)
|
| 425 |
if st.button("重新采样"):
|
| 426 |
+
# 创建GMM数据集进行采样
|
| 427 |
+
gmm = GeneralizedGaussianMixture(
|
| 428 |
+
D=2,
|
| 429 |
+
K=K,
|
| 430 |
+
p=st.session_state.p,
|
| 431 |
+
centers=centers[:K],
|
| 432 |
+
scales=scales[:K],
|
| 433 |
+
weights=weights[:K]
|
| 434 |
+
)
|
| 435 |
+
# 使用GMM生成采样点
|
| 436 |
+
samples, _ = gmm.generate_samples(n_samples)
|
| 437 |
+
st.session_state.sample_points = samples
|
| 438 |
+
st.session_state.kan_model = None # 重置KAN模型
|
| 439 |
|
| 440 |
# 创建GMM数据集
|
| 441 |
dataset = GeneralizedGaussianMixture(
|
|
|
|
| 507 |
posteriors = []
|
| 508 |
for sample in samples:
|
| 509 |
component_probs = [
|
| 510 |
+
weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
|
| 511 |
for k in range(K)
|
| 512 |
]
|
| 513 |
total = sum(component_probs)
|
|
|
|
| 529 |
),
|
| 530 |
row=1, col=2
|
| 531 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
|
| 533 |
# 更新布局
|
| 534 |
fig.update_layout(
|
|
|
|
| 547 |
fig.update_xaxes(title_text='X', row=1, col=2)
|
| 548 |
fig.update_yaxes(title_text='Y', row=1, col=2)
|
| 549 |
|
| 550 |
+
# 显示GMM主图
|
| 551 |
st.plotly_chart(fig, use_container_width=True)
|
| 552 |
|
| 553 |
+
# KAN网络训练和预测部分
|
| 554 |
+
if st.session_state.sample_points is not None:
|
| 555 |
+
st.markdown("---")
|
| 556 |
+
st.subheader("KAN网络训练与预测")
|
| 557 |
+
|
| 558 |
+
# 训练控制按钮
|
| 559 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
| 560 |
+
with col1:
|
| 561 |
+
if st.button("拟合KAN", use_container_width=True):
|
| 562 |
+
with st.spinner('训练KAN网络中...'):
|
| 563 |
+
st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset)
|
| 564 |
+
st.balloons()
|
| 565 |
+
|
| 566 |
+
with col3:
|
| 567 |
+
if st.session_state.kan_model is not None:
|
| 568 |
+
if st.button("清除KAN结果", use_container_width=True):
|
| 569 |
+
st.session_state.kan_model = None
|
| 570 |
+
st.rerun()
|
| 571 |
+
|
| 572 |
+
# 显示KAN预测结果
|
| 573 |
+
if st.session_state.kan_model is not None:
|
| 574 |
+
st.subheader("KAN预测结果")
|
| 575 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 576 |
+
kan_plot_placeholder = st.empty()
|
| 577 |
+
show_kan_prediction(st.session_state.kan_model, device,
|
| 578 |
+
st.session_state.sample_points, kan_plot_placeholder)
|
| 579 |
+
|
| 580 |
+
st.markdown("---")
|
| 581 |
+
|
| 582 |
+
# 显示采样点信息
|
| 583 |
+
if st.session_state.sample_points is not None:
|
| 584 |
+
# 重新计算采样点的概率密度和后验概率
|
| 585 |
+
samples = st.session_state.sample_points
|
| 586 |
+
probs = dataset.pdf(samples)
|
| 587 |
+
posteriors = []
|
| 588 |
+
for sample in samples:
|
| 589 |
+
component_probs = [
|
| 590 |
+
weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
|
| 591 |
+
for k in range(K)
|
| 592 |
+
]
|
| 593 |
+
total = sum(component_probs)
|
| 594 |
+
posteriors.append([p/total for p in component_probs])
|
| 595 |
+
|
| 596 |
+
with st.expander("采样点信息"):
|
| 597 |
+
# 创建数据列表
|
| 598 |
+
point_data = []
|
| 599 |
+
for i, (sample, prob, post) in enumerate(zip(samples, probs, posteriors)):
|
| 600 |
+
row = {
|
| 601 |
+
'采样点': f'S{i+1}',
|
| 602 |
+
'X坐标': f'{sample[0]:.2f}',
|
| 603 |
+
'Y坐标': f'{sample[1]:.2f}',
|
| 604 |
+
'概率密度': f'{prob:.4f}'
|
| 605 |
+
}
|
| 606 |
+
# 添加每个分量的后验概率
|
| 607 |
+
for k in range(K):
|
| 608 |
+
row[f'分量{k+1}后验概率'] = f'{post[k]:.4f}'
|
| 609 |
+
point_data.append(row)
|
| 610 |
+
|
| 611 |
+
# 显示dataframe
|
| 612 |
+
st.dataframe(point_data)
|
| 613 |
+
|
| 614 |
# 添加参数说明
|
| 615 |
with st.expander("分布参数说明"):
|
| 616 |
st.markdown("""
|