| """report.py โโ DQN ็ฎๆณๆจชๅๅฏนๆฏๆฅๅ็ๆๅจ |
| |
| ่่ดฃ |
| ---- |
| ๆซๆ ``results/`` ็ฎๅฝไธๆๆ ``best_model_train_*.pth``๏ผๅฏนๆฏไธช็ฎๆณ๏ผ |
| 1. ๅ ่ฝฝ checkpoint๏ผ่ชๅจ่ฏๅซ DQNNetwork / DuelingDQNNetwork๏ผ |
| 2. ๅจ Holdout ้๏ผseed + 200000๏ผๅ
ฑ 100 ๅผ ไปๆชๅไธ่ฎญ็ป็ๅฐๅพ๏ผไธ่ฟ่ก่ฏไผฐ |
| 3. ๆฑๆปๆๅ็ใSPLใไฟๅญ Episodeใ่ฎญ็ป AvgReward๏ผ่พๅบๅฏนๆฏ่กจๆ ผ |
| |
| ่พๅบ |
| ---- |
| * ็ป็ซฏๆๅฐๅฏนๆฏ่กจๆ ผ |
| * ไฟๅญ ``reports/comparison.md`` |
| |
| ็จๆณ |
| ---- |
| python src/report.py # ไฝฟ็จ้ป่ฎค config.yaml |
| python src/report.py --config config.yaml # ๆพๅผๆๅฎ้
็ฝฎๆไปถ |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import yaml |
|
|
| |
| from src.model import DQNNetwork, DuelingDQNNetwork |
| from src.train import run_evaluation |
|
|
|
|
| |
| |
| |
|
|
| def _algo_from_path(pth: Path) -> str: |
| """ไป best_model_train_<algo>.pth ไธญๆๅ <algo>ใ""" |
| stem = pth.stem |
| prefix = "best_model_train_" |
| if stem.startswith(prefix): |
| return stem[len(prefix):] |
| return stem |
|
|
|
|
| |
| |
| |
|
|
| def build_report(config_path: str = "config.yaml") -> None: |
| """ๆซๆ results/๏ผ่ฏไผฐๆๆ็ฎๆณ๏ผ่พๅบๅฏนๆฏๆฅๅใ""" |
|
|
| |
| cfg_file = Path(config_path) |
| if not cfg_file.exists(): |
| print(f"[WARN] ้
็ฝฎๆไปถๆชๆพๅฐ๏ผ{cfg_file}๏ผไฝฟ็จๅ
็ฝฎ้ป่ฎคๅผใ") |
| cfg = {} |
| else: |
| cfg = yaml.safe_load(cfg_file.read_text(encoding="utf-8")) |
|
|
| maze_cfg = cfg.get("maze", {}) |
| reward_cfg = cfg.get("rewards", {}) |
| dqn_cfg = cfg.get("dqn", {}) |
|
|
| grid_size = int(maze_cfg.get("grid_size", 10)) |
| obstacle_density = float(maze_cfg.get("obstacle_density", 0.25)) |
| max_steps = int(maze_cfg.get("max_steps", 200)) |
| reward_goal = float(reward_cfg.get("goal", 100.0)) |
| reward_wall_hit = float(reward_cfg.get("wall_hit", -10.0)) |
| reward_step = float(reward_cfg.get("step", -1.0)) |
| seed = int(dqn_cfg.get("seed", 42)) |
| save_dir = str(dqn_cfg.get("save_dir", "results")) |
|
|
| |
| holdout_seed_base = seed + 200000 |
| holdout_seeds = [holdout_seed_base + i for i in range(100)] |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| results_dir = Path(save_dir) |
| pth_files = sorted(results_dir.glob("best_model_train_*.pth")) |
| if not pth_files: |
| print(f"[ERROR] ๅจ {results_dir.resolve()} ไธญๆชๆพๅฐไปปไฝ best_model_train_*.pthใ") |
| print(" ่ฏทๅ
่ฟ่ก python src/train.py ๆ ./pipeline.sh ๅฎๆ่ฎญ็ปใ") |
| sys.exit(1) |
|
|
| |
| rows: list[dict] = [] |
| for pth in pth_files: |
| algo = _algo_from_path(pth) |
| print(f" ่ฏไผฐ [{algo}] {pth.name} โฆ", end=" ", flush=True) |
|
|
| try: |
| ckpt = torch.load(pth, map_location=device, weights_only=True) |
| saved_gs = ckpt.get("grid_size", grid_size) |
| ckpt_algo = ckpt.get("algorithm", algo).strip().lower() |
| NetClass = DuelingDQNNetwork if "dueling" in ckpt_algo else DQNNetwork |
| net = NetClass(grid_size=saved_gs).to(device) |
| net.load_state_dict(ckpt["state_dict"]) |
|
|
| success_rate, spl = run_evaluation( |
| policy_net=net, |
| grid_size=saved_gs, |
| obstacle_density=obstacle_density, |
| max_steps=max_steps, |
| device=device, |
| test_seeds=holdout_seeds, |
| reward_goal=reward_goal, |
| reward_wall_hit=reward_wall_hit, |
| reward_step=reward_step, |
| random_start_goal=False, |
| ) |
| rows.append({ |
| "algo": algo, |
| "success": success_rate, |
| "spl": spl, |
| "episode": ckpt.get("episode", -1), |
| "avg_reward": ckpt.get("avg_reward", float("nan")), |
| }) |
| print(f"Success={success_rate:.1f}% SPL={spl:.3f}") |
| except Exception as exc: |
| print(f"[SKIP] ๅ ่ฝฝๅคฑ่ดฅ๏ผ{exc}") |
|
|
| if not rows: |
| print("[ERROR] ๆฒกๆๆๅๅ ่ฝฝไปปไฝๆจกๅ๏ผๆฅๅ็ๆไธญๆญขใ") |
| sys.exit(1) |
|
|
| |
| rows.sort(key=lambda r: r["success"], reverse=True) |
| best = rows[0] |
|
|
| |
| SEP = "=" * 62 |
| HDR = f"{'็ฎๆณ':<18} {'ๆๅ็':>6} {'SPL':>6} {'ไฟๅญEpisode':>11} {'่ฎญ็ปAvgReward':>13}" |
| lines = [ |
| SEP, |
| " DQN ็ฎๆณๅฏนๆฏๆฅๅ๏ผHoldout Test๏ผ100 ๅผ ็ฌ็ซๅฐๅพ๏ผ", |
| SEP, |
| HDR, |
| ] |
| for r in rows: |
| lines.append( |
| f"{r['algo']:<18} {r['success']:>5.1f}% {r['spl']:>6.3f}" |
| f" {r['episode']:>11d} {r['avg_reward']:>13.1f}" |
| ) |
| lines += [ |
| SEP, |
| f"ๆไผ็ฎๆณ๏ผ{best['algo']}๏ผHoldout ๆๅ็ {best['success']:.1f}%๏ผ", |
| ] |
|
|
| |
| print() |
| for line in lines: |
| print(line) |
|
|
| |
| reports_dir = Path("reports") |
| reports_dir.mkdir(exist_ok=True) |
| md_path = reports_dir / "comparison.md" |
|
|
| md_rows_header = "| ็ฎๆณ | ๆๅ็ | SPL | ไฟๅญEpisode | ่ฎญ็ปAvgReward |" |
| md_rows_sep = "|------|-------:|----:|------------:|--------------:|" |
| md_data_rows = [ |
| f"| {r['algo']} | {r['success']:.1f}% | {r['spl']:.3f}" |
| f" | {r['episode']} | {r['avg_reward']:.1f} |" |
| for r in rows |
| ] |
| md_content = "\n".join([ |
| "# DQN ็ฎๆณๅฏนๆฏๆฅๅ", |
| "", |
| "> Holdout Test๏ผ100 ๅผ ็ฌ็ซๅฐๅพ๏ผseed+200000๏ผ๏ผๆดไธช่ฎญ็ป่ฟ็จไธญไปๆชไฝฟ็จใ", |
| "", |
| md_rows_header, |
| md_rows_sep, |
| *md_data_rows, |
| "", |
| f"**ๆไผ็ฎๆณ๏ผ{best['algo']}**๏ผHoldout ๆๅ็ {best['success']:.1f}%๏ผ", |
| "", |
| ]) |
| md_path.write_text(md_content, encoding="utf-8") |
| print(f"ๆฅๅๅทฒไฟๅญ่ณ๏ผ{md_path.resolve()}") |
| print(SEP + "\n") |
|
|
|
|
| |
| |
| |
|
|
| def _parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="DQN ็ฎๆณๅฏนๆฏๆฅๅ็ๆๅจ") |
| parser.add_argument( |
| "--config", type=str, default="config.yaml", |
| help="YAML ้
็ฝฎๆไปถ่ทฏๅพ๏ผ้ป่ฎค๏ผconfig.yaml๏ผ", |
| ) |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = _parse_args() |
| |
| cfg_path = Path(args.config) |
| if not cfg_path.is_absolute(): |
| candidates = [cfg_path, Path(__file__).resolve().parent.parent / cfg_path] |
| for c in candidates: |
| if c.exists(): |
| cfg_path = c |
| break |
| build_report(config_path=str(cfg_path)) |
|
|