chore: codebase hygiene pass — untrack weights, migrate to logging, tidy comments
Browse filesThree independent cleanups consolidated into one commit since they were
all part of the same one-shot editing pass and share the theme
"non-functional codebase tightening":
1) Stop tracking model weights in git
- .gitignore: add `results/`, update header comment to
"runs/logs/weights 不进版本控制"
- README §quickstart already names HF Space as the canonical weight
store ("权重文件统一存放于 HF Spaces lil58/interview") and
download_weights.py is idempotent for re-pulls, so git was just
duplicating state
- After this change: code/configs/docs in git, .pth files in HF Space
only, `python download_weights.py` populates results/ on a fresh
clone
2) Migrate train.py print() to Python logging module
- new _setup_logging() helper with LOG_LEVEL env-var override
- module-level logger = _setup_logging() at module load
- replace ~25 print() calls with logger.info() / logger.warning()
(overfit-assertion failure path goes to .warning() so a real
failure is structurally distinct from info output)
- docs/experiment_log.md §R4-A3 inline code snippet synced
(logger.info, not print)
- Visible output unchanged at default INFO level; enables
LOG_LEVEL=WARNING to mute, and caplog-based test capture
3) Remove duplicate section-divider comment in app.py
- "右栏:主画布" divider block was duplicated back-to-back from a
previous copy-paste glitch; pure comment cleanup
4 files changed: 63 insertions / 40 deletions.
Co-Authored-By: Lee93whut <30529279@qq.com>
- .gitignore +2 -1
- app.py +0 -3
- docs/experiment_log.md +1 -1
- src/train.py +60 -35
|
@@ -23,9 +23,10 @@ coverage.xml
|
|
| 23 |
*.xml
|
| 24 |
.pytest_cache/
|
| 25 |
|
| 26 |
-
# ── 训练产物(runs/logs 不进版本控制
|
| 27 |
runs/
|
| 28 |
logs/
|
|
|
|
| 29 |
|
| 30 |
# ── IDE ───────────────────────────────────────────────────────────────────────
|
| 31 |
.idea/
|
|
|
|
| 23 |
*.xml
|
| 24 |
.pytest_cache/
|
| 25 |
|
| 26 |
+
# ── 训练产物(runs/logs/weights 不进版本控制)──────────────────────────────
|
| 27 |
runs/
|
| 28 |
logs/
|
| 29 |
+
results/
|
| 30 |
|
| 31 |
# ── IDE ───────────────────────────────────────────────────────────────────────
|
| 32 |
.idea/
|
|
@@ -690,9 +690,6 @@ def main() -> None:
|
|
| 690 |
st.error(f"❌ 未找到 {_cur_path.name}")
|
| 691 |
st.info(f"请先运行 `python src/train.py --algorithm {_cur_algo}` 训练模型。")
|
| 692 |
|
| 693 |
-
# ───────────────────────────────────────────────────────────────────────
|
| 694 |
-
# 右栏:主画布
|
| 695 |
-
# ───────────────────────────────────────────────────────────────────────
|
| 696 |
# ───────────────────────────────────────────────────────────────────────
|
| 697 |
# 右栏:主画布
|
| 698 |
# ───────────────────────────────────────────────────────────────────────
|
|
|
|
| 690 |
st.error(f"❌ 未找到 {_cur_path.name}")
|
| 691 |
st.info(f"请先运行 `python src/train.py --algorithm {_cur_algo}` 训练模型。")
|
| 692 |
|
|
|
|
|
|
|
|
|
|
| 693 |
# ───────────────────────────────────────────────────────────────────────
|
| 694 |
# 右栏:主画布
|
| 695 |
# ───────────────────────────────────────────────────────────────────────
|
|
@@ -779,7 +779,7 @@ best_eval_success = float("-inf")
|
|
| 779 |
if not in_warmup and test_success_rate > best_eval_success:
|
| 780 |
best_eval_success = test_success_rate
|
| 781 |
torch.save({"state_dict": policy_net.state_dict(), ...}, best_model_path)
|
| 782 |
-
|
| 783 |
# 训练奖励保存块保留 ✓ 标记,不再写入权重
|
| 784 |
```
|
| 785 |
|
|
|
|
| 779 |
if not in_warmup and test_success_rate > best_eval_success:
|
| 780 |
best_eval_success = test_success_rate
|
| 781 |
torch.save({"state_dict": policy_net.state_dict(), ...}, best_model_path)
|
| 782 |
+
logger.info(f" [EVAL SAVE] EVAL 新高 {best_eval_success:.1f}%")
|
| 783 |
# 训练奖励保存块保留 ✓ 标记,不再写入权重
|
| 784 |
```
|
| 785 |
|
|
@@ -29,8 +29,10 @@ python src/train.py --config config.yaml --overfit
|
|
| 29 |
from __future__ import annotations
|
| 30 |
|
| 31 |
import argparse
|
|
|
|
| 32 |
import os
|
| 33 |
import random
|
|
|
|
| 34 |
import time
|
| 35 |
from collections import deque
|
| 36 |
from pathlib import Path
|
|
@@ -47,6 +49,29 @@ from torch.utils.tensorboard import SummaryWriter
|
|
| 47 |
# benchmark 实测:8线程 13.6s vs 16线程 528s(0.03x),4线程约快 2-3x
|
| 48 |
torch.set_num_threads(4)
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# ── 项目内部模块 ──────────────────────────────────────────────────────────────
|
| 51 |
# maze_env 通过 `pip install -e .` 安装,可直接 import。
|
| 52 |
# src/ 通过 pyproject.toml packages.find 配置,同样作为包安装,可直接 import。
|
|
@@ -383,9 +408,9 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 383 |
"num_test_mazes": ov.get("num_test_mazes", 10),
|
| 384 |
})
|
| 385 |
run_tag = f"overfit_5x5_{algorithm}"
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
else:
|
| 390 |
run_tag = f"train_{algorithm}"
|
| 391 |
|
|
@@ -427,14 +452,14 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 427 |
# ── Seed Lock ────────────────────────────────────────────────────────
|
| 428 |
set_seed(seed)
|
| 429 |
|
| 430 |
-
|
| 431 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 432 |
-
|
| 433 |
f"Episodes {num_episodes} | Seed {seed}")
|
| 434 |
-
|
| 435 |
f"Net={'Dueling' if use_dueling else 'Vanilla'} | "
|
| 436 |
f"Target={'Double' if use_double else 'Vanilla'}")
|
| 437 |
-
|
| 438 |
|
| 439 |
# ── 环境(训练用)──────────────────────────────────────────────────────
|
| 440 |
# 正常训练:不传 seed,每局 reset() 使用 Gymnasium 内部 RNG 续进,
|
|
@@ -467,7 +492,7 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 467 |
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 468 |
writer_dir = os.path.join(log_dir, f"{run_tag}_{timestamp}")
|
| 469 |
writer = SummaryWriter(log_dir=writer_dir)
|
| 470 |
-
|
| 471 |
|
| 472 |
# ── 保存目录 ───────────────────────────────────────────────────────────
|
| 473 |
os.makedirs(save_dir, exist_ok=True)
|
|
@@ -484,10 +509,10 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 484 |
global_update_steps = 0 # Backend_Net/ 横坐标
|
| 485 |
total_env_steps = 0 # 全局环境交互步数(用于 Target Net 同步)
|
| 486 |
|
| 487 |
-
|
| 488 |
-
|
| 489 |
f"{'Loss':>8} {'AvgQ':>7} {'Suc%':>6} {'BestR':>8}")
|
| 490 |
-
|
| 491 |
|
| 492 |
# =========================================================
|
| 493 |
# 主训练循环
|
|
@@ -615,7 +640,7 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 615 |
)
|
| 616 |
writer.add_scalar("Evaluation_Exam/Test_Success_Rate", test_success_rate, episode)
|
| 617 |
writer.add_scalar("Evaluation_Exam/SPL", test_spl, episode)
|
| 618 |
-
|
| 619 |
f"Test_Success={test_success_rate:.1f}% "
|
| 620 |
f"SPL={test_spl:.3f} "
|
| 621 |
f"(越接近 1.0 越好,失败局贡献 0)")
|
|
@@ -637,7 +662,7 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 637 |
},
|
| 638 |
best_model_path,
|
| 639 |
)
|
| 640 |
-
|
| 641 |
|
| 642 |
# ── Best Model Save(训练奖励,仅用于控制台 ✓ 标记,不再保存权重)────
|
| 643 |
# 权重保存已移至 EVAL-based checkpoint(见上方 EVAL 块)。
|
|
@@ -653,13 +678,13 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 653 |
# 每 20 行数据前重打一次表头,方便在长日志中快速定位列含义
|
| 654 |
_rows_printed = (episode // print_every)
|
| 655 |
if episode == 1 or _rows_printed % 20 == 0:
|
| 656 |
-
|
| 657 |
-
|
| 658 |
f"{'Loss':>8} {'AvgQ':>7} {'Suc%':>6} {'BestR':>8}")
|
| 659 |
-
|
| 660 |
warmup_flag = " [WARMUP]" if in_warmup else ""
|
| 661 |
saved_flag = " ✓" if model_saved else ""
|
| 662 |
-
|
| 663 |
f"{episode:>6d} "
|
| 664 |
f"{ep_reward:>8.1f} "
|
| 665 |
f"{ep_steps:>6d} "
|
|
@@ -672,22 +697,22 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 672 |
|
| 673 |
# ── 训练结束 ──────────────────────────────────────────────────────────
|
| 674 |
writer.close()
|
| 675 |
-
|
| 676 |
-
|
| 677 |
f"{global_update_steps} 梯度步。")
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
|
| 684 |
# ── Holdout Test:训练后一次性最终评估(仅正常训练模式执行)─────────────
|
| 685 |
# Holdout 地图(seed+200000)在整个训练过程中从未使用,
|
| 686 |
# 是唯一可对外报告的无偏泛化性能数字。
|
| 687 |
if not overfit_mode and os.path.exists(best_model_path):
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
holdout_seed_base = seed + 200000
|
| 692 |
holdout_seeds = [holdout_seed_base + i for i in range(100)]
|
| 693 |
|
|
@@ -708,10 +733,10 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 708 |
reward_step=reward_step_r,
|
| 709 |
random_start_goal=random_start_goal,
|
| 710 |
)
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
|
| 716 |
# ── 过拟合模式验收断言 ─────────────────────────────────────────────────
|
| 717 |
if overfit_mode:
|
|
@@ -730,12 +755,12 @@ def train(cfg: dict[str, Any], overfit_mode: bool = False) -> None:
|
|
| 730 |
reward_step=reward_step_r,
|
| 731 |
random_start_goal=False, # overfit 模式始终固定起终点
|
| 732 |
)
|
| 733 |
-
|
| 734 |
f"{final_success_rate:.1f}% SPL={final_spl:.3f}")
|
| 735 |
if final_success_rate >= 80.0:
|
| 736 |
-
|
| 737 |
else:
|
| 738 |
-
|
| 739 |
|
| 740 |
|
| 741 |
# ===========================================================================
|
|
@@ -785,6 +810,6 @@ if __name__ == "__main__": # pragma: no cover
|
|
| 785 |
if args.algorithm is not None:
|
| 786 |
key = "overfit" if overfit_mode else "dqn"
|
| 787 |
cfg.setdefault(key, {})["algorithm"] = args.algorithm
|
| 788 |
-
|
| 789 |
|
| 790 |
train(cfg, overfit_mode=overfit_mode)
|
|
|
|
| 29 |
from __future__ import annotations
|
| 30 |
|
| 31 |
import argparse
|
| 32 |
+
import logging
|
| 33 |
import os
|
| 34 |
import random
|
| 35 |
+
import sys
|
| 36 |
import time
|
| 37 |
from collections import deque
|
| 38 |
from pathlib import Path
|
|
|
|
| 49 |
# benchmark 实测:8线程 13.6s vs 16线程 528s(0.03x),4线程约快 2-3x
|
| 50 |
torch.set_num_threads(4)
|
| 51 |
|
| 52 |
+
# ── 日志配置 ─────────────────────────────────────────────────────────────────
|
| 53 |
+
def _setup_logging(level: int = logging.INFO) -> logging.Logger:
|
| 54 |
+
"""配置模块级 logger,输出到控制台。
|
| 55 |
+
|
| 56 |
+
日志格式:时间戳 | 级别 | 消息
|
| 57 |
+
可通过环境变量 LOG_LEVEL 覆盖默认级别(例:export LOG_LEVEL=DEBUG)
|
| 58 |
+
"""
|
| 59 |
+
env_level = os.environ.get("LOG_LEVEL", "").upper()
|
| 60 |
+
if env_level in logging._levelToName.values(): # type: ignore[attr-defined]
|
| 61 |
+
level = getattr(logging, env_level, level)
|
| 62 |
+
|
| 63 |
+
logging.basicConfig(
|
| 64 |
+
level=level,
|
| 65 |
+
format="%(asctime)s | %(levelname)-7s | %(message)s",
|
| 66 |
+
datefmt="%H:%M:%S",
|
| 67 |
+
stream=sys.stdout,
|
| 68 |
+
)
|
| 69 |
+
logger = logging.getLogger("train")
|
| 70 |
+
return logger
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
logger = _setup_logging()
|
| 74 |
+
|
| 75 |
# ── 项目内部模块 ──────────────────────────────────────────────────────────────
|
| 76 |
# maze_env 通过 `pip install -e .` 安装,可直接 import。
|
| 77 |
# src/ 通过 pyproject.toml packages.find 配置,同样作为包安装,可直接 import。
|
|
|
|
| 408 |
"num_test_mazes": ov.get("num_test_mazes", 10),
|
| 409 |
})
|
| 410 |
run_tag = f"overfit_5x5_{algorithm}"
|
| 411 |
+
logger.info("=" * 60)
|
| 412 |
+
logger.info(" [OVERFIT MODE] 5×5 超小迷宫过拟合调试")
|
| 413 |
+
logger.info("=" * 60)
|
| 414 |
else:
|
| 415 |
run_tag = f"train_{algorithm}"
|
| 416 |
|
|
|
|
| 452 |
# ── Seed Lock ────────────────────────────────────────────────────────
|
| 453 |
set_seed(seed)
|
| 454 |
|
| 455 |
+
# ── 设备 ───────────────────────────────────────────────────────────────
|
| 456 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 457 |
+
logger.info(f"[Device] {device} | Grid {grid_size}×{grid_size} | "
|
| 458 |
f"Episodes {num_episodes} | Seed {seed}")
|
| 459 |
+
logger.info(f"[Algorithm] {algorithm.upper()} | "
|
| 460 |
f"Net={'Dueling' if use_dueling else 'Vanilla'} | "
|
| 461 |
f"Target={'Double' if use_double else 'Vanilla'}")
|
| 462 |
+
logger.info(f"[Warmup] 前 {warmup_episodes} 局纯随机探索,不执行梯度更新")
|
| 463 |
|
| 464 |
# ── 环境(训练用)──────────────────────────────────────────────────────
|
| 465 |
# 正常训练:不传 seed,每局 reset() 使用 Gymnasium 内部 RNG 续进,
|
|
|
|
| 492 |
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 493 |
writer_dir = os.path.join(log_dir, f"{run_tag}_{timestamp}")
|
| 494 |
writer = SummaryWriter(log_dir=writer_dir)
|
| 495 |
+
logger.info(f"[TensorBoard] tensorboard --logdir={log_dir}")
|
| 496 |
|
| 497 |
# ── 保存目录 ───────────────────────────────────────────────────────────
|
| 498 |
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
| 509 |
global_update_steps = 0 # Backend_Net/ 横坐标
|
| 510 |
total_env_steps = 0 # 全局环境交互步数(用于 Target Net 同步)
|
| 511 |
|
| 512 |
+
logger.info(f"\n{'─'*70}")
|
| 513 |
+
logger.info(f"{'Ep':>6} {'Reward':>8} {'Steps':>6} {'Eps':>7} "
|
| 514 |
f"{'Loss':>8} {'AvgQ':>7} {'Suc%':>6} {'BestR':>8}")
|
| 515 |
+
logger.info(f"{'─'*70}")
|
| 516 |
|
| 517 |
# =========================================================
|
| 518 |
# 主训练循环
|
|
|
|
| 640 |
)
|
| 641 |
writer.add_scalar("Evaluation_Exam/Test_Success_Rate", test_success_rate, episode)
|
| 642 |
writer.add_scalar("Evaluation_Exam/SPL", test_spl, episode)
|
| 643 |
+
logger.info(f" [EVAL ep={episode:4d}] "
|
| 644 |
f"Test_Success={test_success_rate:.1f}% "
|
| 645 |
f"SPL={test_spl:.3f} "
|
| 646 |
f"(越接近 1.0 越好,失败局贡献 0)")
|
|
|
|
| 662 |
},
|
| 663 |
best_model_path,
|
| 664 |
)
|
| 665 |
+
logger.info(f" [EVAL SAVE] EVAL 新高 {best_eval_success:.1f}% → 已保存 {best_model_path}")
|
| 666 |
|
| 667 |
# ── Best Model Save(训练奖励,仅用于控制台 ✓ 标记,不再保存权重)────
|
| 668 |
# 权重保存已移至 EVAL-based checkpoint(见上方 EVAL 块)。
|
|
|
|
| 678 |
# 每 20 行数据前重打一次表头,方便在长日志中快速定位列含义
|
| 679 |
_rows_printed = (episode // print_every)
|
| 680 |
if episode == 1 or _rows_printed % 20 == 0:
|
| 681 |
+
logger.info(f"{'─'*70}")
|
| 682 |
+
logger.info(f"{'Ep':>6} {'Reward':>8} {'Steps':>6} {'Eps':>7} "
|
| 683 |
f"{'Loss':>8} {'AvgQ':>7} {'Suc%':>6} {'BestR':>8}")
|
| 684 |
+
logger.info(f"{'─'*70}")
|
| 685 |
warmup_flag = " [WARMUP]" if in_warmup else ""
|
| 686 |
saved_flag = " ✓" if model_saved else ""
|
| 687 |
+
logger.info(
|
| 688 |
f"{episode:>6d} "
|
| 689 |
f"{ep_reward:>8.1f} "
|
| 690 |
f"{ep_steps:>6d} "
|
|
|
|
| 697 |
|
| 698 |
# ── 训练结束 ──────────────────────────────────────────────────────────
|
| 699 |
writer.close()
|
| 700 |
+
logger.info(f"\n{'═'*70}")
|
| 701 |
+
logger.info(f" 训练完成。共 {num_episodes} 个 Episode,{total_env_steps} 环境步,"
|
| 702 |
f"{global_update_steps} 梯度步。")
|
| 703 |
+
logger.info(f" Best Avg Reward(近{save_window}局): {best_avg_reward:.2f}")
|
| 704 |
+
logger.info(f" 最终 ε = {epsilon:.4f}")
|
| 705 |
+
logger.info(f" 模型已保存至:{best_model_path}")
|
| 706 |
+
logger.info(f" TensorBoard:tensorboard --logdir={log_dir}")
|
| 707 |
+
logger.info(f"{'═'*70}\n")
|
| 708 |
|
| 709 |
# ── Holdout Test:训练后一次性最终评估(仅正常训练模式执行)─────────────
|
| 710 |
# Holdout 地图(seed+200000)在整个训练过程中从未使用,
|
| 711 |
# 是唯一可对外报告的无偏泛化性能数字。
|
| 712 |
if not overfit_mode and os.path.exists(best_model_path):
|
| 713 |
+
logger.info("=" * 70)
|
| 714 |
+
logger.info(" [HOLDOUT TEST] 加载 best_model.pth,在 100 张全新地图上最终评估")
|
| 715 |
+
logger.info("=" * 70)
|
| 716 |
holdout_seed_base = seed + 200000
|
| 717 |
holdout_seeds = [holdout_seed_base + i for i in range(100)]
|
| 718 |
|
|
|
|
| 733 |
reward_step=reward_step_r,
|
| 734 |
random_start_goal=random_start_goal,
|
| 735 |
)
|
| 736 |
+
logger.info(f" Holdout Success Rate : {holdout_sr:.1f}% (100 张独立地图)")
|
| 737 |
+
logger.info(f" Holdout SPL : {holdout_spl:.3f} (Success-weighted Path Length,越接近 1.0 越好)")
|
| 738 |
+
logger.info(f" ← 此数字为唯一可信的最终泛化性能,可对外报告。")
|
| 739 |
+
logger.info("=" * 70 + "\n")
|
| 740 |
|
| 741 |
# ── 过拟合模式验收断言 ─────────────────────────────────────────────────
|
| 742 |
if overfit_mode:
|
|
|
|
| 755 |
reward_step=reward_step_r,
|
| 756 |
random_start_goal=False, # overfit 模式始终固定起终点
|
| 757 |
)
|
| 758 |
+
logger.info(f"[OVERFIT 验收] 固定地图(seed={overfit_eval_seed})成功率: "
|
| 759 |
f"{final_success_rate:.1f}% SPL={final_spl:.3f}")
|
| 760 |
if final_success_rate >= 80.0:
|
| 761 |
+
logger.info("✅ 过拟合测试通过:Agent 已在 5×5 迷宫上充分收敛。")
|
| 762 |
else:
|
| 763 |
+
logger.warning("⚠️ 过拟合测试未达预期(成功率 < 80%),请检查超参数。")
|
| 764 |
|
| 765 |
|
| 766 |
# ===========================================================================
|
|
|
|
| 810 |
if args.algorithm is not None:
|
| 811 |
key = "overfit" if overfit_mode else "dqn"
|
| 812 |
cfg.setdefault(key, {})["algorithm"] = args.algorithm
|
| 813 |
+
logger.info(f"[CLI] --algorithm 覆盖 config.yaml:algorithm = {args.algorithm}")
|
| 814 |
|
| 815 |
train(cfg, overfit_mode=overfit_mode)
|