import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Dict, List, Tuple class DQNAgent(nn.Module): """ DQN-based reinforcement learning agent """ def __init__(self, config: Dict): super().__init__() self.config = config self.state_dim = config['state_dim'] self.action_dim = config['action_dim'] self.learning_rate = config.get('learning_rate', 1e-4) self.target_net = self._build_model() self.q_net = self._build_model() # Optimizer self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) # Experience replay buffer self.memory = [] self.batch_size = config.get('batch_size', 64) self.memory_size = config.get('memory_size', 10000) # Training parameters self.gamma = config.get('gamma', 0.99) # Discount factor self.epsilon = config.get('epsilon', 1.0) # Exploration rate self.epsilon_min = config.get('epsilon_min', 0.01) self.epsilon_decay = config.get('epsilon_decay', 0.995) #double DQN self.target_net.load_state_dict(self.q_net.state_dict()) self.target_net.eval() self.target_update_interval = config.get('target_update_interval', 10) self.update_count = 0 def _build_model(self): return nn.Sequential( nn.Linear(self.state_dim, 256), nn.ReLU(), nn.Linear(256,256), nn.ReLU(), nn.Linear(256,self.action_dim) ) def forward(self, state: torch.Tensor) -> torch.Tensor: """Forward pass""" return self.q_net(state) def get_action(self, state: np.ndarray, training: bool = True) -> int: """Select action""" if training and np.random.random() < self.epsilon: return np.random.randint(self.action_dim) state = torch.FloatTensor(state).unsqueeze(0) q_values = self.forward(state) return q_values.argmax().item() def remember(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool): """Store experience""" if len(self.memory) >= self.memory_size: self.memory.pop(0) self.memory.append((state, action, reward, next_state, done)) def train_step(self) -> float: """Train on a batch""" if len(self.memory) < self.batch_size: return 0.0 # Sample batch batch = np.random.choice(len(self.memory), self.batch_size, replace=False) states, actions, rewards, next_states, dones = [], [], [], [], [] for idx in batch: s, a, r, ns, d = self.memory[idx] states.append(s) actions.append(a) rewards.append(r) next_states.append(ns) dones.append(d) states = torch.FloatTensor(np.array(states)) actions = torch.LongTensor(actions) rewards = torch.FloatTensor(rewards) next_states = torch.FloatTensor(np.array(next_states)) dones = torch.FloatTensor(dones) # Calculate target Q values current_q_values = self.forward(states).gather(1, actions.unsqueeze(1)) next_q_values = self.target_net(next_states).max(1)[0].detach() target_q_values = rewards + (1 - dones) * self.gamma * next_q_values self.update_count += 1 if self.update_count % self.target_update_interval == 0: self.target_net.load_state_dict(self.q_net.state_dict()) # Calculate loss and update loss = F.mse_loss(current_q_values.squeeze(), target_q_values) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update exploration rate self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) return loss.item() def save(self, path: str): """Save model""" torch.save({ 'q_net_state_dict': self.q_net.state_dict(), 'target_net_state_dict': self.target_net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'epsilon': self.epsilon }, path) def load(self, path: str): """Load model""" checkpoint = torch.load(path) self.q_net.load_state_dict(checkpoint['q_net_state_dict']) self.target_net.load_state_dict(checkpoint['target_net_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.epsilon = checkpoint['epsilon']