yqq1231231's picture
Upload folder using huggingface_hub
0d5bad6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple
class DQNAgent(nn.Module):
"""
DQN-based reinforcement learning agent
"""
def __init__(self, config: Dict):
super().__init__()
self.config = config
self.state_dim = config['state_dim']
self.action_dim = config['action_dim']
self.learning_rate = config.get('learning_rate', 1e-4)
self.target_net = self._build_model()
self.q_net = self._build_model()
# Optimizer
self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
# Experience replay buffer
self.memory = []
self.batch_size = config.get('batch_size', 64)
self.memory_size = config.get('memory_size', 10000)
# Training parameters
self.gamma = config.get('gamma', 0.99) # Discount factor
self.epsilon = config.get('epsilon', 1.0) # Exploration rate
self.epsilon_min = config.get('epsilon_min', 0.01)
self.epsilon_decay = config.get('epsilon_decay', 0.995)
#double DQN
self.target_net.load_state_dict(self.q_net.state_dict())
self.target_net.eval()
self.target_update_interval = config.get('target_update_interval', 10)
self.update_count = 0
def _build_model(self):
return nn.Sequential(
nn.Linear(self.state_dim, 256),
nn.ReLU(),
nn.Linear(256,256),
nn.ReLU(),
nn.Linear(256,self.action_dim)
)
def forward(self, state: torch.Tensor) -> torch.Tensor:
"""Forward pass"""
return self.q_net(state)
def get_action(self, state: np.ndarray, training: bool = True) -> int:
"""Select action"""
if training and np.random.random() < self.epsilon:
return np.random.randint(self.action_dim)
state = torch.FloatTensor(state).unsqueeze(0)
q_values = self.forward(state)
return q_values.argmax().item()
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool):
"""Store experience"""
if len(self.memory) >= self.memory_size:
self.memory.pop(0)
self.memory.append((state, action, reward, next_state, done))
def train_step(self) -> float:
"""Train on a batch"""
if len(self.memory) < self.batch_size:
return 0.0
# Sample batch
batch = np.random.choice(len(self.memory), self.batch_size, replace=False)
states, actions, rewards, next_states, dones = [], [], [], [], []
for idx in batch:
s, a, r, ns, d = self.memory[idx]
states.append(s)
actions.append(a)
rewards.append(r)
next_states.append(ns)
dones.append(d)
states = torch.FloatTensor(np.array(states))
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(np.array(next_states))
dones = torch.FloatTensor(dones)
# Calculate target Q values
current_q_values = self.forward(states).gather(1, actions.unsqueeze(1))
next_q_values = self.target_net(next_states).max(1)[0].detach()
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
self.update_count += 1
if self.update_count % self.target_update_interval == 0:
self.target_net.load_state_dict(self.q_net.state_dict())
# Calculate loss and update
loss = F.mse_loss(current_q_values.squeeze(), target_q_values)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Update exploration rate
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
def save(self, path: str):
"""Save model"""
torch.save({
'q_net_state_dict': self.q_net.state_dict(),
'target_net_state_dict': self.target_net.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'epsilon': self.epsilon
}, path)
def load(self, path: str):
"""Load model"""
checkpoint = torch.load(path)
self.q_net.load_state_dict(checkpoint['q_net_state_dict'])
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.epsilon = checkpoint['epsilon']