Lee93whut
chore: codebase hygiene pass โ untrack weights, migrate to logging, tidy comments
17bc537 | """app.py โโ DQN ่ฟทๅฎซๅฏป่ทฏๅฏ่งๅ Web App | |
| Hugging Face Spaces (Docker SDK) ไธ็จ | |
| ้จ็ฝฒๆธ ๅ๏ผไธไผ ๅฐ HF Space ็ๅ จ้จๆไปถ๏ผ | |
| -------------------------------------- | |
| app.py ๆฌๆไปถ | |
| src/model.py ็ฅ็ป็ฝ็ปๆถๆ | |
| results/best_model_train_vanilla.pth vanilla DQN ๆ้ | |
| results/best_model_train_double.pth Double DQN ๆ้ | |
| results/best_model_train_dueling.pth Dueling DQN ๆ้ | |
| results/best_model_train_double_dueling.pth Double Dueling DQN ๆ้ | |
| config.yaml ็ฏๅข้ ็ฝฎ๏ผgrid_size / obstacle_density / max_steps๏ผ | |
| requirements.txt ไพ่ตๅ่กจ | |
| ๅฏผๅ ฅ็ญ็ฅ | |
| -------- | |
| * maze_env ้่ฟ `pip install -e .` ๅฎ่ฃ ๏ผ่ง Dockerfile๏ผ๏ผ็ดๆฅ importใ | |
| * src/ ้่ฟ pyproject.toml packages.find ้ ็ฝฎ๏ผๅๆ ทๅฏๅฎ่ฃ ๏ผ็ดๆฅ importใ | |
| * ๆๆๆจกๅๅ้่ฟๆ ๅ import ่ทฏๅพ่งฃๆ๏ผๆ ้ sys.path ๆๅจๆณจๅ ฅใ | |
| ็ซฏๅฃ่ฏดๆ | |
| -------- | |
| HF Docker Space ๅบๅฎไฝฟ็จ 7860 ็ซฏๅฃ๏ผ่ง Dockerfile / README๏ผใ | |
| ๆฌๅฐ่ฐ่ฏ๏ผstreamlit run app.py | |
| """ | |
| from __future__ import annotations | |
| import random | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| import torch | |
| import yaml | |
| # โโ maze_env ๅ ๏ผๅทฒๅฎ่ฃ ๏ผ็ดๆฅๅฏผๅ ฅ๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| from maze_env import MazeEnv | |
| from maze_env.bfs import bfs as bfs_solve | |
| from maze_env.actions import DELTAS | |
| # โโ src ๅ ๏ผpip install -e . ๅๅฏ็ดๆฅๅฏผๅ ฅ๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| import torch.nn as nn | |
| from src.model import DQNNetwork, DuelingDQNNetwork | |
| # =========================================================================== | |
| # ๅธธ้ & ้ ็ฝฎ | |
| # =========================================================================== | |
| _CONFIG_PATH = Path(__file__).parent / "config.yaml" | |
| if _CONFIG_PATH.exists(): | |
| _cfg = yaml.safe_load(_CONFIG_PATH.read_text(encoding="utf-8")) | |
| else: | |
| import warnings | |
| warnings.warn( | |
| f"config.yaml ๆชๆพๅฐ๏ผ{_CONFIG_PATH}๏ผ๏ผไฝฟ็จๅ ็ฝฎ้ป่ฎคๅผใ" | |
| "่ฅ่ฎญ็ปๆถไฝฟ็จไบ้้ป่ฎค grid_size๏ผๆจ็็ปๆๅฏ่ฝ้่ฏฏใ", | |
| stacklevel=1, | |
| ) | |
| _cfg = {} | |
| _maze_cfg = _cfg.get("maze", {}) | |
| GRID_SIZE = int(_maze_cfg.get("grid_size", 10)) | |
| OBSTACLE_DENSITY = float(_maze_cfg.get("obstacle_density", 0.25)) # ไธ config.yaml maze.obstacle_density ไฟๆไธ่ด๏ผ็กฎไฟ Demo ไธ่ฎญ็ปๅๅธ็ธๅ | |
| MAX_STEPS = int(_maze_cfg.get("max_steps", 200)) # ไธ่ฎญ็ปไฟๆไธ่ด๏ผๆจ็ๆญฅๆฐ้ข็ฎๅฏน้ฝ | |
| # ๆฏๆๅๆข็ๅ็ฎๆณ๏ผ้กบๅบๅณๅฎ UI ไธๆๆกๆๅ๏ผ | |
| ALGO_OPTIONS: list[str] = ["double_dueling", "dueling", "double", "vanilla"] | |
| ALGO_LABELS: dict[str, str] = { | |
| "vanilla": "Vanilla DQN๏ผๅบๅ๏ผ", | |
| "double": "Double DQN๏ผๆๅถ้ซไผฐ๏ผ", | |
| "dueling": "Dueling DQN๏ผV+A ๅ่งฃ๏ผ", | |
| "double_dueling": "Double + Dueling๏ผV+A + ๆๅถ้ซไผฐ๏ผ", | |
| } | |
| # Holdout ๆต่ฏ้ๆๅ็๏ผ็ฌ็ซ่ฏไผฐ๏ผ้่ฎญ็ปๆ eval_success๏ผ | |
| ALGO_SUCCESS_RATES: dict[str, Optional[float]] = { | |
| "vanilla": 75.0, | |
| "double": 78.0, | |
| "dueling": 84.0, | |
| "double_dueling": 81.0, | |
| } | |
| def algo_display_label(algo: str) -> str: | |
| """่ฟๅ็ฎๆณไธๆๆกๆพ็คบๆๅญ๏ผ็ฎๆณๅ + ็ฎ่ฟฐ + holdout ๆๅ็๏ผ่ฅๅฏ็จ๏ผใ""" | |
| base = ALGO_LABELS[algo] | |
| rate = ALGO_SUCCESS_RATES.get(algo) | |
| if rate is not None: | |
| return f"{base} | holdout {rate:.0f}%" | |
| return base | |
| # ้ป่ฎค็ฎๆณ๏ผไผๅ ่ฏป config.yaml๏ผfallback ๅฐ double_dueling | |
| _default_algo = str(_cfg.get("dqn", {}).get("algorithm", "double_dueling")).strip().lower() | |
| DEFAULT_ALGO: str = _default_algo if _default_algo in ALGO_OPTIONS else "double_dueling" | |
| def model_path_for(algo: str) -> Path: | |
| """ๆ นๆฎ็ฎๆณๅ่ฟๅๅฏนๅบๆ้ๆไปถ่ทฏๅพใ""" | |
| return Path(__file__).parent / "results" / f"best_model_train_{algo}.pth" | |
| # ้ฆๅฑ้ป่ฎค่ฟทๅฎซ seedใ | |
| # ๅบๅฎๅผไฟ่ฏๅไบซ้พๆฅๆถๅๆน็ๅฐ็ธๅๅฐๅพ๏ผๆนไธบ None ๅฏ่ฎฉๆฏๆฌกๅทๆฐ้ๆบ็ๆใ | |
| DEFAULT_MAZE_SEED: int = 42 | |
| # ๅจ็ปๅธง้ด้๏ผ็ง๏ผ | |
| ANIM_DELAY = 0.08 | |
| # ้ข่ฒๆ ๅฐ๏ผRGB ๅ่กจ๏ผไพ Plotly heatmap๏ผ | |
| COLOR_EMPTY = "#F8F9FA" # ็ฝ/ๆต ็ฐ โโ ๅฏ้่กๅฐๆฟ | |
| COLOR_WALL = "#2C3E50" # ๆทฑ่็ฐ โโ ๅขๅฃ | |
| COLOR_START = "#27AE60" # ็ปฟ่ฒ โโ ่ตท็น | |
| COLOR_GOAL = "#E74C3C" # ็บข่ฒ โโ ็ป็น | |
| COLOR_DQN_PATH = "#3498DB" # ่่ฒ โโ DQN ่ฝจ่ฟน | |
| COLOR_BFS_PATH = "#F39C12" # ๆฉ่ฒ โโ BFS ๆ็ญ่ทฏ | |
| COLOR_AGENT = "#9B59B6" # ็ดซ่ฒ โโ ๅฝๅ Agent ไฝ็ฝฎ | |
| # =========================================================================== | |
| # ๅทฅๅ ทๅฝๆฐ | |
| # =========================================================================== | |
| def generate_maze(seed: Optional[int] = None) -> np.ndarray: | |
| """็ๆ GRID_SIZEรGRID_SIZE ่ฟทๅฎซ๏ผไฟ่ฏ่ตท็น (1,1) ไธ็ป็น (N-2,N-2) ๅฏ่พพใ | |
| ๅงๆ็ป :class:`MazeEnv` ็ ``reset()`` ๆนๆณ๏ผ็กฎไฟไธ่ฎญ็ป็ฏๅขๅฎๅ จไธ่ด | |
| ๏ผ็ธๅ็่พน็ๅขใ้็ขๅฏๅบฆใBFS ่ฟ้ๆงไฟ่ฏ๏ผไธ้ๅค้ ่ฝฎๅญ๏ผใ | |
| Args: | |
| seed: ้ๆบ็งๅญ๏ผ``None`` ่กจ็คบไธๅบๅฎ้ๆบๆงใ | |
| Returns: | |
| wall_map: shape ``(N, N)``๏ผdtype ``int32``๏ผ0=้่ทฏ๏ผ1=ๅขๅฃใ | |
| """ | |
| env = MazeEnv( | |
| grid_size=GRID_SIZE, | |
| obstacle_density=OBSTACLE_DENSITY, | |
| ) | |
| env.reset(seed=seed) | |
| return env.wall_map.astype(np.int32) | |
| def generate_maze_with_random_sg( | |
| seed: Optional[int] = None, | |
| ) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]: | |
| """็ๆ่ฟทๅฎซๅนถไปๅฏ้่กๅ ้จๆ ผ้ๆบ้ๅ่ตท็นๅ็ป็น๏ผไธ่ฎญ็ปๅๅธๅฎๅ จไธ่ดใ | |
| ๅค็ฐ train.py ไธญ ``random_start_goal=True`` ็้ป่พ๏ผ | |
| ๅ ็ๆ่ฟทๅฎซ๏ผๅ็จ ``env.np_random``๏ผGymnasium ๆณจๅ ฅ็ๅฏไธ้ๆบๆบ๏ผ | |
| ไปๅ ้จๅฏ้่กๆ ผไธญไธๆพๅๅฐๆฝๅไธคไธชไธๅๅๆ ๏ผ็กฎไฟ Demo ไธ่ฎญ็ปๅๅๅธใ | |
| Args: | |
| seed: ้ๆบ็งๅญ๏ผ``None`` ่กจ็คบไธๅบๅฎ้ๆบๆงใ | |
| Returns: | |
| (wall_map, start, goal)๏ผ | |
| * wall_map: shape ``(N, N)``๏ผdtype ``int32``ใ | |
| * start: ่ตท็นๅๆ ``(row, col)``ใ | |
| * goal: ็ป็นๅๆ ``(row, col)``ใ | |
| """ | |
| env = MazeEnv( | |
| grid_size=GRID_SIZE, | |
| obstacle_density=OBSTACLE_DENSITY, | |
| ) | |
| env.reset(seed=seed) | |
| wall_map = env.wall_map.astype(np.int32) # (N, N) | |
| # ๆถ้ๅ ้จ๏ผ้่พน็๏ผๅฏ้่กๆ ผ๏ผไธ train.py ่ฟๆปคๆกไปถๅฎๅ จ็ธๅ | |
| rows, cols = np.where(wall_map == 0) | |
| inner_cells: list[tuple[int, int]] = [ | |
| (int(r), int(c)) | |
| for r, c in zip(rows, cols) | |
| if 0 < r < GRID_SIZE - 1 and 0 < c < GRID_SIZE - 1 | |
| ] | |
| if len(inner_cells) < 2: | |
| # ๆ็ซฏๆ ๅต๏ผ้็ขๅฏๅบฆๆ้ซ๏ผ๏ผ้ๅๅฐๅบๅฎ่ตท็ป็น | |
| return wall_map, (1, 1), (GRID_SIZE - 2, GRID_SIZE - 2) | |
| # rng.choice(replace=False) ไธๆฌก่ฐ็จๅคฉ็ถไฟ่ฏไธคไธช็ดขๅผไธ้ๅค๏ผ | |
| # ๆถ้ค rejection sampling ็ๆฝๅจๆ ้ๅพช็ฏ้ฃ้ฉ | |
| idxs = env.np_random.choice(len(inner_cells), size=2, replace=False) | |
| start = inner_cells[int(idxs[0])] | |
| goal = inner_cells[int(idxs[1])] | |
| return wall_map, start, goal | |
| def load_model(algo: str = DEFAULT_ALGO, grid_size: int = GRID_SIZE) -> tuple[Optional[nn.Module], int]: | |
| """ๅ ่ฝฝๆๅฎ็ฎๆณ็ DQN ๆจกๅๆ้๏ผ่ฟๅ (net, saved_grid_size)ใ | |
| Args: | |
| algo: ็ฎๆณๅ๏ผ้กปๅจ ALGO_OPTIONS ไธญใ | |
| grid_size: ๅฝๅ็ฏๅข grid_size๏ผ็จไบ็ปดๅบฆไธไธ่ดๆถ็ fallback ่ฟๅๅผใ | |
| ๅคฑ่ดฅๆถ่ฟๅ (None, grid_size)ใsaved_grid_size ไพ่ฐ็จๆนๆฃๆต็ปดๅบฆๆฏๅฆไธ | |
| ๅฝๅ GRID_SIZE ไธ่ด๏ผไธไธ่ดๆถๆจ็่พๅ ฅ็ปดๅบฆไผไธ็ฝ็ปๆๆไธ็ฌฆ๏ผๅบๆๅๅ่ญฆใ | |
| """ | |
| path = model_path_for(algo) | |
| if not path.exists(): | |
| return None, grid_size | |
| try: | |
| ckpt = torch.load(path, map_location="cpu", weights_only=True) | |
| saved_gs = ckpt.get("grid_size", grid_size) | |
| algorithm = ckpt.get("algorithm", "vanilla").strip().lower() | |
| NetClass = DuelingDQNNetwork if "dueling" in algorithm else DQNNetwork | |
| in_ch = ckpt["state_dict"]["conv.0.weight"].shape[1] | |
| net = NetClass(grid_size=saved_gs, input_channels=in_ch) | |
| net.load_state_dict(ckpt["state_dict"]) | |
| net.eval() | |
| return net, saved_gs | |
| except Exception as e: | |
| st.error(f"โ ๆจกๅๅ ่ฝฝๅคฑ่ดฅ๏ผ{e}") | |
| return None, grid_size | |
| def dqn_rollout( | |
| net: nn.Module, | |
| wall_map: np.ndarray, | |
| start: tuple, | |
| goal: tuple, | |
| ) -> list[tuple]: | |
| """็บฏๆจ็๏ผฮต=0๏ผ่ฟ่ก DQN Agent๏ผ่ฟๅๅฎๆด่ฝจ่ฟนๅๆ ๅ่กจใ | |
| ๅงๆ็ป :class:`MazeEnv` ็ๆ ๅ ``reset()`` / ``step()`` ๆฅๅฃ๏ผ | |
| ไฟ่ฏ่งๆต็ผ็ ไธ่ฎญ็ปๆถๅฎๅ จไธ่ด๏ผๆ ้ๅจ app.py ไธญ้ๅคๅฎ็ฐ็ขฐๆๆฃๆตใ | |
| Args: | |
| net: ๅทฒๅ ่ฝฝๆ้ใๅคไบ eval ๆจกๅผ็ DQN ็ฝ็ปใ | |
| wall_map: shape ``(N, N)``๏ผdtype int32๏ผ0=้่ทฏ๏ผ1=ๅขๅฃใ | |
| start: Agent ่ตท็น ``(row, col)``ใ | |
| goal: ็ป็น ``(row, col)``ใ | |
| Returns: | |
| ๅฎๆด่ฝจ่ฟน๏ผๅซ่ตท็น๏ผ๏ผๆฏๆกไธบ ``(row, col)``ใ | |
| """ | |
| env = MazeEnv( | |
| grid_size=wall_map.shape[0], | |
| obstacle_density=0.0, # ๅฏๅบฆๆ ๅ ณ๏ผๅฐๅพ็ฑๅค้จๆณจๅ ฅ | |
| max_steps=MAX_STEPS, | |
| ) | |
| obs, _ = env.reset(options={ | |
| "wall_map": wall_map.astype(np.float32), | |
| "start": start, | |
| "goal": goal, | |
| }) | |
| path = [env.agent_pos] | |
| # ๆจ็ไพง anti-loop ๅ ๅบ๏ผvisited_map๏ผch3๏ผๅทฒ่ฎฉ Q ๅฝๆฐๅ ๅ่ฎฟ้ฎๅๅฒ๏ผ | |
| # ไฝๅฏนๆชๅ ๅ่ฆ็็็ถๆไปๅฏ่ฝ้ทๅ ฅไธคๆ ผๆญปๅพช็ฏใ | |
| # ่ฎฟ้ฎๆฌกๆฐ >= 2 ๆถๅฏนๅฝๅ argmax ๅจไฝๆฝๅ ้่ฟ Q ๅผๆฉ็ฝไฝไธบๅฎๅ จ็ฝ๏ผ | |
| # ไธไฟฎๆน็ฝ็ปๆ้๏ผไธๅฝฑๅ่ฎญ็ปๅๅธใ | |
| visited_count: dict[tuple, int] = {} | |
| while True: | |
| s = torch.from_numpy(obs).unsqueeze(0) | |
| with torch.no_grad(): | |
| q_values = net(s)[0].clone() # shape: (num_actions,) | |
| # ๅฏน้ซ้ข้่ฎฟๆ ผๅญ็ๅฝๅๆไผๅจไฝๆฝๅ ๆฉ็ฝ | |
| cur_pos = env.agent_pos | |
| cnt = visited_count.get(cur_pos, 0) | |
| if cnt >= 2: | |
| action_candidate = int(q_values.argmax().item()) | |
| q_values[action_candidate] -= 3.0 * cnt | |
| # ๅฏนๆฏไธชๅจไฝ้ขๅค็ฎๆ ๆ ผ๏ผ่ฅ็ฎๆ ๆ ผไนๆฏ้ซ้ข่ฎฟ้ฎๆ ผๅ้ขๅคๆฉ็ฝ | |
| cur_r, cur_c = cur_pos | |
| N = env.grid_size | |
| for a, (dr, dc) in enumerate(DELTAS): | |
| nr, nc = cur_r + dr, cur_c + dc | |
| if 0 <= nr < N and 0 <= nc < N: | |
| next_cnt = visited_count.get((nr, nc), 0) | |
| if next_cnt >= 2: | |
| q_values[a] -= 3.0 * next_cnt | |
| action = int(q_values.argmax().item()) | |
| visited_count[cur_pos] = cnt + 1 | |
| obs, _reward, terminated, truncated, info = env.step(action) | |
| # ๅชๅจๅฎ้ ็งปๅจๆถ่ฟฝๅ ๏ผๆๅขๆถไฝ็ฝฎไธๅ๏ผ้ฟๅ ้ๅคๅๆ ๅฏผ่ดๅจ็ปๆๅธง๏ผ | |
| if not info["hit_wall"]: | |
| path.append(env.agent_pos) | |
| if terminated or truncated: | |
| break | |
| return path | |
| # =========================================================================== | |
| # Plotly ่ฟทๅฎซ็ปๅถ | |
| # =========================================================================== | |
| def build_maze_figure( | |
| wall_map: np.ndarray, | |
| start: tuple, | |
| goal: tuple, | |
| dqn_path: Optional[list] = None, | |
| bfs_path: Optional[list] = None, | |
| agent_pos: Optional[tuple] = None, | |
| highlight_dqn_step: int = -1, | |
| ) -> go.Figure: | |
| """ๆๅปบ Plotly ่ฟทๅฎซๅพ๏ผๆฏๆๅ ๅ DQN / BFS ่ทฏๅพไธๅจๆ Agent ๆ ่ฎฐใ""" | |
| N = wall_map.shape[0] | |
| # โโ ๅบๅฑ็ญๅๅพ๏ผๅ Heatmap trace๏ผO(1) traces vs O(Nยฒ) shapes๏ผโโโโโโโโโ | |
| # ๆฐๅผ็ฉ้ต๏ผ0=้่ทฏ, 1=ๅข, 2=่ตท็น, 3=็ป็น | |
| z = wall_map.astype(float).copy() | |
| z[start[0], start[1]] = 2.0 | |
| z[goal[0], goal[1]] = 3.0 | |
| # ็ฆปๆฃ้ข่ฒๆ ๅฐ๏ผๅผ โ ้ข่ฒ | |
| colorscale = [ | |
| [0.00, COLOR_EMPTY], # 0 = ้่ทฏ | |
| [0.25, COLOR_EMPTY], | |
| [0.25, COLOR_WALL], # 1 = ๅข | |
| [0.50, COLOR_WALL], | |
| [0.50, COLOR_START], # 2 = ่ตท็น | |
| [0.75, COLOR_START], | |
| [0.75, COLOR_GOAL], # 3 = ็ป็น | |
| [1.00, COLOR_GOAL], | |
| ] | |
| fig = go.Figure() | |
| fig.add_trace(go.Heatmap( | |
| z=z, | |
| colorscale=colorscale, | |
| zmin=0, zmax=3, | |
| showscale=False, | |
| xgap=1, ygap=1, | |
| hoverinfo="skip", | |
| )) | |
| # โโ BFS ่ทฏๅพ๏ผๆฉ่ฒ่็บฟ๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| if bfs_path and len(bfs_path) > 1: | |
| bx = [c for r, c in bfs_path] | |
| by = [r for r, c in bfs_path] | |
| fig.add_trace(go.Scatter( | |
| x=bx, y=by, | |
| mode="lines+markers", | |
| name="BFS ๆ็ญ่ทฏ", | |
| line=dict(color=COLOR_BFS_PATH, width=3, dash="dot"), | |
| marker=dict(size=6, color=COLOR_BFS_PATH, opacity=0.7), | |
| )) | |
| # โโ DQN ่ทฏๅพ๏ผ่่ฒๅฎ็บฟ๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| if dqn_path and len(dqn_path) > 1: | |
| # ๆชๅๅฐ highlight_dqn_step๏ผๅจ็ป็จ๏ผ | |
| end_idx = highlight_dqn_step + 1 if highlight_dqn_step >= 0 else len(dqn_path) | |
| sub_path = dqn_path[:end_idx] | |
| dx = [c for r, c in sub_path] | |
| dy = [r for r, c in sub_path] | |
| fig.add_trace(go.Scatter( | |
| x=dx, y=dy, | |
| mode="lines+markers", | |
| name="DQN ่ฝจ่ฟน", | |
| line=dict(color=COLOR_DQN_PATH, width=3), | |
| marker=dict(size=7, color=COLOR_DQN_PATH), | |
| )) | |
| # โโ ๅฝๅ Agent ไฝ็ฝฎ๏ผ็ดซ่ฒๅคงๅ็น๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| ap = agent_pos if agent_pos else (start if not dqn_path else | |
| (dqn_path[min(highlight_dqn_step, len(dqn_path)-1)] | |
| if highlight_dqn_step >= 0 else start)) | |
| fig.add_trace(go.Scatter( | |
| x=[ap[1]], y=[ap[0]], | |
| mode="markers", | |
| name="Agent", | |
| marker=dict(size=16, color=COLOR_AGENT, symbol="circle", | |
| line=dict(color="white", width=2)), | |
| showlegend=True, | |
| )) | |
| # โโ ่ตท็น / ็ป็นๆ ็ญพ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| fig.add_trace(go.Scatter( | |
| x=[start[1], goal[1]], | |
| y=[start[0], goal[0]], | |
| mode="markers+text", | |
| text=["S", "G"], | |
| textposition="middle center", | |
| textfont=dict(size=13, color="white", family="Arial Black"), | |
| marker=dict(size=22, color=[COLOR_START, COLOR_GOAL], | |
| symbol="square", opacity=0.0), # ้ๆๅบ๏ผๅชๆพ็คบๅญ | |
| showlegend=False, | |
| hoverinfo="skip", | |
| )) | |
| # โโ ๅธๅฑ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| fig.update_layout( | |
| width=560, height=560, | |
| margin=dict(l=10, r=10, t=30, b=10), | |
| xaxis=dict( | |
| range=[-0.5, N - 0.5], tickvals=list(range(N)), | |
| showgrid=False, zeroline=False, title="ๅ (col)", | |
| ), | |
| yaxis=dict( | |
| range=[N - 0.5, -0.5], | |
| tickvals=list(range(N)), | |
| showgrid=False, zeroline=False, title="่ก (row)", | |
| ), | |
| legend=dict(x=1.01, y=1, bgcolor="rgba(255,255,255,0.8)", | |
| bordercolor="#BDC3C7", borderwidth=1), | |
| paper_bgcolor="white", | |
| plot_bgcolor="white", | |
| title=dict(text="๐ DQN ่ฟทๅฎซๅฏป่ทฏ", x=0.5, font=dict(size=16)), | |
| ) | |
| return fig | |
| def _find_cell_index(free_cells: list[tuple], pos: tuple) -> int: | |
| """ๅจ free_cells ๅ่กจไธญๆฅๆพ pos ็็ดขๅผ๏ผๆชๆพๅฐๆถ่ฟๅ 0๏ผๅฎๅ จๅ้๏ผใ""" | |
| try: | |
| return free_cells.index(pos) | |
| except ValueError: | |
| return 0 | |
| # =========================================================================== | |
| # Session State ๅๅงๅ | |
| # =========================================================================== | |
| def _init_state() -> None: | |
| if "wall_map" not in st.session_state: | |
| # ้ฆๅฑไฝฟ็จ้ๆบ่ตท็ป็น๏ผไธ่ฎญ็ปๅๅธไธ่ด๏ผ๏ผๅบๅฎ seed ไฟ่ฏๅฏๅค็ฐ | |
| wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=DEFAULT_MAZE_SEED) | |
| st.session_state.wall_map = wm | |
| st.session_state.start = sg_start | |
| st.session_state.goal = sg_goal | |
| if "start" not in st.session_state: | |
| st.session_state.start = (1, 1) | |
| if "goal" not in st.session_state: | |
| st.session_state.goal = (GRID_SIZE - 2, GRID_SIZE - 2) | |
| if "dqn_path" not in st.session_state: | |
| st.session_state.dqn_path = None | |
| if "bfs_path" not in st.session_state: | |
| st.session_state.bfs_path = None | |
| if "metrics" not in st.session_state: | |
| st.session_state.metrics = None | |
| if "selected_algo" not in st.session_state: | |
| st.session_state.selected_algo = DEFAULT_ALGO | |
| if "model" not in st.session_state: | |
| net, saved_gs = load_model(algo=DEFAULT_ALGO) | |
| st.session_state.model = net | |
| st.session_state.model_grid_size = saved_gs | |
| if "maze_seed" not in st.session_state: | |
| st.session_state.maze_seed = DEFAULT_MAZE_SEED | |
| if "anim_running" not in st.session_state: | |
| st.session_state.anim_running = False | |
| if "anim_step" not in st.session_state: | |
| st.session_state.anim_step = 0 | |
| if "anim_path" not in st.session_state: | |
| st.session_state.anim_path = None | |
| # =========================================================================== | |
| # ไธป็จๅบ | |
| # =========================================================================== | |
| def main() -> None: | |
| st.set_page_config( | |
| page_title="DQN ่ฟทๅฎซๅฏป่ทฏ Demo", | |
| page_icon="๐ค", | |
| layout="wide", | |
| ) | |
| # โโ ๅ จๅฑๆ ทๅผๆณจๅ ฅ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| st.markdown(""" | |
| <style> | |
| .metric-card { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border-radius: 12px; padding: 16px 20px; color: white; | |
| text-align: center; margin: 6px 0; | |
| } | |
| .metric-label { font-size: 13px; opacity: 0.85; margin-bottom: 4px; } | |
| .metric-value { font-size: 28px; font-weight: 700; } | |
| .por-perfect { color: #2ECC71; font-weight: 800; } | |
| .por-good { color: #F39C12; font-weight: 700; } | |
| .por-bad { color: #E74C3C; font-weight: 600; } | |
| div[data-testid="stButton"] button { | |
| width: 100%; border-radius: 8px; font-weight: 600; | |
| } | |
| /* ่ฟทๅฎซๆ้ฎ็ฝๆ ผ๏ผๆฏๆ ผ็ดงๅๆญฃๆนๅฝข๏ผๆ ๅ ่พน่ท */ | |
| div[data-testid="stHorizontalBlock"] div[data-testid="stButton"] button { | |
| padding: 0 !important; | |
| min-height: 40px !important; | |
| font-size: 15px !important; | |
| border-radius: 3px !important; | |
| border: 1px solid #ccc !important; | |
| line-height: 1 !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| _init_state() | |
| st.title("๐ค DQN ่ฟทๅฎซๅฏป่ทฏ ยท ๅฏ่งๅ Demo") | |
| st.caption("Deep Q-Network ร BFS Ground-Truth ยท Hugging Face Spaces") | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๆญฃๅธธๅๆ ๅธๅฑ๏ผ็นๅปๆจกๅผๅจๅณๆ ๅ ๅค็๏ผไธ็ ดๅๆดไฝๅธๅฑ๏ผ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| left_col, right_col = st.columns([1, 2.2], gap="large") | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๅทฆๆ ๏ผๆงๅถ้ขๆฟ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| with left_col: | |
| st.subheader("โ๏ธ ๆงๅถ้ขๆฟ") | |
| # โโ ่ฟทๅฎซ็ๆ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| st.markdown("**โ ่ฟทๅฎซๅฐๅพ**") | |
| col_seed, col_rand = st.columns([3, 1]) | |
| with col_seed: | |
| input_seed = st.number_input( | |
| "่ฟทๅฎซ Seed", | |
| min_value=0, | |
| max_value=999999, | |
| value=st.session_state.maze_seed, | |
| step=1, | |
| help="ๅบๅฎๆฐๅญๅฏๅค็ฐๆๅฎๅฐๅพ๏ผ็นๅปๅณไพงๆ้ฎ้ๆบ็ๆๆฐๅฐๅพ", | |
| ) | |
| with col_rand: | |
| st.write("") # ๅฏน้ฝๅ ไฝ | |
| if st.button("๐ฒ ้ๆบ"): | |
| # ้ๆบ seed๏ผๅๆถ้ๆบ็ๆๅฐๅพๅ่ตท็ป็น๏ผไธ่ฎญ็ปๅๅธไธ่ด๏ผ | |
| new_seed = random.randint(0, 999999) | |
| wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=new_seed) | |
| st.session_state.maze_seed = new_seed | |
| st.session_state.wall_map = wm | |
| st.session_state.start = sg_start | |
| st.session_state.goal = sg_goal | |
| st.session_state.dqn_path = None | |
| st.session_state.bfs_path = None | |
| st.session_state.metrics = None | |
| # ๅๆญฅไธๆๆก็ดขๅผ๏ผ้ฟๅ selectbox key ็ผๅญๆงๅผ | |
| _fc = [(r,c) for r in range(1,GRID_SIZE-1) for c in range(1,GRID_SIZE-1) if wm[r,c]==0] | |
| st.session_state.start_select = _find_cell_index(_fc, sg_start) | |
| st.session_state.goal_select = _find_cell_index(_fc, sg_goal) | |
| st.session_state.anim_running = False | |
| st.rerun() # ็ซๅณ็ปๆญขๅฝๅ่ๆฌ๏ผไธๆน input_seed ๆฃๆตไธไผๆง่ก | |
| # ๆๅจไฟฎๆน seed ่พๅ ฅๆกๆถ่งฆๅ๏ผ้ๆบๆ้ฎๅทฒ็ฑไธๆน rerun ็ญ่ทฏ๏ผไธไผ้ๅค๏ผ | |
| if input_seed != st.session_state.maze_seed: | |
| wm, sg_start, sg_goal = generate_maze_with_random_sg(seed=input_seed) | |
| st.session_state.maze_seed = input_seed | |
| st.session_state.wall_map = wm | |
| st.session_state.start = sg_start | |
| st.session_state.goal = sg_goal | |
| st.session_state.dqn_path = None | |
| st.session_state.bfs_path = None | |
| st.session_state.metrics = None | |
| _fc = [(r,c) for r in range(1,GRID_SIZE-1) for c in range(1,GRID_SIZE-1) if wm[r,c]==0] | |
| st.session_state.start_select = _find_cell_index(_fc, sg_start) | |
| st.session_state.goal_select = _find_cell_index(_fc, sg_goal) | |
| st.session_state.anim_running = False | |
| st.rerun() | |
| st.divider() | |
| # โโ ่ตท็น / ็ป็น้ๆฉ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| st.markdown("**โก ่ตท็น & ็ป็น**") | |
| # ใ้ๆบ่ตท็ป็นใๆ้ฎ๏ผไปๅฝๅๅฐๅพ็ๅฏ้่กๆ ผ้ๆบ้ๅ๏ผไธ่ฎญ็ปๅๅธไธ่ด | |
| if st.button("๐ฒ ้ๆบ่ตท็ป็น", use_container_width=True, | |
| help="ไปๅฝๅๅฐๅพๅฏ้่กๆ ผ้ๆบ้ๅ่ตท็นๅ็ป็น๏ผไธ่ฎญ็ปๅๅธๅฎๅ จไธ่ด"): | |
| _wm = st.session_state.wall_map | |
| _rows, _cols = np.where(_wm == 0) | |
| _inner = [ | |
| (int(r), int(c)) | |
| for r, c in zip(_rows, _cols) | |
| if 0 < r < GRID_SIZE - 1 and 0 < c < GRID_SIZE - 1 | |
| ] | |
| if len(_inner) >= 2: | |
| _i, _j = random.sample(range(len(_inner)), 2) | |
| st.session_state.start = _inner[_i] | |
| st.session_state.goal = _inner[_j] | |
| st.session_state.dqn_path = None | |
| st.session_state.bfs_path = None | |
| st.session_state.metrics = None | |
| st.session_state.start_select = _find_cell_index(_inner, _inner[_i]) | |
| st.session_state.goal_select = _find_cell_index(_inner, _inner[_j]) | |
| st.session_state.anim_running = False | |
| st.rerun() | |
| N = GRID_SIZE | |
| free_cells = [ | |
| (r, c) | |
| for r in range(1, N - 1) | |
| for c in range(1, N - 1) | |
| if st.session_state.wall_map[r, c] == 0 | |
| ] | |
| cell_labels = [f"({r},{c})" for r, c in free_cells] | |
| start_idx = st.selectbox( | |
| "่ตท็น (row, col)", | |
| options=range(len(free_cells)), | |
| format_func=lambda i: cell_labels[i], | |
| index=_find_cell_index(free_cells, st.session_state.start), | |
| key="start_select", | |
| ) | |
| goal_idx = st.selectbox( | |
| "็ป็น (row, col)", | |
| options=range(len(free_cells)), | |
| format_func=lambda i: cell_labels[i], | |
| index=_find_cell_index(free_cells, st.session_state.goal), | |
| key="goal_select", | |
| ) | |
| new_start = free_cells[start_idx] | |
| new_goal = free_cells[goal_idx] | |
| if new_start == new_goal: | |
| st.warning("โ ๏ธ ่ตท็นไธ็ป็นไธ่ฝ็ธๅ๏ผ่ฏท้ๆฐ้ๆฉใ") | |
| elif new_start != st.session_state.start or new_goal != st.session_state.goal: | |
| st.session_state.start = new_start | |
| st.session_state.goal = new_goal | |
| st.session_state.dqn_path = None | |
| st.session_state.bfs_path = None | |
| st.session_state.metrics = None | |
| st.divider() | |
| # โโ ็ฎๆณ้ๆฉ & ๅฏป่ทฏ่งฆๅๆ้ฎ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| st.markdown("**โข ๅฏป่ทฏ็ฎๆณ**") | |
| selected_algo = st.selectbox( | |
| "DQN ็ฎๆณๅไฝ", | |
| options=ALGO_OPTIONS, | |
| format_func=algo_display_label, | |
| index=ALGO_OPTIONS.index(st.session_state.selected_algo), | |
| key="algo_select", | |
| help="ๅๆข็ฎๆณๅ็นๅปใDQN ๅฏป่ทฏใๆ้ฎๅฏๅฏนๆฏไธๅ็ฎๆณๅจๅไธๅฐๅพไธ็่ทฏๅพ", | |
| ) | |
| # ็ฎๆณๅๆขๆถ้ๆฐๅ ่ฝฝๅฏนๅบๆจกๅ๏ผๆธ ็ฉบไธๆฌก่ทฏๅพ็ปๆ | |
| if selected_algo != st.session_state.selected_algo: | |
| st.session_state.selected_algo = selected_algo | |
| net, saved_gs = load_model(algo=selected_algo) | |
| st.session_state.model = net | |
| st.session_state.model_grid_size = saved_gs | |
| st.session_state.dqn_path = None | |
| st.session_state.metrics = None | |
| st.session_state.anim_running = False | |
| st.rerun() | |
| run_dqn = st.button( | |
| "๐ค DQN ๆบ่ฝไฝๅฏป่ทฏ", | |
| use_container_width=True, | |
| type="primary", | |
| ) | |
| run_bfs = st.button( | |
| "๐ BFS ไธๅฎถๅฏป่ทฏ", | |
| use_container_width=True, | |
| ) | |
| st.divider() | |
| # โโ ๅพไพ่ฏดๆ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| st.markdown("**ๅพไพ**") | |
| legend_html = """ | |
| <div style='font-size:13px; line-height:2'> | |
| ๐ฉ <b>S</b> ่ตท็น | |
| ๐ฅ <b>G</b> ็ป็น<br> | |
| โฌ ๅขๅฃ | |
| โฌ ้่ทฏ<br> | |
| ๐ต DQN ่ฝจ่ฟน | |
| ๐ BFS ๆ็ญ่ทฏ<br> | |
| ๐ฃ Agent ๅฝๅไฝ็ฝฎ | |
| </div> | |
| """ | |
| st.markdown(legend_html, unsafe_allow_html=True) | |
| # โโ ๆจกๅ็ถๆ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| st.divider() | |
| _cur_algo = st.session_state.get("selected_algo", DEFAULT_ALGO) | |
| _cur_path = model_path_for(_cur_algo) | |
| if st.session_state.model is not None: | |
| st.success(f"โ ๆจกๅๅทฒๅ ่ฝฝ ({_cur_path.name})") | |
| # ็ปดๅบฆไธไธ่ดๆถๆๅๅ่ญฆ๏ผ็ฝ็ปๆๆ (3, saved_gs, saved_gs) ่พๅ ฅ๏ผ | |
| # ่ๆจ็็ฏๅขไผ็ๆ (3, GRID_SIZE, GRID_SIZE) ่งๆต๏ผไธค่ ไธ็ฌฆไผๅจ | |
| # ็ฝ็ป forward ๆถๆๅบๅผ ้ๅฐบๅฏธๅผๅธธใๆๅๅฑ็คบ่ญฆๅไพฟไบ็จๆทๅฎไฝๅๅ ใ | |
| _saved_gs = st.session_state.get("model_grid_size", GRID_SIZE) | |
| if _saved_gs != GRID_SIZE: | |
| st.warning( | |
| f"โ ๏ธ ๆจกๅ่ฎญ็ปไบ {_saved_gs}ร{_saved_gs} ่ฟทๅฎซ๏ผ" | |
| f"ๅฝๅ้ ็ฝฎไธบ {GRID_SIZE}ร{GRID_SIZE}ใ\n" | |
| "ๆจ็ๆถ่พๅ ฅ็ปดๅบฆไธๅน้ ๏ผๅฐๅฏผ่ด่ฟ่กๆถ้่ฏฏใ\n" | |
| "่ฏทไฝฟ็จๅน้ grid_size ็ๆจกๅ๏ผๆๆดๆฐ config.yamlใ" | |
| ) | |
| else: | |
| st.error(f"โ ๆชๆพๅฐ {_cur_path.name}") | |
| st.info(f"่ฏทๅ ่ฟ่ก `python src/train.py --algorithm {_cur_algo}` ่ฎญ็ปๆจกๅใ") | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๅณๆ ๏ผไธป็ปๅธ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| with right_col: | |
| wall_map = st.session_state.wall_map | |
| start = st.session_state.start | |
| goal = st.session_state.goal | |
| status_placeholder = st.empty() | |
| # โโ BFS ๅฏป่ทฏ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| if run_bfs: | |
| result = bfs_solve(wall_map.astype(np.int32), start, goal) | |
| if result["success"]: | |
| st.session_state.bfs_path = result["path"] | |
| status_placeholder.success( | |
| f"โ BFS ๅฎๆ๏ผๆ็ญๆญฅๆฐ = **{result['steps']}**๏ผ" | |
| f"่ๆถ {result['execution_time_ms']:.3f} ms" | |
| ) | |
| else: | |
| st.session_state.bfs_path = None | |
| status_placeholder.error("โ BFS๏ผ่ตท็นไธ็ป็นไน้ดๆ ๅฏ่พพ่ทฏๅพ๏ผ") | |
| # โโ DQN ๅฏป่ทฏๆ้ฎ่งฆๅ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| if run_dqn: | |
| model = st.session_state.model | |
| if model is None: | |
| status_placeholder.error("โ ๆจกๅๆชๅ ่ฝฝ๏ผๆ ๆณๆจ็ใ") | |
| elif st.session_state.get("model_grid_size", GRID_SIZE) != GRID_SIZE: | |
| _mgs = st.session_state.model_grid_size | |
| status_placeholder.error( | |
| f"โ ๆจกๅ่ฎญ็ปไบ {_mgs}ร{_mgs}๏ผๅฝๅไธบ {GRID_SIZE}ร{GRID_SIZE}๏ผ็ปดๅบฆไธๅน้ ใ" | |
| ) | |
| else: | |
| bfs_result = bfs_solve(wall_map.astype(np.int32), start, goal) | |
| if not bfs_result["success"]: | |
| status_placeholder.error("โ ่ฏฅ่ฟทๅฎซ้ ็ฝฎๆ ่งฃ๏ผ่ฏทๆข่ตท็ป็นใ") | |
| else: | |
| with st.spinner("๐ค DQN ๆจ็ไธญโฆ"): | |
| dqn_path = dqn_rollout(model, wall_map, start, goal) | |
| ai_steps = len(dqn_path) - 1 | |
| bfs_steps = bfs_result["steps"] | |
| success = (dqn_path[-1] == goal) | |
| por = round(bfs_steps / ai_steps, 4) if (success and ai_steps > 0) else 0.0 | |
| st.session_state.dqn_path = dqn_path | |
| st.session_state.bfs_path = bfs_result["path"] | |
| st.session_state.metrics = { | |
| "ai_steps": ai_steps, "bfs_steps": bfs_steps, | |
| "success": success, "por": por, | |
| } | |
| # ๅฏๅจๅธงๅจ็ป | |
| st.session_state.anim_running = True | |
| st.session_state.anim_step = 0 | |
| st.session_state.anim_path = dqn_path | |
| st.rerun() | |
| # โโ ๅจ็ป้ฉฑๅจ๏ผsession_state ๅธงๆจ่ฟ๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| if st.session_state.anim_running: | |
| step_i = st.session_state.anim_step | |
| anim_p = st.session_state.anim_path | |
| total = len(anim_p) | |
| status_placeholder.info(f"๐ฌ ๅจ็ปๆญๆพไธญโฆ {step_i + 1}/{total}") | |
| fig = build_maze_figure( | |
| wall_map, start, goal, | |
| dqn_path=anim_p, | |
| bfs_path=st.session_state.bfs_path, | |
| highlight_dqn_step=step_i, | |
| ) | |
| st.plotly_chart(fig, use_container_width=False, key=f"anim_{step_i}") | |
| if step_i + 1 < total: | |
| time.sleep(ANIM_DELAY) | |
| st.session_state.anim_step += 1 | |
| st.rerun() | |
| else: | |
| st.session_state.anim_running = False | |
| m = st.session_state.metrics | |
| ok = m["success"] | |
| status_placeholder.success( | |
| f"{'โ ' if ok else 'โ'} DQN ๅฏป่ทฏ{'ๆๅ' if ok else 'ๅคฑ่ดฅ'}๏ผ" | |
| f" AI ๆญฅๆฐ = **{m['ai_steps']}** | BFS ๆ็ญ = **{m['bfs_steps']}**" | |
| ) | |
| # โโ ้ๆ่ฟทๅฎซๅพ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| elif not run_dqn: | |
| fig = build_maze_figure( | |
| wall_map, start, goal, | |
| dqn_path=st.session_state.dqn_path, | |
| bfs_path=st.session_state.bfs_path, | |
| highlight_dqn_step=-1, | |
| ) | |
| st.plotly_chart(fig, use_container_width=False, key="maze_static") | |
| # โโ ๆๆ ไปช่กจ็ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| m = st.session_state.metrics | |
| if m: | |
| ai_s = m["ai_steps"] | |
| bfs_s = m["bfs_steps"] | |
| por = m["por"] | |
| ok = m["success"] | |
| # POR ๅ็บง้ข่ฒ | |
| if ok and por >= 0.99: | |
| por_cls = "por-perfect" | |
| por_text = f"{por:.2f} ๐ 100% Perfect" | |
| elif ok and por >= 0.75: | |
| por_cls = "por-good" | |
| por_text = f"{por:.2f} ๐ Good" | |
| elif ok: | |
| por_cls = "por-bad" | |
| por_text = f"{por:.2f} โ ๏ธ Sub-optimal" | |
| else: | |
| por_cls = "por-bad" | |
| por_text = "N/A โ ๆชๅฐ่พพ็ป็น" | |
| mc1, mc2, mc3 = st.columns(3) | |
| with mc1: | |
| st.markdown(f""" | |
| <div class='metric-card'> | |
| <div class='metric-label'>๐ค AI ๅฎ้ ๆญฅๆฐ</div> | |
| <div class='metric-value'>{ai_s}</div> | |
| </div>""", unsafe_allow_html=True) | |
| with mc2: | |
| st.markdown(f""" | |
| <div class='metric-card'> | |
| <div class='metric-label'>๐ BFS ็่ฎบๆ็ญ</div> | |
| <div class='metric-value'>{bfs_s}</div> | |
| </div>""", unsafe_allow_html=True) | |
| with mc3: | |
| st.markdown(f""" | |
| <div class='metric-card' style='background:linear-gradient(135deg,#11998e,#38ef7d)'> | |
| <div class='metric-label'>โก Path Optimality Ratio</div> | |
| <div class='metric-value {por_cls}'>{por_text}</div> | |
| </div>""", unsafe_allow_html=True) | |
| with st.expander("๐ ๆๆ ่ฏดๆ"): | |
| st.markdown(""" | |
| | ๆๆ | ๅซไน | | |
| |------|------| | |
| | **AI ๅฎ้ ๆญฅๆฐ** | DQN Agent ไป่ตท็น่ตฐๅฐ็ป็น๏ผๆ่ถ ๆถ๏ผๆ็จ็ๆปๆญฅๆฐ | | |
| | **BFS ็่ฎบๆ็ญ** | BFS ็ฎๆณ่ฎก็ฎ็็ปๅฏนๆ็ญ่ทฏๅพๆญฅๆฐ๏ผGround Truth๏ผ| | |
| | **Path Optimality Ratio** | `BFSๆญฅๆฐ / AIๆญฅๆฐ`๏ผ่ถๆฅ่ฟ **1.00** ่ถๅฎ็พใ็ญไบ 1.00 ่ฏดๆ AI ่ตฐๅบไบไธ BFS ๅฎๅ จ็ธๅ็ๆ็ญ่ทฏ๏ผ | | |
| """) | |
| # โโ ้กต่ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| st.divider() | |
| st.markdown( | |
| "<div style='text-align:center;color:#95A5A6;font-size:12px'>" | |
| "DQN Maze Solver ยท PyTorch + Gymnasium + Streamlit ยท " | |
| "Hugging Face Spaces Demo" | |
| "</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |