""" DQN Space Invaders Demo — Watch trained agents play Atari Space Invaders. Hugging Face Spaces deployment with Gradio interface. """ import os import gradio as gr import gymnasium as gym import ale_py import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import cv2 import imageio import tempfile # Register Atari environments gym.register_envs(ale_py) device = torch.device("cpu") # ─── Network Architectures ──────────────────────────────────── class QNetwork(nn.Module): """Standard DQN CNN architecture.""" def __init__(self, action_size, seed=42, frame_stack=4, frame_size=84): super(QNetwork, self).__init__() self.seed = torch.manual_seed(seed) self.conv1 = nn.Conv2d(frame_stack, 32, kernel_size=8, stride=4) self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) conv_out_size = self._get_conv_out_size(frame_stack, frame_size) self.fc1 = nn.Linear(conv_out_size, 512) self.fc2 = nn.Linear(512, action_size) def _get_conv_out_size(self, channels, size): dummy = torch.zeros(1, channels, size, size) x = F.relu(self.conv1(dummy)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) return int(np.prod(x.size())) def forward(self, state): x = F.relu(self.conv1(state)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) return self.fc2(x) class DuelingQNetwork(nn.Module): """Dueling DQN with separate value and advantage streams.""" def __init__(self, action_size, seed=42, frame_stack=4, frame_size=84): super(DuelingQNetwork, self).__init__() self.seed = torch.manual_seed(seed) self.action_size = action_size self.conv1 = nn.Conv2d(frame_stack, 32, kernel_size=8, stride=4) self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) conv_out_size = self._get_conv_out_size(frame_stack, frame_size) self.value_fc = nn.Linear(conv_out_size, 512) self.value = nn.Linear(512, 1) self.advantage_fc = nn.Linear(conv_out_size, 512) self.advantage = nn.Linear(512, action_size) def _get_conv_out_size(self, channels, size): dummy = torch.zeros(1, channels, size, size) x = F.relu(self.conv1(dummy)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) return int(np.prod(x.size())) def forward(self, state): x = F.relu(self.conv1(state)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) value = self.value(F.relu(self.value_fc(x))) advantage = self.advantage(F.relu(self.advantage_fc(x))) return value + advantage - advantage.mean(dim=1, keepdim=True) # ─── Preprocessing ──────────────────────────────────────────── def preprocess_frame(frame, size=84): gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) cropped = gray[20:200, :] resized = cv2.resize(cropped, (size, size), interpolation=cv2.INTER_AREA) return resized / 255.0 class FrameStack: def __init__(self, num_frames=4, frame_size=84): self.num_frames = num_frames self.frame_size = frame_size self.frames = [] def reset(self, frame): processed = preprocess_frame(frame, self.frame_size) self.frames = [processed] * self.num_frames return self._get_state() def step(self, frame): processed = preprocess_frame(frame, self.frame_size) self.frames.append(processed) self.frames = self.frames[-self.num_frames:] return self._get_state() def _get_state(self): return np.array(self.frames, dtype=np.float32) # ─── Model Loading ──────────────────────────────────────────── CHECKPOINTS = { "Baseline DQN (avg: 524.75)": { "file": "checkpoints/Baseline_DQN_best_checkpoint.pth", "dueling": False, }, "Double DQN (avg: 650.20) ⭐ Best": { "file": "checkpoints/Double_DQN_best_checkpoint.pth", "dueling": False, }, "Dueling DQN (avg: 497.55)": { "file": "checkpoints/Dueling_DQN_best_checkpoint.pth", "dueling": True, }, } def load_model(variant_name): config = CHECKPOINTS[variant_name] NetworkClass = DuelingQNetwork if config["dueling"] else QNetwork model = NetworkClass(action_size=6, seed=42).to(device) model.load_state_dict( torch.load(config["file"], map_location=device, weights_only=True) ) model.eval() return model # ─── Game Runner ────────────────────────────────────────────── ACTION_NAMES = ["NOOP", "FIRE", "RIGHT", "LEFT", "RIGHTFIRE", "LEFTFIRE"] def play_game(variant_name, seed=42, max_steps=3000, fps=30): """Run one full game and return a video file path.""" model = load_model(variant_name) env = gym.make("ALE/SpaceInvaders-v5", render_mode="rgb_array") frame_stack = FrameStack(4, 84) obs, info = env.reset(seed=seed) state = frame_stack.reset(obs) frames = [] total_reward = 0 step_count = 0 while step_count < max_steps: # Greedy action selection (no exploration) state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device) with torch.no_grad(): q_values = model(state_tensor) action = q_values.argmax(1).item() obs, reward, terminated, truncated, info = env.step(action) done = terminated or truncated state = frame_stack.step(obs) total_reward += reward step_count += 1 # Capture rendered frame with HUD overlay frame = env.render() frame_with_hud = add_hud(frame.copy(), total_reward, step_count, ACTION_NAMES[action], variant_name) frames.append(frame_with_hud) if done: # Add a few freeze frames at the end for _ in range(fps): frames.append(frame_with_hud) break env.close() # Write video tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) writer = imageio.get_writer(tmp.name, fps=fps, quality=8) for f in frames: writer.append_data(f) writer.close() return tmp.name, f"**Score: {int(total_reward)}** | Steps: {step_count}" def add_hud(frame, score, steps, action, variant): """Add score/action overlay to the frame.""" h, w = frame.shape[:2] scale = max(1, w // 160) # Semi-transparent bar at bottom overlay = frame.copy() bar_h = 18 * scale cv2.rectangle(overlay, (0, h - bar_h), (w, h), (0, 0, 0), -1) frame = cv2.addWeighted(overlay, 0.7, frame, 0.3, 0) font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.35 * scale thickness = max(1, scale // 2) cv2.putText(frame, f"Score: {int(score)}", (4, h - 5 * scale), font, font_scale, (0, 255, 0), thickness) cv2.putText(frame, f"{action}", (w // 2 - 15 * scale, h - 5 * scale), font, font_scale, (255, 255, 0), thickness) cv2.putText(frame, f"Step: {steps}", (w - 60 * scale, h - 5 * scale), font, font_scale, (200, 200, 200), thickness) return frame # ─── Gradio Interface ───────────────────────────────────────── def run_demo(variant, seed): seed = int(seed) video_path, result_text = play_game(variant, seed=seed) return video_path, result_text with gr.Blocks( title="DQN Space Invaders", theme=gr.themes.Base( primary_hue="blue", neutral_hue="slate", ), ) as demo: gr.Markdown( """ # 🕹️ Deep Q-Network — Space Invaders Watch trained DQN agents play Atari Space Invaders in real-time. Three variants trained from raw pixels using PyTorch. """ ) with gr.Tabs(): with gr.TabItem("🎮 Play Game"): with gr.Row(): with gr.Column(scale=1): variant_dropdown = gr.Dropdown( choices=list(CHECKPOINTS.keys()), value="Double DQN (avg: 650.20) ⭐ Best", label="Agent Variant", ) seed_input = gr.Number( value=42, label="Random Seed", info="Change for a different game", precision=0, ) play_btn = gr.Button("▶ Play Game", variant="primary", size="lg") result_text = gr.Markdown("") gr.Markdown( """ --- **About the agents:** - **Baseline DQN** — Standard architecture - **Double DQN** ⭐ — Reduces Q-value overestimation - **Dueling DQN** — Separates state value from action advantage """ ) with gr.Column(scale=2): video_output = gr.Video(label="Gameplay", autoplay=True) play_btn.click( fn=run_demo, inputs=[variant_dropdown, seed_input], outputs=[video_output, result_text], ) with gr.TabItem("📊 Training Info"): gr.Markdown( """ ## Training Results All three variants exceeded 490+ average score over 100 consecutive episodes. | Variant | Avg Score | Best Score | Episodes | Training Time | |---------|-----------|-----------|----------|---------------| | Baseline DQN | 524.75 | 586.45 | 1,470 | 7,000 | | Double DQN | 650.20 | 650.20 | 1,355 | 6,090 | | Dueling DQN | 497.55 | 647.05 | 1,465 | 7,000 | **Key Finding:** Double DQN achieved the highest sustained performance with zero degradation between best and final averages. Dueling DQN reached the highest peak (647) but exhibited catastrophic forgetting in extended training — demonstrating why model checkpointing is critical in RL. --- ## Network Architecture All variants share a convolutional backbone from Mnih et al. (2015): | Layer | Filters | Kernel | Stride | Output | |-------|---------|--------|--------|--------| | Conv1 | 32 | 8×8 | 4 | 20×20×32 | | Conv2 | 64 | 4×4 | 2 | 9×9×64 | | Conv3 | 64 | 3×3 | 1 | 7×7×64 | **Variant-specific heads:** - **Baseline DQN:** Standard single-stream fully connected head → Q-values - **Double DQN:** Same architecture, but decouples action selection (local network) from evaluation (target network) to reduce overestimation bias - **Dueling DQN:** Splits into Value and Advantage streams — Q(s,a) = V(s) + A(s,a) - mean(A) — allowing the network to learn state quality independently of action choice --- ## Preprocessing Pipeline Raw Atari frames (210×160×3) are converted into CNN-ready tensors (4×84×84), reducing input dimensionality by 72% while preserving all gameplay information: 1. Convert RGB → Grayscale 2. Crop to game region (20:200) 3. Resize to 84×84 4. Stack 4 consecutive frames 5. Normalize to [0, 1] --- ## Training Configuration | Config | Baseline | Double | Dueling | |--------|----------|--------|---------| | Learning Rate | 1e-4 | 1.5e-4 | 1e-4 | | Epsilon Decay | 0.9993 | 0.9992 | 0.9995 | | Batch Size | 32 | 32 | 32 | | Gamma (Discount) | 0.99 | 0.99 | 0.99 | **Design rationale:** - Double DQN uses a higher learning rate because its more conservative Q-estimates can tolerate faster learning without divergence - Dueling DQN uses slower epsilon decay to benefit from extended exploration during value/advantage stream learning --- ## Key Components - **Experience Replay Buffer** (100K transitions): Breaks temporal correlation and enables sample reuse - **Target Network** with soft updates (τ=0.001): Stabilizes Q-value targets - **ε-Greedy Exploration:** Decays from 1.0 → 0.01 at variant-specific rates - **Gradient Clipping** (max norm 10): Prevents exploding gradients from large TD errors --- ## Key Findings 1. **Low learning rate is essential** — Standard 1e-3 causes divergence in Atari; 1e-4 provides stable convergence 2. **Double DQN is the most stable** — Zero gap between best and final averages; the only variant to reach its extended target 3. **Longer training ≠ better** — All variants showed diminishing returns or degradation past ~6,000 episodes 4. **Checkpointing matters** — Dueling DQN's peak performance (647) was 150 points above its final average; the best policy is not always the last one --- ## References - Mnih, V. et al. (2015). [Human-level control through deep reinforcement learning](https://www.nature.com/articles/nature14236). Nature, 518(7540). - Van Hasselt, H. et al. (2016). [Deep Reinforcement Learning with Double Q-learning](https://arxiv.org/abs/1509.06461). AAAI. - Wang, Z. et al. (2016). [Dueling Network Architectures for Deep RL](https://arxiv.org/abs/1511.06581). ICML. [GitHub Repository](https://github.com/antonisbast/DQN-SpaceInvaders) """ ) if __name__ == "__main__": demo.launch()