ImaghT's picture
Update Unit_3.py
ad5646f verified
# ============================================================
# Unit 3 - DQN 训练 Space Invaders (太空侵略者)
# ============================================================
# --- 导入必要的库 ---
import gymnasium as gym # OpenAI Gym 的新版本,用于强化学习环境
import torch # PyTorch 深度学习框架
import numpy as np # 数值计算库
import ale_py # 🔥 Atari Learning Environment - 提供 Atari 游戏环境
# 🔥 关键步骤:显式注册 ALE 环境到 gymnasium
# Gymnasium 1.x 不再自动注册 ALE 环境,必须手动调用
gym.register_envs(ale_py)
from stable_baselines3 import DQN # DQN 算法实现
from stable_baselines3.common.env_util import make_atari_env # Atari 环境创建工具
from stable_baselines3.common.vec_env import VecFrameStack # 帧堆叠包装器(将连续4帧堆叠作为状态)
from huggingface_sb3 import package_to_hub # 用于上传模型到 Hugging Face Hub
import time # 计时工具
# ============================================================
# 1. 检查 GPU 是否可用
# ============================================================
print(f"Checking GPU: {torch.cuda.is_available()}")
if torch.cuda.is_available():
# 如果 CUDA 可用,打印 GPU 型号
print(f"Device Name: {torch.cuda.get_device_name(0)}")
else:
# 如果没有 GPU,训练会非常慢(DQN 需要大量计算)
print("❌ WARNING: CUDA not available! Training will be painfully slow.")
# ============================================================
# 2. 虚拟显示(可选,如果需要渲染可以取消注释)
# ============================================================
# 在无头服务器(如 WSL2)上运行时,需要虚拟显示来渲染游戏画面
# 如果你的环境有 GUI,可以跳过这部分
# from pyvirtualdisplay import Display
# print("Starting Virtual Display...")
# try:
# virtual_display = Display(visible=0, size=(1400, 900))
# virtual_display.start()
# print("Virtual Display Started.")
# except Exception as e:
# print(f"❌ Display Error: {e}")
# # 如果报错,WSL2 用户需要安装: sudo apt-get install xvfb
# ============================================================
# 3. 创建训练环境
# ============================================================
print("Creating Atari Environment (SpaceInvaders)...")
# 环境 ID(新版本格式)
# "ALE/SpaceInvaders-v5" 是 Space Invaders 游戏
# ALE = Arcade Learning Environment (街机学习环境)
env_id = "ALE/SpaceInvaders-v5"
# 创建 6 个并行环境(加速数据收集)
# n_envs=6: 同时运行 6 个游戏副本
# seed=0: 随机种子,保证可复现性
env = make_atari_env(env_id, n_envs=6, seed=0)
# 帧堆叠:将连续 4 帧堆叠成一个状态
# 为什么需要?单帧图像无法表达速度信息(比如子弹移动方向)
# n_stack=4: 每个状态包含 4 帧历史信息
env = VecFrameStack(env, n_stack=4)
print("Environment Ready.")
# ============================================================
# 4. 定义 DQN 模型
# ============================================================
print("Initializing DQN Model...")
model = DQN(
# 策略网络类型:CnnPolicy = 卷积神经网络策略(适合图像输入)
policy="CnnPolicy",
# 训练环境
env=env,
# --- 训练超参数 ---
batch_size=32, # 每次从经验回放缓冲区采样 64 个样本进行训练
buffer_size=200000, # 经验回放缓冲区大小(存储最近的 20万 个经验)
# 正式训练可以调到 300,000-1,000,000
optimize_memory_usage=False, # 优化内存使用(减少 RAM 消耗)
# 🔥 关键参数:控制 GPU 何时介入
learning_starts=10000, # 收集 10000 步经验后开始训练
# 默认是 50000,改小后 GPU 更快开始工作
# 太小会导致早期经验质量差
# --- Q-Learning 核心参数 ---
learning_rate=1e-4, # 学习率(Adam 优化器)
gamma=0.99, # 折扣因子(未来奖励的权重,0.99 表示重视长期奖励)
target_update_interval=1000, # 每 1000 步更新一次目标网络(DQN 的稳定技巧)
train_freq=4, # 每收集 4 步经验就训练一次
gradient_steps=1, # 每次训练执行 1 步梯度下降
# --- 探索策略(ε-greedy)---
exploration_fraction=0.1, # 前 10% 的训练步数用于从 ε=1 衰减到 ε_final
exploration_final_eps=0.01, # 最终探索率(1% 的随机动作)
# --- 日志和设备 ---
verbose=1, # 打印训练进度
tensorboard_log="./board/", # TensorBoard 日志目录(可视化训练曲线)
device="cuda" # 使用 GPU 加速训练
)
print("Model Initialized.")
# ============================================================
# 5. 开始训练
# ============================================================
print("Start Training Loop... (Wait for 'learning_starts' steps)")
start_time = time.time()
# 训练 200,0000 步(约 1-2 小时,取决于 GPU)
# total_timesteps: 总训练步数(不是 episode 数)
# log_interval: 每 5000 次更新打印一次日志
model.learn(total_timesteps=10000000, log_interval=50000)
print(f"Training finished in {time.time() - start_time:.2f} seconds.")
# ============================================================
# 6. 保存模型
# ============================================================
model_save_name = "dqn-SpaceInvaders"
model.save(model_save_name) # 保存为 dqn-SpaceInvaders.zip
print("Training Finished!")