| # ============================================================ | |
| # Unit 3 - DQN 训练 Space Invaders (太空侵略者) | |
| # ============================================================ | |
| # --- 导入必要的库 --- | |
| import gymnasium as gym # OpenAI Gym 的新版本,用于强化学习环境 | |
| import torch # PyTorch 深度学习框架 | |
| import numpy as np # 数值计算库 | |
| import ale_py # 🔥 Atari Learning Environment - 提供 Atari 游戏环境 | |
| # 🔥 关键步骤:显式注册 ALE 环境到 gymnasium | |
| # Gymnasium 1.x 不再自动注册 ALE 环境,必须手动调用 | |
| gym.register_envs(ale_py) | |
| from stable_baselines3 import DQN # DQN 算法实现 | |
| from stable_baselines3.common.env_util import make_atari_env # Atari 环境创建工具 | |
| from stable_baselines3.common.vec_env import VecFrameStack # 帧堆叠包装器(将连续4帧堆叠作为状态) | |
| from huggingface_sb3 import package_to_hub # 用于上传模型到 Hugging Face Hub | |
| import time # 计时工具 | |
| # ============================================================ | |
| # 1. 检查 GPU 是否可用 | |
| # ============================================================ | |
| print(f"Checking GPU: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| # 如果 CUDA 可用,打印 GPU 型号 | |
| print(f"Device Name: {torch.cuda.get_device_name(0)}") | |
| else: | |
| # 如果没有 GPU,训练会非常慢(DQN 需要大量计算) | |
| print("❌ WARNING: CUDA not available! Training will be painfully slow.") | |
| # ============================================================ | |
| # 2. 虚拟显示(可选,如果需要渲染可以取消注释) | |
| # ============================================================ | |
| # 在无头服务器(如 WSL2)上运行时,需要虚拟显示来渲染游戏画面 | |
| # 如果你的环境有 GUI,可以跳过这部分 | |
| # from pyvirtualdisplay import Display | |
| # print("Starting Virtual Display...") | |
| # try: | |
| # virtual_display = Display(visible=0, size=(1400, 900)) | |
| # virtual_display.start() | |
| # print("Virtual Display Started.") | |
| # except Exception as e: | |
| # print(f"❌ Display Error: {e}") | |
| # # 如果报错,WSL2 用户需要安装: sudo apt-get install xvfb | |
| # ============================================================ | |
| # 3. 创建训练环境 | |
| # ============================================================ | |
| print("Creating Atari Environment (SpaceInvaders)...") | |
| # 环境 ID(新版本格式) | |
| # "ALE/SpaceInvaders-v5" 是 Space Invaders 游戏 | |
| # ALE = Arcade Learning Environment (街机学习环境) | |
| env_id = "ALE/SpaceInvaders-v5" | |
| # 创建 6 个并行环境(加速数据收集) | |
| # n_envs=6: 同时运行 6 个游戏副本 | |
| # seed=0: 随机种子,保证可复现性 | |
| env = make_atari_env(env_id, n_envs=6, seed=0) | |
| # 帧堆叠:将连续 4 帧堆叠成一个状态 | |
| # 为什么需要?单帧图像无法表达速度信息(比如子弹移动方向) | |
| # n_stack=4: 每个状态包含 4 帧历史信息 | |
| env = VecFrameStack(env, n_stack=4) | |
| print("Environment Ready.") | |
| # ============================================================ | |
| # 4. 定义 DQN 模型 | |
| # ============================================================ | |
| print("Initializing DQN Model...") | |
| model = DQN( | |
| # 策略网络类型:CnnPolicy = 卷积神经网络策略(适合图像输入) | |
| policy="CnnPolicy", | |
| # 训练环境 | |
| env=env, | |
| # --- 训练超参数 --- | |
| batch_size=32, # 每次从经验回放缓冲区采样 64 个样本进行训练 | |
| buffer_size=200000, # 经验回放缓冲区大小(存储最近的 20万 个经验) | |
| # 正式训练可以调到 300,000-1,000,000 | |
| optimize_memory_usage=False, # 优化内存使用(减少 RAM 消耗) | |
| # 🔥 关键参数:控制 GPU 何时介入 | |
| learning_starts=10000, # 收集 10000 步经验后开始训练 | |
| # 默认是 50000,改小后 GPU 更快开始工作 | |
| # 太小会导致早期经验质量差 | |
| # --- Q-Learning 核心参数 --- | |
| learning_rate=1e-4, # 学习率(Adam 优化器) | |
| gamma=0.99, # 折扣因子(未来奖励的权重,0.99 表示重视长期奖励) | |
| target_update_interval=1000, # 每 1000 步更新一次目标网络(DQN 的稳定技巧) | |
| train_freq=4, # 每收集 4 步经验就训练一次 | |
| gradient_steps=1, # 每次训练执行 1 步梯度下降 | |
| # --- 探索策略(ε-greedy)--- | |
| exploration_fraction=0.1, # 前 10% 的训练步数用于从 ε=1 衰减到 ε_final | |
| exploration_final_eps=0.01, # 最终探索率(1% 的随机动作) | |
| # --- 日志和设备 --- | |
| verbose=1, # 打印训练进度 | |
| tensorboard_log="./board/", # TensorBoard 日志目录(可视化训练曲线) | |
| device="cuda" # 使用 GPU 加速训练 | |
| ) | |
| print("Model Initialized.") | |
| # ============================================================ | |
| # 5. 开始训练 | |
| # ============================================================ | |
| print("Start Training Loop... (Wait for 'learning_starts' steps)") | |
| start_time = time.time() | |
| # 训练 200,0000 步(约 1-2 小时,取决于 GPU) | |
| # total_timesteps: 总训练步数(不是 episode 数) | |
| # log_interval: 每 5000 次更新打印一次日志 | |
| model.learn(total_timesteps=10000000, log_interval=50000) | |
| print(f"Training finished in {time.time() - start_time:.2f} seconds.") | |
| # ============================================================ | |
| # 6. 保存模型 | |
| # ============================================================ | |
| model_save_name = "dqn-SpaceInvaders" | |
| model.save(model_save_name) # 保存为 dqn-SpaceInvaders.zip | |
| print("Training Finished!") | |