import torch import torch.nn as nn import torch.optim as optim import numpy as np import random from collections import deque from torch.utils.data import Dataset, DataLoader class ReplayBufferDataset(Dataset): def __init__(self, max_size=100000): self.buffer = deque(maxlen=max_size) def add(self, states, actions, rewards, next_states, done): data = ( states, actions, np.array(rewards, dtype=np.float32), next_states, np.float32(done) ) self.buffer.append(data) def __len__(self): return len(self.buffer) def __getitem__(self, idx): states, actions, rewards, next_states, done = self.buffer[idx] return ( torch.from_numpy(states), torch.from_numpy(actions), torch.from_numpy(rewards), torch.from_numpy(next_states), torch.tensor(done, dtype=torch.float32) ) class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=64): super(Actor, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, action_dim), nn.Sigmoid() ) def forward(self, state): return self.net(state) class SharedCritic(nn.Module): def __init__(self, global_state_dim, global_action_dim, hidden_dim=128, num_agents=1): super().__init__() self.net = nn.Sequential( nn.Linear(global_state_dim + global_action_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, num_agents) ) def forward(self, global_state, global_action): x = torch.cat([global_state, global_action], dim=1) return self.net(x) class Agent: def __init__(self, local_state_dim, action_dim, lr_actor=1e-3, device=torch.device('cpu')): self.device = device self.actor = Actor(local_state_dim, action_dim).to(device) self.target_actor = Actor(local_state_dim, action_dim).to(device) self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr_actor) self.target_actor.load_state_dict(self.actor.state_dict()) def sync_target(self, tau): for tp, p in zip(self.target_actor.parameters(), self.actor.parameters()): tp.data.copy_(tau * p.data + (1.0 - tau) * tp.data) class MADDPG: def __init__(self, num_agents, local_state_dim, action_dim, gamma=0.95, tau=0.01, lr_actor=1e-4, lr_critic=1e-3, buffer_size=100000, noise_episodes=100, init_sigma=0.2, final_sigma=0.01, batch_size=128, num_workers=0): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.num_agents = num_agents self.gamma = gamma self.tau = tau self.init_sigma = init_sigma self.final_sigma = final_sigma self.noise_episodes = noise_episodes self.current_episode = 0 self.actor = Actor(local_state_dim, action_dim).to(self.device) self.target_actor = Actor(local_state_dim, action_dim).to(self.device) self.target_actor.load_state_dict(self.actor.state_dict()) self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr_actor) global_state_dim = num_agents * local_state_dim global_action_dim = num_agents * action_dim self.critic = SharedCritic(global_state_dim, global_action_dim, num_agents=num_agents).to(self.device) self.target_critic = SharedCritic(global_state_dim, global_action_dim, num_agents=num_agents).to(self.device) self.target_critic.load_state_dict(self.critic.state_dict()) self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr_critic) self.batch_size = batch_size self.num_workers = num_workers self.memory = ReplayBufferDataset(max_size=buffer_size) self.dataloader = None self.loader_iter = None def select_actions(self, states, evaluate=False): states_t = torch.as_tensor(states, dtype=torch.float32, device=self.device) with torch.no_grad(): actions_t = torch.stack([ self.actor(states_t[i]) for i in range(self.num_agents) ], dim=0) actions = actions_t.cpu().numpy() if not evaluate: frac = min(float(self.current_episode) / self.noise_episodes, 1.0) current_sigma = self.init_sigma - frac * (self.init_sigma - self.final_sigma) noise = np.random.normal(0, current_sigma, size=actions.shape) actions += noise return np.clip(actions, 0.0, 1.0) def store_transition(self, states, actions, rewards, next_states, done): self.memory.add(states, actions, rewards, next_states, done) def train(self): if len(self.memory) < self.batch_size: return should_pin_memory = self.device.type == 'cuda' if self.dataloader is None: self.dataloader = DataLoader(self.memory, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=should_pin_memory, drop_last=True) self.loader_iter = iter(self.dataloader) try: s, a, r, s2, d = next(self.loader_iter) except StopIteration: self.dataloader = DataLoader(self.memory, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=should_pin_memory, drop_last=True) self.loader_iter = iter(self.dataloader) s, a, r, s2, d = next(self.loader_iter) s_t, a_t, r_t, s2_t, d_t = s.to(self.device), a.to(self.device), r.to(self.device), s2.to(self.device), d.to(self.device).unsqueeze(-1) r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-7) batch_len = s_t.shape[0] gs, ga, ns = s_t.reshape(batch_len, -1), a_t.reshape(batch_len, -1), s2_t.reshape(batch_len, -1) with torch.no_grad(): targ_actions = torch.cat([self.target_actor(s2_t[:, i, :]) for i in range(self.num_agents)], dim=1) Q_prime = self.target_critic(ns, targ_actions) targets = r_t + self.gamma * (1 - d_t) * Q_prime Q = self.critic(gs, ga) critic_loss = nn.MSELoss()(Q, targets) self.critic_optim.zero_grad() critic_loss.backward() torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0) self.critic_optim.step() all_actions = torch.cat([self.actor(s_t[:, i, :]) for i in range(self.num_agents)], dim=1) actor_loss = -self.critic(gs, all_actions).mean() self.actor_optim.zero_grad() actor_loss.backward() torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0) self.actor_optim.step() for tp, p in zip(self.target_actor.parameters(), self.actor.parameters()): tp.data.copy_(self.tau * p.data + (1.0 - self.tau) * tp.data) for tp, p in zip(self.target_critic.parameters(), self.critic.parameters()): tp.data.copy_(self.tau * p.data + (1.0 - self.tau) * tp.data) def on_episode_end(self): self.current_episode += 1 def save(self, path: str): payload = { "critic": self.critic.state_dict(), "target_critic": self.target_critic.state_dict(), "critic_optim": self.critic_optim.state_dict(), "actor": self.actor.state_dict(), "target_actor": self.target_actor.state_dict(), "actor_optim": self.actor_optim.state_dict(), "current_episode": self.current_episode, } torch.save(payload, path) def load(self, path: str): checkpoint = torch.load(path, map_location=self.device) self.critic.load_state_dict(checkpoint["critic"]) self.target_critic.load_state_dict(checkpoint["target_critic"]) self.critic_optim.load_state_dict(checkpoint["critic_optim"]) self.actor.load_state_dict(checkpoint["actor"]) self.target_actor.load_state_dict(checkpoint["target_actor"]) self.actor_optim.load_state_dict(checkpoint["actor_optim"]) self.current_episode = checkpoint.get("current_episode", 0)