interview / src /report.py
Lee93whut
fix: eliminate infinite-loop risk in maze start/goal sampling
10926f0
"""report.py โ€”โ€” DQN ็ฎ—ๆณ•ๆจชๅ‘ๅฏนๆฏ”ๆŠฅๅ‘Š็”Ÿๆˆๅ™จ
่Œ่ดฃ
----
ๆ‰ซๆ ``results/`` ็›ฎๅฝ•ไธ‹ๆ‰€ๆœ‰ ``best_model_train_*.pth``๏ผŒๅฏนๆฏไธช็ฎ—ๆณ•๏ผš
1. ๅŠ ่ฝฝ checkpoint๏ผˆ่‡ชๅŠจ่ฏ†ๅˆซ DQNNetwork / DuelingDQNNetwork๏ผ‰
2. ๅœจ Holdout ้›†๏ผˆseed + 200000๏ผŒๅ…ฑ 100 ๅผ ไปŽๆœชๅ‚ไธŽ่ฎญ็ปƒ็š„ๅœฐๅ›พ๏ผ‰ไธŠ่ฟ่กŒ่ฏ„ไผฐ
3. ๆฑ‡ๆ€ปๆˆๅŠŸ็އใ€SPLใ€ไฟๅญ˜ Episodeใ€่ฎญ็ปƒ AvgReward๏ผŒ่พ“ๅ‡บๅฏนๆฏ”่กจๆ ผ
่พ“ๅ‡บ
----
* ็ปˆ็ซฏๆ‰“ๅฐๅฏนๆฏ”่กจๆ ผ
* ไฟๅญ˜ ``reports/comparison.md``
็”จๆณ•
----
python src/report.py # ไฝฟ็”จ้ป˜่ฎค config.yaml
python src/report.py --config config.yaml # ๆ˜พๅผๆŒ‡ๅฎš้…็ฝฎๆ–‡ไปถ
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
import torch
import yaml
# โ”€โ”€ ้กน็›ฎๅ†…้ƒจๆจกๅ— โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
from src.model import DQNNetwork, DuelingDQNNetwork
from src.train import run_evaluation
# ===========================================================================
# ๅทฅๅ…ท๏ผšไปŽๆ–‡ไปถๅๆๅ–็ฎ—ๆณ•ๆ ‡็ญพ
# ===========================================================================
def _algo_from_path(pth: Path) -> str:
"""ไปŽ best_model_train_<algo>.pth ไธญๆๅ– <algo>ใ€‚"""
stem = pth.stem # e.g. "best_model_train_double_dueling"
prefix = "best_model_train_"
if stem.startswith(prefix):
return stem[len(prefix):]
return stem # ๅ…œๅบ•๏ผšๅŽŸๅง‹ๆ–‡ไปถๅๅŽปๆ‰ฉๅฑ•ๅ
# ===========================================================================
# ไธป้€ป่พ‘
# ===========================================================================
def build_report(config_path: str = "config.yaml") -> None:
"""ๆ‰ซๆ results/๏ผŒ่ฏ„ไผฐๆ‰€ๆœ‰็ฎ—ๆณ•๏ผŒ่พ“ๅ‡บๅฏนๆฏ”ๆŠฅๅ‘Šใ€‚"""
# โ”€โ”€ ่ฏปๅ–้…็ฝฎ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
cfg_file = Path(config_path)
if not cfg_file.exists():
print(f"[WARN] ้…็ฝฎๆ–‡ไปถๆœชๆ‰พๅˆฐ๏ผš{cfg_file}๏ผŒไฝฟ็”จๅ†…็ฝฎ้ป˜่ฎคๅ€ผใ€‚")
cfg = {}
else:
cfg = yaml.safe_load(cfg_file.read_text(encoding="utf-8"))
maze_cfg = cfg.get("maze", {})
reward_cfg = cfg.get("rewards", {})
dqn_cfg = cfg.get("dqn", {})
grid_size = int(maze_cfg.get("grid_size", 10))
obstacle_density = float(maze_cfg.get("obstacle_density", 0.25))
max_steps = int(maze_cfg.get("max_steps", 200))
reward_goal = float(reward_cfg.get("goal", 100.0))
reward_wall_hit = float(reward_cfg.get("wall_hit", -10.0))
reward_step = float(reward_cfg.get("step", -1.0))
seed = int(dqn_cfg.get("seed", 42))
save_dir = str(dqn_cfg.get("save_dir", "results"))
# Holdout ้›†๏ผšseed+200000๏ผŒ100 ๅผ ๅœจๆ•ดไธช่ฎญ็ปƒ่ฟ‡็จ‹ไธญไปŽๆœชๅ‡บ็Žฐ็š„ๅœฐๅ›พ
holdout_seed_base = seed + 200000
holdout_seeds = [holdout_seed_base + i for i in range(100)]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# โ”€โ”€ ๆ‰ซๆ results/ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
results_dir = Path(save_dir)
pth_files = sorted(results_dir.glob("best_model_train_*.pth"))
if not pth_files:
print(f"[ERROR] ๅœจ {results_dir.resolve()} ไธญๆœชๆ‰พๅˆฐไปปไฝ• best_model_train_*.pthใ€‚")
print(" ่ฏทๅ…ˆ่ฟ่กŒ python src/train.py ๆˆ– ./pipeline.sh ๅฎŒๆˆ่ฎญ็ปƒใ€‚")
sys.exit(1)
# โ”€โ”€ ้€็ฎ—ๆณ•่ฏ„ไผฐ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
rows: list[dict] = []
for pth in pth_files:
algo = _algo_from_path(pth)
print(f" ่ฏ„ไผฐ [{algo}] {pth.name} โ€ฆ", end=" ", flush=True)
try:
ckpt = torch.load(pth, map_location=device, weights_only=True)
saved_gs = ckpt.get("grid_size", grid_size)
ckpt_algo = ckpt.get("algorithm", algo).strip().lower()
NetClass = DuelingDQNNetwork if "dueling" in ckpt_algo else DQNNetwork
net = NetClass(grid_size=saved_gs).to(device)
net.load_state_dict(ckpt["state_dict"])
success_rate, spl = run_evaluation(
policy_net=net,
grid_size=saved_gs,
obstacle_density=obstacle_density,
max_steps=max_steps,
device=device,
test_seeds=holdout_seeds,
reward_goal=reward_goal,
reward_wall_hit=reward_wall_hit,
reward_step=reward_step,
random_start_goal=False, # ๆจชๅ‘ๅฏนๆฏ”๏ผšๅ››็ฎ—ๆณ•็ปŸไธ€ๅ›บๅฎš่ตท็ปˆ็‚น่ฏ„ไผฐ๏ผŒๆถˆ้™ค่ตท็ปˆ็‚น้šๆœบๅ™ชๅฃฐ๏ผŒ็กฎไฟๆฏ”่พƒๅ…ฌๅนณ
)
rows.append({
"algo": algo,
"success": success_rate,
"spl": spl,
"episode": ckpt.get("episode", -1),
"avg_reward": ckpt.get("avg_reward", float("nan")),
})
print(f"Success={success_rate:.1f}% SPL={spl:.3f}")
except Exception as exc:
print(f"[SKIP] ๅŠ ่ฝฝๅคฑ่ดฅ๏ผš{exc}")
if not rows:
print("[ERROR] ๆฒกๆœ‰ๆˆๅŠŸๅŠ ่ฝฝไปปไฝ•ๆจกๅž‹๏ผŒๆŠฅๅ‘Š็”Ÿๆˆไธญๆญขใ€‚")
sys.exit(1)
# โ”€โ”€ ๆŽ’ๅบ๏ผšHoldout ๆˆๅŠŸ็އ้™ๅบ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
rows.sort(key=lambda r: r["success"], reverse=True)
best = rows[0]
# โ”€โ”€ ๆ ผๅผๅŒ–่กจๆ ผ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
SEP = "=" * 62
HDR = f"{'็ฎ—ๆณ•':<18} {'ๆˆๅŠŸ็އ':>6} {'SPL':>6} {'ไฟๅญ˜Episode':>11} {'่ฎญ็ปƒAvgReward':>13}"
lines = [
SEP,
" DQN ็ฎ—ๆณ•ๅฏนๆฏ”ๆŠฅๅ‘Š๏ผˆHoldout Test๏ผŒ100 ๅผ ็‹ฌ็ซ‹ๅœฐๅ›พ๏ผ‰",
SEP,
HDR,
]
for r in rows:
lines.append(
f"{r['algo']:<18} {r['success']:>5.1f}% {r['spl']:>6.3f}"
f" {r['episode']:>11d} {r['avg_reward']:>13.1f}"
)
lines += [
SEP,
f"ๆœ€ไผ˜็ฎ—ๆณ•๏ผš{best['algo']}๏ผˆHoldout ๆˆๅŠŸ็އ {best['success']:.1f}%๏ผ‰",
]
# โ”€โ”€ ็ปˆ็ซฏ่พ“ๅ‡บ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
print()
for line in lines:
print(line)
# โ”€โ”€ Markdown ๆŠฅๅ‘Š โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
reports_dir = Path("reports")
reports_dir.mkdir(exist_ok=True)
md_path = reports_dir / "comparison.md"
md_rows_header = "| ็ฎ—ๆณ• | ๆˆๅŠŸ็އ | SPL | ไฟๅญ˜Episode | ่ฎญ็ปƒAvgReward |"
md_rows_sep = "|------|-------:|----:|------------:|--------------:|"
md_data_rows = [
f"| {r['algo']} | {r['success']:.1f}% | {r['spl']:.3f}"
f" | {r['episode']} | {r['avg_reward']:.1f} |"
for r in rows
]
md_content = "\n".join([
"# DQN ็ฎ—ๆณ•ๅฏนๆฏ”ๆŠฅๅ‘Š",
"",
"> Holdout Test๏ผš100 ๅผ ็‹ฌ็ซ‹ๅœฐๅ›พ๏ผˆseed+200000๏ผ‰๏ผŒๆ•ดไธช่ฎญ็ปƒ่ฟ‡็จ‹ไธญไปŽๆœชไฝฟ็”จใ€‚",
"",
md_rows_header,
md_rows_sep,
*md_data_rows,
"",
f"**ๆœ€ไผ˜็ฎ—ๆณ•๏ผš{best['algo']}**๏ผˆHoldout ๆˆๅŠŸ็އ {best['success']:.1f}%๏ผ‰",
"",
])
md_path.write_text(md_content, encoding="utf-8")
print(f"ๆŠฅๅ‘Šๅทฒไฟๅญ˜่‡ณ๏ผš{md_path.resolve()}")
print(SEP + "\n")
# ===========================================================================
# ๅ…ฅๅฃ
# ===========================================================================
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="DQN ็ฎ—ๆณ•ๅฏนๆฏ”ๆŠฅๅ‘Š็”Ÿๆˆๅ™จ")
parser.add_argument(
"--config", type=str, default="config.yaml",
help="YAML ้…็ฝฎๆ–‡ไปถ่ทฏๅพ„๏ผˆ้ป˜่ฎค๏ผšconfig.yaml๏ผ‰",
)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
# ๆ”ฏๆŒไปŽ้กน็›ฎๆ น็›ฎๅฝ•ๆˆ– src/ ็›ฎๅฝ•่ฐƒ็”จๆ—ถ้ƒฝ่ƒฝๆ‰พๅˆฐ config.yaml
cfg_path = Path(args.config)
if not cfg_path.is_absolute():
candidates = [cfg_path, Path(__file__).resolve().parent.parent / cfg_path]
for c in candidates:
if c.exists():
cfg_path = c
break
build_report(config_path=str(cfg_path))