interview / app.py
Lee93whut
chore: codebase hygiene pass โ€” untrack weights, migrate to logging, tidy comments
17bc537
"""app.py โ€”โ€” DQN ่ฟทๅฎซๅฏป่ทฏๅฏ่ง†ๅŒ– Web App
Hugging Face Spaces (Docker SDK) ไธ“็”จ
้ƒจ็ฝฒๆธ…ๅ•๏ผˆไธŠไผ ๅˆฐ HF Space ็š„ๅ…จ้ƒจๆ–‡ไปถ๏ผ‰
--------------------------------------
app.py ๆœฌๆ–‡ไปถ
src/model.py ็ฅž็ป็ฝ‘็ปœๆžถๆž„
results/best_model_train_vanilla.pth vanilla DQN ๆƒ้‡
results/best_model_train_double.pth Double DQN ๆƒ้‡
results/best_model_train_dueling.pth Dueling DQN ๆƒ้‡
results/best_model_train_double_dueling.pth Double Dueling DQN ๆƒ้‡
config.yaml ็Žฏๅขƒ้…็ฝฎ๏ผˆgrid_size / obstacle_density / max_steps๏ผ‰
requirements.txt ไพ่ต–ๅˆ—่กจ
ๅฏผๅ…ฅ็ญ–็•ฅ
--------
* maze_env ้€š่ฟ‡ `pip install -e .` ๅฎ‰่ฃ…๏ผˆ่ง Dockerfile๏ผ‰๏ผŒ็›ดๆŽฅ importใ€‚
* src/ ้€š่ฟ‡ pyproject.toml packages.find ้…็ฝฎ๏ผŒๅŒๆ ทๅฏๅฎ‰่ฃ…๏ผŒ็›ดๆŽฅ importใ€‚
* ๆ‰€ๆœ‰ๆจกๅ—ๅ‡้€š่ฟ‡ๆ ‡ๅ‡† import ่ทฏๅพ„่งฃๆž๏ผŒๆ— ้œ€ sys.path ๆ‰‹ๅŠจๆณจๅ…ฅใ€‚
็ซฏๅฃ่ฏดๆ˜Ž
--------
HF Docker Space ๅ›บๅฎšไฝฟ็”จ 7860 ็ซฏๅฃ๏ผˆ่ง Dockerfile / README๏ผ‰ใ€‚
ๆœฌๅœฐ่ฐƒ่ฏ•๏ผšstreamlit run app.py
"""
from __future__ import annotations
import random
import time
from pathlib import Path
from typing import Optional
import numpy as np
import plotly.graph_objects as go
import streamlit as st
import torch
import yaml
# โ”€โ”€ maze_env ๅŒ…๏ผˆๅทฒๅฎ‰่ฃ…๏ผŒ็›ดๆŽฅๅฏผๅ…ฅ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
from maze_env import MazeEnv
from maze_env.bfs import bfs as bfs_solve
from maze_env.actions import DELTAS
# โ”€โ”€ src ๅŒ…๏ผˆpip install -e . ๅŽๅฏ็›ดๆŽฅๅฏผๅ…ฅ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
import torch.nn as nn
from src.model import DQNNetwork, DuelingDQNNetwork
# ===========================================================================
# ๅธธ้‡ & ้…็ฝฎ
# ===========================================================================
_CONFIG_PATH = Path(__file__).parent / "config.yaml"
if _CONFIG_PATH.exists():
_cfg = yaml.safe_load(_CONFIG_PATH.read_text(encoding="utf-8"))
else:
import warnings
warnings.warn(
f"config.yaml ๆœชๆ‰พๅˆฐ๏ผˆ{_CONFIG_PATH}๏ผ‰๏ผŒไฝฟ็”จๅ†…็ฝฎ้ป˜่ฎคๅ€ผใ€‚"
"่‹ฅ่ฎญ็ปƒๆ—ถไฝฟ็”จไบ†้ž้ป˜่ฎค grid_size๏ผŒๆŽจ็†็ป“ๆžœๅฏ่ƒฝ้”™่ฏฏใ€‚",
stacklevel=1,
)
_cfg = {}
_maze_cfg = _cfg.get("maze", {})
GRID_SIZE = int(_maze_cfg.get("grid_size", 10))
OBSTACLE_DENSITY = float(_maze_cfg.get("obstacle_density", 0.25)) # ไธŽ config.yaml maze.obstacle_density ไฟๆŒไธ€่‡ด๏ผŒ็กฎไฟ Demo ไธŽ่ฎญ็ปƒๅˆ†ๅธƒ็›ธๅŒ
MAX_STEPS = int(_maze_cfg.get("max_steps", 200)) # ไธŽ่ฎญ็ปƒไฟๆŒไธ€่‡ด๏ผŒๆŽจ็†ๆญฅๆ•ฐ้ข„็ฎ—ๅฏน้ฝ
# ๆ”ฏๆŒๅˆ‡ๆข็š„ๅ››็ฎ—ๆณ•๏ผˆ้กบๅบๅ†ณๅฎš UI ไธ‹ๆ‹‰ๆก†ๆŽ’ๅˆ—๏ผ‰
ALGO_OPTIONS: list[str] = ["double_dueling", "dueling", "double", "vanilla"]
ALGO_LABELS: dict[str, str] = {
"vanilla": "Vanilla DQN๏ผˆๅŸบๅ‡†๏ผ‰",
"double": "Double DQN๏ผˆๆŠ‘ๅˆถ้ซ˜ไผฐ๏ผ‰",
"dueling": "Dueling DQN๏ผˆV+A ๅˆ†่งฃ๏ผ‰",
"double_dueling": "Double + Dueling๏ผˆV+A + ๆŠ‘ๅˆถ้ซ˜ไผฐ๏ผ‰",
}
# Holdout ๆต‹่ฏ•้›†ๆˆๅŠŸ็އ๏ผˆ็‹ฌ็ซ‹่ฏ„ไผฐ๏ผŒ้ž่ฎญ็ปƒๆœŸ eval_success๏ผ‰
ALGO_SUCCESS_RATES: dict[str, Optional[float]] = {
"vanilla": 75.0,
"double": 78.0,
"dueling": 84.0,
"double_dueling": 81.0,
}
def algo_display_label(algo: str) -> str:
"""่ฟ”ๅ›ž็ฎ—ๆณ•ไธ‹ๆ‹‰ๆก†ๆ˜พ็คบๆ–‡ๅญ—๏ผš็ฎ—ๆณ•ๅ + ็ฎ€่ฟฐ + holdout ๆˆๅŠŸ็އ๏ผˆ่‹ฅๅฏ็”จ๏ผ‰ใ€‚"""
base = ALGO_LABELS[algo]
rate = ALGO_SUCCESS_RATES.get(algo)
if rate is not None:
return f"{base} | holdout {rate:.0f}%"
return base
# ้ป˜่ฎค็ฎ—ๆณ•๏ผšไผ˜ๅ…ˆ่ฏป config.yaml๏ผŒfallback ๅˆฐ double_dueling
_default_algo = str(_cfg.get("dqn", {}).get("algorithm", "double_dueling")).strip().lower()
DEFAULT_ALGO: str = _default_algo if _default_algo in ALGO_OPTIONS else "double_dueling"
def model_path_for(algo: str) -> Path:
"""ๆ นๆฎ็ฎ—ๆณ•ๅ่ฟ”ๅ›žๅฏนๅบ”ๆƒ้‡ๆ–‡ไปถ่ทฏๅพ„ใ€‚"""
return Path(__file__).parent / "results" / f"best_model_train_{algo}.pth"
# ้ฆ–ๅฑ้ป˜่ฎค่ฟทๅฎซ seedใ€‚
# ๅ›บๅฎšๅ€ผไฟ่ฏๅˆ†ไบซ้“พๆŽฅๆ—ถๅŒๆ–น็œ‹ๅˆฐ็›ธๅŒๅœฐๅ›พ๏ผ›ๆ”นไธบ None ๅฏ่ฎฉๆฏๆฌกๅˆทๆ–ฐ้šๆœบ็”Ÿๆˆใ€‚
DEFAULT_MAZE_SEED: int = 42
# ๅŠจ็”ปๅธง้—ด้š”๏ผˆ็ง’๏ผ‰
ANIM_DELAY = 0.08
# ้ขœ่‰ฒๆ˜ ๅฐ„๏ผˆRGB ๅˆ—่กจ๏ผŒไพ› Plotly heatmap๏ผ‰
COLOR_EMPTY = "#F8F9FA" # ็™ฝ/ๆต…็ฐ โ€”โ€” ๅฏ้€š่กŒๅœฐๆฟ
COLOR_WALL = "#2C3E50" # ๆทฑ่“็ฐ โ€”โ€” ๅข™ๅฃ
COLOR_START = "#27AE60" # ็ปฟ่‰ฒ โ€”โ€” ่ตท็‚น
COLOR_GOAL = "#E74C3C" # ็บข่‰ฒ โ€”โ€” ็ปˆ็‚น
COLOR_DQN_PATH = "#3498DB" # ่“่‰ฒ โ€”โ€” DQN ่ฝจ่ฟน
COLOR_BFS_PATH = "#F39C12" # ๆฉ™่‰ฒ โ€”โ€” BFS ๆœ€็Ÿญ่ทฏ
COLOR_AGENT = "#9B59B6" # ็ดซ่‰ฒ โ€”โ€” ๅฝ“ๅ‰ Agent ไฝ็ฝฎ
# ===========================================================================
# ๅทฅๅ…ทๅ‡ฝๆ•ฐ
# ===========================================================================
def generate_maze(seed: Optional[int] = None) -> np.ndarray:
"""็”Ÿๆˆ GRID_SIZEร—GRID_SIZE ่ฟทๅฎซ๏ผŒไฟ่ฏ่ตท็‚น (1,1) ไธŽ็ปˆ็‚น (N-2,N-2) ๅฏ่พพใ€‚
ๅง”ๆ‰˜็ป™ :class:`MazeEnv` ็š„ ``reset()`` ๆ–นๆณ•๏ผŒ็กฎไฟไธŽ่ฎญ็ปƒ็ŽฏๅขƒๅฎŒๅ…จไธ€่‡ด
๏ผˆ็›ธๅŒ็š„่พน็•Œๅข™ใ€้šœ็ขๅฏ†ๅบฆใ€BFS ่ฟž้€šๆ€งไฟ่ฏ๏ผŒไธ้‡ๅค้€ ่ฝฎๅญ๏ผ‰ใ€‚
Args:
seed: ้šๆœบ็งๅญ๏ผ›``None`` ่กจ็คบไธๅ›บๅฎš้šๆœบๆ€งใ€‚
Returns:
wall_map: shape ``(N, N)``๏ผŒdtype ``int32``๏ผŒ0=้€š่ทฏ๏ผŒ1=ๅข™ๅฃใ€‚
"""
env = MazeEnv(
grid_size=GRID_SIZE,
obstacle_density=OBSTACLE_DENSITY,
)
env.reset(seed=seed)
return env.wall_map.astype(np.int32)
def generate_maze_with_random_sg(
seed: Optional[int] = None,
) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
"""็”Ÿๆˆ่ฟทๅฎซๅนถไปŽๅฏ้€š่กŒๅ†…้ƒจๆ ผ้šๆœบ้€‰ๅ–่ตท็‚นๅ’Œ็ปˆ็‚น๏ผŒไธŽ่ฎญ็ปƒๅˆ†ๅธƒๅฎŒๅ…จไธ€่‡ดใ€‚
ๅค็Žฐ train.py ไธญ ``random_start_goal=True`` ็š„้€ป่พ‘๏ผš
ๅ…ˆ็”Ÿๆˆ่ฟทๅฎซ๏ผŒๅ†็”จ ``env.np_random``๏ผˆGymnasium ๆณจๅ…ฅ็š„ๅ”ฏไธ€้šๆœบๆบ๏ผ‰
ไปŽๅ†…้ƒจๅฏ้€š่กŒๆ ผไธญไธๆ”พๅ›žๅœฐๆŠฝๅ–ไธคไธชไธๅŒๅๆ ‡๏ผŒ็กฎไฟ Demo ไธŽ่ฎญ็ปƒๅŒๅˆ†ๅธƒใ€‚
Args:
seed: ้šๆœบ็งๅญ๏ผ›``None`` ่กจ็คบไธๅ›บๅฎš้šๆœบๆ€งใ€‚
Returns:
(wall_map, start, goal)๏ผš
* wall_map: shape ``(N, N)``๏ผŒdtype ``int32``ใ€‚
* start: ่ตท็‚นๅๆ ‡ ``(row, col)``ใ€‚
* goal: ็ปˆ็‚นๅๆ ‡ ``(row, col)``ใ€‚
"""
env = MazeEnv(
grid_size=GRID_SIZE,
obstacle_density=OBSTACLE_DENSITY,
)
env.reset(seed=seed)
wall_map = env.wall_map.astype(np.int32) # (N, N)
# ๆ”ถ้›†ๅ†…้ƒจ๏ผˆ้ž่พน็•Œ๏ผ‰ๅฏ้€š่กŒๆ ผ๏ผŒไธŽ train.py ่ฟ‡ๆปคๆกไปถๅฎŒๅ…จ็›ธๅŒ
rows, cols = np.where(wall_map == 0)
inner_cells: list[tuple[int, int]] = [
(int(r), int(c))
for r, c in zip(rows, cols)
if 0 < r < GRID_SIZE - 1 and 0 < c < GRID_SIZE - 1
]
if len(inner_cells) < 2:
# ๆž็ซฏๆƒ…ๅ†ต๏ผˆ้šœ็ขๅฏ†ๅบฆๆž้ซ˜๏ผ‰๏ผš้€€ๅ›žๅˆฐๅ›บๅฎš่ตท็ปˆ็‚น
return wall_map, (1, 1), (GRID_SIZE - 2, GRID_SIZE - 2)
# rng.choice(replace=False) ไธ€ๆฌก่ฐƒ็”จๅคฉ็„ถไฟ่ฏไธคไธช็ดขๅผ•ไธ้‡ๅค๏ผŒ
# ๆถˆ้™ค rejection sampling ็š„ๆฝœๅœจๆ— ้™ๅพช็Žฏ้ฃŽ้™ฉ
idxs = env.np_random.choice(len(inner_cells), size=2, replace=False)
start = inner_cells[int(idxs[0])]
goal = inner_cells[int(idxs[1])]
return wall_map, start, goal
def load_model(algo: str = DEFAULT_ALGO, grid_size: int = GRID_SIZE) -> tuple[Optional[nn.Module], int]:
"""ๅŠ ่ฝฝๆŒ‡ๅฎš็ฎ—ๆณ•็š„ DQN ๆจกๅž‹ๆƒ้‡๏ผŒ่ฟ”ๅ›ž (net, saved_grid_size)ใ€‚
Args:
algo: ็ฎ—ๆณ•ๅ๏ผŒ้กปๅœจ ALGO_OPTIONS ไธญใ€‚
grid_size: ๅฝ“ๅ‰็Žฏๅขƒ grid_size๏ผŒ็”จไบŽ็ปดๅบฆไธไธ€่‡ดๆ—ถ็š„ fallback ่ฟ”ๅ›žๅ€ผใ€‚
ๅคฑ่ดฅๆ—ถ่ฟ”ๅ›ž (None, grid_size)ใ€‚saved_grid_size ไพ›่ฐƒ็”จๆ–นๆฃ€ๆต‹็ปดๅบฆๆ˜ฏๅฆไธŽ
ๅฝ“ๅ‰ GRID_SIZE ไธ€่‡ด๏ผ›ไธไธ€่‡ดๆ—ถๆŽจ็†่พ“ๅ…ฅ็ปดๅบฆไผšไธŽ็ฝ‘็ปœๆœŸๆœ›ไธ็ฌฆ๏ผŒๅบ”ๆๅ‰ๅ‘Š่ญฆใ€‚
"""
path = model_path_for(algo)
if not path.exists():
return None, grid_size
try:
ckpt = torch.load(path, map_location="cpu", weights_only=True)
saved_gs = ckpt.get("grid_size", grid_size)
algorithm = ckpt.get("algorithm", "vanilla").strip().lower()
NetClass = DuelingDQNNetwork if "dueling" in algorithm else DQNNetwork
in_ch = ckpt["state_dict"]["conv.0.weight"].shape[1]
net = NetClass(grid_size=saved_gs, input_channels=in_ch)
net.load_state_dict(ckpt["state_dict"])
net.eval()
return net, saved_gs
except Exception as e:
st.error(f"โŒ ๆจกๅž‹ๅŠ ่ฝฝๅคฑ่ดฅ๏ผš{e}")
return None, grid_size
def dqn_rollout(
net: nn.Module,
wall_map: np.ndarray,
start: tuple,
goal: tuple,
) -> list[tuple]:
"""็บฏๆŽจ็†๏ผˆฮต=0๏ผ‰่ฟ่กŒ DQN Agent๏ผŒ่ฟ”ๅ›žๅฎŒๆ•ด่ฝจ่ฟนๅๆ ‡ๅˆ—่กจใ€‚
ๅง”ๆ‰˜็ป™ :class:`MazeEnv` ็š„ๆ ‡ๅ‡† ``reset()`` / ``step()`` ๆŽฅๅฃ๏ผŒ
ไฟ่ฏ่ง‚ๆต‹็ผ–็ ไธŽ่ฎญ็ปƒๆ—ถๅฎŒๅ…จไธ€่‡ด๏ผŒๆ— ้œ€ๅœจ app.py ไธญ้‡ๅคๅฎž็Žฐ็ขฐๆ’žๆฃ€ๆต‹ใ€‚
Args:
net: ๅทฒๅŠ ่ฝฝๆƒ้‡ใ€ๅค„ไบŽ eval ๆจกๅผ็š„ DQN ็ฝ‘็ปœใ€‚
wall_map: shape ``(N, N)``๏ผŒdtype int32๏ผŒ0=้€š่ทฏ๏ผŒ1=ๅข™ๅฃใ€‚
start: Agent ่ตท็‚น ``(row, col)``ใ€‚
goal: ็ปˆ็‚น ``(row, col)``ใ€‚
Returns:
ๅฎŒๆ•ด่ฝจ่ฟน๏ผˆๅซ่ตท็‚น๏ผ‰๏ผŒๆฏๆกไธบ ``(row, col)``ใ€‚
"""
env = MazeEnv(
grid_size=wall_map.shape[0],
obstacle_density=0.0, # ๅฏ†ๅบฆๆ— ๅ…ณ๏ผŒๅœฐๅ›พ็”ฑๅค–้ƒจๆณจๅ…ฅ
max_steps=MAX_STEPS,
)
obs, _ = env.reset(options={
"wall_map": wall_map.astype(np.float32),
"start": start,
"goal": goal,
})
path = [env.agent_pos]
# ๆŽจ็†ไพง anti-loop ๅ…œๅบ•๏ผšvisited_map๏ผˆch3๏ผ‰ๅทฒ่ฎฉ Q ๅ‡ฝๆ•ฐๅ†…ๅŒ–่ฎฟ้—ฎๅކๅฒ๏ผŒ
# ไฝ†ๅฏนๆœชๅ……ๅˆ†่ฆ†็›–็š„็Šถๆ€ไปๅฏ่ƒฝ้™ทๅ…ฅไธคๆ ผๆญปๅพช็Žฏใ€‚
# ่ฎฟ้—ฎๆฌกๆ•ฐ >= 2 ๆ—ถๅฏนๅฝ“ๅ‰ argmax ๅŠจไฝœๆ–ฝๅŠ ้€’่ฟ› Q ๅ€ผๆƒฉ็ฝšไฝœไธบๅฎ‰ๅ…จ็ฝ‘๏ผŒ
# ไธไฟฎๆ”น็ฝ‘็ปœๆƒ้‡๏ผŒไธๅฝฑๅ“่ฎญ็ปƒๅˆ†ๅธƒใ€‚
visited_count: dict[tuple, int] = {}
while True:
s = torch.from_numpy(obs).unsqueeze(0)
with torch.no_grad():
q_values = net(s)[0].clone() # shape: (num_actions,)
# ๅฏน้ซ˜้ข‘้‡่ฎฟๆ ผๅญ็š„ๅฝ“ๅ‰ๆœ€ไผ˜ๅŠจไฝœๆ–ฝๅŠ ๆƒฉ็ฝš
cur_pos = env.agent_pos
cnt = visited_count.get(cur_pos, 0)
if cnt >= 2:
action_candidate = int(q_values.argmax().item())
q_values[action_candidate] -= 3.0 * cnt
# ๅฏนๆฏไธชๅŠจไฝœ้ข„ๅˆค็›ฎๆ ‡ๆ ผ๏ผŒ่‹ฅ็›ฎๆ ‡ๆ ผไนŸๆ˜ฏ้ซ˜้ข‘่ฎฟ้—ฎๆ ผๅˆ™้ขๅค–ๆƒฉ็ฝš
cur_r, cur_c = cur_pos
N = env.grid_size
for a, (dr, dc) in enumerate(DELTAS):
nr, nc = cur_r + dr, cur_c + dc
if 0 <= nr < N and 0 <= nc < N:
next_cnt = visited_count.get((nr, nc), 0)
if next_cnt >= 2:
q_values[a] -= 3.0 * next_cnt
action = int(q_values.argmax().item())
visited_count[cur_pos] = cnt + 1
obs, _reward, terminated, truncated, info = env.step(action)
# ๅชๅœจๅฎž้™…็งปๅŠจๆ—ถ่ฟฝๅŠ ๏ผˆๆ’žๅข™ๆ—ถไฝ็ฝฎไธๅ˜๏ผŒ้ฟๅ…้‡ๅคๅๆ ‡ๅฏผ่‡ดๅŠจ็”ปๆŠ–ๅธง๏ผ‰
if not info["hit_wall"]:
path.append(env.agent_pos)
if terminated or truncated:
break
return path
# ===========================================================================
# Plotly ่ฟทๅฎซ็ป˜ๅˆถ
# ===========================================================================
def build_maze_figure(
wall_map: np.ndarray,
start: tuple,
goal: tuple,
dqn_path: Optional[list] = None,
bfs_path: Optional[list] = None,
agent_pos: Optional[tuple] = None,
highlight_dqn_step: int = -1,
) -> go.Figure:
"""ๆž„ๅปบ Plotly ่ฟทๅฎซๅ›พ๏ผŒๆ”ฏๆŒๅ ๅŠ  DQN / BFS ่ทฏๅพ„ไธŽๅŠจๆ€ Agent ๆ ‡่ฎฐใ€‚"""
N = wall_map.shape[0]
# โ”€โ”€ ๅบ•ๅฑ‚็ƒญๅŠ›ๅ›พ๏ผˆๅ• Heatmap trace๏ผŒO(1) traces vs O(Nยฒ) shapes๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๆ•ฐๅ€ผ็Ÿฉ้˜ต๏ผš0=้€š่ทฏ, 1=ๅข™, 2=่ตท็‚น, 3=็ปˆ็‚น
z = wall_map.astype(float).copy()
z[start[0], start[1]] = 2.0
z[goal[0], goal[1]] = 3.0
# ็ฆปๆ•ฃ้ขœ่‰ฒๆ˜ ๅฐ„๏ผšๅ€ผ โ†’ ้ขœ่‰ฒ
colorscale = [
[0.00, COLOR_EMPTY], # 0 = ้€š่ทฏ
[0.25, COLOR_EMPTY],
[0.25, COLOR_WALL], # 1 = ๅข™
[0.50, COLOR_WALL],
[0.50, COLOR_START], # 2 = ่ตท็‚น
[0.75, COLOR_START],
[0.75, COLOR_GOAL], # 3 = ็ปˆ็‚น
[1.00, COLOR_GOAL],
]
fig = go.Figure()
fig.add_trace(go.Heatmap(
z=z,
colorscale=colorscale,
zmin=0, zmax=3,
showscale=False,
xgap=1, ygap=1,
hoverinfo="skip",
))
# โ”€โ”€ BFS ่ทฏๅพ„๏ผˆๆฉ™่‰ฒ่™š็บฟ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if bfs_path and len(bfs_path) > 1:
bx = [c for r, c in bfs_path]
by = [r for r, c in bfs_path]
fig.add_trace(go.Scatter(
x=bx, y=by,
mode="lines+markers",
name="BFS ๆœ€็Ÿญ่ทฏ",
line=dict(color=COLOR_BFS_PATH, width=3, dash="dot"),
marker=dict(size=6, color=COLOR_BFS_PATH, opacity=0.7),
))
# โ”€โ”€ DQN ่ทฏๅพ„๏ผˆ่“่‰ฒๅฎž็บฟ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if dqn_path and len(dqn_path) > 1:
# ๆˆชๅ–ๅˆฐ highlight_dqn_step๏ผˆๅŠจ็”ป็”จ๏ผ‰
end_idx = highlight_dqn_step + 1 if highlight_dqn_step >= 0 else len(dqn_path)
sub_path = dqn_path[:end_idx]
dx = [c for r, c in sub_path]
dy = [r for r, c in sub_path]
fig.add_trace(go.Scatter(
x=dx, y=dy,
mode="lines+markers",
name="DQN ่ฝจ่ฟน",
line=dict(color=COLOR_DQN_PATH, width=3),
marker=dict(size=7, color=COLOR_DQN_PATH),
))
# โ”€โ”€ ๅฝ“ๅ‰ Agent ไฝ็ฝฎ๏ผˆ็ดซ่‰ฒๅคงๅœ†็‚น๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ap = agent_pos if agent_pos else (start if not dqn_path else
(dqn_path[min(highlight_dqn_step, len(dqn_path)-1)]
if highlight_dqn_step >= 0 else start))
fig.add_trace(go.Scatter(
x=[ap[1]], y=[ap[0]],
mode="markers",
name="Agent",
marker=dict(size=16, color=COLOR_AGENT, symbol="circle",
line=dict(color="white", width=2)),
showlegend=True,
))
# โ”€โ”€ ่ตท็‚น / ็ปˆ็‚นๆ ‡็ญพ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
fig.add_trace(go.Scatter(
x=[start[1], goal[1]],
y=[start[0], goal[0]],
mode="markers+text",
text=["S", "G"],
textposition="middle center",
textfont=dict(size=13, color="white", family="Arial Black"),
marker=dict(size=22, color=[COLOR_START, COLOR_GOAL],
symbol="square", opacity=0.0), # ้€ๆ˜Žๅบ•๏ผŒๅชๆ˜พ็คบๅญ—
showlegend=False,
hoverinfo="skip",
))
# โ”€โ”€ ๅธƒๅฑ€ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
fig.update_layout(
width=560, height=560,
margin=dict(l=10, r=10, t=30, b=10),
xaxis=dict(
range=[-0.5, N - 0.5], tickvals=list(range(N)),
showgrid=False, zeroline=False, title="ๅˆ— (col)",
),
yaxis=dict(
range=[N - 0.5, -0.5],
tickvals=list(range(N)),
showgrid=False, zeroline=False, title="่กŒ (row)",
),
legend=dict(x=1.01, y=1, bgcolor="rgba(255,255,255,0.8)",
bordercolor="#BDC3C7", borderwidth=1),
paper_bgcolor="white",
plot_bgcolor="white",
title=dict(text="๐Ÿ DQN ่ฟทๅฎซๅฏป่ทฏ", x=0.5, font=dict(size=16)),
)
return fig
def _find_cell_index(free_cells: list[tuple], pos: tuple) -> int:
"""ๅœจ free_cells ๅˆ—่กจไธญๆŸฅๆ‰พ pos ็š„็ดขๅผ•๏ผ›ๆœชๆ‰พๅˆฐๆ—ถ่ฟ”ๅ›ž 0๏ผˆๅฎ‰ๅ…จๅ›ž้€€๏ผ‰ใ€‚"""
try:
return free_cells.index(pos)
except ValueError:
return 0
# ===========================================================================
# Session State ๅˆๅง‹ๅŒ–
# ===========================================================================
def _init_state() -> None:
if "wall_map" not in st.session_state:
# ้ฆ–ๅฑไฝฟ็”จ้šๆœบ่ตท็ปˆ็‚น๏ผˆไธŽ่ฎญ็ปƒๅˆ†ๅธƒไธ€่‡ด๏ผ‰๏ผŒๅ›บๅฎš seed ไฟ่ฏๅฏๅค็Žฐ
wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=DEFAULT_MAZE_SEED)
st.session_state.wall_map = wm
st.session_state.start = sg_start
st.session_state.goal = sg_goal
if "start" not in st.session_state:
st.session_state.start = (1, 1)
if "goal" not in st.session_state:
st.session_state.goal = (GRID_SIZE - 2, GRID_SIZE - 2)
if "dqn_path" not in st.session_state:
st.session_state.dqn_path = None
if "bfs_path" not in st.session_state:
st.session_state.bfs_path = None
if "metrics" not in st.session_state:
st.session_state.metrics = None
if "selected_algo" not in st.session_state:
st.session_state.selected_algo = DEFAULT_ALGO
if "model" not in st.session_state:
net, saved_gs = load_model(algo=DEFAULT_ALGO)
st.session_state.model = net
st.session_state.model_grid_size = saved_gs
if "maze_seed" not in st.session_state:
st.session_state.maze_seed = DEFAULT_MAZE_SEED
if "anim_running" not in st.session_state:
st.session_state.anim_running = False
if "anim_step" not in st.session_state:
st.session_state.anim_step = 0
if "anim_path" not in st.session_state:
st.session_state.anim_path = None
# ===========================================================================
# ไธป็จ‹ๅบ
# ===========================================================================
def main() -> None:
st.set_page_config(
page_title="DQN ่ฟทๅฎซๅฏป่ทฏ Demo",
page_icon="๐Ÿค–",
layout="wide",
)
# โ”€โ”€ ๅ…จๅฑ€ๆ ทๅผๆณจๅ…ฅ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
st.markdown("""
<style>
.metric-card {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 12px; padding: 16px 20px; color: white;
text-align: center; margin: 6px 0;
}
.metric-label { font-size: 13px; opacity: 0.85; margin-bottom: 4px; }
.metric-value { font-size: 28px; font-weight: 700; }
.por-perfect { color: #2ECC71; font-weight: 800; }
.por-good { color: #F39C12; font-weight: 700; }
.por-bad { color: #E74C3C; font-weight: 600; }
div[data-testid="stButton"] button {
width: 100%; border-radius: 8px; font-weight: 600;
}
/* ่ฟทๅฎซๆŒ‰้’ฎ็ฝ‘ๆ ผ๏ผšๆฏๆ ผ็ดงๅ‡‘ๆญฃๆ–นๅฝข๏ผŒๆ— ๅ†…่พน่ท */
div[data-testid="stHorizontalBlock"] div[data-testid="stButton"] button {
padding: 0 !important;
min-height: 40px !important;
font-size: 15px !important;
border-radius: 3px !important;
border: 1px solid #ccc !important;
line-height: 1 !important;
}
</style>
""", unsafe_allow_html=True)
_init_state()
st.title("๐Ÿค– DQN ่ฟทๅฎซๅฏป่ทฏ ยท ๅฏ่ง†ๅŒ– Demo")
st.caption("Deep Q-Network ร— BFS Ground-Truth ยท Hugging Face Spaces")
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ๆญฃๅธธๅŒๆ ๅธƒๅฑ€๏ผˆ็‚นๅ‡ปๆจกๅผๅœจๅณๆ ๅ†…ๅค„็†๏ผŒไธ็ ดๅๆ•ดไฝ“ๅธƒๅฑ€๏ผ‰
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
left_col, right_col = st.columns([1, 2.2], gap="large")
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๅทฆๆ ๏ผšๆŽงๅˆถ้ขๆฟ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with left_col:
st.subheader("โš™๏ธ ๆŽงๅˆถ้ขๆฟ")
# โ”€โ”€ ่ฟทๅฎซ็”Ÿๆˆ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
st.markdown("**โ‘  ่ฟทๅฎซๅœฐๅ›พ**")
col_seed, col_rand = st.columns([3, 1])
with col_seed:
input_seed = st.number_input(
"่ฟทๅฎซ Seed",
min_value=0,
max_value=999999,
value=st.session_state.maze_seed,
step=1,
help="ๅ›บๅฎšๆ•ฐๅญ—ๅฏๅค็ŽฐๆŒ‡ๅฎšๅœฐๅ›พ๏ผ›็‚นๅ‡ปๅณไพงๆŒ‰้’ฎ้šๆœบ็”Ÿๆˆๆ–ฐๅœฐๅ›พ",
)
with col_rand:
st.write("") # ๅฏน้ฝๅ ไฝ
if st.button("๐ŸŽฒ ้šๆœบ"):
# ้šๆœบ seed๏ผšๅŒๆ—ถ้šๆœบ็”Ÿๆˆๅœฐๅ›พๅ’Œ่ตท็ปˆ็‚น๏ผˆไธŽ่ฎญ็ปƒๅˆ†ๅธƒไธ€่‡ด๏ผ‰
new_seed = random.randint(0, 999999)
wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=new_seed)
st.session_state.maze_seed = new_seed
st.session_state.wall_map = wm
st.session_state.start = sg_start
st.session_state.goal = sg_goal
st.session_state.dqn_path = None
st.session_state.bfs_path = None
st.session_state.metrics = None
# ๅŒๆญฅไธ‹ๆ‹‰ๆก†็ดขๅผ•๏ผŒ้ฟๅ… selectbox key ็ผ“ๅญ˜ๆ—งๅ€ผ
_fc = [(r,c) for r in range(1,GRID_SIZE-1) for c in range(1,GRID_SIZE-1) if wm[r,c]==0]
st.session_state.start_select = _find_cell_index(_fc, sg_start)
st.session_state.goal_select = _find_cell_index(_fc, sg_goal)
st.session_state.anim_running = False
st.rerun() # ็ซ‹ๅณ็ปˆๆญขๅฝ“ๅ‰่„šๆœฌ๏ผŒไธ‹ๆ–น input_seed ๆฃ€ๆต‹ไธไผšๆ‰ง่กŒ
# ๆ‰‹ๅŠจไฟฎๆ”น seed ่พ“ๅ…ฅๆก†ๆ—ถ่งฆๅ‘๏ผˆ้šๆœบๆŒ‰้’ฎๅทฒ็”ฑไธŠๆ–น rerun ็Ÿญ่ทฏ๏ผŒไธไผš้‡ๅค๏ผ‰
if input_seed != st.session_state.maze_seed:
wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=input_seed)
st.session_state.maze_seed = input_seed
st.session_state.wall_map = wm
st.session_state.start = sg_start
st.session_state.goal = sg_goal
st.session_state.dqn_path = None
st.session_state.bfs_path = None
st.session_state.metrics = None
_fc = [(r,c) for r in range(1,GRID_SIZE-1) for c in range(1,GRID_SIZE-1) if wm[r,c]==0]
st.session_state.start_select = _find_cell_index(_fc, sg_start)
st.session_state.goal_select = _find_cell_index(_fc, sg_goal)
st.session_state.anim_running = False
st.rerun()
st.divider()
# โ”€โ”€ ่ตท็‚น / ็ปˆ็‚น้€‰ๆ‹ฉ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
st.markdown("**โ‘ก ่ตท็‚น & ็ปˆ็‚น**")
# ใ€Œ้šๆœบ่ตท็ปˆ็‚นใ€ๆŒ‰้’ฎ๏ผšไปŽๅฝ“ๅ‰ๅœฐๅ›พ็š„ๅฏ้€š่กŒๆ ผ้šๆœบ้€‰ๅ–๏ผŒไธŽ่ฎญ็ปƒๅˆ†ๅธƒไธ€่‡ด
if st.button("๐ŸŽฒ ้šๆœบ่ตท็ปˆ็‚น", use_container_width=True,
help="ไปŽๅฝ“ๅ‰ๅœฐๅ›พๅฏ้€š่กŒๆ ผ้šๆœบ้€‰ๅ–่ตท็‚นๅ’Œ็ปˆ็‚น๏ผŒไธŽ่ฎญ็ปƒๅˆ†ๅธƒๅฎŒๅ…จไธ€่‡ด"):
_wm = st.session_state.wall_map
_rows, _cols = np.where(_wm == 0)
_inner = [
(int(r), int(c))
for r, c in zip(_rows, _cols)
if 0 < r < GRID_SIZE - 1 and 0 < c < GRID_SIZE - 1
]
if len(_inner) >= 2:
_i, _j = random.sample(range(len(_inner)), 2)
st.session_state.start = _inner[_i]
st.session_state.goal = _inner[_j]
st.session_state.dqn_path = None
st.session_state.bfs_path = None
st.session_state.metrics = None
st.session_state.start_select = _find_cell_index(_inner, _inner[_i])
st.session_state.goal_select = _find_cell_index(_inner, _inner[_j])
st.session_state.anim_running = False
st.rerun()
N = GRID_SIZE
free_cells = [
(r, c)
for r in range(1, N - 1)
for c in range(1, N - 1)
if st.session_state.wall_map[r, c] == 0
]
cell_labels = [f"({r},{c})" for r, c in free_cells]
start_idx = st.selectbox(
"่ตท็‚น (row, col)",
options=range(len(free_cells)),
format_func=lambda i: cell_labels[i],
index=_find_cell_index(free_cells, st.session_state.start),
key="start_select",
)
goal_idx = st.selectbox(
"็ปˆ็‚น (row, col)",
options=range(len(free_cells)),
format_func=lambda i: cell_labels[i],
index=_find_cell_index(free_cells, st.session_state.goal),
key="goal_select",
)
new_start = free_cells[start_idx]
new_goal = free_cells[goal_idx]
if new_start == new_goal:
st.warning("โš ๏ธ ่ตท็‚นไธŽ็ปˆ็‚นไธ่ƒฝ็›ธๅŒ๏ผŒ่ฏท้‡ๆ–ฐ้€‰ๆ‹ฉใ€‚")
elif new_start != st.session_state.start or new_goal != st.session_state.goal:
st.session_state.start = new_start
st.session_state.goal = new_goal
st.session_state.dqn_path = None
st.session_state.bfs_path = None
st.session_state.metrics = None
st.divider()
# โ”€โ”€ ็ฎ—ๆณ•้€‰ๆ‹ฉ & ๅฏป่ทฏ่งฆๅ‘ๆŒ‰้’ฎ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
st.markdown("**โ‘ข ๅฏป่ทฏ็ฎ—ๆณ•**")
selected_algo = st.selectbox(
"DQN ็ฎ—ๆณ•ๅ˜ไฝ“",
options=ALGO_OPTIONS,
format_func=algo_display_label,
index=ALGO_OPTIONS.index(st.session_state.selected_algo),
key="algo_select",
help="ๅˆ‡ๆข็ฎ—ๆณ•ๅŽ็‚นๅ‡ปใ€ŒDQN ๅฏป่ทฏใ€ๆŒ‰้’ฎๅฏๅฏนๆฏ”ไธๅŒ็ฎ—ๆณ•ๅœจๅŒไธ€ๅœฐๅ›พไธŠ็š„่ทฏๅพ„",
)
# ็ฎ—ๆณ•ๅˆ‡ๆขๆ—ถ้‡ๆ–ฐๅŠ ่ฝฝๅฏนๅบ”ๆจกๅž‹๏ผŒๆธ…็ฉบไธŠๆฌก่ทฏๅพ„็ป“ๆžœ
if selected_algo != st.session_state.selected_algo:
st.session_state.selected_algo = selected_algo
net, saved_gs = load_model(algo=selected_algo)
st.session_state.model = net
st.session_state.model_grid_size = saved_gs
st.session_state.dqn_path = None
st.session_state.metrics = None
st.session_state.anim_running = False
st.rerun()
run_dqn = st.button(
"๐Ÿค– DQN ๆ™บ่ƒฝไฝ“ๅฏป่ทฏ",
use_container_width=True,
type="primary",
)
run_bfs = st.button(
"๐Ÿ“ BFS ไธ“ๅฎถๅฏป่ทฏ",
use_container_width=True,
)
st.divider()
# โ”€โ”€ ๅ›พไพ‹่ฏดๆ˜Ž โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
st.markdown("**ๅ›พไพ‹**")
legend_html = """
<div style='font-size:13px; line-height:2'>
๐ŸŸฉ <b>S</b> ่ตท็‚น &nbsp;&nbsp;
๐ŸŸฅ <b>G</b> ็ปˆ็‚น<br>
โฌ› ๅข™ๅฃ &nbsp;&nbsp;
โฌœ ้€š่ทฏ<br>
๐Ÿ”ต DQN ่ฝจ่ฟน &nbsp;&nbsp;
๐ŸŸ  BFS ๆœ€็Ÿญ่ทฏ<br>
๐ŸŸฃ Agent ๅฝ“ๅ‰ไฝ็ฝฎ
</div>
"""
st.markdown(legend_html, unsafe_allow_html=True)
# โ”€โ”€ ๆจกๅž‹็Šถๆ€ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
st.divider()
_cur_algo = st.session_state.get("selected_algo", DEFAULT_ALGO)
_cur_path = model_path_for(_cur_algo)
if st.session_state.model is not None:
st.success(f"โœ… ๆจกๅž‹ๅทฒๅŠ ่ฝฝ ({_cur_path.name})")
# ็ปดๅบฆไธไธ€่‡ดๆ—ถๆๅ‰ๅ‘Š่ญฆ๏ผš็ฝ‘็ปœๆœŸๆœ› (3, saved_gs, saved_gs) ่พ“ๅ…ฅ๏ผŒ
# ่€ŒๆŽจ็†็Žฏๅขƒไผš็”Ÿๆˆ (3, GRID_SIZE, GRID_SIZE) ่ง‚ๆต‹๏ผŒไธค่€…ไธ็ฌฆไผšๅœจ
# ็ฝ‘็ปœ forward ๆ—ถๆŠ›ๅ‡บๅผ ้‡ๅฐบๅฏธๅผ‚ๅธธใ€‚ๆๅ‰ๅฑ•็คบ่ญฆๅ‘ŠไพฟไบŽ็”จๆˆทๅฎšไฝๅŽŸๅ› ใ€‚
_saved_gs = st.session_state.get("model_grid_size", GRID_SIZE)
if _saved_gs != GRID_SIZE:
st.warning(
f"โš ๏ธ ๆจกๅž‹่ฎญ็ปƒไบŽ {_saved_gs}ร—{_saved_gs} ่ฟทๅฎซ๏ผŒ"
f"ๅฝ“ๅ‰้…็ฝฎไธบ {GRID_SIZE}ร—{GRID_SIZE}ใ€‚\n"
"ๆŽจ็†ๆ—ถ่พ“ๅ…ฅ็ปดๅบฆไธๅŒน้…๏ผŒๅฐ†ๅฏผ่‡ด่ฟ่กŒๆ—ถ้”™่ฏฏใ€‚\n"
"่ฏทไฝฟ็”จๅŒน้… grid_size ็š„ๆจกๅž‹๏ผŒๆˆ–ๆ›ดๆ–ฐ config.yamlใ€‚"
)
else:
st.error(f"โŒ ๆœชๆ‰พๅˆฐ {_cur_path.name}")
st.info(f"่ฏทๅ…ˆ่ฟ่กŒ `python src/train.py --algorithm {_cur_algo}` ่ฎญ็ปƒๆจกๅž‹ใ€‚")
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๅณๆ ๏ผšไธป็”ปๅธƒ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with right_col:
wall_map = st.session_state.wall_map
start = st.session_state.start
goal = st.session_state.goal
status_placeholder = st.empty()
# โ”€โ”€ BFS ๅฏป่ทฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if run_bfs:
result = bfs_solve(wall_map.astype(np.int32), start, goal)
if result["success"]:
st.session_state.bfs_path = result["path"]
status_placeholder.success(
f"โœ… BFS ๅฎŒๆˆ๏ผๆœ€็Ÿญๆญฅๆ•ฐ = **{result['steps']}**๏ผŒ"
f"่€—ๆ—ถ {result['execution_time_ms']:.3f} ms"
)
else:
st.session_state.bfs_path = None
status_placeholder.error("โŒ BFS๏ผš่ตท็‚นไธŽ็ปˆ็‚นไน‹้—ดๆ— ๅฏ่พพ่ทฏๅพ„๏ผ")
# โ”€โ”€ DQN ๅฏป่ทฏๆŒ‰้’ฎ่งฆๅ‘ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if run_dqn:
model = st.session_state.model
if model is None:
status_placeholder.error("โŒ ๆจกๅž‹ๆœชๅŠ ่ฝฝ๏ผŒๆ— ๆณ•ๆŽจ็†ใ€‚")
elif st.session_state.get("model_grid_size", GRID_SIZE) != GRID_SIZE:
_mgs = st.session_state.model_grid_size
status_placeholder.error(
f"โŒ ๆจกๅž‹่ฎญ็ปƒไบŽ {_mgs}ร—{_mgs}๏ผŒๅฝ“ๅ‰ไธบ {GRID_SIZE}ร—{GRID_SIZE}๏ผŒ็ปดๅบฆไธๅŒน้…ใ€‚"
)
else:
bfs_result = bfs_solve(wall_map.astype(np.int32), start, goal)
if not bfs_result["success"]:
status_placeholder.error("โŒ ่ฏฅ่ฟทๅฎซ้…็ฝฎๆ— ่งฃ๏ผŒ่ฏทๆข่ตท็ปˆ็‚นใ€‚")
else:
with st.spinner("๐Ÿค– DQN ๆŽจ็†ไธญโ€ฆ"):
dqn_path = dqn_rollout(model, wall_map, start, goal)
ai_steps = len(dqn_path) - 1
bfs_steps = bfs_result["steps"]
success = (dqn_path[-1] == goal)
por = round(bfs_steps / ai_steps, 4) if (success and ai_steps > 0) else 0.0
st.session_state.dqn_path = dqn_path
st.session_state.bfs_path = bfs_result["path"]
st.session_state.metrics = {
"ai_steps": ai_steps, "bfs_steps": bfs_steps,
"success": success, "por": por,
}
# ๅฏๅŠจๅธงๅŠจ็”ป
st.session_state.anim_running = True
st.session_state.anim_step = 0
st.session_state.anim_path = dqn_path
st.rerun()
# โ”€โ”€ ๅŠจ็”ป้ฉฑๅŠจ๏ผˆsession_state ๅธงๆŽจ่ฟ›๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if st.session_state.anim_running:
step_i = st.session_state.anim_step
anim_p = st.session_state.anim_path
total = len(anim_p)
status_placeholder.info(f"๐ŸŽฌ ๅŠจ็”ปๆ’ญๆ”พไธญโ€ฆ {step_i + 1}/{total}")
fig = build_maze_figure(
wall_map, start, goal,
dqn_path=anim_p,
bfs_path=st.session_state.bfs_path,
highlight_dqn_step=step_i,
)
st.plotly_chart(fig, use_container_width=False, key=f"anim_{step_i}")
if step_i + 1 < total:
time.sleep(ANIM_DELAY)
st.session_state.anim_step += 1
st.rerun()
else:
st.session_state.anim_running = False
m = st.session_state.metrics
ok = m["success"]
status_placeholder.success(
f"{'โœ…' if ok else 'โŒ'} DQN ๅฏป่ทฏ{'ๆˆๅŠŸ' if ok else 'ๅคฑ่ดฅ'}๏ผ"
f" AI ๆญฅๆ•ฐ = **{m['ai_steps']}** | BFS ๆœ€็Ÿญ = **{m['bfs_steps']}**"
)
# โ”€โ”€ ้™ๆ€่ฟทๅฎซๅ›พ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
elif not run_dqn:
fig = build_maze_figure(
wall_map, start, goal,
dqn_path=st.session_state.dqn_path,
bfs_path=st.session_state.bfs_path,
highlight_dqn_step=-1,
)
st.plotly_chart(fig, use_container_width=False, key="maze_static")
# โ”€โ”€ ๆŒ‡ๆ ‡ไปช่กจ็›˜ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
m = st.session_state.metrics
if m:
ai_s = m["ai_steps"]
bfs_s = m["bfs_steps"]
por = m["por"]
ok = m["success"]
# POR ๅˆ†็บง้ขœ่‰ฒ
if ok and por >= 0.99:
por_cls = "por-perfect"
por_text = f"{por:.2f} ๐Ÿ† 100% Perfect"
elif ok and por >= 0.75:
por_cls = "por-good"
por_text = f"{por:.2f} ๐Ÿ‘ Good"
elif ok:
por_cls = "por-bad"
por_text = f"{por:.2f} โš ๏ธ Sub-optimal"
else:
por_cls = "por-bad"
por_text = "N/A โŒ ๆœชๅˆฐ่พพ็ปˆ็‚น"
mc1, mc2, mc3 = st.columns(3)
with mc1:
st.markdown(f"""
<div class='metric-card'>
<div class='metric-label'>๐Ÿค– AI ๅฎž้™…ๆญฅๆ•ฐ</div>
<div class='metric-value'>{ai_s}</div>
</div>""", unsafe_allow_html=True)
with mc2:
st.markdown(f"""
<div class='metric-card'>
<div class='metric-label'>๐Ÿ“ BFS ็†่ฎบๆœ€็Ÿญ</div>
<div class='metric-value'>{bfs_s}</div>
</div>""", unsafe_allow_html=True)
with mc3:
st.markdown(f"""
<div class='metric-card' style='background:linear-gradient(135deg,#11998e,#38ef7d)'>
<div class='metric-label'>โšก Path Optimality Ratio</div>
<div class='metric-value {por_cls}'>{por_text}</div>
</div>""", unsafe_allow_html=True)
with st.expander("๐Ÿ“Š ๆŒ‡ๆ ‡่ฏดๆ˜Ž"):
st.markdown("""
| ๆŒ‡ๆ ‡ | ๅซไน‰ |
|------|------|
| **AI ๅฎž้™…ๆญฅๆ•ฐ** | DQN Agent ไปŽ่ตท็‚น่ตฐๅˆฐ็ปˆ็‚น๏ผˆๆˆ–่ถ…ๆ—ถ๏ผ‰ๆ‰€็”จ็š„ๆ€ปๆญฅๆ•ฐ |
| **BFS ็†่ฎบๆœ€็Ÿญ** | BFS ็ฎ—ๆณ•่ฎก็ฎ—็š„็ปๅฏนๆœ€็Ÿญ่ทฏๅพ„ๆญฅๆ•ฐ๏ผˆGround Truth๏ผ‰|
| **Path Optimality Ratio** | `BFSๆญฅๆ•ฐ / AIๆญฅๆ•ฐ`๏ผŒ่ถŠๆŽฅ่ฟ‘ **1.00** ่ถŠๅฎŒ็พŽใ€‚็ญ‰ไบŽ 1.00 ่ฏดๆ˜Ž AI ่ตฐๅ‡บไบ†ไธŽ BFS ๅฎŒๅ…จ็›ธๅŒ็š„ๆœ€็Ÿญ่ทฏ๏ผ |
""")
# โ”€โ”€ ้กต่„š โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
st.divider()
st.markdown(
"<div style='text-align:center;color:#95A5A6;font-size:12px'>"
"DQN Maze Solver ยท PyTorch + Gymnasium + Streamlit ยท "
"Hugging Face Spaces Demo"
"</div>",
unsafe_allow_html=True,
)
if __name__ == "__main__":
main()