| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import numpy as np | |
| import random | |
| from collections import deque | |
| from torch.utils.data import Dataset, DataLoader | |
| class ReplayBufferDataset(Dataset): | |
| def __init__(self, max_size=100000): | |
| self.buffer = deque(maxlen=max_size) | |
| def add(self, states, actions, rewards, next_states, done): | |
| data = ( | |
| states, | |
| actions, | |
| np.array(rewards, dtype=np.float32), | |
| next_states, | |
| np.float32(done) | |
| ) | |
| self.buffer.append(data) | |
| def __len__(self): | |
| return len(self.buffer) | |
| def __getitem__(self, idx): | |
| states, actions, rewards, next_states, done = self.buffer[idx] | |
| return ( | |
| torch.from_numpy(states), | |
| torch.from_numpy(actions), | |
| torch.from_numpy(rewards), | |
| torch.from_numpy(next_states), | |
| torch.tensor(done, dtype=torch.float32) | |
| ) | |
| class Actor(nn.Module): | |
| def __init__(self, state_dim, action_dim, hidden_dim=64): | |
| super(Actor, self).__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(state_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, action_dim), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, state): | |
| return self.net(state) | |
| class SharedCritic(nn.Module): | |
| def __init__(self, global_state_dim, global_action_dim, hidden_dim=128, num_agents=1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(global_state_dim + global_action_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, num_agents) | |
| ) | |
| def forward(self, global_state, global_action): | |
| x = torch.cat([global_state, global_action], dim=1) | |
| return self.net(x) | |
| class Agent: | |
| def __init__(self, local_state_dim, action_dim, lr_actor=1e-3, device=torch.device('cpu')): | |
| self.device = device | |
| self.actor = Actor(local_state_dim, action_dim).to(device) | |
| self.target_actor = Actor(local_state_dim, action_dim).to(device) | |
| self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr_actor) | |
| self.target_actor.load_state_dict(self.actor.state_dict()) | |
| def sync_target(self, tau): | |
| for tp, p in zip(self.target_actor.parameters(), self.actor.parameters()): | |
| tp.data.copy_(tau * p.data + (1.0 - tau) * tp.data) | |
| class MADDPG: | |
| def __init__(self, num_agents, local_state_dim, action_dim, | |
| gamma=0.95, tau=0.01, lr_actor=1e-4, lr_critic=1e-3, | |
| buffer_size=100000, noise_episodes=100, init_sigma=0.2, final_sigma=0.01, | |
| batch_size=128, num_workers=0): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.num_agents = num_agents | |
| self.gamma = gamma | |
| self.tau = tau | |
| self.init_sigma = init_sigma | |
| self.final_sigma = final_sigma | |
| self.noise_episodes = noise_episodes | |
| self.current_episode = 0 | |
| self.actor = Actor(local_state_dim, action_dim).to(self.device) | |
| self.target_actor = Actor(local_state_dim, action_dim).to(self.device) | |
| self.target_actor.load_state_dict(self.actor.state_dict()) | |
| self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr_actor) | |
| global_state_dim = num_agents * local_state_dim | |
| global_action_dim = num_agents * action_dim | |
| self.critic = SharedCritic(global_state_dim, global_action_dim, num_agents=num_agents).to(self.device) | |
| self.target_critic = SharedCritic(global_state_dim, global_action_dim, num_agents=num_agents).to(self.device) | |
| self.target_critic.load_state_dict(self.critic.state_dict()) | |
| self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr_critic) | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.memory = ReplayBufferDataset(max_size=buffer_size) | |
| self.dataloader = None | |
| self.loader_iter = None | |
| def select_actions(self, states, evaluate=False): | |
| states_t = torch.as_tensor(states, dtype=torch.float32, device=self.device) | |
| with torch.no_grad(): | |
| actions_t = torch.stack([ | |
| self.actor(states_t[i]) for i in range(self.num_agents) | |
| ], dim=0) | |
| actions = actions_t.cpu().numpy() | |
| if not evaluate: | |
| frac = min(float(self.current_episode) / self.noise_episodes, 1.0) | |
| current_sigma = self.init_sigma - frac * (self.init_sigma - self.final_sigma) | |
| noise = np.random.normal(0, current_sigma, size=actions.shape) | |
| actions += noise | |
| return np.clip(actions, 0.0, 1.0) | |
| def store_transition(self, states, actions, rewards, next_states, done): | |
| self.memory.add(states, actions, rewards, next_states, done) | |
| def train(self): | |
| if len(self.memory) < self.batch_size: | |
| return | |
| should_pin_memory = self.device.type == 'cuda' | |
| if self.dataloader is None: | |
| self.dataloader = DataLoader(self.memory, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=should_pin_memory, drop_last=True) | |
| self.loader_iter = iter(self.dataloader) | |
| try: | |
| s, a, r, s2, d = next(self.loader_iter) | |
| except StopIteration: | |
| self.dataloader = DataLoader(self.memory, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=should_pin_memory, drop_last=True) | |
| self.loader_iter = iter(self.dataloader) | |
| s, a, r, s2, d = next(self.loader_iter) | |
| s_t, a_t, r_t, s2_t, d_t = s.to(self.device), a.to(self.device), r.to(self.device), s2.to(self.device), d.to(self.device).unsqueeze(-1) | |
| r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-7) | |
| batch_len = s_t.shape[0] | |
| gs, ga, ns = s_t.reshape(batch_len, -1), a_t.reshape(batch_len, -1), s2_t.reshape(batch_len, -1) | |
| with torch.no_grad(): | |
| targ_actions = torch.cat([self.target_actor(s2_t[:, i, :]) for i in range(self.num_agents)], dim=1) | |
| Q_prime = self.target_critic(ns, targ_actions) | |
| targets = r_t + self.gamma * (1 - d_t) * Q_prime | |
| Q = self.critic(gs, ga) | |
| critic_loss = nn.MSELoss()(Q, targets) | |
| self.critic_optim.zero_grad() | |
| critic_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0) | |
| self.critic_optim.step() | |
| all_actions = torch.cat([self.actor(s_t[:, i, :]) for i in range(self.num_agents)], dim=1) | |
| actor_loss = -self.critic(gs, all_actions).mean() | |
| self.actor_optim.zero_grad() | |
| actor_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0) | |
| self.actor_optim.step() | |
| for tp, p in zip(self.target_actor.parameters(), self.actor.parameters()): | |
| tp.data.copy_(self.tau * p.data + (1.0 - self.tau) * tp.data) | |
| for tp, p in zip(self.target_critic.parameters(), self.critic.parameters()): | |
| tp.data.copy_(self.tau * p.data + (1.0 - self.tau) * tp.data) | |
| def on_episode_end(self): | |
| self.current_episode += 1 | |
| def save(self, path: str): | |
| payload = { | |
| "critic": self.critic.state_dict(), | |
| "target_critic": self.target_critic.state_dict(), | |
| "critic_optim": self.critic_optim.state_dict(), | |
| "actor": self.actor.state_dict(), | |
| "target_actor": self.target_actor.state_dict(), | |
| "actor_optim": self.actor_optim.state_dict(), | |
| "current_episode": self.current_episode, | |
| } | |
| torch.save(payload, path) | |
| def load(self, path: str): | |
| checkpoint = torch.load(path, map_location=self.device) | |
| self.critic.load_state_dict(checkpoint["critic"]) | |
| self.target_critic.load_state_dict(checkpoint["target_critic"]) | |
| self.critic_optim.load_state_dict(checkpoint["critic_optim"]) | |
| self.actor.load_state_dict(checkpoint["actor"]) | |
| self.target_actor.load_state_dict(checkpoint["target_actor"]) | |
| self.actor_optim.load_state_dict(checkpoint["actor_optim"]) | |
| self.current_episode = checkpoint.get("current_episode", 0) | |