Lee93whut commited on
Commit ·
a264030
1
Parent(s): a91b194
feat(demo): Streamlit web demo — Plotly heatmap, anti-loop inference
Browse filesapp.py:
- Interactive 10×10 maze rendered as Plotly go.Heatmap
- Dropdown + random button for start/goal selection
- Load any of 4 algorithm weights (Vanilla/Double/Dueling/Double+Dueling)
- DQN rollout with anti-loop inference guard:
visit_cnt >= 2 → Q[action] -= 3.0 × visit_cnt
(inference-only Q-value patch, does not affect training distribution)
- BFS shortest path overlay for SPL ground-truth comparison
- Deployed on Hugging Face Spaces (Docker SDK)
app.py
ADDED
|
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""app.py —— DQN 迷宫寻路可视化 Web App
|
| 2 |
+
Hugging Face Spaces (Docker SDK) 专用
|
| 3 |
+
|
| 4 |
+
部署清单(上传到 HF Space 的全部文件)
|
| 5 |
+
--------------------------------------
|
| 6 |
+
app.py 本文件
|
| 7 |
+
src/model.py 神经网络架构
|
| 8 |
+
results/best_model_train_vanilla.pth vanilla DQN 权重
|
| 9 |
+
results/best_model_train_double.pth Double DQN 权重
|
| 10 |
+
results/best_model_train_dueling.pth Dueling DQN 权重
|
| 11 |
+
results/best_model_train_double_dueling.pth Double Dueling DQN 权重
|
| 12 |
+
config.yaml 环境配置(grid_size / obstacle_density / max_steps)
|
| 13 |
+
requirements.txt 依赖列表
|
| 14 |
+
|
| 15 |
+
导入策略
|
| 16 |
+
--------
|
| 17 |
+
* maze_env 通过 `pip install -e .` 安装(见 Dockerfile),直接 import。
|
| 18 |
+
* src/ 通过 pyproject.toml packages.find 配置,同样可安装,直接 import。
|
| 19 |
+
* 所有模块均通过标准 import 路径解析,无需 sys.path 手动注入。
|
| 20 |
+
|
| 21 |
+
端口说明
|
| 22 |
+
--------
|
| 23 |
+
HF Docker Space 固定使用 7860 端口(见 Dockerfile / README)。
|
| 24 |
+
本地调试:streamlit run app.py
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
import random
|
| 30 |
+
import time
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Optional
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
import plotly.graph_objects as go
|
| 36 |
+
import streamlit as st
|
| 37 |
+
import torch
|
| 38 |
+
import yaml
|
| 39 |
+
|
| 40 |
+
# ── maze_env 包(已安装,直接导入)──────────────────────────────────────────
|
| 41 |
+
from maze_env import MazeEnv
|
| 42 |
+
from maze_env.bfs import bfs as bfs_solve
|
| 43 |
+
|
| 44 |
+
# ── src 包(pip install -e . 后可直接导入)───────────────────────────────────
|
| 45 |
+
import torch.nn as nn
|
| 46 |
+
from src.model import DQNNetwork, DuelingDQNNetwork
|
| 47 |
+
|
| 48 |
+
# ===========================================================================
|
| 49 |
+
# 常量 & 配置
|
| 50 |
+
# ===========================================================================
|
| 51 |
+
_CONFIG_PATH = Path(__file__).parent / "config.yaml"
|
| 52 |
+
if _CONFIG_PATH.exists():
|
| 53 |
+
_cfg = yaml.safe_load(_CONFIG_PATH.read_text(encoding="utf-8"))
|
| 54 |
+
else:
|
| 55 |
+
import warnings
|
| 56 |
+
warnings.warn(
|
| 57 |
+
f"config.yaml 未找到({_CONFIG_PATH}),使用内置默认值。"
|
| 58 |
+
"若训练时使用了非默认 grid_size,推理结果可能错误。",
|
| 59 |
+
stacklevel=1,
|
| 60 |
+
)
|
| 61 |
+
_cfg = {}
|
| 62 |
+
_maze_cfg = _cfg.get("maze", {})
|
| 63 |
+
|
| 64 |
+
GRID_SIZE = int(_maze_cfg.get("grid_size", 10))
|
| 65 |
+
OBSTACLE_DENSITY = float(_maze_cfg.get("obstacle_density", 0.25)) # 与 config.yaml maze.obstacle_density 保持一致,确保 Demo 与训练分布相同
|
| 66 |
+
MAX_STEPS = int(_maze_cfg.get("max_steps", 200)) # 与训练保持一致,推理步数预算对齐
|
| 67 |
+
|
| 68 |
+
# 支持切换的四算法(顺序决定 UI 下拉框排列)
|
| 69 |
+
ALGO_OPTIONS: list[str] = ["double_dueling", "dueling", "double", "vanilla"]
|
| 70 |
+
ALGO_LABELS: dict[str, str] = {
|
| 71 |
+
"vanilla": "Vanilla DQN(基准)",
|
| 72 |
+
"double": "Double DQN(抑制高估)",
|
| 73 |
+
"dueling": "Dueling DQN(V+A 分解)",
|
| 74 |
+
"double_dueling": "Double + Dueling(推荐)",
|
| 75 |
+
}
|
| 76 |
+
# 默认算法:优先读 config.yaml,fallback 到 double_dueling
|
| 77 |
+
_default_algo = str(_cfg.get("dqn", {}).get("algorithm", "double_dueling")).strip().lower()
|
| 78 |
+
DEFAULT_ALGO: str = _default_algo if _default_algo in ALGO_OPTIONS else "double_dueling"
|
| 79 |
+
|
| 80 |
+
def model_path_for(algo: str) -> Path:
|
| 81 |
+
"""根据算法名返回对应权重文件路径。"""
|
| 82 |
+
return Path(__file__).parent / "results" / f"best_model_train_{algo}.pth"
|
| 83 |
+
|
| 84 |
+
# 首屏默认迷宫 seed。
|
| 85 |
+
# 固定值保证分享链接时双方看到相同地图;改为 None 可让每次刷新随机生成。
|
| 86 |
+
DEFAULT_MAZE_SEED: int = 42
|
| 87 |
+
|
| 88 |
+
# 动画帧间隔(秒)
|
| 89 |
+
ANIM_DELAY = 0.08
|
| 90 |
+
|
| 91 |
+
# 颜色映射(RGB 列表,供 Plotly heatmap)
|
| 92 |
+
COLOR_EMPTY = "#F8F9FA" # 白/浅灰 —— 可通行地板
|
| 93 |
+
COLOR_WALL = "#2C3E50" # 深蓝灰 —— 墙壁
|
| 94 |
+
COLOR_START = "#27AE60" # 绿色 —— 起点
|
| 95 |
+
COLOR_GOAL = "#E74C3C" # 红色 —— 终点
|
| 96 |
+
COLOR_DQN_PATH = "#3498DB" # 蓝色 —— DQN 轨迹
|
| 97 |
+
COLOR_BFS_PATH = "#F39C12" # 橙色 —— BFS 最短路
|
| 98 |
+
COLOR_AGENT = "#9B59B6" # 紫色 —— 当前 Agent 位置
|
| 99 |
+
|
| 100 |
+
# ===========================================================================
|
| 101 |
+
# 工具函数
|
| 102 |
+
# ===========================================================================
|
| 103 |
+
|
| 104 |
+
def generate_maze(seed: Optional[int] = None) -> np.ndarray:
|
| 105 |
+
"""生成 GRID_SIZE×GRID_SIZE 迷宫,保证起点 (1,1) 与终点 (N-2,N-2) 可达。
|
| 106 |
+
|
| 107 |
+
委托给 :class:`MazeEnv` 的 ``reset()`` 方法,确保与训练环境完全一致
|
| 108 |
+
(相同的边界墙、障碍密度、BFS 连通性保证,不重复造轮子)。
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
seed: 随机种子;``None`` 表示不固定随机性。
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
wall_map: shape ``(N, N)``,dtype ``int32``,0=通路,1=墙壁。
|
| 115 |
+
"""
|
| 116 |
+
env = MazeEnv(
|
| 117 |
+
grid_size=GRID_SIZE,
|
| 118 |
+
obstacle_density=OBSTACLE_DENSITY,
|
| 119 |
+
)
|
| 120 |
+
env.reset(seed=seed)
|
| 121 |
+
return env.wall_map.astype(np.int32)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def generate_maze_with_random_sg(
|
| 125 |
+
seed: Optional[int] = None,
|
| 126 |
+
) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
|
| 127 |
+
"""生成迷宫并从可通行内部格随机选取起点和终点,与训练分布完全一致。
|
| 128 |
+
|
| 129 |
+
复现 train.py 中 ``random_start_goal=True`` 的逻辑:
|
| 130 |
+
先生成迷宫,再用 ``env.np_random``(Gymnasium 注入的唯一随机源)
|
| 131 |
+
从内部可通行格中不放回地抽取两个不同坐标,确保 Demo 与训练同分布。
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
seed: 随机种子;``None`` 表示不固定随机性。
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
(wall_map, start, goal):
|
| 138 |
+
* wall_map: shape ``(N, N)``,dtype ``int32``。
|
| 139 |
+
* start: 起点坐标 ``(row, col)``。
|
| 140 |
+
* goal: 终点坐标 ``(row, col)``。
|
| 141 |
+
"""
|
| 142 |
+
env = MazeEnv(
|
| 143 |
+
grid_size=GRID_SIZE,
|
| 144 |
+
obstacle_density=OBSTACLE_DENSITY,
|
| 145 |
+
)
|
| 146 |
+
env.reset(seed=seed)
|
| 147 |
+
wall_map = env.wall_map.astype(np.int32) # (N, N)
|
| 148 |
+
|
| 149 |
+
# 收集内部(非边界)可通行格,与 train.py 过滤条件完全相同
|
| 150 |
+
rows, cols = np.where(wall_map == 0)
|
| 151 |
+
inner_cells: list[tuple[int, int]] = [
|
| 152 |
+
(int(r), int(c))
|
| 153 |
+
for r, c in zip(rows, cols)
|
| 154 |
+
if 0 < r < GRID_SIZE - 1 and 0 < c < GRID_SIZE - 1
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
if len(inner_cells) < 2:
|
| 158 |
+
# 极端情况(障碍密度极高):退回到固定起终点
|
| 159 |
+
return wall_map, (1, 1), (GRID_SIZE - 2, GRID_SIZE - 2)
|
| 160 |
+
|
| 161 |
+
# 使用 env.np_random(与训练逻辑完全一致,不污染全局随机状态)
|
| 162 |
+
idxs = env.np_random.integers(0, len(inner_cells), size=2)
|
| 163 |
+
while idxs[0] == idxs[1]:
|
| 164 |
+
idxs = env.np_random.integers(0, len(inner_cells), size=2)
|
| 165 |
+
|
| 166 |
+
start = inner_cells[int(idxs[0])]
|
| 167 |
+
goal = inner_cells[int(idxs[1])]
|
| 168 |
+
return wall_map, start, goal
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def load_model(algo: str = DEFAULT_ALGO, grid_size: int = GRID_SIZE) -> tuple[Optional[nn.Module], int]:
|
| 172 |
+
"""加载指定算法的 DQN 模型权重,返回 (net, saved_grid_size)。
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
algo: 算法名,须在 ALGO_OPTIONS 中。
|
| 176 |
+
grid_size: 当前环境 grid_size,用于维度不一致时的 fallback 返回值。
|
| 177 |
+
|
| 178 |
+
失败时返回 (None, grid_size)。saved_grid_size 供调用方检测维度是否与
|
| 179 |
+
当前 GRID_SIZE 一致;不一致时推理输入维度会与网络期望不符,应提前告警。
|
| 180 |
+
"""
|
| 181 |
+
path = model_path_for(algo)
|
| 182 |
+
if not path.exists():
|
| 183 |
+
return None, grid_size
|
| 184 |
+
try:
|
| 185 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=True)
|
| 186 |
+
saved_gs = ckpt.get("grid_size", grid_size)
|
| 187 |
+
algorithm = ckpt.get("algorithm", "vanilla").strip().lower()
|
| 188 |
+
NetClass = DuelingDQNNetwork if "dueling" in algorithm else DQNNetwork
|
| 189 |
+
net = NetClass(grid_size=saved_gs)
|
| 190 |
+
net.load_state_dict(ckpt["state_dict"])
|
| 191 |
+
net.eval()
|
| 192 |
+
return net, saved_gs
|
| 193 |
+
except Exception as e:
|
| 194 |
+
st.error(f"❌ 模型加载失败:{e}")
|
| 195 |
+
return None, grid_size
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def dqn_rollout(
|
| 199 |
+
net: nn.Module,
|
| 200 |
+
wall_map: np.ndarray,
|
| 201 |
+
start: tuple,
|
| 202 |
+
goal: tuple,
|
| 203 |
+
) -> list[tuple]:
|
| 204 |
+
"""纯推理(ε=0)运行 DQN Agent,返回完整轨迹坐标列表。
|
| 205 |
+
|
| 206 |
+
委托给 :class:`MazeEnv` 的标准 ``reset()`` / ``step()`` 接口,
|
| 207 |
+
保证观测编码与训练时完全一致,无需在 app.py 中重复实现碰撞检测。
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
net: 已加载权重、处于 eval 模式的 DQN 网络。
|
| 211 |
+
wall_map: shape ``(N, N)``,dtype int32,0=通路,1=墙壁。
|
| 212 |
+
start: Agent 起点 ``(row, col)``。
|
| 213 |
+
goal: 终点 ``(row, col)``。
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
完整轨迹(含起点),每条为 ``(row, col)``。
|
| 217 |
+
"""
|
| 218 |
+
env = MazeEnv(
|
| 219 |
+
grid_size=wall_map.shape[0],
|
| 220 |
+
obstacle_density=0.0, # 密度无关,地图由外部注入
|
| 221 |
+
max_steps=MAX_STEPS,
|
| 222 |
+
)
|
| 223 |
+
obs, _ = env.reset(options={
|
| 224 |
+
"wall_map": wall_map.astype(np.float32),
|
| 225 |
+
"start": start,
|
| 226 |
+
"goal": goal,
|
| 227 |
+
})
|
| 228 |
+
|
| 229 |
+
path = [env.agent_pos]
|
| 230 |
+
|
| 231 |
+
# 注:R4 起观测已包含 visited_map 第4通道(ch3),Agent 天然感知访问历史,
|
| 232 |
+
# 无需在推理侧注入 Q 值惩罚。直接贪心执行即可。
|
| 233 |
+
while True:
|
| 234 |
+
s = torch.from_numpy(obs).unsqueeze(0)
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
q_values = net(s)[0] # shape: (num_actions,)
|
| 237 |
+
|
| 238 |
+
action = int(q_values.argmax().item())
|
| 239 |
+
obs, _reward, terminated, truncated, info = env.step(action)
|
| 240 |
+
|
| 241 |
+
# 只在实际移动时追加(撞墙时位置不变,避免重复坐标导致动画抖帧)
|
| 242 |
+
if not info["hit_wall"]:
|
| 243 |
+
path.append(env.agent_pos)
|
| 244 |
+
|
| 245 |
+
if terminated or truncated:
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
return path
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# ===========================================================================
|
| 252 |
+
# Plotly 迷宫绘制
|
| 253 |
+
# ===========================================================================
|
| 254 |
+
|
| 255 |
+
def build_maze_figure(
|
| 256 |
+
wall_map: np.ndarray,
|
| 257 |
+
start: tuple,
|
| 258 |
+
goal: tuple,
|
| 259 |
+
dqn_path: Optional[list] = None,
|
| 260 |
+
bfs_path: Optional[list] = None,
|
| 261 |
+
agent_pos: Optional[tuple] = None,
|
| 262 |
+
highlight_dqn_step: int = -1,
|
| 263 |
+
) -> go.Figure:
|
| 264 |
+
"""构建 Plotly 迷宫图,支持叠加 DQN / BFS 路径与动态 Agent 标记。"""
|
| 265 |
+
N = wall_map.shape[0]
|
| 266 |
+
|
| 267 |
+
# ── 底层热力图(单 Heatmap trace,O(1) traces vs O(N²) shapes)─────────
|
| 268 |
+
# 数值矩阵:0=通路, 1=墙, 2=起点, 3=终点
|
| 269 |
+
z = wall_map.astype(float).copy()
|
| 270 |
+
z[start[0], start[1]] = 2.0
|
| 271 |
+
z[goal[0], goal[1]] = 3.0
|
| 272 |
+
|
| 273 |
+
# 离散颜色映射:值 → 颜色
|
| 274 |
+
colorscale = [
|
| 275 |
+
[0.00, COLOR_EMPTY], # 0 = 通路
|
| 276 |
+
[0.25, COLOR_EMPTY],
|
| 277 |
+
[0.25, COLOR_WALL], # 1 = 墙
|
| 278 |
+
[0.50, COLOR_WALL],
|
| 279 |
+
[0.50, COLOR_START], # 2 = 起点
|
| 280 |
+
[0.75, COLOR_START],
|
| 281 |
+
[0.75, COLOR_GOAL], # 3 = 终点
|
| 282 |
+
[1.00, COLOR_GOAL],
|
| 283 |
+
]
|
| 284 |
+
|
| 285 |
+
fig = go.Figure()
|
| 286 |
+
fig.add_trace(go.Heatmap(
|
| 287 |
+
z=z,
|
| 288 |
+
colorscale=colorscale,
|
| 289 |
+
zmin=0, zmax=3,
|
| 290 |
+
showscale=False,
|
| 291 |
+
xgap=1, ygap=1,
|
| 292 |
+
hoverinfo="skip",
|
| 293 |
+
))
|
| 294 |
+
|
| 295 |
+
# ── BFS 路径(橙色虚线)──────────────────────────────────────────────
|
| 296 |
+
if bfs_path and len(bfs_path) > 1:
|
| 297 |
+
bx = [c for r, c in bfs_path]
|
| 298 |
+
by = [r for r, c in bfs_path]
|
| 299 |
+
fig.add_trace(go.Scatter(
|
| 300 |
+
x=bx, y=by,
|
| 301 |
+
mode="lines+markers",
|
| 302 |
+
name="BFS 最短路",
|
| 303 |
+
line=dict(color=COLOR_BFS_PATH, width=3, dash="dot"),
|
| 304 |
+
marker=dict(size=6, color=COLOR_BFS_PATH, opacity=0.7),
|
| 305 |
+
))
|
| 306 |
+
|
| 307 |
+
# ── DQN 路径(蓝色实线)──────────────────────────────────────────────
|
| 308 |
+
if dqn_path and len(dqn_path) > 1:
|
| 309 |
+
# 截取到 highlight_dqn_step(动画用)
|
| 310 |
+
end_idx = highlight_dqn_step + 1 if highlight_dqn_step >= 0 else len(dqn_path)
|
| 311 |
+
sub_path = dqn_path[:end_idx]
|
| 312 |
+
dx = [c for r, c in sub_path]
|
| 313 |
+
dy = [r for r, c in sub_path]
|
| 314 |
+
fig.add_trace(go.Scatter(
|
| 315 |
+
x=dx, y=dy,
|
| 316 |
+
mode="lines+markers",
|
| 317 |
+
name="DQN 轨迹",
|
| 318 |
+
line=dict(color=COLOR_DQN_PATH, width=3),
|
| 319 |
+
marker=dict(size=7, color=COLOR_DQN_PATH),
|
| 320 |
+
))
|
| 321 |
+
|
| 322 |
+
# ── 当前 Agent 位置(紫色大圆点)────────────────────────────────────
|
| 323 |
+
ap = agent_pos if agent_pos else (start if not dqn_path else
|
| 324 |
+
(dqn_path[min(highlight_dqn_step, len(dqn_path)-1)]
|
| 325 |
+
if highlight_dqn_step >= 0 else start))
|
| 326 |
+
fig.add_trace(go.Scatter(
|
| 327 |
+
x=[ap[1]], y=[ap[0]],
|
| 328 |
+
mode="markers",
|
| 329 |
+
name="Agent",
|
| 330 |
+
marker=dict(size=16, color=COLOR_AGENT, symbol="circle",
|
| 331 |
+
line=dict(color="white", width=2)),
|
| 332 |
+
showlegend=True,
|
| 333 |
+
))
|
| 334 |
+
|
| 335 |
+
# ── 起点 / 终点标签 ───────────────────────────────────────────────────
|
| 336 |
+
fig.add_trace(go.Scatter(
|
| 337 |
+
x=[start[1], goal[1]],
|
| 338 |
+
y=[start[0], goal[0]],
|
| 339 |
+
mode="markers+text",
|
| 340 |
+
text=["S", "G"],
|
| 341 |
+
textposition="middle center",
|
| 342 |
+
textfont=dict(size=13, color="white", family="Arial Black"),
|
| 343 |
+
marker=dict(size=22, color=[COLOR_START, COLOR_GOAL],
|
| 344 |
+
symbol="square", opacity=0.0), # 透明底,只显示字
|
| 345 |
+
showlegend=False,
|
| 346 |
+
hoverinfo="skip",
|
| 347 |
+
))
|
| 348 |
+
|
| 349 |
+
# ── 布局 ─────────────────────────────────────────────────────────────
|
| 350 |
+
fig.update_layout(
|
| 351 |
+
width=560, height=560,
|
| 352 |
+
margin=dict(l=10, r=10, t=30, b=10),
|
| 353 |
+
xaxis=dict(
|
| 354 |
+
range=[-0.5, N - 0.5], tickvals=list(range(N)),
|
| 355 |
+
showgrid=False, zeroline=False, title="列 (col)",
|
| 356 |
+
),
|
| 357 |
+
yaxis=dict(
|
| 358 |
+
range=[N - 0.5, -0.5],
|
| 359 |
+
tickvals=list(range(N)),
|
| 360 |
+
showgrid=False, zeroline=False, title="行 (row)",
|
| 361 |
+
),
|
| 362 |
+
legend=dict(x=1.01, y=1, bgcolor="rgba(255,255,255,0.8)",
|
| 363 |
+
bordercolor="#BDC3C7", borderwidth=1),
|
| 364 |
+
paper_bgcolor="white",
|
| 365 |
+
plot_bgcolor="white",
|
| 366 |
+
title=dict(text="🏁 DQN 迷宫寻路", x=0.5, font=dict(size=16)),
|
| 367 |
+
)
|
| 368 |
+
return fig
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def _find_cell_index(free_cells: list[tuple], pos: tuple) -> int:
|
| 372 |
+
"""在 free_cells 列表中查找 pos 的索引;未找到时返回 0(安全回退)。"""
|
| 373 |
+
try:
|
| 374 |
+
return free_cells.index(pos)
|
| 375 |
+
except ValueError:
|
| 376 |
+
return 0
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# ===========================================================================
|
| 380 |
+
# Session State 初始化
|
| 381 |
+
# ===========================================================================
|
| 382 |
+
|
| 383 |
+
def _init_state() -> None:
|
| 384 |
+
if "wall_map" not in st.session_state:
|
| 385 |
+
# 首屏使用随机起终点(与训练分布一致),固定 seed 保证可复现
|
| 386 |
+
wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=DEFAULT_MAZE_SEED)
|
| 387 |
+
st.session_state.wall_map = wm
|
| 388 |
+
st.session_state.start = sg_start
|
| 389 |
+
st.session_state.goal = sg_goal
|
| 390 |
+
if "start" not in st.session_state:
|
| 391 |
+
st.session_state.start = (1, 1)
|
| 392 |
+
if "goal" not in st.session_state:
|
| 393 |
+
st.session_state.goal = (GRID_SIZE - 2, GRID_SIZE - 2)
|
| 394 |
+
if "dqn_path" not in st.session_state:
|
| 395 |
+
st.session_state.dqn_path = None
|
| 396 |
+
if "bfs_path" not in st.session_state:
|
| 397 |
+
st.session_state.bfs_path = None
|
| 398 |
+
if "metrics" not in st.session_state:
|
| 399 |
+
st.session_state.metrics = None
|
| 400 |
+
if "selected_algo" not in st.session_state:
|
| 401 |
+
st.session_state.selected_algo = DEFAULT_ALGO
|
| 402 |
+
if "model" not in st.session_state:
|
| 403 |
+
net, saved_gs = load_model(algo=DEFAULT_ALGO)
|
| 404 |
+
st.session_state.model = net
|
| 405 |
+
st.session_state.model_grid_size = saved_gs
|
| 406 |
+
if "maze_seed" not in st.session_state:
|
| 407 |
+
st.session_state.maze_seed = DEFAULT_MAZE_SEED
|
| 408 |
+
if "anim_running" not in st.session_state:
|
| 409 |
+
st.session_state.anim_running = False
|
| 410 |
+
if "anim_step" not in st.session_state:
|
| 411 |
+
st.session_state.anim_step = 0
|
| 412 |
+
if "anim_path" not in st.session_state:
|
| 413 |
+
st.session_state.anim_path = None
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# ===========================================================================
|
| 417 |
+
# 主程序
|
| 418 |
+
# ===========================================================================
|
| 419 |
+
|
| 420 |
+
def main() -> None:
|
| 421 |
+
st.set_page_config(
|
| 422 |
+
page_title="DQN 迷宫寻路 Demo",
|
| 423 |
+
page_icon="🤖",
|
| 424 |
+
layout="wide",
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# ── 全局样式注入 ────────────────────────────────────────────────────────
|
| 428 |
+
st.markdown("""
|
| 429 |
+
<style>
|
| 430 |
+
.metric-card {
|
| 431 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 432 |
+
border-radius: 12px; padding: 16px 20px; color: white;
|
| 433 |
+
text-align: center; margin: 6px 0;
|
| 434 |
+
}
|
| 435 |
+
.metric-label { font-size: 13px; opacity: 0.85; margin-bottom: 4px; }
|
| 436 |
+
.metric-value { font-size: 28px; font-weight: 700; }
|
| 437 |
+
.por-perfect { color: #2ECC71; font-weight: 800; }
|
| 438 |
+
.por-good { color: #F39C12; font-weight: 700; }
|
| 439 |
+
.por-bad { color: #E74C3C; font-weight: 600; }
|
| 440 |
+
div[data-testid="stButton"] button {
|
| 441 |
+
width: 100%; border-radius: 8px; font-weight: 600;
|
| 442 |
+
}
|
| 443 |
+
/* 迷宫按钮网格:每格紧凑正方形,无内边距 */
|
| 444 |
+
div[data-testid="stHorizontalBlock"] div[data-testid="stButton"] button {
|
| 445 |
+
padding: 0 !important;
|
| 446 |
+
min-height: 40px !important;
|
| 447 |
+
font-size: 15px !important;
|
| 448 |
+
border-radius: 3px !important;
|
| 449 |
+
border: 1px solid #ccc !important;
|
| 450 |
+
line-height: 1 !important;
|
| 451 |
+
}
|
| 452 |
+
</style>
|
| 453 |
+
""", unsafe_allow_html=True)
|
| 454 |
+
|
| 455 |
+
_init_state()
|
| 456 |
+
|
| 457 |
+
st.title("🤖 DQN 迷宫寻路 · 可视化 Demo")
|
| 458 |
+
st.caption("Deep Q-Network × BFS Ground-Truth · Hugging Face Spaces")
|
| 459 |
+
|
| 460 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 461 |
+
# 正常双栏布局(点击模式在右栏内处理,不破坏整体布局)
|
| 462 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 463 |
+
left_col, right_col = st.columns([1, 2.2], gap="large")
|
| 464 |
+
|
| 465 |
+
# ───────────────────────────────────────────────────────────────────────
|
| 466 |
+
# 左栏:控制面板
|
| 467 |
+
# ───────────────────────────────────────────────────────────────────────
|
| 468 |
+
with left_col:
|
| 469 |
+
st.subheader("⚙️ 控制面板")
|
| 470 |
+
|
| 471 |
+
# ── 迷宫生成 ─────────────────────────────────────────────────────
|
| 472 |
+
st.markdown("**① 迷宫地图**")
|
| 473 |
+
col_seed, col_rand = st.columns([3, 1])
|
| 474 |
+
with col_seed:
|
| 475 |
+
input_seed = st.number_input(
|
| 476 |
+
"迷宫 Seed",
|
| 477 |
+
min_value=0,
|
| 478 |
+
max_value=999999,
|
| 479 |
+
value=st.session_state.maze_seed,
|
| 480 |
+
step=1,
|
| 481 |
+
help="固定数字可复现指定地图;点击右侧按钮随机生成新地图",
|
| 482 |
+
)
|
| 483 |
+
with col_rand:
|
| 484 |
+
st.write("") # 对齐占位
|
| 485 |
+
if st.button("🎲 随机"):
|
| 486 |
+
# 随机 seed:同时随机生成地图和起终点(与训练分布一致)
|
| 487 |
+
new_seed = random.randint(0, 999999)
|
| 488 |
+
wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=new_seed)
|
| 489 |
+
st.session_state.maze_seed = new_seed
|
| 490 |
+
st.session_state.wall_map = wm
|
| 491 |
+
st.session_state.start = sg_start
|
| 492 |
+
st.session_state.goal = sg_goal
|
| 493 |
+
st.session_state.dqn_path = None
|
| 494 |
+
st.session_state.bfs_path = None
|
| 495 |
+
st.session_state.metrics = None
|
| 496 |
+
# 同步下拉框索引,避免 selectbox key 缓存旧值
|
| 497 |
+
_fc = [(r,c) for r in range(1,GRID_SIZE-1) for c in range(1,GRID_SIZE-1) if wm[r,c]==0]
|
| 498 |
+
st.session_state.start_select = _find_cell_index(_fc, sg_start)
|
| 499 |
+
st.session_state.goal_select = _find_cell_index(_fc, sg_goal)
|
| 500 |
+
st.rerun() # 立即终止当前脚本,下方 input_seed 检测不会执行
|
| 501 |
+
|
| 502 |
+
# 手动修改 seed 输入框时触发(随机按钮已由上方 rerun 短路,不会重复)
|
| 503 |
+
if input_seed != st.session_state.maze_seed:
|
| 504 |
+
wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=input_seed)
|
| 505 |
+
st.session_state.maze_seed = input_seed
|
| 506 |
+
st.session_state.wall_map = wm
|
| 507 |
+
st.session_state.start = sg_start
|
| 508 |
+
st.session_state.goal = sg_goal
|
| 509 |
+
st.session_state.dqn_path = None
|
| 510 |
+
st.session_state.bfs_path = None
|
| 511 |
+
st.session_state.metrics = None
|
| 512 |
+
_fc = [(r,c) for r in range(1,GRID_SIZE-1) for c in range(1,GRID_SIZE-1) if wm[r,c]==0]
|
| 513 |
+
st.session_state.start_select = _find_cell_index(_fc, sg_start)
|
| 514 |
+
st.session_state.goal_select = _find_cell_index(_fc, sg_goal)
|
| 515 |
+
st.rerun()
|
| 516 |
+
|
| 517 |
+
st.divider()
|
| 518 |
+
|
| 519 |
+
# ── 起点 / 终点选择 ────────────────────────────────────────────────
|
| 520 |
+
st.markdown("**② 起点 & 终点**")
|
| 521 |
+
|
| 522 |
+
# 「随机起终点」按钮:从当前地图的可通行格随机选取,与训练分布一致
|
| 523 |
+
if st.button("🎲 随机起终点", use_container_width=True,
|
| 524 |
+
help="从当前地图可通行格随机选取起点和终点,与训练分布完全一致"):
|
| 525 |
+
_wm = st.session_state.wall_map
|
| 526 |
+
_rows, _cols = np.where(_wm == 0)
|
| 527 |
+
_inner = [
|
| 528 |
+
(int(r), int(c))
|
| 529 |
+
for r, c in zip(_rows, _cols)
|
| 530 |
+
if 0 < r < GRID_SIZE - 1 and 0 < c < GRID_SIZE - 1
|
| 531 |
+
]
|
| 532 |
+
if len(_inner) >= 2:
|
| 533 |
+
_i, _j = random.sample(range(len(_inner)), 2)
|
| 534 |
+
st.session_state.start = _inner[_i]
|
| 535 |
+
st.session_state.goal = _inner[_j]
|
| 536 |
+
st.session_state.dqn_path = None
|
| 537 |
+
st.session_state.bfs_path = None
|
| 538 |
+
st.session_state.metrics = None
|
| 539 |
+
st.session_state.start_select = _find_cell_index(_inner, _inner[_i])
|
| 540 |
+
st.session_state.goal_select = _find_cell_index(_inner, _inner[_j])
|
| 541 |
+
st.rerun()
|
| 542 |
+
|
| 543 |
+
N = GRID_SIZE
|
| 544 |
+
free_cells = [
|
| 545 |
+
(r, c)
|
| 546 |
+
for r in range(1, N - 1)
|
| 547 |
+
for c in range(1, N - 1)
|
| 548 |
+
if st.session_state.wall_map[r, c] == 0
|
| 549 |
+
]
|
| 550 |
+
cell_labels = [f"({r},{c})" for r, c in free_cells]
|
| 551 |
+
|
| 552 |
+
start_idx = st.selectbox(
|
| 553 |
+
"起点 (row, col)",
|
| 554 |
+
options=range(len(free_cells)),
|
| 555 |
+
format_func=lambda i: cell_labels[i],
|
| 556 |
+
index=_find_cell_index(free_cells, st.session_state.start),
|
| 557 |
+
key="start_select",
|
| 558 |
+
)
|
| 559 |
+
goal_idx = st.selectbox(
|
| 560 |
+
"终点 (row, col)",
|
| 561 |
+
options=range(len(free_cells)),
|
| 562 |
+
format_func=lambda i: cell_labels[i],
|
| 563 |
+
index=_find_cell_index(free_cells, st.session_state.goal),
|
| 564 |
+
key="goal_select",
|
| 565 |
+
)
|
| 566 |
+
new_start = free_cells[start_idx]
|
| 567 |
+
new_goal = free_cells[goal_idx]
|
| 568 |
+
|
| 569 |
+
if new_start == new_goal:
|
| 570 |
+
st.warning("⚠️ 起点与终点不能相同,请重新选择。")
|
| 571 |
+
elif new_start != st.session_state.start or new_goal != st.session_state.goal:
|
| 572 |
+
st.session_state.start = new_start
|
| 573 |
+
st.session_state.goal = new_goal
|
| 574 |
+
st.session_state.dqn_path = None
|
| 575 |
+
st.session_state.bfs_path = None
|
| 576 |
+
st.session_state.metrics = None
|
| 577 |
+
|
| 578 |
+
st.divider()
|
| 579 |
+
|
| 580 |
+
# ── 算法选择 & 寻路触发按钮 ───────────────────────────────────────
|
| 581 |
+
st.markdown("**③ ��路算法**")
|
| 582 |
+
|
| 583 |
+
selected_algo = st.selectbox(
|
| 584 |
+
"DQN 算法变体",
|
| 585 |
+
options=ALGO_OPTIONS,
|
| 586 |
+
format_func=lambda a: ALGO_LABELS[a],
|
| 587 |
+
index=ALGO_OPTIONS.index(st.session_state.selected_algo),
|
| 588 |
+
key="algo_select",
|
| 589 |
+
help="切换算法后点击「DQN 寻路」按钮可对比不同算法在同一地图上的路径",
|
| 590 |
+
)
|
| 591 |
+
# 算法切换时重新加载对应模型,清空上次路径结果
|
| 592 |
+
if selected_algo != st.session_state.selected_algo:
|
| 593 |
+
st.session_state.selected_algo = selected_algo
|
| 594 |
+
net, saved_gs = load_model(algo=selected_algo)
|
| 595 |
+
st.session_state.model = net
|
| 596 |
+
st.session_state.model_grid_size = saved_gs
|
| 597 |
+
st.session_state.dqn_path = None
|
| 598 |
+
st.session_state.metrics = None
|
| 599 |
+
st.rerun()
|
| 600 |
+
|
| 601 |
+
run_dqn = st.button(
|
| 602 |
+
"🤖 DQN 智能体寻路",
|
| 603 |
+
use_container_width=True,
|
| 604 |
+
type="primary",
|
| 605 |
+
)
|
| 606 |
+
run_bfs = st.button(
|
| 607 |
+
"📐 BFS 专家寻路",
|
| 608 |
+
use_container_width=True,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
st.divider()
|
| 612 |
+
|
| 613 |
+
# ── 图例说明 ────────────────────────────────────────────────────
|
| 614 |
+
st.markdown("**图例**")
|
| 615 |
+
legend_html = """
|
| 616 |
+
<div style='font-size:13px; line-height:2'>
|
| 617 |
+
🟩 <b>S</b> 起点
|
| 618 |
+
🟥 <b>G</b> 终点<br>
|
| 619 |
+
⬛ 墙壁
|
| 620 |
+
⬜ 通路<br>
|
| 621 |
+
🔵 DQN 轨迹
|
| 622 |
+
🟠 BFS 最短路<br>
|
| 623 |
+
🟣 Agent 当前位置
|
| 624 |
+
</div>
|
| 625 |
+
"""
|
| 626 |
+
st.markdown(legend_html, unsafe_allow_html=True)
|
| 627 |
+
|
| 628 |
+
# ── 模型状态 ────────────────────────────────────────────────────
|
| 629 |
+
st.divider()
|
| 630 |
+
_cur_algo = st.session_state.get("selected_algo", DEFAULT_ALGO)
|
| 631 |
+
_cur_path = model_path_for(_cur_algo)
|
| 632 |
+
if st.session_state.model is not None:
|
| 633 |
+
st.success(f"✅ 模型已加载 ({_cur_path.name})")
|
| 634 |
+
# 维度不一致时提前告警:网络期望 (3, saved_gs, saved_gs) 输入,
|
| 635 |
+
# 而推理环境会生成 (3, GRID_SIZE, GRID_SIZE) 观测,两者不符会在
|
| 636 |
+
# 网络 forward 时抛出张量尺寸异常。提前展示警告便于用户定位原因。
|
| 637 |
+
_saved_gs = st.session_state.get("model_grid_size", GRID_SIZE)
|
| 638 |
+
if _saved_gs != GRID_SIZE:
|
| 639 |
+
st.warning(
|
| 640 |
+
f"⚠️ 模型训练于 {_saved_gs}×{_saved_gs} 迷宫,"
|
| 641 |
+
f"当前配置为 {GRID_SIZE}×{GRID_SIZE}。\n"
|
| 642 |
+
"推理时输入维度不匹配,将导致运行时错误。\n"
|
| 643 |
+
"请使用匹配 grid_size 的模型,或更新 config.yaml。"
|
| 644 |
+
)
|
| 645 |
+
else:
|
| 646 |
+
st.error(f"❌ 未找到 {_cur_path.name}")
|
| 647 |
+
st.info(f"请先运行 `python src/train.py --algorithm {_cur_algo}` 训练模型。")
|
| 648 |
+
|
| 649 |
+
# ───────────────────────────────────────────────────────────────────────
|
| 650 |
+
# 右栏:主画布
|
| 651 |
+
# ───────────────────────────────────────────────────────────────────────
|
| 652 |
+
# ───────────────────────────────────────────────────────────────────────
|
| 653 |
+
# 右栏:主画布
|
| 654 |
+
# ───────────────────────────────────────────────────────────────────────
|
| 655 |
+
with right_col:
|
| 656 |
+
wall_map = st.session_state.wall_map
|
| 657 |
+
start = st.session_state.start
|
| 658 |
+
goal = st.session_state.goal
|
| 659 |
+
|
| 660 |
+
status_placeholder = st.empty()
|
| 661 |
+
|
| 662 |
+
# ── BFS 寻路 ─────────────────────────────────────────────────────
|
| 663 |
+
if run_bfs:
|
| 664 |
+
result = bfs_solve(wall_map.astype(np.int32), start, goal)
|
| 665 |
+
if result["success"]:
|
| 666 |
+
st.session_state.bfs_path = result["path"]
|
| 667 |
+
status_placeholder.success(
|
| 668 |
+
f"✅ BFS 完成!最短步数 = **{result['steps']}**,"
|
| 669 |
+
f"耗时 {result['execution_time_ms']:.3f} ms"
|
| 670 |
+
)
|
| 671 |
+
else:
|
| 672 |
+
st.session_state.bfs_path = None
|
| 673 |
+
status_placeholder.error("❌ BFS:起点与终点之间无可达路���!")
|
| 674 |
+
|
| 675 |
+
# ── DQN 寻路按钮触发 ──────────────────────────────────────────────
|
| 676 |
+
if run_dqn:
|
| 677 |
+
model = st.session_state.model
|
| 678 |
+
if model is None:
|
| 679 |
+
status_placeholder.error("❌ 模型未加载,无法推理。")
|
| 680 |
+
elif st.session_state.get("model_grid_size", GRID_SIZE) != GRID_SIZE:
|
| 681 |
+
_mgs = st.session_state.model_grid_size
|
| 682 |
+
status_placeholder.error(
|
| 683 |
+
f"❌ 模型训练于 {_mgs}×{_mgs},当前为 {GRID_SIZE}×{GRID_SIZE},维度不匹配。"
|
| 684 |
+
)
|
| 685 |
+
else:
|
| 686 |
+
bfs_result = bfs_solve(wall_map.astype(np.int32), start, goal)
|
| 687 |
+
if not bfs_result["success"]:
|
| 688 |
+
status_placeholder.error("❌ 该迷宫配置无解,请换起终点。")
|
| 689 |
+
else:
|
| 690 |
+
with st.spinner("🤖 DQN 推理中…"):
|
| 691 |
+
dqn_path = dqn_rollout(model, wall_map, start, goal)
|
| 692 |
+
|
| 693 |
+
ai_steps = len(dqn_path) - 1
|
| 694 |
+
bfs_steps = bfs_result["steps"]
|
| 695 |
+
success = (dqn_path[-1] == goal)
|
| 696 |
+
por = round(bfs_steps / ai_steps, 4) if (success and ai_steps > 0) else 0.0
|
| 697 |
+
|
| 698 |
+
st.session_state.dqn_path = dqn_path
|
| 699 |
+
st.session_state.bfs_path = bfs_result["path"]
|
| 700 |
+
st.session_state.metrics = {
|
| 701 |
+
"ai_steps": ai_steps, "bfs_steps": bfs_steps,
|
| 702 |
+
"success": success, "por": por,
|
| 703 |
+
}
|
| 704 |
+
# 启动帧动画
|
| 705 |
+
st.session_state.anim_running = True
|
| 706 |
+
st.session_state.anim_step = 0
|
| 707 |
+
st.session_state.anim_path = dqn_path
|
| 708 |
+
st.rerun()
|
| 709 |
+
|
| 710 |
+
# ── 动画驱动(session_state 帧推进)──────────────────────────────
|
| 711 |
+
if st.session_state.anim_running:
|
| 712 |
+
step_i = st.session_state.anim_step
|
| 713 |
+
anim_p = st.session_state.anim_path
|
| 714 |
+
total = len(anim_p)
|
| 715 |
+
status_placeholder.info(f"🎬 动画播放中… {step_i + 1}/{total}")
|
| 716 |
+
|
| 717 |
+
fig = build_maze_figure(
|
| 718 |
+
wall_map, start, goal,
|
| 719 |
+
dqn_path=anim_p,
|
| 720 |
+
bfs_path=st.session_state.bfs_path,
|
| 721 |
+
highlight_dqn_step=step_i,
|
| 722 |
+
)
|
| 723 |
+
st.plotly_chart(fig, use_container_width=False, key=f"anim_{step_i}")
|
| 724 |
+
|
| 725 |
+
if step_i + 1 < total:
|
| 726 |
+
time.sleep(ANIM_DELAY)
|
| 727 |
+
st.session_state.anim_step += 1
|
| 728 |
+
st.rerun()
|
| 729 |
+
else:
|
| 730 |
+
st.session_state.anim_running = False
|
| 731 |
+
m = st.session_state.metrics
|
| 732 |
+
ok = m["success"]
|
| 733 |
+
status_placeholder.success(
|
| 734 |
+
f"{'✅' if ok else '❌'} DQN 寻路{'成功' if ok else '失败'}!"
|
| 735 |
+
f" AI 步数 = **{m['ai_steps']}** | BFS 最短 = **{m['bfs_steps']}**"
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# ── 静态迷宫图 ────────────────────────────────────────────────────
|
| 739 |
+
elif not run_dqn:
|
| 740 |
+
fig = build_maze_figure(
|
| 741 |
+
wall_map, start, goal,
|
| 742 |
+
dqn_path=st.session_state.dqn_path,
|
| 743 |
+
bfs_path=st.session_state.bfs_path,
|
| 744 |
+
highlight_dqn_step=-1,
|
| 745 |
+
)
|
| 746 |
+
st.plotly_chart(fig, use_container_width=False, key="maze_static")
|
| 747 |
+
|
| 748 |
+
# ── 指标仪表盘 ───────────────────────────────────────────────────
|
| 749 |
+
m = st.session_state.metrics
|
| 750 |
+
if m:
|
| 751 |
+
ai_s = m["ai_steps"]
|
| 752 |
+
bfs_s = m["bfs_steps"]
|
| 753 |
+
por = m["por"]
|
| 754 |
+
ok = m["success"]
|
| 755 |
+
|
| 756 |
+
# POR 分级颜色
|
| 757 |
+
if ok and por >= 0.99:
|
| 758 |
+
por_cls = "por-perfect"
|
| 759 |
+
por_text = f"{por:.2f} 🏆 100% Perfect"
|
| 760 |
+
elif ok and por >= 0.75:
|
| 761 |
+
por_cls = "por-good"
|
| 762 |
+
por_text = f"{por:.2f} 👍 Good"
|
| 763 |
+
elif ok:
|
| 764 |
+
por_cls = "por-bad"
|
| 765 |
+
por_text = f"{por:.2f} ⚠️ Sub-optimal"
|
| 766 |
+
else:
|
| 767 |
+
por_cls = "por-bad"
|
| 768 |
+
por_text = "N/A ❌ 未到达终点"
|
| 769 |
+
|
| 770 |
+
mc1, mc2, mc3 = st.columns(3)
|
| 771 |
+
with mc1:
|
| 772 |
+
st.markdown(f"""
|
| 773 |
+
<div class='metric-card'>
|
| 774 |
+
<div class='metric-label'>🤖 AI 实际步数</div>
|
| 775 |
+
<div class='metric-value'>{ai_s}</div>
|
| 776 |
+
</div>""", unsafe_allow_html=True)
|
| 777 |
+
with mc2:
|
| 778 |
+
st.markdown(f"""
|
| 779 |
+
<div class='metric-card'>
|
| 780 |
+
<div class='metric-label'>📐 BFS 理论最短</div>
|
| 781 |
+
<div class='metric-value'>{bfs_s}</div>
|
| 782 |
+
</div>""", unsafe_allow_html=True)
|
| 783 |
+
with mc3:
|
| 784 |
+
st.markdown(f"""
|
| 785 |
+
<div class='metric-card' style='background:linear-gradient(135deg,#11998e,#38ef7d)'>
|
| 786 |
+
<div class='metric-label'>⚡ Path Optimality Ratio</div>
|
| 787 |
+
<div class='metric-value {por_cls}'>{por_text}</div>
|
| 788 |
+
</div>""", unsafe_allow_html=True)
|
| 789 |
+
|
| 790 |
+
with st.expander("📊 指标说明"):
|
| 791 |
+
st.markdown("""
|
| 792 |
+
| 指标 | 含义 |
|
| 793 |
+
|------|------|
|
| 794 |
+
| **AI 实际步数** | DQN Agent 从起点走到终点(或超时)所用的总步数 |
|
| 795 |
+
| **BFS 理论最短** | BFS 算法计算的绝对最短路径步数(Ground Truth)|
|
| 796 |
+
| **Path Optimality Ratio** | `BFS步数 / AI步数`,越接近 **1.00** 越完美。等于 1.00 说明 AI 走出了与 BFS 完全相同的最短路! |
|
| 797 |
+
""")
|
| 798 |
+
|
| 799 |
+
# ── 页脚 ─────────────────────────────────────────────────────────────
|
| 800 |
+
st.divider()
|
| 801 |
+
st.markdown(
|
| 802 |
+
"<div style='text-align:center;color:#95A5A6;font-size:12px'>"
|
| 803 |
+
"DQN Maze Solver · PyTorch + Gymnasium + Streamlit · "
|
| 804 |
+
"Hugging Face Spaces Demo"
|
| 805 |
+
"</div>",
|
| 806 |
+
unsafe_allow_html=True,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
if __name__ == "__main__":
|
| 811 |
+
main()
|