antonisbast commited on
Commit
af63b78
·
1 Parent(s): 8c7f4b9

Add training information tab to frontend

Browse files
Files changed (1) hide show
  1. app.py +116 -30
app.py CHANGED
@@ -245,47 +245,133 @@ with gr.Blocks(
245
 
246
  Watch trained DQN agents play Atari Space Invaders in real-time.
247
  Three variants trained from raw pixels using PyTorch.
248
-
249
- **Pick an agent and hit Play!**
250
  """
251
  )
252
 
253
- with gr.Row():
254
- with gr.Column(scale=1):
255
- variant_dropdown = gr.Dropdown(
256
- choices=list(CHECKPOINTS.keys()),
257
- value="Double DQN (avg: 650.20) ⭐ Best",
258
- label="Agent Variant",
259
- )
260
- seed_input = gr.Number(
261
- value=42,
262
- label="Random Seed",
263
- info="Change for a different game",
264
- precision=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  )
266
- play_btn = gr.Button("▶ Play Game", variant="primary", size="lg")
267
- result_text = gr.Markdown("")
268
 
 
269
  gr.Markdown(
270
  """
 
 
 
 
 
 
 
 
 
 
 
 
271
  ---
272
- **About the agents:**
273
- - **Baseline DQN** — Standard architecture
274
- - **Double DQN** ⭐ — Reduces Q-value overestimation
275
- - **Dueling DQN** — Separates state value from action advantage
276
 
277
- [GitHub](https://github.com/antonisbast/DQN-SpaceInvaders)
278
- """
279
- )
280
 
281
- with gr.Column(scale=2):
282
- video_output = gr.Video(label="Gameplay", autoplay=True)
283
 
284
- play_btn.click(
285
- fn=run_demo,
286
- inputs=[variant_dropdown, seed_input],
287
- outputs=[video_output, result_text],
288
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  if __name__ == "__main__":
291
  demo.launch()
 
245
 
246
  Watch trained DQN agents play Atari Space Invaders in real-time.
247
  Three variants trained from raw pixels using PyTorch.
 
 
248
  """
249
  )
250
 
251
+ with gr.Tabs():
252
+ with gr.TabItem("🎮 Play Game"):
253
+ with gr.Row():
254
+ with gr.Column(scale=1):
255
+ variant_dropdown = gr.Dropdown(
256
+ choices=list(CHECKPOINTS.keys()),
257
+ value="Double DQN (avg: 650.20) ⭐ Best",
258
+ label="Agent Variant",
259
+ )
260
+ seed_input = gr.Number(
261
+ value=42,
262
+ label="Random Seed",
263
+ info="Change for a different game",
264
+ precision=0,
265
+ )
266
+ play_btn = gr.Button("▶ Play Game", variant="primary", size="lg")
267
+ result_text = gr.Markdown("")
268
+
269
+ gr.Markdown(
270
+ """
271
+ ---
272
+ **About the agents:**
273
+ - **Baseline DQN** — Standard architecture
274
+ - **Double DQN** ⭐ — Reduces Q-value overestimation
275
+ - **Dueling DQN** — Separates state value from action advantage
276
+ """
277
+ )
278
+
279
+ with gr.Column(scale=2):
280
+ video_output = gr.Video(label="Gameplay", autoplay=True)
281
+
282
+ play_btn.click(
283
+ fn=run_demo,
284
+ inputs=[variant_dropdown, seed_input],
285
+ outputs=[video_output, result_text],
286
  )
 
 
287
 
288
+ with gr.TabItem("📊 Training Info"):
289
  gr.Markdown(
290
  """
291
+ ## Training Results
292
+
293
+ All three variants exceeded 490+ average score over 100 consecutive episodes.
294
+
295
+ | Variant | Avg Score | Best Score | Episodes | Training Time |
296
+ |---------|-----------|-----------|----------|---------------|
297
+ | Baseline DQN | 524.75 | 586.45 | 1,470 | 7,000 |
298
+ | Double DQN | 650.20 | 650.20 | 1,355 | 6,090 |
299
+ | Dueling DQN | 497.55 | 647.05 | 1,465 | 7,000 |
300
+
301
+ **Key Finding:** Double DQN achieved the highest sustained performance with zero degradation between best and final averages. Dueling DQN reached the highest peak (647) but exhibited catastrophic forgetting in extended training — demonstrating why model checkpointing is critical in RL.
302
+
303
  ---
 
 
 
 
304
 
305
+ ## Network Architecture
 
 
306
 
307
+ All variants share a convolutional backbone from Mnih et al. (2015):
 
308
 
309
+ | Layer | Filters | Kernel | Stride | Output |
310
+ |-------|---------|--------|--------|--------|
311
+ | Conv1 | 32 | 8×8 | 4 | 20×20×32 |
312
+ | Conv2 | 64 | 4×4 | 2 | 9×9×64 |
313
+ | Conv3 | 64 | 3×3 | 1 | 7×7×64 |
314
+
315
+ **Variant-specific heads:**
316
+ - **Baseline DQN:** Standard single-stream fully connected head → Q-values
317
+ - **Double DQN:** Same architecture, but decouples action selection (local network) from evaluation (target network) to reduce overestimation bias
318
+ - **Dueling DQN:** Splits into Value and Advantage streams — Q(s,a) = V(s) + A(s,a) - mean(A) — allowing the network to learn state quality independently of action choice
319
+
320
+ ---
321
+
322
+ ## Preprocessing Pipeline
323
+
324
+ Raw Atari frames (210×160×3) are converted into CNN-ready tensors (4×84×84), reducing input dimensionality by 72% while preserving all gameplay information:
325
+ 1. Convert RGB → Grayscale
326
+ 2. Crop to game region (20:200)
327
+ 3. Resize to 84×84
328
+ 4. Stack 4 consecutive frames
329
+ 5. Normalize to [0, 1]
330
+
331
+ ---
332
+
333
+ ## Training Configuration
334
+
335
+ | Config | Baseline | Double | Dueling |
336
+ |--------|----------|--------|---------|
337
+ | Learning Rate | 1e-4 | 1.5e-4 | 1e-4 |
338
+ | Epsilon Decay | 0.9993 | 0.9992 | 0.9995 |
339
+ | Batch Size | 32 | 32 | 32 |
340
+ | Gamma (Discount) | 0.99 | 0.99 | 0.99 |
341
+
342
+ **Design rationale:**
343
+ - Double DQN uses a higher learning rate because its more conservative Q-estimates can tolerate faster learning without divergence
344
+ - Dueling DQN uses slower epsilon decay to benefit from extended exploration during value/advantage stream learning
345
+
346
+ ---
347
+
348
+ ## Key Components
349
+
350
+ - **Experience Replay Buffer** (100K transitions): Breaks temporal correlation and enables sample reuse
351
+ - **Target Network** with soft updates (τ=0.001): Stabilizes Q-value targets
352
+ - **ε-Greedy Exploration:** Decays from 1.0 → 0.01 at variant-specific rates
353
+ - **Gradient Clipping** (max norm 10): Prevents exploding gradients from large TD errors
354
+
355
+ ---
356
+
357
+ ## Key Findings
358
+
359
+ 1. **Low learning rate is essential** — Standard 1e-3 causes divergence in Atari; 1e-4 provides stable convergence
360
+ 2. **Double DQN is the most stable** — Zero gap between best and final averages; the only variant to reach its extended target
361
+ 3. **Longer training ≠ better** — All variants showed diminishing returns or degradation past ~6,000 episodes
362
+ 4. **Checkpointing matters** — Dueling DQN's peak performance (647) was 150 points above its final average; the best policy is not always the last one
363
+
364
+ ---
365
+
366
+ ## References
367
+
368
+ - Mnih, V. et al. (2015). [Human-level control through deep reinforcement learning](https://www.nature.com/articles/nature14236). Nature, 518(7540).
369
+ - Van Hasselt, H. et al. (2016). [Deep Reinforcement Learning with Double Q-learning](https://arxiv.org/abs/1509.06461). AAAI.
370
+ - Wang, Z. et al. (2016). [Dueling Network Architectures for Deep RL](https://arxiv.org/abs/1511.06581). ICML.
371
+
372
+ [GitHub Repository](https://github.com/antonisbast/DQN-SpaceInvaders)
373
+ """
374
+ )
375
 
376
  if __name__ == "__main__":
377
  demo.launch()