Lee93whut commited on
Commit
a264030
·
1 Parent(s): a91b194

feat(demo): Streamlit web demo — Plotly heatmap, anti-loop inference

Browse files

app.py:
- Interactive 10×10 maze rendered as Plotly go.Heatmap
- Dropdown + random button for start/goal selection
- Load any of 4 algorithm weights (Vanilla/Double/Dueling/Double+Dueling)
- DQN rollout with anti-loop inference guard:
visit_cnt >= 2 → Q[action] -= 3.0 × visit_cnt
(inference-only Q-value patch, does not affect training distribution)
- BFS shortest path overlay for SPL ground-truth comparison
- Deployed on Hugging Face Spaces (Docker SDK)

Files changed (1) hide show
  1. app.py +811 -0
app.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """app.py —— DQN 迷宫寻路可视化 Web App
2
+ Hugging Face Spaces (Docker SDK) 专用
3
+
4
+ 部署清单(上传到 HF Space 的全部文件)
5
+ --------------------------------------
6
+ app.py 本文件
7
+ src/model.py 神经网络架构
8
+ results/best_model_train_vanilla.pth vanilla DQN 权重
9
+ results/best_model_train_double.pth Double DQN 权重
10
+ results/best_model_train_dueling.pth Dueling DQN 权重
11
+ results/best_model_train_double_dueling.pth Double Dueling DQN 权重
12
+ config.yaml 环境配置(grid_size / obstacle_density / max_steps)
13
+ requirements.txt 依赖列表
14
+
15
+ 导入策略
16
+ --------
17
+ * maze_env 通过 `pip install -e .` 安装(见 Dockerfile),直接 import。
18
+ * src/ 通过 pyproject.toml packages.find 配置,同样可安装,直接 import。
19
+ * 所有模块均通过标准 import 路径解析,无需 sys.path 手动注入。
20
+
21
+ 端口说明
22
+ --------
23
+ HF Docker Space 固定使用 7860 端口(见 Dockerfile / README)。
24
+ 本地调试:streamlit run app.py
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import random
30
+ import time
31
+ from pathlib import Path
32
+ from typing import Optional
33
+
34
+ import numpy as np
35
+ import plotly.graph_objects as go
36
+ import streamlit as st
37
+ import torch
38
+ import yaml
39
+
40
+ # ── maze_env 包(已安装,直接导入)──────────────────────────────────────────
41
+ from maze_env import MazeEnv
42
+ from maze_env.bfs import bfs as bfs_solve
43
+
44
+ # ── src 包(pip install -e . 后可直接导入)───────────────────────────────────
45
+ import torch.nn as nn
46
+ from src.model import DQNNetwork, DuelingDQNNetwork
47
+
48
+ # ===========================================================================
49
+ # 常量 & 配置
50
+ # ===========================================================================
51
+ _CONFIG_PATH = Path(__file__).parent / "config.yaml"
52
+ if _CONFIG_PATH.exists():
53
+ _cfg = yaml.safe_load(_CONFIG_PATH.read_text(encoding="utf-8"))
54
+ else:
55
+ import warnings
56
+ warnings.warn(
57
+ f"config.yaml 未找到({_CONFIG_PATH}),使用内置默认值。"
58
+ "若训练时使用了非默认 grid_size,推理结果可能错误。",
59
+ stacklevel=1,
60
+ )
61
+ _cfg = {}
62
+ _maze_cfg = _cfg.get("maze", {})
63
+
64
+ GRID_SIZE = int(_maze_cfg.get("grid_size", 10))
65
+ OBSTACLE_DENSITY = float(_maze_cfg.get("obstacle_density", 0.25)) # 与 config.yaml maze.obstacle_density 保持一致,确保 Demo 与训练分布相同
66
+ MAX_STEPS = int(_maze_cfg.get("max_steps", 200)) # 与训练保持一致,推理步数预算对齐
67
+
68
+ # 支持切换的四算法(顺序决定 UI 下拉框排列)
69
+ ALGO_OPTIONS: list[str] = ["double_dueling", "dueling", "double", "vanilla"]
70
+ ALGO_LABELS: dict[str, str] = {
71
+ "vanilla": "Vanilla DQN(基准)",
72
+ "double": "Double DQN(抑制高估)",
73
+ "dueling": "Dueling DQN(V+A 分解)",
74
+ "double_dueling": "Double + Dueling(推荐)",
75
+ }
76
+ # 默认算法:优先读 config.yaml,fallback 到 double_dueling
77
+ _default_algo = str(_cfg.get("dqn", {}).get("algorithm", "double_dueling")).strip().lower()
78
+ DEFAULT_ALGO: str = _default_algo if _default_algo in ALGO_OPTIONS else "double_dueling"
79
+
80
+ def model_path_for(algo: str) -> Path:
81
+ """根据算法名返回对应权重文件路径。"""
82
+ return Path(__file__).parent / "results" / f"best_model_train_{algo}.pth"
83
+
84
+ # 首屏默认迷宫 seed。
85
+ # 固定值保证分享链接时双方看到相同地图;改为 None 可让每次刷新随机生成。
86
+ DEFAULT_MAZE_SEED: int = 42
87
+
88
+ # 动画帧间隔(秒)
89
+ ANIM_DELAY = 0.08
90
+
91
+ # 颜色映射(RGB 列表,供 Plotly heatmap)
92
+ COLOR_EMPTY = "#F8F9FA" # 白/浅灰 —— 可通行地板
93
+ COLOR_WALL = "#2C3E50" # 深蓝灰 —— 墙壁
94
+ COLOR_START = "#27AE60" # 绿色 —— 起点
95
+ COLOR_GOAL = "#E74C3C" # 红色 —— 终点
96
+ COLOR_DQN_PATH = "#3498DB" # 蓝色 —— DQN 轨迹
97
+ COLOR_BFS_PATH = "#F39C12" # 橙色 —— BFS 最短路
98
+ COLOR_AGENT = "#9B59B6" # 紫色 —— 当前 Agent 位置
99
+
100
+ # ===========================================================================
101
+ # 工具函数
102
+ # ===========================================================================
103
+
104
+ def generate_maze(seed: Optional[int] = None) -> np.ndarray:
105
+ """生成 GRID_SIZE×GRID_SIZE 迷宫,保证起点 (1,1) 与终点 (N-2,N-2) 可达。
106
+
107
+ 委托给 :class:`MazeEnv` 的 ``reset()`` 方法,确保与训练环境完全一致
108
+ (相同的边界墙、障碍密度、BFS 连通性保证,不重复造轮子)。
109
+
110
+ Args:
111
+ seed: 随机种子;``None`` 表示不固定随机性。
112
+
113
+ Returns:
114
+ wall_map: shape ``(N, N)``,dtype ``int32``,0=通路,1=墙壁。
115
+ """
116
+ env = MazeEnv(
117
+ grid_size=GRID_SIZE,
118
+ obstacle_density=OBSTACLE_DENSITY,
119
+ )
120
+ env.reset(seed=seed)
121
+ return env.wall_map.astype(np.int32)
122
+
123
+
124
+ def generate_maze_with_random_sg(
125
+ seed: Optional[int] = None,
126
+ ) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
127
+ """生成迷宫并从可通行内部格随机选取起点和终点,与训练分布完全一致。
128
+
129
+ 复现 train.py 中 ``random_start_goal=True`` 的逻辑:
130
+ 先生成迷宫,再用 ``env.np_random``(Gymnasium 注入的唯一随机源)
131
+ 从内部可通行格中不放回地抽取两个不同坐标,确保 Demo 与训练同分布。
132
+
133
+ Args:
134
+ seed: 随机种子;``None`` 表示不固定随机性。
135
+
136
+ Returns:
137
+ (wall_map, start, goal):
138
+ * wall_map: shape ``(N, N)``,dtype ``int32``。
139
+ * start: 起点坐标 ``(row, col)``。
140
+ * goal: 终点坐标 ``(row, col)``。
141
+ """
142
+ env = MazeEnv(
143
+ grid_size=GRID_SIZE,
144
+ obstacle_density=OBSTACLE_DENSITY,
145
+ )
146
+ env.reset(seed=seed)
147
+ wall_map = env.wall_map.astype(np.int32) # (N, N)
148
+
149
+ # 收集内部(非边界)可通行格,与 train.py 过滤条件完全相同
150
+ rows, cols = np.where(wall_map == 0)
151
+ inner_cells: list[tuple[int, int]] = [
152
+ (int(r), int(c))
153
+ for r, c in zip(rows, cols)
154
+ if 0 < r < GRID_SIZE - 1 and 0 < c < GRID_SIZE - 1
155
+ ]
156
+
157
+ if len(inner_cells) < 2:
158
+ # 极端情况(障碍密度极高):退回到固定起终点
159
+ return wall_map, (1, 1), (GRID_SIZE - 2, GRID_SIZE - 2)
160
+
161
+ # 使用 env.np_random(与训练逻辑完全一致,不污染全局随机状态)
162
+ idxs = env.np_random.integers(0, len(inner_cells), size=2)
163
+ while idxs[0] == idxs[1]:
164
+ idxs = env.np_random.integers(0, len(inner_cells), size=2)
165
+
166
+ start = inner_cells[int(idxs[0])]
167
+ goal = inner_cells[int(idxs[1])]
168
+ return wall_map, start, goal
169
+
170
+
171
+ def load_model(algo: str = DEFAULT_ALGO, grid_size: int = GRID_SIZE) -> tuple[Optional[nn.Module], int]:
172
+ """加载指定算法的 DQN 模型权重,返回 (net, saved_grid_size)。
173
+
174
+ Args:
175
+ algo: 算法名,须在 ALGO_OPTIONS 中。
176
+ grid_size: 当前环境 grid_size,用于维度不一致时的 fallback 返回值。
177
+
178
+ 失败时返回 (None, grid_size)。saved_grid_size 供调用方检测维度是否与
179
+ 当前 GRID_SIZE 一致;不一致时推理输入维度会与网络期望不符,应提前告警。
180
+ """
181
+ path = model_path_for(algo)
182
+ if not path.exists():
183
+ return None, grid_size
184
+ try:
185
+ ckpt = torch.load(path, map_location="cpu", weights_only=True)
186
+ saved_gs = ckpt.get("grid_size", grid_size)
187
+ algorithm = ckpt.get("algorithm", "vanilla").strip().lower()
188
+ NetClass = DuelingDQNNetwork if "dueling" in algorithm else DQNNetwork
189
+ net = NetClass(grid_size=saved_gs)
190
+ net.load_state_dict(ckpt["state_dict"])
191
+ net.eval()
192
+ return net, saved_gs
193
+ except Exception as e:
194
+ st.error(f"❌ 模型加载失败:{e}")
195
+ return None, grid_size
196
+
197
+
198
+ def dqn_rollout(
199
+ net: nn.Module,
200
+ wall_map: np.ndarray,
201
+ start: tuple,
202
+ goal: tuple,
203
+ ) -> list[tuple]:
204
+ """纯推理(ε=0)运行 DQN Agent,返回完整轨迹坐标列表。
205
+
206
+ 委托给 :class:`MazeEnv` 的标准 ``reset()`` / ``step()`` 接口,
207
+ 保证观测编码与训练时完全一致,无需在 app.py 中重复实现碰撞检测。
208
+
209
+ Args:
210
+ net: 已加载权重、处于 eval 模式的 DQN 网络。
211
+ wall_map: shape ``(N, N)``,dtype int32,0=通路,1=墙壁。
212
+ start: Agent 起点 ``(row, col)``。
213
+ goal: 终点 ``(row, col)``。
214
+
215
+ Returns:
216
+ 完整轨迹(含起点),每条为 ``(row, col)``。
217
+ """
218
+ env = MazeEnv(
219
+ grid_size=wall_map.shape[0],
220
+ obstacle_density=0.0, # 密度无关,地图由外部注入
221
+ max_steps=MAX_STEPS,
222
+ )
223
+ obs, _ = env.reset(options={
224
+ "wall_map": wall_map.astype(np.float32),
225
+ "start": start,
226
+ "goal": goal,
227
+ })
228
+
229
+ path = [env.agent_pos]
230
+
231
+ # 注:R4 起观测已包含 visited_map 第4通道(ch3),Agent 天然感知访问历史,
232
+ # 无需在推理侧注入 Q 值惩罚。直接贪心执行即可。
233
+ while True:
234
+ s = torch.from_numpy(obs).unsqueeze(0)
235
+ with torch.no_grad():
236
+ q_values = net(s)[0] # shape: (num_actions,)
237
+
238
+ action = int(q_values.argmax().item())
239
+ obs, _reward, terminated, truncated, info = env.step(action)
240
+
241
+ # 只在实际移动时追加(撞墙时位置不变,避免重复坐标导致动画抖帧)
242
+ if not info["hit_wall"]:
243
+ path.append(env.agent_pos)
244
+
245
+ if terminated or truncated:
246
+ break
247
+
248
+ return path
249
+
250
+
251
+ # ===========================================================================
252
+ # Plotly 迷宫绘制
253
+ # ===========================================================================
254
+
255
+ def build_maze_figure(
256
+ wall_map: np.ndarray,
257
+ start: tuple,
258
+ goal: tuple,
259
+ dqn_path: Optional[list] = None,
260
+ bfs_path: Optional[list] = None,
261
+ agent_pos: Optional[tuple] = None,
262
+ highlight_dqn_step: int = -1,
263
+ ) -> go.Figure:
264
+ """构建 Plotly 迷宫图,支持叠加 DQN / BFS 路径与动态 Agent 标记。"""
265
+ N = wall_map.shape[0]
266
+
267
+ # ── 底层热力图(单 Heatmap trace,O(1) traces vs O(N²) shapes)─────────
268
+ # 数值矩阵:0=通路, 1=墙, 2=起点, 3=终点
269
+ z = wall_map.astype(float).copy()
270
+ z[start[0], start[1]] = 2.0
271
+ z[goal[0], goal[1]] = 3.0
272
+
273
+ # 离散颜色映射:值 → 颜色
274
+ colorscale = [
275
+ [0.00, COLOR_EMPTY], # 0 = 通路
276
+ [0.25, COLOR_EMPTY],
277
+ [0.25, COLOR_WALL], # 1 = 墙
278
+ [0.50, COLOR_WALL],
279
+ [0.50, COLOR_START], # 2 = 起点
280
+ [0.75, COLOR_START],
281
+ [0.75, COLOR_GOAL], # 3 = 终点
282
+ [1.00, COLOR_GOAL],
283
+ ]
284
+
285
+ fig = go.Figure()
286
+ fig.add_trace(go.Heatmap(
287
+ z=z,
288
+ colorscale=colorscale,
289
+ zmin=0, zmax=3,
290
+ showscale=False,
291
+ xgap=1, ygap=1,
292
+ hoverinfo="skip",
293
+ ))
294
+
295
+ # ── BFS 路径(橙色虚线)──────────────────────────────────────────────
296
+ if bfs_path and len(bfs_path) > 1:
297
+ bx = [c for r, c in bfs_path]
298
+ by = [r for r, c in bfs_path]
299
+ fig.add_trace(go.Scatter(
300
+ x=bx, y=by,
301
+ mode="lines+markers",
302
+ name="BFS 最短路",
303
+ line=dict(color=COLOR_BFS_PATH, width=3, dash="dot"),
304
+ marker=dict(size=6, color=COLOR_BFS_PATH, opacity=0.7),
305
+ ))
306
+
307
+ # ── DQN 路径(蓝色实线)──────────────────────────────────────────────
308
+ if dqn_path and len(dqn_path) > 1:
309
+ # 截取到 highlight_dqn_step(动画用)
310
+ end_idx = highlight_dqn_step + 1 if highlight_dqn_step >= 0 else len(dqn_path)
311
+ sub_path = dqn_path[:end_idx]
312
+ dx = [c for r, c in sub_path]
313
+ dy = [r for r, c in sub_path]
314
+ fig.add_trace(go.Scatter(
315
+ x=dx, y=dy,
316
+ mode="lines+markers",
317
+ name="DQN 轨迹",
318
+ line=dict(color=COLOR_DQN_PATH, width=3),
319
+ marker=dict(size=7, color=COLOR_DQN_PATH),
320
+ ))
321
+
322
+ # ── 当前 Agent 位置(紫色大圆点)────────────────────────────────────
323
+ ap = agent_pos if agent_pos else (start if not dqn_path else
324
+ (dqn_path[min(highlight_dqn_step, len(dqn_path)-1)]
325
+ if highlight_dqn_step >= 0 else start))
326
+ fig.add_trace(go.Scatter(
327
+ x=[ap[1]], y=[ap[0]],
328
+ mode="markers",
329
+ name="Agent",
330
+ marker=dict(size=16, color=COLOR_AGENT, symbol="circle",
331
+ line=dict(color="white", width=2)),
332
+ showlegend=True,
333
+ ))
334
+
335
+ # ── 起点 / 终点标签 ───────────────────────────────────────────────────
336
+ fig.add_trace(go.Scatter(
337
+ x=[start[1], goal[1]],
338
+ y=[start[0], goal[0]],
339
+ mode="markers+text",
340
+ text=["S", "G"],
341
+ textposition="middle center",
342
+ textfont=dict(size=13, color="white", family="Arial Black"),
343
+ marker=dict(size=22, color=[COLOR_START, COLOR_GOAL],
344
+ symbol="square", opacity=0.0), # 透明底,只显示字
345
+ showlegend=False,
346
+ hoverinfo="skip",
347
+ ))
348
+
349
+ # ── 布局 ─────────────────────────────────────────────────────────────
350
+ fig.update_layout(
351
+ width=560, height=560,
352
+ margin=dict(l=10, r=10, t=30, b=10),
353
+ xaxis=dict(
354
+ range=[-0.5, N - 0.5], tickvals=list(range(N)),
355
+ showgrid=False, zeroline=False, title="列 (col)",
356
+ ),
357
+ yaxis=dict(
358
+ range=[N - 0.5, -0.5],
359
+ tickvals=list(range(N)),
360
+ showgrid=False, zeroline=False, title="行 (row)",
361
+ ),
362
+ legend=dict(x=1.01, y=1, bgcolor="rgba(255,255,255,0.8)",
363
+ bordercolor="#BDC3C7", borderwidth=1),
364
+ paper_bgcolor="white",
365
+ plot_bgcolor="white",
366
+ title=dict(text="🏁 DQN 迷宫寻路", x=0.5, font=dict(size=16)),
367
+ )
368
+ return fig
369
+
370
+
371
+ def _find_cell_index(free_cells: list[tuple], pos: tuple) -> int:
372
+ """在 free_cells 列表中查找 pos 的索引;未找到时返回 0(安全回退)。"""
373
+ try:
374
+ return free_cells.index(pos)
375
+ except ValueError:
376
+ return 0
377
+
378
+
379
+ # ===========================================================================
380
+ # Session State 初始化
381
+ # ===========================================================================
382
+
383
+ def _init_state() -> None:
384
+ if "wall_map" not in st.session_state:
385
+ # 首屏使用随机起终点(与训练分布一致),固定 seed 保证可复现
386
+ wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=DEFAULT_MAZE_SEED)
387
+ st.session_state.wall_map = wm
388
+ st.session_state.start = sg_start
389
+ st.session_state.goal = sg_goal
390
+ if "start" not in st.session_state:
391
+ st.session_state.start = (1, 1)
392
+ if "goal" not in st.session_state:
393
+ st.session_state.goal = (GRID_SIZE - 2, GRID_SIZE - 2)
394
+ if "dqn_path" not in st.session_state:
395
+ st.session_state.dqn_path = None
396
+ if "bfs_path" not in st.session_state:
397
+ st.session_state.bfs_path = None
398
+ if "metrics" not in st.session_state:
399
+ st.session_state.metrics = None
400
+ if "selected_algo" not in st.session_state:
401
+ st.session_state.selected_algo = DEFAULT_ALGO
402
+ if "model" not in st.session_state:
403
+ net, saved_gs = load_model(algo=DEFAULT_ALGO)
404
+ st.session_state.model = net
405
+ st.session_state.model_grid_size = saved_gs
406
+ if "maze_seed" not in st.session_state:
407
+ st.session_state.maze_seed = DEFAULT_MAZE_SEED
408
+ if "anim_running" not in st.session_state:
409
+ st.session_state.anim_running = False
410
+ if "anim_step" not in st.session_state:
411
+ st.session_state.anim_step = 0
412
+ if "anim_path" not in st.session_state:
413
+ st.session_state.anim_path = None
414
+
415
+
416
+ # ===========================================================================
417
+ # 主程序
418
+ # ===========================================================================
419
+
420
+ def main() -> None:
421
+ st.set_page_config(
422
+ page_title="DQN 迷宫寻路 Demo",
423
+ page_icon="🤖",
424
+ layout="wide",
425
+ )
426
+
427
+ # ── 全局样式注入 ────────────────────────────────────────────────────────
428
+ st.markdown("""
429
+ <style>
430
+ .metric-card {
431
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
432
+ border-radius: 12px; padding: 16px 20px; color: white;
433
+ text-align: center; margin: 6px 0;
434
+ }
435
+ .metric-label { font-size: 13px; opacity: 0.85; margin-bottom: 4px; }
436
+ .metric-value { font-size: 28px; font-weight: 700; }
437
+ .por-perfect { color: #2ECC71; font-weight: 800; }
438
+ .por-good { color: #F39C12; font-weight: 700; }
439
+ .por-bad { color: #E74C3C; font-weight: 600; }
440
+ div[data-testid="stButton"] button {
441
+ width: 100%; border-radius: 8px; font-weight: 600;
442
+ }
443
+ /* 迷宫按钮网格:每格紧凑正方形,无内边距 */
444
+ div[data-testid="stHorizontalBlock"] div[data-testid="stButton"] button {
445
+ padding: 0 !important;
446
+ min-height: 40px !important;
447
+ font-size: 15px !important;
448
+ border-radius: 3px !important;
449
+ border: 1px solid #ccc !important;
450
+ line-height: 1 !important;
451
+ }
452
+ </style>
453
+ """, unsafe_allow_html=True)
454
+
455
+ _init_state()
456
+
457
+ st.title("🤖 DQN 迷宫寻路 · 可视化 Demo")
458
+ st.caption("Deep Q-Network × BFS Ground-Truth · Hugging Face Spaces")
459
+
460
+ # ═══════════════════════════════════════════════════════════════════════
461
+ # 正常双栏布局(点击模式在右栏内处理,不破坏整体布局)
462
+ # ═══════════════════════════════════════════════════════════════════════
463
+ left_col, right_col = st.columns([1, 2.2], gap="large")
464
+
465
+ # ───────────────────────────────────────────────────────────────────────
466
+ # 左栏:控制面板
467
+ # ───────────────────────────────────────────────────────────────────────
468
+ with left_col:
469
+ st.subheader("⚙️ 控制面板")
470
+
471
+ # ── 迷宫生成 ─────────────────────────────────────────────────────
472
+ st.markdown("**① 迷宫地图**")
473
+ col_seed, col_rand = st.columns([3, 1])
474
+ with col_seed:
475
+ input_seed = st.number_input(
476
+ "迷宫 Seed",
477
+ min_value=0,
478
+ max_value=999999,
479
+ value=st.session_state.maze_seed,
480
+ step=1,
481
+ help="固定数字可复现指定地图;点击右侧按钮随机生成新地图",
482
+ )
483
+ with col_rand:
484
+ st.write("") # 对齐占位
485
+ if st.button("🎲 随机"):
486
+ # 随机 seed:同时随机生成地图和起终点(与训练分布一致)
487
+ new_seed = random.randint(0, 999999)
488
+ wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=new_seed)
489
+ st.session_state.maze_seed = new_seed
490
+ st.session_state.wall_map = wm
491
+ st.session_state.start = sg_start
492
+ st.session_state.goal = sg_goal
493
+ st.session_state.dqn_path = None
494
+ st.session_state.bfs_path = None
495
+ st.session_state.metrics = None
496
+ # 同步下拉框索引,避免 selectbox key 缓存旧值
497
+ _fc = [(r,c) for r in range(1,GRID_SIZE-1) for c in range(1,GRID_SIZE-1) if wm[r,c]==0]
498
+ st.session_state.start_select = _find_cell_index(_fc, sg_start)
499
+ st.session_state.goal_select = _find_cell_index(_fc, sg_goal)
500
+ st.rerun() # 立即终止当前脚本,下方 input_seed 检测不会执行
501
+
502
+ # 手动修改 seed 输入框时触发(随机按钮已由上方 rerun 短路,不会重复)
503
+ if input_seed != st.session_state.maze_seed:
504
+ wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=input_seed)
505
+ st.session_state.maze_seed = input_seed
506
+ st.session_state.wall_map = wm
507
+ st.session_state.start = sg_start
508
+ st.session_state.goal = sg_goal
509
+ st.session_state.dqn_path = None
510
+ st.session_state.bfs_path = None
511
+ st.session_state.metrics = None
512
+ _fc = [(r,c) for r in range(1,GRID_SIZE-1) for c in range(1,GRID_SIZE-1) if wm[r,c]==0]
513
+ st.session_state.start_select = _find_cell_index(_fc, sg_start)
514
+ st.session_state.goal_select = _find_cell_index(_fc, sg_goal)
515
+ st.rerun()
516
+
517
+ st.divider()
518
+
519
+ # ── 起点 / 终点选择 ────────────────────────────────────────────────
520
+ st.markdown("**② 起点 & 终点**")
521
+
522
+ # 「随机起终点」按钮:从当前地图的可通行格随机选取,与训练分布一致
523
+ if st.button("🎲 随机起终点", use_container_width=True,
524
+ help="从当前地图可通行格随机选取起点和终点,与训练分布完全一致"):
525
+ _wm = st.session_state.wall_map
526
+ _rows, _cols = np.where(_wm == 0)
527
+ _inner = [
528
+ (int(r), int(c))
529
+ for r, c in zip(_rows, _cols)
530
+ if 0 < r < GRID_SIZE - 1 and 0 < c < GRID_SIZE - 1
531
+ ]
532
+ if len(_inner) >= 2:
533
+ _i, _j = random.sample(range(len(_inner)), 2)
534
+ st.session_state.start = _inner[_i]
535
+ st.session_state.goal = _inner[_j]
536
+ st.session_state.dqn_path = None
537
+ st.session_state.bfs_path = None
538
+ st.session_state.metrics = None
539
+ st.session_state.start_select = _find_cell_index(_inner, _inner[_i])
540
+ st.session_state.goal_select = _find_cell_index(_inner, _inner[_j])
541
+ st.rerun()
542
+
543
+ N = GRID_SIZE
544
+ free_cells = [
545
+ (r, c)
546
+ for r in range(1, N - 1)
547
+ for c in range(1, N - 1)
548
+ if st.session_state.wall_map[r, c] == 0
549
+ ]
550
+ cell_labels = [f"({r},{c})" for r, c in free_cells]
551
+
552
+ start_idx = st.selectbox(
553
+ "起点 (row, col)",
554
+ options=range(len(free_cells)),
555
+ format_func=lambda i: cell_labels[i],
556
+ index=_find_cell_index(free_cells, st.session_state.start),
557
+ key="start_select",
558
+ )
559
+ goal_idx = st.selectbox(
560
+ "终点 (row, col)",
561
+ options=range(len(free_cells)),
562
+ format_func=lambda i: cell_labels[i],
563
+ index=_find_cell_index(free_cells, st.session_state.goal),
564
+ key="goal_select",
565
+ )
566
+ new_start = free_cells[start_idx]
567
+ new_goal = free_cells[goal_idx]
568
+
569
+ if new_start == new_goal:
570
+ st.warning("⚠️ 起点与终点不能相同,请重新选择。")
571
+ elif new_start != st.session_state.start or new_goal != st.session_state.goal:
572
+ st.session_state.start = new_start
573
+ st.session_state.goal = new_goal
574
+ st.session_state.dqn_path = None
575
+ st.session_state.bfs_path = None
576
+ st.session_state.metrics = None
577
+
578
+ st.divider()
579
+
580
+ # ── 算法选择 & 寻路触发按钮 ───────────────────────────────────────
581
+ st.markdown("**③ ��路算法**")
582
+
583
+ selected_algo = st.selectbox(
584
+ "DQN 算法变体",
585
+ options=ALGO_OPTIONS,
586
+ format_func=lambda a: ALGO_LABELS[a],
587
+ index=ALGO_OPTIONS.index(st.session_state.selected_algo),
588
+ key="algo_select",
589
+ help="切换算法后点击「DQN 寻路」按钮可对比不同算法在同一地图上的路径",
590
+ )
591
+ # 算法切换时重新加载对应模型,清空上次路径结果
592
+ if selected_algo != st.session_state.selected_algo:
593
+ st.session_state.selected_algo = selected_algo
594
+ net, saved_gs = load_model(algo=selected_algo)
595
+ st.session_state.model = net
596
+ st.session_state.model_grid_size = saved_gs
597
+ st.session_state.dqn_path = None
598
+ st.session_state.metrics = None
599
+ st.rerun()
600
+
601
+ run_dqn = st.button(
602
+ "🤖 DQN 智能体寻路",
603
+ use_container_width=True,
604
+ type="primary",
605
+ )
606
+ run_bfs = st.button(
607
+ "📐 BFS 专家寻路",
608
+ use_container_width=True,
609
+ )
610
+
611
+ st.divider()
612
+
613
+ # ── 图例说明 ────────────────────────────────────────────────────
614
+ st.markdown("**图例**")
615
+ legend_html = """
616
+ <div style='font-size:13px; line-height:2'>
617
+ 🟩 <b>S</b> 起点 &nbsp;&nbsp;
618
+ 🟥 <b>G</b> 终点<br>
619
+ ⬛ 墙壁 &nbsp;&nbsp;
620
+ ⬜ 通路<br>
621
+ 🔵 DQN 轨迹 &nbsp;&nbsp;
622
+ 🟠 BFS 最短路<br>
623
+ 🟣 Agent 当前位置
624
+ </div>
625
+ """
626
+ st.markdown(legend_html, unsafe_allow_html=True)
627
+
628
+ # ── 模型状态 ────────────────────────────────────────────────────
629
+ st.divider()
630
+ _cur_algo = st.session_state.get("selected_algo", DEFAULT_ALGO)
631
+ _cur_path = model_path_for(_cur_algo)
632
+ if st.session_state.model is not None:
633
+ st.success(f"✅ 模型已加载 ({_cur_path.name})")
634
+ # 维度不一致时提前告警:网络期望 (3, saved_gs, saved_gs) 输入,
635
+ # 而推理环境会生成 (3, GRID_SIZE, GRID_SIZE) 观测,两者不符会在
636
+ # 网络 forward 时抛出张量尺寸异常。提前展示警告便于用户定位原因。
637
+ _saved_gs = st.session_state.get("model_grid_size", GRID_SIZE)
638
+ if _saved_gs != GRID_SIZE:
639
+ st.warning(
640
+ f"⚠️ 模型训练于 {_saved_gs}×{_saved_gs} 迷宫,"
641
+ f"当前配置为 {GRID_SIZE}×{GRID_SIZE}。\n"
642
+ "推理时输入维度不匹配,将导致运行时错误。\n"
643
+ "请使用匹配 grid_size 的模型,或更新 config.yaml。"
644
+ )
645
+ else:
646
+ st.error(f"❌ 未找到 {_cur_path.name}")
647
+ st.info(f"请先运行 `python src/train.py --algorithm {_cur_algo}` 训练模型。")
648
+
649
+ # ───────────────────────────────────────────────────────────────────────
650
+ # 右栏:主画布
651
+ # ───────────────────────────────────────────────────────────────────────
652
+ # ───────────────────────────────────────────────────────────────────────
653
+ # 右栏:主画布
654
+ # ───────────────────────────────────────────────────────────────────────
655
+ with right_col:
656
+ wall_map = st.session_state.wall_map
657
+ start = st.session_state.start
658
+ goal = st.session_state.goal
659
+
660
+ status_placeholder = st.empty()
661
+
662
+ # ── BFS 寻路 ─────────────────────────────────────────────────────
663
+ if run_bfs:
664
+ result = bfs_solve(wall_map.astype(np.int32), start, goal)
665
+ if result["success"]:
666
+ st.session_state.bfs_path = result["path"]
667
+ status_placeholder.success(
668
+ f"✅ BFS 完成!最短步数 = **{result['steps']}**,"
669
+ f"耗时 {result['execution_time_ms']:.3f} ms"
670
+ )
671
+ else:
672
+ st.session_state.bfs_path = None
673
+ status_placeholder.error("❌ BFS:起点与终点之间无可达路���!")
674
+
675
+ # ── DQN 寻路按钮触发 ──────────────────────────────────────────────
676
+ if run_dqn:
677
+ model = st.session_state.model
678
+ if model is None:
679
+ status_placeholder.error("❌ 模型未加载,无法推理。")
680
+ elif st.session_state.get("model_grid_size", GRID_SIZE) != GRID_SIZE:
681
+ _mgs = st.session_state.model_grid_size
682
+ status_placeholder.error(
683
+ f"❌ 模型训练于 {_mgs}×{_mgs},当前为 {GRID_SIZE}×{GRID_SIZE},维度不匹配。"
684
+ )
685
+ else:
686
+ bfs_result = bfs_solve(wall_map.astype(np.int32), start, goal)
687
+ if not bfs_result["success"]:
688
+ status_placeholder.error("❌ 该迷宫配置无解,请换起终点。")
689
+ else:
690
+ with st.spinner("🤖 DQN 推理中…"):
691
+ dqn_path = dqn_rollout(model, wall_map, start, goal)
692
+
693
+ ai_steps = len(dqn_path) - 1
694
+ bfs_steps = bfs_result["steps"]
695
+ success = (dqn_path[-1] == goal)
696
+ por = round(bfs_steps / ai_steps, 4) if (success and ai_steps > 0) else 0.0
697
+
698
+ st.session_state.dqn_path = dqn_path
699
+ st.session_state.bfs_path = bfs_result["path"]
700
+ st.session_state.metrics = {
701
+ "ai_steps": ai_steps, "bfs_steps": bfs_steps,
702
+ "success": success, "por": por,
703
+ }
704
+ # 启动帧动画
705
+ st.session_state.anim_running = True
706
+ st.session_state.anim_step = 0
707
+ st.session_state.anim_path = dqn_path
708
+ st.rerun()
709
+
710
+ # ── 动画驱动(session_state 帧推进)──────────────────────────────
711
+ if st.session_state.anim_running:
712
+ step_i = st.session_state.anim_step
713
+ anim_p = st.session_state.anim_path
714
+ total = len(anim_p)
715
+ status_placeholder.info(f"🎬 动画播放中… {step_i + 1}/{total}")
716
+
717
+ fig = build_maze_figure(
718
+ wall_map, start, goal,
719
+ dqn_path=anim_p,
720
+ bfs_path=st.session_state.bfs_path,
721
+ highlight_dqn_step=step_i,
722
+ )
723
+ st.plotly_chart(fig, use_container_width=False, key=f"anim_{step_i}")
724
+
725
+ if step_i + 1 < total:
726
+ time.sleep(ANIM_DELAY)
727
+ st.session_state.anim_step += 1
728
+ st.rerun()
729
+ else:
730
+ st.session_state.anim_running = False
731
+ m = st.session_state.metrics
732
+ ok = m["success"]
733
+ status_placeholder.success(
734
+ f"{'✅' if ok else '❌'} DQN 寻路{'成功' if ok else '失败'}!"
735
+ f" AI 步数 = **{m['ai_steps']}** | BFS 最短 = **{m['bfs_steps']}**"
736
+ )
737
+
738
+ # ── 静态迷宫图 ────────────────────────────────────────────────────
739
+ elif not run_dqn:
740
+ fig = build_maze_figure(
741
+ wall_map, start, goal,
742
+ dqn_path=st.session_state.dqn_path,
743
+ bfs_path=st.session_state.bfs_path,
744
+ highlight_dqn_step=-1,
745
+ )
746
+ st.plotly_chart(fig, use_container_width=False, key="maze_static")
747
+
748
+ # ── 指标仪表盘 ───────────────────────────────────────────────────
749
+ m = st.session_state.metrics
750
+ if m:
751
+ ai_s = m["ai_steps"]
752
+ bfs_s = m["bfs_steps"]
753
+ por = m["por"]
754
+ ok = m["success"]
755
+
756
+ # POR 分级颜色
757
+ if ok and por >= 0.99:
758
+ por_cls = "por-perfect"
759
+ por_text = f"{por:.2f} 🏆 100% Perfect"
760
+ elif ok and por >= 0.75:
761
+ por_cls = "por-good"
762
+ por_text = f"{por:.2f} 👍 Good"
763
+ elif ok:
764
+ por_cls = "por-bad"
765
+ por_text = f"{por:.2f} ⚠️ Sub-optimal"
766
+ else:
767
+ por_cls = "por-bad"
768
+ por_text = "N/A ❌ 未到达终点"
769
+
770
+ mc1, mc2, mc3 = st.columns(3)
771
+ with mc1:
772
+ st.markdown(f"""
773
+ <div class='metric-card'>
774
+ <div class='metric-label'>🤖 AI 实际步数</div>
775
+ <div class='metric-value'>{ai_s}</div>
776
+ </div>""", unsafe_allow_html=True)
777
+ with mc2:
778
+ st.markdown(f"""
779
+ <div class='metric-card'>
780
+ <div class='metric-label'>📐 BFS 理论最短</div>
781
+ <div class='metric-value'>{bfs_s}</div>
782
+ </div>""", unsafe_allow_html=True)
783
+ with mc3:
784
+ st.markdown(f"""
785
+ <div class='metric-card' style='background:linear-gradient(135deg,#11998e,#38ef7d)'>
786
+ <div class='metric-label'>⚡ Path Optimality Ratio</div>
787
+ <div class='metric-value {por_cls}'>{por_text}</div>
788
+ </div>""", unsafe_allow_html=True)
789
+
790
+ with st.expander("📊 指标说明"):
791
+ st.markdown("""
792
+ | 指标 | 含义 |
793
+ |------|------|
794
+ | **AI 实际步数** | DQN Agent 从起点走到终点(或超时)所用的总步数 |
795
+ | **BFS 理论最短** | BFS 算法计算的绝对最短路径步数(Ground Truth)|
796
+ | **Path Optimality Ratio** | `BFS步数 / AI步数`,越接近 **1.00** 越完美。等于 1.00 说明 AI 走出了与 BFS 完全相同的最短路! |
797
+ """)
798
+
799
+ # ── 页脚 ─────────────────────────────────────────────────────────────
800
+ st.divider()
801
+ st.markdown(
802
+ "<div style='text-align:center;color:#95A5A6;font-size:12px'>"
803
+ "DQN Maze Solver · PyTorch + Gymnasium + Streamlit · "
804
+ "Hugging Face Spaces Demo"
805
+ "</div>",
806
+ unsafe_allow_html=True,
807
+ )
808
+
809
+
810
+ if __name__ == "__main__":
811
+ main()