| import sys |
| import os |
| import numpy as np |
| import random |
| from collections import deque |
| import gymnasium as gym |
| import ale_py |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torch.nn.functional as F |
| from torch.distributions import Categorical |
|
|
| from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, |
| QHBoxLayout, QPushButton, QLabel, QComboBox, |
| QTextEdit, QProgressBar, QTabWidget, QFrame) |
| from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread |
| from PyQt5.QtGui import QImage, QPixmap, QFont |
|
|
| |
| gym.register_envs(ale_py) |
|
|
| |
| def create_env(env_name='ALE/SpaceInvaders-v5'): |
| """ |
| Create ALE environment with Gymnasium API |
| """ |
| env = gym.make(env_name, render_mode='rgb_array') |
| return env |
|
|
| |
| class DuelingDQN(nn.Module): |
| def __init__(self, input_shape, n_actions): |
| super(DuelingDQN, self).__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), |
| nn.ReLU(), |
| nn.Conv2d(32, 64, kernel_size=4, stride=2), |
| nn.ReLU(), |
| nn.Conv2d(64, 64, kernel_size=3, stride=1), |
| nn.ReLU() |
| ) |
| |
| conv_out_size = self._get_conv_out(input_shape) |
| |
| self.fc_advantage = nn.Sequential( |
| nn.Linear(conv_out_size, 256), |
| nn.ReLU(), |
| nn.Linear(256, n_actions) |
| ) |
| |
| self.fc_value = nn.Sequential( |
| nn.Linear(conv_out_size, 256), |
| nn.ReLU(), |
| nn.Linear(256, 1) |
| ) |
| |
| def _get_conv_out(self, shape): |
| o = self.conv(torch.zeros(1, *shape)) |
| return int(np.prod(o.size())) |
| |
| def forward(self, x): |
| conv_out = self.conv(x).view(x.size()[0], -1) |
| advantage = self.fc_advantage(conv_out) |
| value = self.fc_value(conv_out) |
| return value + advantage - advantage.mean() |
|
|
| |
| class PPONetwork(nn.Module): |
| def __init__(self, input_shape, n_actions): |
| super(PPONetwork, self).__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), |
| nn.ReLU(), |
| nn.Conv2d(32, 64, kernel_size=4, stride=2), |
| nn.ReLU(), |
| nn.Conv2d(64, 64, kernel_size=3, stride=1), |
| nn.ReLU() |
| ) |
| |
| conv_out_size = self._get_conv_out(input_shape) |
| |
| self.actor = nn.Sequential( |
| nn.Linear(conv_out_size, 256), |
| nn.ReLU(), |
| nn.Linear(256, n_actions), |
| nn.Softmax(dim=-1) |
| ) |
| |
| self.critic = nn.Sequential( |
| nn.Linear(conv_out_size, 256), |
| nn.ReLU(), |
| nn.Linear(256, 1) |
| ) |
| |
| def _get_conv_out(self, shape): |
| o = self.conv(torch.zeros(1, *shape)) |
| return int(np.prod(o.size())) |
| |
| def forward(self, x): |
| conv_out = self.conv(x).view(x.size()[0], -1) |
| return self.actor(conv_out), self.critic(conv_out) |
|
|
| |
| class DuelingDQNAgent: |
| def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, epsilon=1.0, |
| epsilon_min=0.01, epsilon_decay=0.999, memory_size=50000, batch_size=32): |
| self.state_dim = state_dim |
| self.action_dim = action_dim |
| self.lr = lr |
| self.gamma = gamma |
| self.epsilon = epsilon |
| self.epsilon_min = epsilon_min |
| self.epsilon_decay = epsilon_decay |
| self.batch_size = batch_size |
| |
| self.memory = deque(maxlen=memory_size) |
| self.model = DuelingDQN(state_dim, action_dim) |
| self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-5) |
| self.criterion = nn.SmoothL1Loss() |
| |
| |
| self.target_model = DuelingDQN(state_dim, action_dim) |
| self.update_target_network() |
| self.target_update_frequency = 1000 |
| self.train_step = 0 |
| |
| def update_target_network(self): |
| self.target_model.load_state_dict(self.model.state_dict()) |
| |
| def remember(self, state, action, reward, next_state, done): |
| self.memory.append((state, action, reward, next_state, done)) |
| |
| def act(self, state): |
| if np.random.random() <= self.epsilon: |
| return random.randrange(self.action_dim) |
| |
| state = torch.FloatTensor(state).unsqueeze(0) |
| with torch.no_grad(): |
| q_values = self.model(state) |
| return np.argmax(q_values.detach().numpy()) |
| |
| def replay(self): |
| if len(self.memory) < self.batch_size: |
| return |
| |
| batch = random.sample(self.memory, self.batch_size) |
| states = torch.FloatTensor(np.array([e[0] for e in batch])) |
| actions = torch.LongTensor([e[1] for e in batch]) |
| rewards = torch.FloatTensor([e[2] for e in batch]) |
| next_states = torch.FloatTensor(np.array([e[3] for e in batch])) |
| dones = torch.BoolTensor([e[4] for e in batch]) |
| |
| current_q_values = self.model(states).gather(1, actions.unsqueeze(1)) |
| |
| with torch.no_grad(): |
| next_actions = self.model(next_states).max(1)[1] |
| next_q_values = self.target_model(next_states).gather(1, next_actions.unsqueeze(1)).squeeze() |
| |
| target_q_values = rewards + (self.gamma * next_q_values * ~dones) |
| |
| loss = self.criterion(current_q_values.squeeze(), target_q_values) |
| |
| self.optimizer.zero_grad() |
| loss.backward() |
| |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| self.optimizer.step() |
| |
| |
| self.train_step += 1 |
| if self.train_step % self.target_update_frequency == 0: |
| self.update_target_network() |
| |
| if self.epsilon > self.epsilon_min: |
| self.epsilon *= self.epsilon_decay |
|
|
| |
| class PPOAgent: |
| def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2, |
| entropy_coef=0.01, value_coef=0.5, ppo_epochs=4, batch_size=64): |
| self.state_dim = state_dim |
| self.action_dim = action_dim |
| self.gamma = gamma |
| self.epsilon = epsilon |
| self.entropy_coef = entropy_coef |
| self.value_coef = value_coef |
| self.ppo_epochs = ppo_epochs |
| self.batch_size = batch_size |
| |
| self.model = PPONetwork(state_dim, action_dim) |
| self.optimizer = optim.Adam(self.model.parameters(), lr=lr) |
| |
| self.memory = [] |
| |
| def remember(self, state, action, reward, value, log_prob): |
| self.memory.append((state, action, reward, value, log_prob)) |
| |
| def act(self, state): |
| state = torch.FloatTensor(state).unsqueeze(0) |
| with torch.no_grad(): |
| probs, value = self.model(state) |
| dist = Categorical(probs) |
| action = dist.sample() |
| return action.item(), dist.log_prob(action), value.squeeze() |
| |
| def train(self): |
| if len(self.memory) < self.batch_size: |
| return |
| |
| states, actions, rewards, values, log_probs = zip(*self.memory) |
| |
| |
| returns = [] |
| R = 0 |
| for r in reversed(rewards): |
| R = r + self.gamma * R |
| returns.insert(0, R) |
| |
| returns = torch.FloatTensor(returns) |
| old_values = torch.FloatTensor(values) |
| advantages = returns - old_values |
| |
| |
| advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) |
| |
| |
| states_tensor = torch.FloatTensor(np.array(states)) |
| actions_tensor = torch.LongTensor(actions) |
| old_log_probs = torch.FloatTensor(log_probs) |
| |
| |
| for _ in range(self.ppo_epochs): |
| |
| new_probs, new_values = self.model(states_tensor) |
| dist = Categorical(new_probs) |
| new_log_probs = dist.log_prob(actions_tensor) |
| entropy = dist.entropy().mean() |
| |
| |
| ratio = (new_log_probs - old_log_probs).exp() |
| surr1 = ratio * advantages |
| surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages |
| actor_loss = -torch.min(surr1, surr2).mean() |
| |
| critic_loss = F.mse_loss(new_values.squeeze(), returns) |
| |
| total_loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy |
| |
| self.optimizer.zero_grad() |
| total_loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) |
| self.optimizer.step() |
| |
| self.memory = [] |
|
|
| |
| class TrainingThread(QThread): |
| update_signal = pyqtSignal(dict) |
| frame_signal = pyqtSignal(np.ndarray) |
| |
| def __init__(self, algorithm='dqn', env_name='ALE/Breakout-v5'): |
| super().__init__() |
| self.algorithm = algorithm |
| self.env_name = env_name |
| self.running = False |
| self.env = None |
| self.agent = None |
| |
| def preprocess_state(self, state): |
| |
| if len(state.shape) == 3: |
| |
| state = state.mean(axis=2, keepdims=True) |
| state = state.transpose((2, 0, 1)) |
| state = state / 255.0 |
| return state |
| |
| def run(self): |
| self.running = True |
| try: |
| self.env = create_env(self.env_name) |
| state, info = self.env.reset() |
| state = self.preprocess_state(state) |
| |
| n_actions = self.env.action_space.n |
| state_dim = state.shape |
| |
| print(f"๐ฎ Training on: {self.env_name}") |
| print(f"๐ State shape: {state_dim}, Actions: {n_actions}") |
| print(f"๐ค Algorithm: {self.algorithm}") |
| |
| if self.algorithm == 'dqn': |
| self.agent = DuelingDQNAgent(state_dim, n_actions) |
| else: |
| self.agent = PPOAgent(state_dim, n_actions) |
| |
| episode = 0 |
| total_reward = 0 |
| steps = 0 |
| episode_rewards = [] |
| best_reward = -float('inf') |
| |
| while self.running: |
| try: |
| if self.algorithm == 'dqn': |
| action = self.agent.act(state) |
| next_state, reward, terminated, truncated, info = self.env.step(action) |
| done = terminated or truncated |
| next_state = self.preprocess_state(next_state) |
| self.agent.remember(state, action, reward, next_state, done) |
| self.agent.replay() |
| else: |
| action, log_prob, value = self.agent.act(state) |
| next_state, reward, terminated, truncated, info = self.env.step(action) |
| done = terminated or truncated |
| next_state = self.preprocess_state(next_state) |
| self.agent.remember(state, action, reward, value, log_prob) |
| if done: |
| self.agent.train() |
| |
| state = next_state |
| total_reward += reward |
| steps += 1 |
| |
| |
| try: |
| frame = self.env.render() |
| if frame is not None: |
| self.frame_signal.emit(frame) |
| except Exception as e: |
| |
| frame = np.zeros((210, 160, 3), dtype=np.uint8) |
| self.frame_signal.emit(frame) |
| |
| |
| if steps % 5 == 0: |
| avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward |
| progress_data = { |
| 'episode': episode, |
| 'total_reward': total_reward, |
| 'steps': steps, |
| 'epsilon': self.agent.epsilon if self.algorithm == 'dqn' else 0.2, |
| 'env_name': self.env_name, |
| 'lives': info.get('lives', 0) if isinstance(info, dict) else 0, |
| 'avg_reward': avg_reward, |
| 'best_reward': best_reward |
| } |
| self.update_signal.emit(progress_data) |
| |
| if terminated or truncated: |
| episode_rewards.append(total_reward) |
| if total_reward > best_reward: |
| best_reward = total_reward |
| |
| avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward |
| |
| print(f"๐ฏ Episode {episode}: Reward: {total_reward:.1f}, " |
| f"Steps: {steps}, Avg (last 10): {avg_reward:.1f}, " |
| f"Best: {best_reward:.1f}, Epsilon: {self.agent.epsilon:.3f}") |
| |
| episode += 1 |
| state, info = self.env.reset() |
| state = self.preprocess_state(state) |
| total_reward = 0 |
| steps = 0 |
| |
| except Exception as e: |
| print(f"โ Error in training loop: {e}") |
| import traceback |
| traceback.print_exc() |
| break |
| |
| except Exception as e: |
| print(f"โ Error setting up environment: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| def stop(self): |
| self.running = False |
| if self.env: |
| self.env.close() |
|
|
| |
| class ALE_RLApp(QMainWindow): |
| def __init__(self): |
| super().__init__() |
| self.training_thread = None |
| self.init_ui() |
| |
| def init_ui(self): |
| self.setWindowTitle('๐ฎ ALE Arcade RL Training - Enhanced') |
| self.setGeometry(100, 100, 1200, 800) |
| |
| central_widget = QWidget() |
| self.setCentralWidget(central_widget) |
| layout = QVBoxLayout(central_widget) |
| |
| |
| title = QLabel('๐ฎ Arcade Reinforcement Learning (ALE) - Enhanced Training') |
| title.setFont(QFont('Arial', 16, QFont.Bold)) |
| title.setAlignment(Qt.AlignCenter) |
| layout.addWidget(title) |
| |
| |
| control_layout = QHBoxLayout() |
| |
| self.algorithm_combo = QComboBox() |
| self.algorithm_combo.addItems(['Dueling DQN', 'PPO']) |
| |
| self.env_combo = QComboBox() |
| self.env_combo.addItems([ |
| 'ALE/Breakout-v5', |
| 'ALE/Pong-v5', |
| 'ALE/SpaceInvaders-v5', |
| 'ALE/Assault-v5', |
| 'ALE/BeamRider-v5', |
| 'ALE/Enduro-v5', |
| 'ALE/Seaquest-v5', |
| 'ALE/Qbert-v5' |
| ]) |
| |
| self.start_btn = QPushButton('๐ Start Training') |
| self.start_btn.clicked.connect(self.start_training) |
| |
| self.stop_btn = QPushButton('โน๏ธ Stop Training') |
| self.stop_btn.clicked.connect(self.stop_training) |
| self.stop_btn.setEnabled(False) |
| |
| control_layout.addWidget(QLabel('๐ค Algorithm:')) |
| control_layout.addWidget(self.algorithm_combo) |
| control_layout.addWidget(QLabel('๐ฎ Environment:')) |
| control_layout.addWidget(self.env_combo) |
| control_layout.addWidget(self.start_btn) |
| control_layout.addWidget(self.stop_btn) |
| control_layout.addStretch() |
| |
| layout.addLayout(control_layout) |
| |
| |
| content_layout = QHBoxLayout() |
| |
| |
| left_frame = QFrame() |
| left_frame.setFrameStyle(QFrame.Box) |
| left_layout = QVBoxLayout(left_frame) |
| |
| self.game_display = QLabel() |
| self.game_display.setMinimumSize(400, 300) |
| self.game_display.setAlignment(Qt.AlignCenter) |
| self.game_display.setText('Game display will appear here\nPress "๐ Start Training" to begin') |
| self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white; font-size: 14px;') |
| |
| left_layout.addWidget(QLabel('๐ฎ Game Display:')) |
| left_layout.addWidget(self.game_display) |
| |
| |
| right_frame = QFrame() |
| right_frame.setFrameStyle(QFrame.Box) |
| right_layout = QVBoxLayout(right_frame) |
| |
| |
| self.env_label = QLabel('๐ฏ Environment: Not started') |
| self.episode_label = QLabel('๐ Episode: 0') |
| self.reward_label = QLabel('๐ Total Reward: 0') |
| self.avg_reward_label = QLabel('๐ Avg Reward (last 10): 0') |
| self.best_reward_label = QLabel('โญ Best Reward: 0') |
| self.steps_label = QLabel('โฑ๏ธ Steps: 0') |
| self.epsilon_label = QLabel('๐ฒ Epsilon: 0') |
| self.lives_label = QLabel('โค๏ธ Lives: 0') |
| |
| |
| for label in [self.env_label, self.episode_label, self.reward_label, |
| self.avg_reward_label, self.best_reward_label, self.steps_label, |
| self.epsilon_label, self.lives_label]: |
| label.setStyleSheet('font-weight: bold; font-size: 12px;') |
| |
| right_layout.addWidget(self.env_label) |
| right_layout.addWidget(self.episode_label) |
| right_layout.addWidget(self.reward_label) |
| right_layout.addWidget(self.avg_reward_label) |
| right_layout.addWidget(self.best_reward_label) |
| right_layout.addWidget(self.steps_label) |
| right_layout.addWidget(self.epsilon_label) |
| right_layout.addWidget(self.lives_label) |
| |
| |
| right_layout.addWidget(QLabel('๐ Training Log:')) |
| self.log_text = QTextEdit() |
| self.log_text.setMaximumHeight(200) |
| self.log_text.setStyleSheet('font-family: monospace; font-size: 10px;') |
| right_layout.addWidget(self.log_text) |
| |
| content_layout.addWidget(left_frame) |
| content_layout.addWidget(right_frame) |
| layout.addLayout(content_layout) |
| |
| def start_training(self): |
| algorithm = 'dqn' if self.algorithm_combo.currentText() == 'Dueling DQN' else 'ppo' |
| env_name = self.env_combo.currentText() |
| |
| self.training_thread = TrainingThread(algorithm, env_name) |
| self.training_thread.update_signal.connect(self.update_training_info) |
| self.training_thread.frame_signal.connect(self.update_game_display) |
| self.training_thread.start() |
| |
| self.start_btn.setEnabled(False) |
| self.stop_btn.setEnabled(True) |
| |
| self.log_text.append(f'๐ Started {self.algorithm_combo.currentText()} training on {env_name}...') |
| |
| def stop_training(self): |
| if self.training_thread: |
| self.training_thread.stop() |
| self.training_thread.wait() |
| |
| self.start_btn.setEnabled(True) |
| self.stop_btn.setEnabled(False) |
| self.log_text.append('โน๏ธ Training stopped.') |
| |
| def update_training_info(self, data): |
| self.env_label.setText(f'๐ฏ Environment: {data.get("env_name", "Unknown")}') |
| self.episode_label.setText(f'๐ Episode: {data["episode"]}') |
| self.reward_label.setText(f'๐ Total Reward: {data["total_reward"]:.1f}') |
| self.avg_reward_label.setText(f'๐ Avg Reward (last 10): {data.get("avg_reward", 0):.1f}') |
| self.best_reward_label.setText(f'โญ Best Reward: {data.get("best_reward", 0):.1f}') |
| self.steps_label.setText(f'โฑ๏ธ Steps: {data["steps"]}') |
| self.epsilon_label.setText(f'๐ฒ Epsilon: {data["epsilon"]:.3f}') |
| self.lives_label.setText(f'โค๏ธ Lives: {data.get("lives", 0)}') |
| |
| def update_game_display(self, frame): |
| if frame is not None: |
| try: |
| h, w, ch = frame.shape |
| bytes_per_line = ch * w |
| q_img = QImage(frame.data, w, h, bytes_per_line, QImage.Format_RGB888) |
| pixmap = QPixmap.fromImage(q_img) |
| self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio)) |
| except Exception as e: |
| print(f"Error updating display: {e}") |
| |
| def closeEvent(self, event): |
| self.stop_training() |
| event.accept() |
|
|
| def main(): |
| |
| torch.manual_seed(42) |
| np.random.seed(42) |
| random.seed(42) |
| |
| app = QApplication(sys.argv) |
| window = ALE_RLApp() |
| window.show() |
| sys.exit(app.exec_()) |
|
|
| if __name__ == '__main__': |
| main() |