antonisbast's picture
Add training information tab to frontend
af63b78
"""
DQN Space Invaders Demo โ€” Watch trained agents play Atari Space Invaders.
Hugging Face Spaces deployment with Gradio interface.
"""
import os
import gradio as gr
import gymnasium as gym
import ale_py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import imageio
import tempfile
# Register Atari environments
gym.register_envs(ale_py)
device = torch.device("cpu")
# โ”€โ”€โ”€ Network Architectures โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class QNetwork(nn.Module):
"""Standard DQN CNN architecture."""
def __init__(self, action_size, seed=42, frame_stack=4, frame_size=84):
super(QNetwork, self).__init__()
self.seed = torch.manual_seed(seed)
self.conv1 = nn.Conv2d(frame_stack, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
conv_out_size = self._get_conv_out_size(frame_stack, frame_size)
self.fc1 = nn.Linear(conv_out_size, 512)
self.fc2 = nn.Linear(512, action_size)
def _get_conv_out_size(self, channels, size):
dummy = torch.zeros(1, channels, size, size)
x = F.relu(self.conv1(dummy))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
return int(np.prod(x.size()))
def forward(self, state):
x = F.relu(self.conv1(state))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
return self.fc2(x)
class DuelingQNetwork(nn.Module):
"""Dueling DQN with separate value and advantage streams."""
def __init__(self, action_size, seed=42, frame_stack=4, frame_size=84):
super(DuelingQNetwork, self).__init__()
self.seed = torch.manual_seed(seed)
self.action_size = action_size
self.conv1 = nn.Conv2d(frame_stack, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
conv_out_size = self._get_conv_out_size(frame_stack, frame_size)
self.value_fc = nn.Linear(conv_out_size, 512)
self.value = nn.Linear(512, 1)
self.advantage_fc = nn.Linear(conv_out_size, 512)
self.advantage = nn.Linear(512, action_size)
def _get_conv_out_size(self, channels, size):
dummy = torch.zeros(1, channels, size, size)
x = F.relu(self.conv1(dummy))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
return int(np.prod(x.size()))
def forward(self, state):
x = F.relu(self.conv1(state))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
value = self.value(F.relu(self.value_fc(x)))
advantage = self.advantage(F.relu(self.advantage_fc(x)))
return value + advantage - advantage.mean(dim=1, keepdim=True)
# โ”€โ”€โ”€ Preprocessing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def preprocess_frame(frame, size=84):
gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
cropped = gray[20:200, :]
resized = cv2.resize(cropped, (size, size), interpolation=cv2.INTER_AREA)
return resized / 255.0
class FrameStack:
def __init__(self, num_frames=4, frame_size=84):
self.num_frames = num_frames
self.frame_size = frame_size
self.frames = []
def reset(self, frame):
processed = preprocess_frame(frame, self.frame_size)
self.frames = [processed] * self.num_frames
return self._get_state()
def step(self, frame):
processed = preprocess_frame(frame, self.frame_size)
self.frames.append(processed)
self.frames = self.frames[-self.num_frames:]
return self._get_state()
def _get_state(self):
return np.array(self.frames, dtype=np.float32)
# โ”€โ”€โ”€ Model Loading โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
CHECKPOINTS = {
"Baseline DQN (avg: 524.75)": {
"file": "checkpoints/Baseline_DQN_best_checkpoint.pth",
"dueling": False,
},
"Double DQN (avg: 650.20) โญ Best": {
"file": "checkpoints/Double_DQN_best_checkpoint.pth",
"dueling": False,
},
"Dueling DQN (avg: 497.55)": {
"file": "checkpoints/Dueling_DQN_best_checkpoint.pth",
"dueling": True,
},
}
def load_model(variant_name):
config = CHECKPOINTS[variant_name]
NetworkClass = DuelingQNetwork if config["dueling"] else QNetwork
model = NetworkClass(action_size=6, seed=42).to(device)
model.load_state_dict(
torch.load(config["file"], map_location=device, weights_only=True)
)
model.eval()
return model
# โ”€โ”€โ”€ Game Runner โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ACTION_NAMES = ["NOOP", "FIRE", "RIGHT", "LEFT", "RIGHTFIRE", "LEFTFIRE"]
def play_game(variant_name, seed=42, max_steps=3000, fps=30):
"""Run one full game and return a video file path."""
model = load_model(variant_name)
env = gym.make("ALE/SpaceInvaders-v5", render_mode="rgb_array")
frame_stack = FrameStack(4, 84)
obs, info = env.reset(seed=seed)
state = frame_stack.reset(obs)
frames = []
total_reward = 0
step_count = 0
while step_count < max_steps:
# Greedy action selection (no exploration)
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device)
with torch.no_grad():
q_values = model(state_tensor)
action = q_values.argmax(1).item()
obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
state = frame_stack.step(obs)
total_reward += reward
step_count += 1
# Capture rendered frame with HUD overlay
frame = env.render()
frame_with_hud = add_hud(frame.copy(), total_reward, step_count,
ACTION_NAMES[action], variant_name)
frames.append(frame_with_hud)
if done:
# Add a few freeze frames at the end
for _ in range(fps):
frames.append(frame_with_hud)
break
env.close()
# Write video
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
writer = imageio.get_writer(tmp.name, fps=fps, quality=8)
for f in frames:
writer.append_data(f)
writer.close()
return tmp.name, f"**Score: {int(total_reward)}** | Steps: {step_count}"
def add_hud(frame, score, steps, action, variant):
"""Add score/action overlay to the frame."""
h, w = frame.shape[:2]
scale = max(1, w // 160)
# Semi-transparent bar at bottom
overlay = frame.copy()
bar_h = 18 * scale
cv2.rectangle(overlay, (0, h - bar_h), (w, h), (0, 0, 0), -1)
frame = cv2.addWeighted(overlay, 0.7, frame, 0.3, 0)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.35 * scale
thickness = max(1, scale // 2)
cv2.putText(frame, f"Score: {int(score)}", (4, h - 5 * scale),
font, font_scale, (0, 255, 0), thickness)
cv2.putText(frame, f"{action}", (w // 2 - 15 * scale, h - 5 * scale),
font, font_scale, (255, 255, 0), thickness)
cv2.putText(frame, f"Step: {steps}", (w - 60 * scale, h - 5 * scale),
font, font_scale, (200, 200, 200), thickness)
return frame
# โ”€โ”€โ”€ Gradio Interface โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def run_demo(variant, seed):
seed = int(seed)
video_path, result_text = play_game(variant, seed=seed)
return video_path, result_text
with gr.Blocks(
title="DQN Space Invaders",
theme=gr.themes.Base(
primary_hue="blue",
neutral_hue="slate",
),
) as demo:
gr.Markdown(
"""
# ๐Ÿ•น๏ธ Deep Q-Network โ€” Space Invaders
Watch trained DQN agents play Atari Space Invaders in real-time.
Three variants trained from raw pixels using PyTorch.
"""
)
with gr.Tabs():
with gr.TabItem("๐ŸŽฎ Play Game"):
with gr.Row():
with gr.Column(scale=1):
variant_dropdown = gr.Dropdown(
choices=list(CHECKPOINTS.keys()),
value="Double DQN (avg: 650.20) โญ Best",
label="Agent Variant",
)
seed_input = gr.Number(
value=42,
label="Random Seed",
info="Change for a different game",
precision=0,
)
play_btn = gr.Button("โ–ถ Play Game", variant="primary", size="lg")
result_text = gr.Markdown("")
gr.Markdown(
"""
---
**About the agents:**
- **Baseline DQN** โ€” Standard architecture
- **Double DQN** โญ โ€” Reduces Q-value overestimation
- **Dueling DQN** โ€” Separates state value from action advantage
"""
)
with gr.Column(scale=2):
video_output = gr.Video(label="Gameplay", autoplay=True)
play_btn.click(
fn=run_demo,
inputs=[variant_dropdown, seed_input],
outputs=[video_output, result_text],
)
with gr.TabItem("๐Ÿ“Š Training Info"):
gr.Markdown(
"""
## Training Results
All three variants exceeded 490+ average score over 100 consecutive episodes.
| Variant | Avg Score | Best Score | Episodes | Training Time |
|---------|-----------|-----------|----------|---------------|
| Baseline DQN | 524.75 | 586.45 | 1,470 | 7,000 |
| Double DQN | 650.20 | 650.20 | 1,355 | 6,090 |
| Dueling DQN | 497.55 | 647.05 | 1,465 | 7,000 |
**Key Finding:** Double DQN achieved the highest sustained performance with zero degradation between best and final averages. Dueling DQN reached the highest peak (647) but exhibited catastrophic forgetting in extended training โ€” demonstrating why model checkpointing is critical in RL.
---
## Network Architecture
All variants share a convolutional backbone from Mnih et al. (2015):
| Layer | Filters | Kernel | Stride | Output |
|-------|---------|--------|--------|--------|
| Conv1 | 32 | 8ร—8 | 4 | 20ร—20ร—32 |
| Conv2 | 64 | 4ร—4 | 2 | 9ร—9ร—64 |
| Conv3 | 64 | 3ร—3 | 1 | 7ร—7ร—64 |
**Variant-specific heads:**
- **Baseline DQN:** Standard single-stream fully connected head โ†’ Q-values
- **Double DQN:** Same architecture, but decouples action selection (local network) from evaluation (target network) to reduce overestimation bias
- **Dueling DQN:** Splits into Value and Advantage streams โ€” Q(s,a) = V(s) + A(s,a) - mean(A) โ€” allowing the network to learn state quality independently of action choice
---
## Preprocessing Pipeline
Raw Atari frames (210ร—160ร—3) are converted into CNN-ready tensors (4ร—84ร—84), reducing input dimensionality by 72% while preserving all gameplay information:
1. Convert RGB โ†’ Grayscale
2. Crop to game region (20:200)
3. Resize to 84ร—84
4. Stack 4 consecutive frames
5. Normalize to [0, 1]
---
## Training Configuration
| Config | Baseline | Double | Dueling |
|--------|----------|--------|---------|
| Learning Rate | 1e-4 | 1.5e-4 | 1e-4 |
| Epsilon Decay | 0.9993 | 0.9992 | 0.9995 |
| Batch Size | 32 | 32 | 32 |
| Gamma (Discount) | 0.99 | 0.99 | 0.99 |
**Design rationale:**
- Double DQN uses a higher learning rate because its more conservative Q-estimates can tolerate faster learning without divergence
- Dueling DQN uses slower epsilon decay to benefit from extended exploration during value/advantage stream learning
---
## Key Components
- **Experience Replay Buffer** (100K transitions): Breaks temporal correlation and enables sample reuse
- **Target Network** with soft updates (ฯ„=0.001): Stabilizes Q-value targets
- **ฮต-Greedy Exploration:** Decays from 1.0 โ†’ 0.01 at variant-specific rates
- **Gradient Clipping** (max norm 10): Prevents exploding gradients from large TD errors
---
## Key Findings
1. **Low learning rate is essential** โ€” Standard 1e-3 causes divergence in Atari; 1e-4 provides stable convergence
2. **Double DQN is the most stable** โ€” Zero gap between best and final averages; the only variant to reach its extended target
3. **Longer training โ‰  better** โ€” All variants showed diminishing returns or degradation past ~6,000 episodes
4. **Checkpointing matters** โ€” Dueling DQN's peak performance (647) was 150 points above its final average; the best policy is not always the last one
---
## References
- Mnih, V. et al. (2015). [Human-level control through deep reinforcement learning](https://www.nature.com/articles/nature14236). Nature, 518(7540).
- Van Hasselt, H. et al. (2016). [Deep Reinforcement Learning with Double Q-learning](https://arxiv.org/abs/1509.06461). AAAI.
- Wang, Z. et al. (2016). [Dueling Network Architectures for Deep RL](https://arxiv.org/abs/1511.06581). ICML.
[GitHub Repository](https://github.com/antonisbast/DQN-SpaceInvaders)
"""
)
if __name__ == "__main__":
demo.launch()