Spaces:
Sleeping
Sleeping
| """ | |
| DQN Space Invaders Demo โ Watch trained agents play Atari Space Invaders. | |
| Hugging Face Spaces deployment with Gradio interface. | |
| """ | |
| import os | |
| import gradio as gr | |
| import gymnasium as gym | |
| import ale_py | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import cv2 | |
| import imageio | |
| import tempfile | |
| # Register Atari environments | |
| gym.register_envs(ale_py) | |
| device = torch.device("cpu") | |
| # โโโ Network Architectures โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class QNetwork(nn.Module): | |
| """Standard DQN CNN architecture.""" | |
| def __init__(self, action_size, seed=42, frame_stack=4, frame_size=84): | |
| super(QNetwork, self).__init__() | |
| self.seed = torch.manual_seed(seed) | |
| self.conv1 = nn.Conv2d(frame_stack, 32, kernel_size=8, stride=4) | |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) | |
| self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) | |
| conv_out_size = self._get_conv_out_size(frame_stack, frame_size) | |
| self.fc1 = nn.Linear(conv_out_size, 512) | |
| self.fc2 = nn.Linear(512, action_size) | |
| def _get_conv_out_size(self, channels, size): | |
| dummy = torch.zeros(1, channels, size, size) | |
| x = F.relu(self.conv1(dummy)) | |
| x = F.relu(self.conv2(x)) | |
| x = F.relu(self.conv3(x)) | |
| return int(np.prod(x.size())) | |
| def forward(self, state): | |
| x = F.relu(self.conv1(state)) | |
| x = F.relu(self.conv2(x)) | |
| x = F.relu(self.conv3(x)) | |
| x = x.view(x.size(0), -1) | |
| x = F.relu(self.fc1(x)) | |
| return self.fc2(x) | |
| class DuelingQNetwork(nn.Module): | |
| """Dueling DQN with separate value and advantage streams.""" | |
| def __init__(self, action_size, seed=42, frame_stack=4, frame_size=84): | |
| super(DuelingQNetwork, self).__init__() | |
| self.seed = torch.manual_seed(seed) | |
| self.action_size = action_size | |
| self.conv1 = nn.Conv2d(frame_stack, 32, kernel_size=8, stride=4) | |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) | |
| self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) | |
| conv_out_size = self._get_conv_out_size(frame_stack, frame_size) | |
| self.value_fc = nn.Linear(conv_out_size, 512) | |
| self.value = nn.Linear(512, 1) | |
| self.advantage_fc = nn.Linear(conv_out_size, 512) | |
| self.advantage = nn.Linear(512, action_size) | |
| def _get_conv_out_size(self, channels, size): | |
| dummy = torch.zeros(1, channels, size, size) | |
| x = F.relu(self.conv1(dummy)) | |
| x = F.relu(self.conv2(x)) | |
| x = F.relu(self.conv3(x)) | |
| return int(np.prod(x.size())) | |
| def forward(self, state): | |
| x = F.relu(self.conv1(state)) | |
| x = F.relu(self.conv2(x)) | |
| x = F.relu(self.conv3(x)) | |
| x = x.view(x.size(0), -1) | |
| value = self.value(F.relu(self.value_fc(x))) | |
| advantage = self.advantage(F.relu(self.advantage_fc(x))) | |
| return value + advantage - advantage.mean(dim=1, keepdim=True) | |
| # โโโ Preprocessing โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def preprocess_frame(frame, size=84): | |
| gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) | |
| cropped = gray[20:200, :] | |
| resized = cv2.resize(cropped, (size, size), interpolation=cv2.INTER_AREA) | |
| return resized / 255.0 | |
| class FrameStack: | |
| def __init__(self, num_frames=4, frame_size=84): | |
| self.num_frames = num_frames | |
| self.frame_size = frame_size | |
| self.frames = [] | |
| def reset(self, frame): | |
| processed = preprocess_frame(frame, self.frame_size) | |
| self.frames = [processed] * self.num_frames | |
| return self._get_state() | |
| def step(self, frame): | |
| processed = preprocess_frame(frame, self.frame_size) | |
| self.frames.append(processed) | |
| self.frames = self.frames[-self.num_frames:] | |
| return self._get_state() | |
| def _get_state(self): | |
| return np.array(self.frames, dtype=np.float32) | |
| # โโโ Model Loading โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| CHECKPOINTS = { | |
| "Baseline DQN (avg: 524.75)": { | |
| "file": "checkpoints/Baseline_DQN_best_checkpoint.pth", | |
| "dueling": False, | |
| }, | |
| "Double DQN (avg: 650.20) โญ Best": { | |
| "file": "checkpoints/Double_DQN_best_checkpoint.pth", | |
| "dueling": False, | |
| }, | |
| "Dueling DQN (avg: 497.55)": { | |
| "file": "checkpoints/Dueling_DQN_best_checkpoint.pth", | |
| "dueling": True, | |
| }, | |
| } | |
| def load_model(variant_name): | |
| config = CHECKPOINTS[variant_name] | |
| NetworkClass = DuelingQNetwork if config["dueling"] else QNetwork | |
| model = NetworkClass(action_size=6, seed=42).to(device) | |
| model.load_state_dict( | |
| torch.load(config["file"], map_location=device, weights_only=True) | |
| ) | |
| model.eval() | |
| return model | |
| # โโโ Game Runner โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| ACTION_NAMES = ["NOOP", "FIRE", "RIGHT", "LEFT", "RIGHTFIRE", "LEFTFIRE"] | |
| def play_game(variant_name, seed=42, max_steps=3000, fps=30): | |
| """Run one full game and return a video file path.""" | |
| model = load_model(variant_name) | |
| env = gym.make("ALE/SpaceInvaders-v5", render_mode="rgb_array") | |
| frame_stack = FrameStack(4, 84) | |
| obs, info = env.reset(seed=seed) | |
| state = frame_stack.reset(obs) | |
| frames = [] | |
| total_reward = 0 | |
| step_count = 0 | |
| while step_count < max_steps: | |
| # Greedy action selection (no exploration) | |
| state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| q_values = model(state_tensor) | |
| action = q_values.argmax(1).item() | |
| obs, reward, terminated, truncated, info = env.step(action) | |
| done = terminated or truncated | |
| state = frame_stack.step(obs) | |
| total_reward += reward | |
| step_count += 1 | |
| # Capture rendered frame with HUD overlay | |
| frame = env.render() | |
| frame_with_hud = add_hud(frame.copy(), total_reward, step_count, | |
| ACTION_NAMES[action], variant_name) | |
| frames.append(frame_with_hud) | |
| if done: | |
| # Add a few freeze frames at the end | |
| for _ in range(fps): | |
| frames.append(frame_with_hud) | |
| break | |
| env.close() | |
| # Write video | |
| tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| writer = imageio.get_writer(tmp.name, fps=fps, quality=8) | |
| for f in frames: | |
| writer.append_data(f) | |
| writer.close() | |
| return tmp.name, f"**Score: {int(total_reward)}** | Steps: {step_count}" | |
| def add_hud(frame, score, steps, action, variant): | |
| """Add score/action overlay to the frame.""" | |
| h, w = frame.shape[:2] | |
| scale = max(1, w // 160) | |
| # Semi-transparent bar at bottom | |
| overlay = frame.copy() | |
| bar_h = 18 * scale | |
| cv2.rectangle(overlay, (0, h - bar_h), (w, h), (0, 0, 0), -1) | |
| frame = cv2.addWeighted(overlay, 0.7, frame, 0.3, 0) | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.35 * scale | |
| thickness = max(1, scale // 2) | |
| cv2.putText(frame, f"Score: {int(score)}", (4, h - 5 * scale), | |
| font, font_scale, (0, 255, 0), thickness) | |
| cv2.putText(frame, f"{action}", (w // 2 - 15 * scale, h - 5 * scale), | |
| font, font_scale, (255, 255, 0), thickness) | |
| cv2.putText(frame, f"Step: {steps}", (w - 60 * scale, h - 5 * scale), | |
| font, font_scale, (200, 200, 200), thickness) | |
| return frame | |
| # โโโ Gradio Interface โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def run_demo(variant, seed): | |
| seed = int(seed) | |
| video_path, result_text = play_game(variant, seed=seed) | |
| return video_path, result_text | |
| with gr.Blocks( | |
| title="DQN Space Invaders", | |
| theme=gr.themes.Base( | |
| primary_hue="blue", | |
| neutral_hue="slate", | |
| ), | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐น๏ธ Deep Q-Network โ Space Invaders | |
| Watch trained DQN agents play Atari Space Invaders in real-time. | |
| Three variants trained from raw pixels using PyTorch. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("๐ฎ Play Game"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| variant_dropdown = gr.Dropdown( | |
| choices=list(CHECKPOINTS.keys()), | |
| value="Double DQN (avg: 650.20) โญ Best", | |
| label="Agent Variant", | |
| ) | |
| seed_input = gr.Number( | |
| value=42, | |
| label="Random Seed", | |
| info="Change for a different game", | |
| precision=0, | |
| ) | |
| play_btn = gr.Button("โถ Play Game", variant="primary", size="lg") | |
| result_text = gr.Markdown("") | |
| gr.Markdown( | |
| """ | |
| --- | |
| **About the agents:** | |
| - **Baseline DQN** โ Standard architecture | |
| - **Double DQN** โญ โ Reduces Q-value overestimation | |
| - **Dueling DQN** โ Separates state value from action advantage | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| video_output = gr.Video(label="Gameplay", autoplay=True) | |
| play_btn.click( | |
| fn=run_demo, | |
| inputs=[variant_dropdown, seed_input], | |
| outputs=[video_output, result_text], | |
| ) | |
| with gr.TabItem("๐ Training Info"): | |
| gr.Markdown( | |
| """ | |
| ## Training Results | |
| All three variants exceeded 490+ average score over 100 consecutive episodes. | |
| | Variant | Avg Score | Best Score | Episodes | Training Time | | |
| |---------|-----------|-----------|----------|---------------| | |
| | Baseline DQN | 524.75 | 586.45 | 1,470 | 7,000 | | |
| | Double DQN | 650.20 | 650.20 | 1,355 | 6,090 | | |
| | Dueling DQN | 497.55 | 647.05 | 1,465 | 7,000 | | |
| **Key Finding:** Double DQN achieved the highest sustained performance with zero degradation between best and final averages. Dueling DQN reached the highest peak (647) but exhibited catastrophic forgetting in extended training โ demonstrating why model checkpointing is critical in RL. | |
| --- | |
| ## Network Architecture | |
| All variants share a convolutional backbone from Mnih et al. (2015): | |
| | Layer | Filters | Kernel | Stride | Output | | |
| |-------|---------|--------|--------|--------| | |
| | Conv1 | 32 | 8ร8 | 4 | 20ร20ร32 | | |
| | Conv2 | 64 | 4ร4 | 2 | 9ร9ร64 | | |
| | Conv3 | 64 | 3ร3 | 1 | 7ร7ร64 | | |
| **Variant-specific heads:** | |
| - **Baseline DQN:** Standard single-stream fully connected head โ Q-values | |
| - **Double DQN:** Same architecture, but decouples action selection (local network) from evaluation (target network) to reduce overestimation bias | |
| - **Dueling DQN:** Splits into Value and Advantage streams โ Q(s,a) = V(s) + A(s,a) - mean(A) โ allowing the network to learn state quality independently of action choice | |
| --- | |
| ## Preprocessing Pipeline | |
| Raw Atari frames (210ร160ร3) are converted into CNN-ready tensors (4ร84ร84), reducing input dimensionality by 72% while preserving all gameplay information: | |
| 1. Convert RGB โ Grayscale | |
| 2. Crop to game region (20:200) | |
| 3. Resize to 84ร84 | |
| 4. Stack 4 consecutive frames | |
| 5. Normalize to [0, 1] | |
| --- | |
| ## Training Configuration | |
| | Config | Baseline | Double | Dueling | | |
| |--------|----------|--------|---------| | |
| | Learning Rate | 1e-4 | 1.5e-4 | 1e-4 | | |
| | Epsilon Decay | 0.9993 | 0.9992 | 0.9995 | | |
| | Batch Size | 32 | 32 | 32 | | |
| | Gamma (Discount) | 0.99 | 0.99 | 0.99 | | |
| **Design rationale:** | |
| - Double DQN uses a higher learning rate because its more conservative Q-estimates can tolerate faster learning without divergence | |
| - Dueling DQN uses slower epsilon decay to benefit from extended exploration during value/advantage stream learning | |
| --- | |
| ## Key Components | |
| - **Experience Replay Buffer** (100K transitions): Breaks temporal correlation and enables sample reuse | |
| - **Target Network** with soft updates (ฯ=0.001): Stabilizes Q-value targets | |
| - **ฮต-Greedy Exploration:** Decays from 1.0 โ 0.01 at variant-specific rates | |
| - **Gradient Clipping** (max norm 10): Prevents exploding gradients from large TD errors | |
| --- | |
| ## Key Findings | |
| 1. **Low learning rate is essential** โ Standard 1e-3 causes divergence in Atari; 1e-4 provides stable convergence | |
| 2. **Double DQN is the most stable** โ Zero gap between best and final averages; the only variant to reach its extended target | |
| 3. **Longer training โ better** โ All variants showed diminishing returns or degradation past ~6,000 episodes | |
| 4. **Checkpointing matters** โ Dueling DQN's peak performance (647) was 150 points above its final average; the best policy is not always the last one | |
| --- | |
| ## References | |
| - Mnih, V. et al. (2015). [Human-level control through deep reinforcement learning](https://www.nature.com/articles/nature14236). Nature, 518(7540). | |
| - Van Hasselt, H. et al. (2016). [Deep Reinforcement Learning with Double Q-learning](https://arxiv.org/abs/1509.06461). AAAI. | |
| - Wang, Z. et al. (2016). [Dueling Network Architectures for Deep RL](https://arxiv.org/abs/1511.06581). ICML. | |
| [GitHub Repository](https://github.com/antonisbast/DQN-SpaceInvaders) | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |