|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
class DQNAgent(nn.Module): |
|
|
""" |
|
|
DQN-based reinforcement learning agent |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Dict): |
|
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
self.state_dim = config['state_dim'] |
|
|
self.action_dim = config['action_dim'] |
|
|
self.learning_rate = config.get('learning_rate', 1e-4) |
|
|
|
|
|
|
|
|
self.target_net = self._build_model() |
|
|
self.q_net = self._build_model() |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) |
|
|
|
|
|
|
|
|
self.memory = [] |
|
|
self.batch_size = config.get('batch_size', 64) |
|
|
self.memory_size = config.get('memory_size', 10000) |
|
|
|
|
|
|
|
|
self.gamma = config.get('gamma', 0.99) |
|
|
self.epsilon = config.get('epsilon', 1.0) |
|
|
self.epsilon_min = config.get('epsilon_min', 0.01) |
|
|
self.epsilon_decay = config.get('epsilon_decay', 0.995) |
|
|
|
|
|
|
|
|
self.target_net.load_state_dict(self.q_net.state_dict()) |
|
|
self.target_net.eval() |
|
|
self.target_update_interval = config.get('target_update_interval', 10) |
|
|
self.update_count = 0 |
|
|
def _build_model(self): |
|
|
return nn.Sequential( |
|
|
nn.Linear(self.state_dim, 256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256,256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256,self.action_dim) |
|
|
) |
|
|
def forward(self, state: torch.Tensor) -> torch.Tensor: |
|
|
"""Forward pass""" |
|
|
return self.q_net(state) |
|
|
|
|
|
def get_action(self, state: np.ndarray, training: bool = True) -> int: |
|
|
"""Select action""" |
|
|
if training and np.random.random() < self.epsilon: |
|
|
return np.random.randint(self.action_dim) |
|
|
|
|
|
state = torch.FloatTensor(state).unsqueeze(0) |
|
|
q_values = self.forward(state) |
|
|
return q_values.argmax().item() |
|
|
|
|
|
def remember(self, state: np.ndarray, action: int, reward: float, |
|
|
next_state: np.ndarray, done: bool): |
|
|
"""Store experience""" |
|
|
if len(self.memory) >= self.memory_size: |
|
|
self.memory.pop(0) |
|
|
self.memory.append((state, action, reward, next_state, done)) |
|
|
|
|
|
def train_step(self) -> float: |
|
|
"""Train on a batch""" |
|
|
if len(self.memory) < self.batch_size: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
batch = np.random.choice(len(self.memory), self.batch_size, replace=False) |
|
|
states, actions, rewards, next_states, dones = [], [], [], [], [] |
|
|
|
|
|
for idx in batch: |
|
|
s, a, r, ns, d = self.memory[idx] |
|
|
states.append(s) |
|
|
actions.append(a) |
|
|
rewards.append(r) |
|
|
next_states.append(ns) |
|
|
dones.append(d) |
|
|
|
|
|
states = torch.FloatTensor(np.array(states)) |
|
|
actions = torch.LongTensor(actions) |
|
|
rewards = torch.FloatTensor(rewards) |
|
|
next_states = torch.FloatTensor(np.array(next_states)) |
|
|
dones = torch.FloatTensor(dones) |
|
|
|
|
|
|
|
|
current_q_values = self.forward(states).gather(1, actions.unsqueeze(1)) |
|
|
next_q_values = self.target_net(next_states).max(1)[0].detach() |
|
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values |
|
|
|
|
|
self.update_count += 1 |
|
|
if self.update_count % self.target_update_interval == 0: |
|
|
self.target_net.load_state_dict(self.q_net.state_dict()) |
|
|
|
|
|
|
|
|
loss = F.mse_loss(current_q_values.squeeze(), target_q_values) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) |
|
|
|
|
|
|
|
|
|
|
|
return loss.item() |
|
|
|
|
|
def save(self, path: str): |
|
|
"""Save model""" |
|
|
torch.save({ |
|
|
'q_net_state_dict': self.q_net.state_dict(), |
|
|
'target_net_state_dict': self.target_net.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'epsilon': self.epsilon |
|
|
}, path) |
|
|
|
|
|
def load(self, path: str): |
|
|
"""Load model""" |
|
|
checkpoint = torch.load(path) |
|
|
self.q_net.load_state_dict(checkpoint['q_net_state_dict']) |
|
|
self.target_net.load_state_dict(checkpoint['target_net_state_dict']) |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
self.epsilon = checkpoint['epsilon'] |