| | import pickle |
| | import random |
| | import time |
| | from collections import deque |
| |
|
| | import gym_super_mario_bros |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.optim as optim |
| | from gym_super_mario_bros.actions import COMPLEX_MOVEMENT |
| | from nes_py.wrappers import JoypadSpace |
| |
|
| | from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, |
| | QHBoxLayout, QPushButton, QLabel, QComboBox, |
| | QTextEdit, QProgressBar, QTabWidget, QFrame, QGroupBox) |
| | from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread |
| | from PyQt5.QtGui import QImage, QPixmap, QFont |
| | import sys |
| | import cv2 |
| |
|
| | |
| | try: |
| | from wrappers import * |
| | except ImportError: |
| | |
| | class SimpleWrapper: |
| | def __init__(self, env): |
| | self.env = env |
| | self.action_space = env.action_space |
| | self.observation_space = env.observation_space |
| | |
| | def reset(self): |
| | return self.env.reset() |
| | |
| | def step(self, action): |
| | return self.env.step(action) |
| | |
| | def render(self, mode='rgb_array'): |
| | return self.env.render(mode) |
| | |
| | def close(self): |
| | if hasattr(self.env, 'close'): |
| | self.env.close() |
| | |
| | def wrap_mario(env): |
| | return SimpleWrapper(env) |
| |
|
| |
|
| | class FrameStacker: |
| | """Handles frame stacking and preprocessing""" |
| | def __init__(self, frame_size=(84, 84), stack_size=4): |
| | self.frame_size = frame_size |
| | self.stack_size = stack_size |
| | self.frames = deque(maxlen=stack_size) |
| | |
| | def preprocess_frame(self, frame): |
| | """Convert frame to grayscale and resize""" |
| | |
| | gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) |
| | |
| | resized = cv2.resize(gray, self.frame_size, interpolation=cv2.INTER_AREA) |
| | |
| | normalized = resized.astype(np.float32) / 255.0 |
| | return normalized |
| | |
| | def reset(self, frame): |
| | """Reset frame stack with initial frame""" |
| | self.frames.clear() |
| | processed_frame = self.preprocess_frame(frame) |
| | for _ in range(self.stack_size): |
| | self.frames.append(processed_frame) |
| | return self.get_stacked_frames() |
| | |
| | def append(self, frame): |
| | """Add new frame to stack""" |
| | processed_frame = self.preprocess_frame(frame) |
| | self.frames.append(processed_frame) |
| | return self.get_stacked_frames() |
| | |
| | def get_stacked_frames(self): |
| | """Get stacked frames as numpy array""" |
| | stacked = np.array(self.frames) |
| | return np.ascontiguousarray(stacked) |
| |
|
| |
|
| | class replay_memory(object): |
| | def __init__(self, N): |
| | self.memory = deque(maxlen=N) |
| |
|
| | def push(self, transition): |
| | self.memory.append(transition) |
| |
|
| | def sample(self, n): |
| | return random.sample(self.memory, n) |
| |
|
| | def __len__(self): |
| | return len(self.memory) |
| |
|
| |
|
| | class DuelingDQNModel(nn.Module): |
| | def __init__(self, n_frame, n_action, device): |
| | super(DuelingDQNModel, self).__init__() |
| | |
| | |
| | self.conv_layers = nn.Sequential( |
| | nn.Conv2d(n_frame, 32, kernel_size=8, stride=4), |
| | nn.ReLU(), |
| | nn.Conv2d(32, 64, kernel_size=4, stride=2), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 64, kernel_size=3, stride=1), |
| | nn.ReLU() |
| | ) |
| | |
| | |
| | self.conv_out_size = self._get_conv_out((n_frame, 84, 84)) |
| | |
| | |
| | self.advantage_stream = nn.Sequential( |
| | nn.Linear(self.conv_out_size, 512), |
| | nn.ReLU(), |
| | nn.Linear(512, n_action) |
| | ) |
| | |
| | |
| | self.value_stream = nn.Sequential( |
| | nn.Linear(self.conv_out_size, 512), |
| | nn.ReLU(), |
| | nn.Linear(512, 1) |
| | ) |
| |
|
| | self.device = device |
| | self.apply(self.init_weights) |
| |
|
| | def _get_conv_out(self, shape): |
| | with torch.no_grad(): |
| | x = torch.zeros(1, *shape) |
| | x = self.conv_layers(x) |
| | return int(np.prod(x.size())) |
| |
|
| | def init_weights(self, m): |
| | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | if m.bias is not None: |
| | m.bias.data.fill_(0.01) |
| |
|
| | def forward(self, x): |
| | if not isinstance(x, torch.Tensor): |
| | x = torch.FloatTensor(x).to(self.device) |
| | |
| | |
| | x = self.conv_layers(x) |
| | x = x.view(x.size(0), -1) |
| | |
| | |
| | advantage = self.advantage_stream(x) |
| | value = self.value_stream(x) |
| | |
| | |
| | q_values = value + (advantage - advantage.mean(dim=1, keepdim=True)) |
| | |
| | return q_values |
| |
|
| |
|
| | def train(q, q_target, memory, batch_size, gamma, optimizer, device): |
| | if len(memory) < batch_size: |
| | return 0.0 |
| | |
| | transitions = memory.sample(batch_size) |
| | s, r, a, s_prime, done = list(map(list, zip(*transitions))) |
| | |
| | |
| | s = np.array([np.ascontiguousarray(arr) for arr in s]) |
| | s_prime = np.array([np.ascontiguousarray(arr) for arr in s_prime]) |
| | |
| | |
| | s_tensor = torch.FloatTensor(s).to(device) |
| | s_prime_tensor = torch.FloatTensor(s_prime).to(device) |
| | |
| | |
| | with torch.no_grad(): |
| | next_q_values = q_target(s_prime_tensor) |
| | next_actions = next_q_values.max(1)[1].unsqueeze(1) |
| | next_q_value = next_q_values.gather(1, next_actions) |
| | |
| | |
| | r = torch.FloatTensor(r).unsqueeze(1).to(device) |
| | done = torch.FloatTensor(done).unsqueeze(1).to(device) |
| | target_q_values = r + gamma * next_q_value * (1 - done) |
| | |
| | |
| | a_tensor = torch.LongTensor(a).unsqueeze(1).to(device) |
| | current_q_values = q(s_tensor).gather(1, a_tensor) |
| | |
| | |
| | loss = F.smooth_l1_loss(current_q_values, target_q_values) |
| | |
| | |
| | optimizer.zero_grad() |
| | loss.backward() |
| | |
| | |
| | torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0) |
| | |
| | optimizer.step() |
| | return loss.item() |
| |
|
| |
|
| | def copy_weights(q, q_target): |
| | q_dict = q.state_dict() |
| | q_target.load_state_dict(q_dict) |
| |
|
| |
|
| | class MarioTrainingThread(QThread): |
| | update_signal = pyqtSignal(dict) |
| | frame_signal = pyqtSignal(np.ndarray) |
| | |
| | def __init__(self, device="cpu"): |
| | super().__init__() |
| | self.device = device |
| | self.running = False |
| | self.env = None |
| | self.q = None |
| | self.q_target = None |
| | self.optimizer = None |
| | self.frame_stacker = None |
| | |
| | |
| | self.gamma = 0.99 |
| | self.batch_size = 32 |
| | self.memory_size = 10000 |
| | self.eps = 1.0 |
| | self.eps_min = 0.01 |
| | self.eps_decay = 0.995 |
| | self.update_interval = 1000 |
| | self.save_interval = 100 |
| | self.print_interval = 10 |
| | |
| | self.memory = None |
| | self.t = 0 |
| | self.k = 0 |
| | self.total_score = 0.0 |
| | self.loss_accumulator = 0.0 |
| | self.best_score = -float('inf') |
| | self.last_x_pos = 0 |
| | |
| | def setup_training(self): |
| | n_frame = 4 |
| | try: |
| | self.env = gym_super_mario_bros.make("SuperMarioBros-v3") |
| | self.env = JoypadSpace(self.env, COMPLEX_MOVEMENT) |
| | self.env = wrap_mario(self.env) |
| | |
| | |
| | self.frame_stacker = FrameStacker(frame_size=(84, 84), stack_size=n_frame) |
| | |
| | self.q = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device) |
| | self.q_target = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device) |
| | |
| | copy_weights(self.q, self.q_target) |
| | |
| | |
| | self.q_target.eval() |
| | |
| | |
| | self.optimizer = optim.Adam(self.q.parameters(), lr=0.0001, weight_decay=1e-5) |
| | |
| | self.memory = replay_memory(self.memory_size) |
| | |
| | self.log_message(f"✅ Training setup complete - Actions: {self.env.action_space.n}, Device: {self.device}") |
| | |
| | except Exception as e: |
| | self.log_message(f"❌ Error setting up training: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | self.running = False |
| | |
| | def run(self): |
| | self.running = True |
| | self.setup_training() |
| | |
| | if not self.running: |
| | return |
| | |
| | start_time = time.perf_counter() |
| | score_lst = [] |
| | |
| | try: |
| | for k in range(1000000): |
| | if not self.running: |
| | break |
| | |
| | |
| | frame = self.env.reset() |
| | s = self.frame_stacker.reset(frame) |
| | done = False |
| | episode_loss = 0.0 |
| | episode_steps = 0 |
| | episode_score = 0.0 |
| | self.last_x_pos = 0 |
| |
|
| | while not done and self.running: |
| | |
| | s_processed = np.ascontiguousarray(s) |
| | |
| | |
| | if np.random.random() <= self.eps: |
| | a = self.env.action_space.sample() |
| | else: |
| | with torch.no_grad(): |
| | |
| | state_tensor = torch.FloatTensor(s_processed).unsqueeze(0).to(self.device) |
| | q_values = self.q(state_tensor) |
| | |
| | if self.device == "cuda" or self.device == "mps": |
| | a = np.argmax(q_values.cpu().numpy()) |
| | else: |
| | a = np.argmax(q_values.detach().numpy()) |
| | |
| | |
| | frame, r, done, info = self.env.step(a) |
| | |
| | |
| | s_prime = self.frame_stacker.append(frame) |
| | |
| | episode_score += r |
| | |
| | |
| | reward = r |
| | |
| | |
| | if 'x_pos' in info: |
| | x_pos = info['x_pos'] |
| | x_progress = x_pos - self.last_x_pos |
| | if x_progress > 0: |
| | reward += 0.1 * x_progress |
| | self.last_x_pos = x_pos |
| | |
| | |
| | if done and info.get('flag_get', False): |
| | reward += 100.0 |
| | self.log_message(f"🎉 LEVEL COMPLETED at episode {k}! 🎉") |
| | |
| | |
| | s_contiguous = np.ascontiguousarray(s) |
| | s_prime_contiguous = np.ascontiguousarray(s_prime) |
| | self.memory.push((s_contiguous, float(reward), int(a), s_prime_contiguous, int(1 - done))) |
| | |
| | s = s_prime |
| | stage = info.get('stage', 1) |
| | world = info.get('world', 1) |
| | |
| | |
| | try: |
| | display_frame = self.env.render() |
| | if display_frame is not None: |
| | |
| | frame_contiguous = np.ascontiguousarray(display_frame) |
| | self.frame_signal.emit(frame_contiguous) |
| | except Exception as e: |
| | |
| | frame = np.zeros((240, 256, 3), dtype=np.uint8) |
| | self.frame_signal.emit(frame) |
| | |
| | |
| | if len(self.memory) > self.batch_size: |
| | loss_val = train(self.q, self.q_target, self.memory, self.batch_size, |
| | self.gamma, self.optimizer, self.device) |
| | if loss_val > 0: |
| | self.loss_accumulator += loss_val |
| | episode_loss += loss_val |
| | self.t += 1 |
| | |
| | |
| | if self.t % self.update_interval == 0: |
| | copy_weights(self.q, self.q_target) |
| | |
| | episode_steps += 1 |
| |
|
| | |
| | if episode_steps % 10 == 0: |
| | progress_data = { |
| | 'episode': k, |
| | 'total_reward': episode_score, |
| | 'steps': episode_steps, |
| | 'epsilon': self.eps, |
| | 'world': world, |
| | 'stage': stage, |
| | 'loss': episode_loss / (episode_steps + 1e-8), |
| | 'memory_size': len(self.memory), |
| | 'x_pos': info.get('x_pos', 0), |
| | 'score': info.get('score', 0), |
| | 'coins': info.get('coins', 0), |
| | 'time': info.get('time', 400), |
| | 'flag_get': info.get('flag_get', False) |
| | } |
| | self.update_signal.emit(progress_data) |
| |
|
| | |
| | if self.eps > self.eps_min: |
| | self.eps *= self.eps_decay |
| |
|
| | |
| | self.total_score += episode_score |
| |
|
| | |
| | if episode_score > self.best_score and k > 0: |
| | self.best_score = episode_score |
| | torch.save(self.q.state_dict(), "enhanced_mario_q_best.pth") |
| | torch.save(self.q_target.state_dict(), "enhanced_mario_q_target_best.pth") |
| | self.log_message(f"💾 New best model saved! Score: {self.best_score:.2f}") |
| |
|
| | |
| | if k % self.save_interval == 0 and k > 0: |
| | torch.save(self.q.state_dict(), "enhanced_mario_q.pth") |
| | torch.save(self.q_target.state_dict(), "enhanced_mario_q_target.pth") |
| | self.log_message(f"💾 Models saved at episode {k}") |
| |
|
| | |
| | if k % self.print_interval == 0 and k > 0: |
| | time_spent = time.perf_counter() - start_time |
| | start_time = time.perf_counter() |
| | |
| | avg_loss = self.loss_accumulator / (self.print_interval * max(episode_steps, 1)) |
| | avg_score = self.total_score / self.print_interval |
| | |
| | log_msg = ( |
| | f"{self.device} | Ep: {k} | Score: {avg_score:.2f} | Loss: {avg_loss:.4f} | " |
| | f"Stage: {world}-{stage} | Eps: {self.eps:.3f} | Time: {time_spent:.2f}s | " |
| | f"Mem: {len(self.memory)} | Steps: {episode_steps}" |
| | ) |
| | self.log_message(log_msg) |
| | |
| | score_lst.append(avg_score) |
| | self.total_score = 0.0 |
| | self.loss_accumulator = 0.0 |
| | |
| | try: |
| | pickle.dump(score_lst, open("score.p", "wb")) |
| | except Exception as e: |
| | self.log_message(f"⚠️ Could not save scores: {e}") |
| | |
| | self.k = k |
| | |
| | except Exception as e: |
| | self.log_message(f"❌ Training error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | |
| | def log_message(self, message): |
| | progress_data = { |
| | 'log_message': message |
| | } |
| | self.update_signal.emit(progress_data) |
| | |
| | def stop(self): |
| | self.running = False |
| | if self.env: |
| | try: |
| | self.env.close() |
| | except: |
| | pass |
| |
|
| |
|
| | class MarioRLApp(QMainWindow): |
| | def __init__(self): |
| | super().__init__() |
| | self.training_thread = None |
| | self.init_ui() |
| | |
| | def init_ui(self): |
| | self.setWindowTitle('🎮 Super Mario Bros - Dueling DQN Training') |
| | self.setGeometry(100, 100, 1200, 800) |
| | |
| | central_widget = QWidget() |
| | self.setCentralWidget(central_widget) |
| | layout = QVBoxLayout(central_widget) |
| | |
| | |
| | title = QLabel('🎮 Super Mario Bros - Enhanced Dueling DQN') |
| | title.setFont(QFont('Arial', 16, QFont.Bold)) |
| | title.setAlignment(Qt.AlignCenter) |
| | layout.addWidget(title) |
| | |
| | |
| | control_layout = QHBoxLayout() |
| | |
| | self.device_combo = QComboBox() |
| | self.device_combo.addItems(['cpu', 'cuda', 'mps']) |
| | |
| | self.start_btn = QPushButton('Start Training') |
| | self.start_btn.clicked.connect(self.start_training) |
| | |
| | self.stop_btn = QPushButton('Stop Training') |
| | self.stop_btn.clicked.connect(self.stop_training) |
| | self.stop_btn.setEnabled(False) |
| | |
| | self.load_btn = QPushButton('Load Model') |
| | self.load_btn.clicked.connect(self.load_model) |
| | |
| | control_layout.addWidget(QLabel('Device:')) |
| | control_layout.addWidget(self.device_combo) |
| | control_layout.addWidget(self.start_btn) |
| | control_layout.addWidget(self.stop_btn) |
| | control_layout.addWidget(self.load_btn) |
| | control_layout.addStretch() |
| | |
| | layout.addLayout(control_layout) |
| | |
| | |
| | content_layout = QHBoxLayout() |
| | |
| | |
| | left_frame = QFrame() |
| | left_frame.setFrameStyle(QFrame.Box) |
| | left_layout = QVBoxLayout(left_frame) |
| | |
| | self.game_display = QLabel() |
| | self.game_display.setMinimumSize(400, 300) |
| | self.game_display.setAlignment(Qt.AlignCenter) |
| | self.game_display.setText('Game display will appear here\nPress "Start Training" to begin') |
| | self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;') |
| | |
| | left_layout.addWidget(QLabel('Mario Game Display:')) |
| | left_layout.addWidget(self.game_display) |
| | |
| | |
| | right_frame = QFrame() |
| | right_frame.setFrameStyle(QFrame.Box) |
| | right_layout = QVBoxLayout(right_frame) |
| | |
| | |
| | stats_group = QGroupBox("Training Statistics") |
| | stats_layout = QVBoxLayout(stats_group) |
| | |
| | self.episode_label = QLabel('Episode: 0') |
| | self.world_label = QLabel('World: 1-1') |
| | self.score_label = QLabel('Score: 0') |
| | self.reward_label = QLabel('Episode Reward: 0') |
| | self.steps_label = QLabel('Steps: 0') |
| | self.epsilon_label = QLabel('Epsilon: 1.000') |
| | self.loss_label = QLabel('Loss: 0.0000') |
| | self.memory_label = QLabel('Memory: 0') |
| | self.xpos_label = QLabel('X Position: 0') |
| | self.coins_label = QLabel('Coins: 0') |
| | self.time_label = QLabel('Time: 400') |
| | self.flag_label = QLabel('Flag: No') |
| | |
| | stats_layout.addWidget(self.episode_label) |
| | stats_layout.addWidget(self.world_label) |
| | stats_layout.addWidget(self.score_label) |
| | stats_layout.addWidget(self.reward_label) |
| | stats_layout.addWidget(self.steps_label) |
| | stats_layout.addWidget(self.epsilon_label) |
| | stats_layout.addWidget(self.loss_label) |
| | stats_layout.addWidget(self.memory_label) |
| | stats_layout.addWidget(self.xpos_label) |
| | stats_layout.addWidget(self.coins_label) |
| | stats_layout.addWidget(self.time_label) |
| | stats_layout.addWidget(self.flag_label) |
| | |
| | right_layout.addWidget(stats_group) |
| | |
| | |
| | right_layout.addWidget(QLabel('Training Log:')) |
| | self.log_text = QTextEdit() |
| | self.log_text.setMaximumHeight(300) |
| | right_layout.addWidget(self.log_text) |
| | |
| | content_layout.addWidget(left_frame) |
| | content_layout.addWidget(right_frame) |
| | layout.addLayout(content_layout) |
| | |
| | def start_training(self): |
| | device = self.device_combo.currentText() |
| | |
| | |
| | if device == "cuda" and not torch.cuda.is_available(): |
| | self.log_text.append("❌ CUDA not available, using CPU instead") |
| | device = "cpu" |
| | elif device == "mps" and not torch.backends.mps.is_available(): |
| | self.log_text.append("❌ MPS not available, using CPU instead") |
| | device = "cpu" |
| | |
| | self.training_thread = MarioTrainingThread(device) |
| | self.training_thread.update_signal.connect(self.update_training_info) |
| | self.training_thread.frame_signal.connect(self.update_game_display) |
| | self.training_thread.start() |
| | |
| | self.start_btn.setEnabled(False) |
| | self.stop_btn.setEnabled(True) |
| | |
| | self.log_text.append(f'🚀 Started Dueling DQN training on {device}...') |
| | |
| | def stop_training(self): |
| | if self.training_thread: |
| | self.training_thread.stop() |
| | self.training_thread.wait() |
| | |
| | self.start_btn.setEnabled(True) |
| | self.stop_btn.setEnabled(False) |
| | self.log_text.append('⏹️ Training stopped.') |
| | |
| | def load_model(self): |
| | |
| | self.log_text.append('📁 Load model functionality not implemented yet') |
| | |
| | def update_training_info(self, data): |
| | if 'episode' in data: |
| | self.episode_label.setText(f'Episode: {data["episode"]}') |
| | if 'world' in data and 'stage' in data: |
| | self.world_label.setText(f'World: {data["world"]}-{data["stage"]}') |
| | if 'score' in data: |
| | self.score_label.setText(f'Score: {data["score"]}') |
| | if 'total_reward' in data: |
| | self.reward_label.setText(f'Episode Reward: {data["total_reward"]:.2f}') |
| | if 'steps' in data: |
| | self.steps_label.setText(f'Steps: {data["steps"]}') |
| | if 'epsilon' in data: |
| | self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}') |
| | if 'loss' in data: |
| | self.loss_label.setText(f'Loss: {data["loss"]:.4f}') |
| | if 'memory_size' in data: |
| | self.memory_label.setText(f'Memory: {data["memory_size"]}') |
| | if 'x_pos' in data: |
| | self.xpos_label.setText(f'X Position: {data["x_pos"]}') |
| | if 'coins' in data: |
| | self.coins_label.setText(f'Coins: {data["coins"]}') |
| | if 'time' in data: |
| | self.time_label.setText(f'Time: {data["time"]}') |
| | if 'flag_get' in data: |
| | flag_text = "Yes" if data["flag_get"] else "No" |
| | self.flag_label.setText(f'Flag: {flag_text}') |
| | if 'log_message' in data: |
| | self.log_text.append(data['log_message']) |
| | |
| | self.log_text.verticalScrollBar().setValue( |
| | self.log_text.verticalScrollBar().maximum() |
| | ) |
| | |
| | def update_game_display(self, frame): |
| | if frame is not None: |
| | try: |
| | h, w, ch = frame.shape |
| | bytes_per_line = ch * w |
| | |
| | frame_contiguous = np.ascontiguousarray(frame) |
| | q_img = QImage(frame_contiguous.data, w, h, bytes_per_line, QImage.Format_RGB888) |
| | pixmap = QPixmap.fromImage(q_img) |
| | self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio)) |
| | except Exception as e: |
| | print(f"Error updating display: {e}") |
| | |
| | def closeEvent(self, event): |
| | self.stop_training() |
| | event.accept() |
| |
|
| |
|
| | def main(): |
| | |
| | torch.manual_seed(42) |
| | np.random.seed(42) |
| | random.seed(42) |
| | |
| | app = QApplication(sys.argv) |
| | window = MarioRLApp() |
| | window.show() |
| | sys.exit(app.exec_()) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |