File size: 5,932 Bytes
1b99688 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# ============================================================
# Unit 3 - DQN 训练 Space Invaders (太空侵略者)
# ============================================================
# --- 导入必要的库 ---
import gymnasium as gym # OpenAI Gym 的新版本,用于强化学习环境
import torch # PyTorch 深度学习框架
import numpy as np # 数值计算库
import ale_py # 🔥 Atari Learning Environment - 提供 Atari 游戏环境
# 🔥 关键步骤:显式注册 ALE 环境到 gymnasium
# Gymnasium 1.x 不再自动注册 ALE 环境,必须手动调用
gym.register_envs(ale_py)
from stable_baselines3 import DQN # DQN 算法实现
from stable_baselines3.common.env_util import make_atari_env # Atari 环境创建工具
from stable_baselines3.common.vec_env import VecFrameStack # 帧堆叠包装器(将连续4帧堆叠作为状态)
from huggingface_sb3 import package_to_hub # 用于上传模型到 Hugging Face Hub
import time # 计时工具
# ============================================================
# 1. 检查 GPU 是否可用
# ============================================================
print(f"Checking GPU: {torch.cuda.is_available()}")
if torch.cuda.is_available():
# 如果 CUDA 可用,打印 GPU 型号
print(f"Device Name: {torch.cuda.get_device_name(0)}")
else:
# 如果没有 GPU,训练会非常慢(DQN 需要大量计算)
print("❌ WARNING: CUDA not available! Training will be painfully slow.")
# ============================================================
# 2. 虚拟显示(可选,如果需要渲染可以取消注释)
# ============================================================
# 在无头服务器(如 WSL2)上运行时,需要虚拟显示来渲染游戏画面
# 如果你的环境有 GUI,可以跳过这部分
# from pyvirtualdisplay import Display
# print("Starting Virtual Display...")
# try:
# virtual_display = Display(visible=0, size=(1400, 900))
# virtual_display.start()
# print("Virtual Display Started.")
# except Exception as e:
# print(f"❌ Display Error: {e}")
# # 如果报错,WSL2 用户需要安装: sudo apt-get install xvfb
# ============================================================
# 3. 创建训练环境
# ============================================================
print("Creating Atari Environment (SpaceInvaders)...")
# 环境 ID(新版本格式)
# "ALE/SpaceInvaders-v5" 是 Space Invaders 游戏
# ALE = Arcade Learning Environment (街机学习环境)
env_id = "ALE/SpaceInvaders-v5"
# 创建 6 个并行环境(加速数据收集)
# n_envs=6: 同时运行 6 个游戏副本
# seed=0: 随机种子,保证可复现性
env = make_atari_env(env_id, n_envs=6, seed=0)
# 帧堆叠:将连续 4 帧堆叠成一个状态
# 为什么需要?单帧图像无法表达速度信息(比如子弹移动方向)
# n_stack=4: 每个状态包含 4 帧历史信息
env = VecFrameStack(env, n_stack=4)
print("Environment Ready.")
# ============================================================
# 4. 定义 DQN 模型
# ============================================================
print("Initializing DQN Model...")
model = DQN(
# 策略网络类型:CnnPolicy = 卷积神经网络策略(适合图像输入)
policy="CnnPolicy",
# 训练环境
env=env,
# --- 训练超参数 ---
batch_size=32, # 每次从经验回放缓冲区采样 64 个样本进行训练
buffer_size=200000, # 经验回放缓冲区大小(存储最近的 20万 个经验)
# 正式训练可以调到 300,000-1,000,000
optimize_memory_usage=False, # 优化内存使用(减少 RAM 消耗)
# 🔥 关键参数:控制 GPU 何时介入
learning_starts=10000, # 收集 10000 步经验后开始训练
# 默认是 50000,改小后 GPU 更快开始工作
# 太小会导致早期经验质量差
# --- Q-Learning 核心参数 ---
learning_rate=1e-4, # 学习率(Adam 优化器)
gamma=0.99, # 折扣因子(未来奖励的权重,0.99 表示重视长期奖励)
target_update_interval=1000, # 每 1000 步更新一次目标网络(DQN 的稳定技巧)
train_freq=4, # 每收集 4 步经验就训练一次
gradient_steps=1, # 每次训练执行 1 步梯度下降
# --- 探索策略(ε-greedy)---
exploration_fraction=0.1, # 前 10% 的训练步数用于从 ε=1 衰减到 ε_final
exploration_final_eps=0.01, # 最终探索率(1% 的随机动作)
# --- 日志和设备 ---
verbose=1, # 打印训练进度
tensorboard_log="./board/", # TensorBoard 日志目录(可视化训练曲线)
device="cuda" # 使用 GPU 加速训练
)
print("Model Initialized.")
# ============================================================
# 5. 开始训练
# ============================================================
print("Start Training Loop... (Wait for 'learning_starts' steps)")
start_time = time.time()
# 训练 200,0000 步(约 1-2 小时,取决于 GPU)
# total_timesteps: 总训练步数(不是 episode 数)
# log_interval: 每 5000 次更新打印一次日志
model.learn(total_timesteps=10000000, log_interval=50000)
print(f"Training finished in {time.time() - start_time:.2f} seconds.")
# ============================================================
# 6. 保存模型
# ============================================================
model_save_name = "dqn-SpaceInvaders"
model.save(model_save_name) # 保存为 dqn-SpaceInvaders.zip
print("Training Finished!")
|