dl-from-scratch / scripts /gen_dqn_notebook.py
yusiwen's picture
refactor: restructure into domain-grounded directories (ml/ cv/ gen/ graph/ rl/ nlp/)
73526b1 unverified
Raw
History Blame Contribute Delete
4.69 kB
#!/usr/bin/env python3
"""Generate DQN notebook."""
import nbformat as nbf
nb = nbf.v4.new_notebook()
nb.metadata = {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},"language_info": {"name": "python", "version": "3.12.0"}}
cells = []
def md(s): cells.append(nbf.v4.new_markdown_cell(s))
def code(s): cells.append(nbf.v4.new_code_cell(s))
md("# DQN: Deep Q-Network\n\nReinforcement learning with experience replay on CartPole-v1.")
md("""## 背景
DQN(Mnih et al. 2015)将深度学习与 Q-Learning 结合,首次在 Atari 游戏上达到人类水平。
核心创新:
- **经验回放(Experience Replay)**:存储过去经验,随机采样训练,打破数据相关性
- **目标网络(Target Network)**:固定 Q-target,稳定训练
环境:**CartPole-v1** — 控制小车左右移动,保持杆子不倒。
状态:4 维(位置、速度、角度、角速度),动作:2 维(左、右)。
""")
md("""## 数学原理
### Q-Learning
$$Q(s, a) \\leftarrow Q(s, a) + \\alpha \\left(r + \\gamma \\max_{a'} Q(s', a') - Q(s, a)\\right)$$
### DQN Loss
$$\\mathcal{L} = \\mathbb{E}_{(s,a,r,s') \\sim \\mathcal{D}} \\left[\\left(r + \\gamma \\max_{a'} Q_{\\theta^-}(s', a') - Q_\\theta(s, a)\\right)^2\\right]$$
- $\\mathcal{D}$: 经验回放缓冲区
- $\\theta$: 在线网络参数
- $\\theta^-$: 目标网络参数(每隔 $C$ 步复制一次)
### ε-greedy 探索
$$a = \\begin{cases} \\text{random}, & \\text{概率 } \\varepsilon \\\\ \\arg\\max_a Q(s, a), & \\text{概率 } 1-\\varepsilon \\end{cases}$$
ε 随时间指数衰减。
""")
code("""\
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from rl.dqn.dqn import DQN, ReplayBuffer, train_episode, epsilon_by_episode
from utils.config import load_config
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
print(f"State: {state_dim} Action: {action_dim}")
""")
code("""\
model = DQN(state_dim, action_dim, hidden_dim=128).to(device)
target = DQN(state_dim, action_dim, hidden_dim=128).to(device)
target.load_state_dict(model.state_dict())
target.eval()
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
""")
md("""## 训练
> ⏱ 预估耗时:**500 episode × ~0.2s ≈ 2 分钟**(M4 Max)
""")
code("""\
NUM_EPISODES = 500
LR = 0.001
GAMMA = 0.99
BATCH_SIZE = 64
BUFFER_SIZE = 50000
TARGET_UPDATE = 100
EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 500
optimizer = optim.Adam(model.parameters(), lr=LR)
buffer = ReplayBuffer(BUFFER_SIZE)
rewards = []
for episode in range(1, NUM_EPISODES + 1):
state, _ = env.reset()
episode_reward = 0
eps = epsilon_by_episode(episode, EPSILON_START, EPSILON_END, EPSILON_DECAY)
while True:
if np.random.random() < eps:
action = env.action_space.sample()
else:
with torch.no_grad():
q = model(torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0))
action = q.argmax().item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
buffer.push(state, action, reward, next_state, done)
state = next_state
episode_reward += reward
_ = train_episode(model, target, optimizer, buffer, BATCH_SIZE, GAMMA)
if done:
break
if episode % TARGET_UPDATE == 0:
target.load_state_dict(model.state_dict())
rewards.append(episode_reward)
if episode % 50 == 0:
avg = np.mean(rewards[-50:])
print(f"Episode [{episode:3d}/{NUM_EPISODES}] Reward: {episode_reward:.0f} Avg(50): {avg:.1f} ε: {eps:.3f}")
env.close()
""")
md("""## Reward 曲线""")
code("""\
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 4))
plt.plot(rewards)
plt.xlabel("Episode"); plt.ylabel("Total Reward"); plt.title("DQN Training on CartPole")
plt.grid(True)
plt.axhline(y=195, color='r', linestyle='--', label='Solved (195)')
plt.legend()
plt.show()
""")
md("""\
## 思考题
1. 为什么需要经验回放(Experience Replay)?在线学习会有什么问题?
2. 目标网络(Target Network)解决了什么?如果不固定 target,Q 值会发散吗?
3. ϵ-greedy 中的 ϵ 从 1.0 开始衰减有什么含义?
4. 如果把 `hidden_dim` 从 128 改到 32,训练速度会怎样?收敛难度呢?
""")
nb.cells = cells
with open("rl/dqn/dqn.ipynb", "w") as f:
nbf.write(nb, f)
print("Generated rl/dqn/dqn.ipynb")