Spaces:
Sleeping
Sleeping
Commit ·
af63b78
1
Parent(s): 8c7f4b9
Add training information tab to frontend
Browse files
app.py
CHANGED
|
@@ -245,47 +245,133 @@ with gr.Blocks(
|
|
| 245 |
|
| 246 |
Watch trained DQN agents play Atari Space Invaders in real-time.
|
| 247 |
Three variants trained from raw pixels using PyTorch.
|
| 248 |
-
|
| 249 |
-
**Pick an agent and hit Play!**
|
| 250 |
"""
|
| 251 |
)
|
| 252 |
|
| 253 |
-
with gr.
|
| 254 |
-
with gr.
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
)
|
| 266 |
-
play_btn = gr.Button("▶ Play Game", variant="primary", size="lg")
|
| 267 |
-
result_text = gr.Markdown("")
|
| 268 |
|
|
|
|
| 269 |
gr.Markdown(
|
| 270 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
---
|
| 272 |
-
**About the agents:**
|
| 273 |
-
- **Baseline DQN** — Standard architecture
|
| 274 |
-
- **Double DQN** ⭐ — Reduces Q-value overestimation
|
| 275 |
-
- **Dueling DQN** — Separates state value from action advantage
|
| 276 |
|
| 277 |
-
|
| 278 |
-
"""
|
| 279 |
-
)
|
| 280 |
|
| 281 |
-
|
| 282 |
-
video_output = gr.Video(label="Gameplay", autoplay=True)
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
if __name__ == "__main__":
|
| 291 |
demo.launch()
|
|
|
|
| 245 |
|
| 246 |
Watch trained DQN agents play Atari Space Invaders in real-time.
|
| 247 |
Three variants trained from raw pixels using PyTorch.
|
|
|
|
|
|
|
| 248 |
"""
|
| 249 |
)
|
| 250 |
|
| 251 |
+
with gr.Tabs():
|
| 252 |
+
with gr.TabItem("🎮 Play Game"):
|
| 253 |
+
with gr.Row():
|
| 254 |
+
with gr.Column(scale=1):
|
| 255 |
+
variant_dropdown = gr.Dropdown(
|
| 256 |
+
choices=list(CHECKPOINTS.keys()),
|
| 257 |
+
value="Double DQN (avg: 650.20) ⭐ Best",
|
| 258 |
+
label="Agent Variant",
|
| 259 |
+
)
|
| 260 |
+
seed_input = gr.Number(
|
| 261 |
+
value=42,
|
| 262 |
+
label="Random Seed",
|
| 263 |
+
info="Change for a different game",
|
| 264 |
+
precision=0,
|
| 265 |
+
)
|
| 266 |
+
play_btn = gr.Button("▶ Play Game", variant="primary", size="lg")
|
| 267 |
+
result_text = gr.Markdown("")
|
| 268 |
+
|
| 269 |
+
gr.Markdown(
|
| 270 |
+
"""
|
| 271 |
+
---
|
| 272 |
+
**About the agents:**
|
| 273 |
+
- **Baseline DQN** — Standard architecture
|
| 274 |
+
- **Double DQN** ⭐ — Reduces Q-value overestimation
|
| 275 |
+
- **Dueling DQN** — Separates state value from action advantage
|
| 276 |
+
"""
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
with gr.Column(scale=2):
|
| 280 |
+
video_output = gr.Video(label="Gameplay", autoplay=True)
|
| 281 |
+
|
| 282 |
+
play_btn.click(
|
| 283 |
+
fn=run_demo,
|
| 284 |
+
inputs=[variant_dropdown, seed_input],
|
| 285 |
+
outputs=[video_output, result_text],
|
| 286 |
)
|
|
|
|
|
|
|
| 287 |
|
| 288 |
+
with gr.TabItem("📊 Training Info"):
|
| 289 |
gr.Markdown(
|
| 290 |
"""
|
| 291 |
+
## Training Results
|
| 292 |
+
|
| 293 |
+
All three variants exceeded 490+ average score over 100 consecutive episodes.
|
| 294 |
+
|
| 295 |
+
| Variant | Avg Score | Best Score | Episodes | Training Time |
|
| 296 |
+
|---------|-----------|-----------|----------|---------------|
|
| 297 |
+
| Baseline DQN | 524.75 | 586.45 | 1,470 | 7,000 |
|
| 298 |
+
| Double DQN | 650.20 | 650.20 | 1,355 | 6,090 |
|
| 299 |
+
| Dueling DQN | 497.55 | 647.05 | 1,465 | 7,000 |
|
| 300 |
+
|
| 301 |
+
**Key Finding:** Double DQN achieved the highest sustained performance with zero degradation between best and final averages. Dueling DQN reached the highest peak (647) but exhibited catastrophic forgetting in extended training — demonstrating why model checkpointing is critical in RL.
|
| 302 |
+
|
| 303 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
+
## Network Architecture
|
|
|
|
|
|
|
| 306 |
|
| 307 |
+
All variants share a convolutional backbone from Mnih et al. (2015):
|
|
|
|
| 308 |
|
| 309 |
+
| Layer | Filters | Kernel | Stride | Output |
|
| 310 |
+
|-------|---------|--------|--------|--------|
|
| 311 |
+
| Conv1 | 32 | 8×8 | 4 | 20×20×32 |
|
| 312 |
+
| Conv2 | 64 | 4×4 | 2 | 9×9×64 |
|
| 313 |
+
| Conv3 | 64 | 3×3 | 1 | 7×7×64 |
|
| 314 |
+
|
| 315 |
+
**Variant-specific heads:**
|
| 316 |
+
- **Baseline DQN:** Standard single-stream fully connected head → Q-values
|
| 317 |
+
- **Double DQN:** Same architecture, but decouples action selection (local network) from evaluation (target network) to reduce overestimation bias
|
| 318 |
+
- **Dueling DQN:** Splits into Value and Advantage streams — Q(s,a) = V(s) + A(s,a) - mean(A) — allowing the network to learn state quality independently of action choice
|
| 319 |
+
|
| 320 |
+
---
|
| 321 |
+
|
| 322 |
+
## Preprocessing Pipeline
|
| 323 |
+
|
| 324 |
+
Raw Atari frames (210×160×3) are converted into CNN-ready tensors (4×84×84), reducing input dimensionality by 72% while preserving all gameplay information:
|
| 325 |
+
1. Convert RGB → Grayscale
|
| 326 |
+
2. Crop to game region (20:200)
|
| 327 |
+
3. Resize to 84×84
|
| 328 |
+
4. Stack 4 consecutive frames
|
| 329 |
+
5. Normalize to [0, 1]
|
| 330 |
+
|
| 331 |
+
---
|
| 332 |
+
|
| 333 |
+
## Training Configuration
|
| 334 |
+
|
| 335 |
+
| Config | Baseline | Double | Dueling |
|
| 336 |
+
|--------|----------|--------|---------|
|
| 337 |
+
| Learning Rate | 1e-4 | 1.5e-4 | 1e-4 |
|
| 338 |
+
| Epsilon Decay | 0.9993 | 0.9992 | 0.9995 |
|
| 339 |
+
| Batch Size | 32 | 32 | 32 |
|
| 340 |
+
| Gamma (Discount) | 0.99 | 0.99 | 0.99 |
|
| 341 |
+
|
| 342 |
+
**Design rationale:**
|
| 343 |
+
- Double DQN uses a higher learning rate because its more conservative Q-estimates can tolerate faster learning without divergence
|
| 344 |
+
- Dueling DQN uses slower epsilon decay to benefit from extended exploration during value/advantage stream learning
|
| 345 |
+
|
| 346 |
+
---
|
| 347 |
+
|
| 348 |
+
## Key Components
|
| 349 |
+
|
| 350 |
+
- **Experience Replay Buffer** (100K transitions): Breaks temporal correlation and enables sample reuse
|
| 351 |
+
- **Target Network** with soft updates (τ=0.001): Stabilizes Q-value targets
|
| 352 |
+
- **ε-Greedy Exploration:** Decays from 1.0 → 0.01 at variant-specific rates
|
| 353 |
+
- **Gradient Clipping** (max norm 10): Prevents exploding gradients from large TD errors
|
| 354 |
+
|
| 355 |
+
---
|
| 356 |
+
|
| 357 |
+
## Key Findings
|
| 358 |
+
|
| 359 |
+
1. **Low learning rate is essential** — Standard 1e-3 causes divergence in Atari; 1e-4 provides stable convergence
|
| 360 |
+
2. **Double DQN is the most stable** — Zero gap between best and final averages; the only variant to reach its extended target
|
| 361 |
+
3. **Longer training ≠ better** — All variants showed diminishing returns or degradation past ~6,000 episodes
|
| 362 |
+
4. **Checkpointing matters** — Dueling DQN's peak performance (647) was 150 points above its final average; the best policy is not always the last one
|
| 363 |
+
|
| 364 |
+
---
|
| 365 |
+
|
| 366 |
+
## References
|
| 367 |
+
|
| 368 |
+
- Mnih, V. et al. (2015). [Human-level control through deep reinforcement learning](https://www.nature.com/articles/nature14236). Nature, 518(7540).
|
| 369 |
+
- Van Hasselt, H. et al. (2016). [Deep Reinforcement Learning with Double Q-learning](https://arxiv.org/abs/1509.06461). AAAI.
|
| 370 |
+
- Wang, Z. et al. (2016). [Dueling Network Architectures for Deep RL](https://arxiv.org/abs/1511.06581). ICML.
|
| 371 |
+
|
| 372 |
+
[GitHub Repository](https://github.com/antonisbast/DQN-SpaceInvaders)
|
| 373 |
+
"""
|
| 374 |
+
)
|
| 375 |
|
| 376 |
if __name__ == "__main__":
|
| 377 |
demo.launch()
|