# mappo.py import torch import torch.nn as nn import random import numpy as np from torch.distributions import Normal def set_global_seed(seed: int): random.seed(seed) # Python np.random.seed(seed) # NumPy torch.manual_seed(seed) # PyTorch CPU if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # PyTorch GPU # make CuDNN deterministic (may slow you down a bit): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Universal device selection if torch.cuda.is_available(): device = torch.device("cuda") print("Using CUDA (NVIDIA GPU)") # elif torch.backends.mps.is_available(): # device = torch.device("mps") # print("Using MPS (Apple Silicon GPU)") else: device = torch.device("cpu") print("Using CPU") # fix EVERYTHING SEED = 42 set_global_seed(SEED) class MLP(nn.Module): def __init__(self, input_dim, hidden_dims, output_dim): super().__init__() layers = [] last_dim = input_dim for h in hidden_dims: layers += [nn.Linear(last_dim, h), nn.ReLU()] last_dim = h layers.append(nn.Linear(last_dim, output_dim)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class Actor(nn.Module): def __init__(self, obs_dim, act_dim, hidden=(64,64)): super().__init__() self.net = MLP(obs_dim, hidden, act_dim) self.log_std = nn.Parameter(torch.zeros(act_dim)) def forward(self, x): mean = self.net(x) std = torch.exp(self.log_std) return mean, std class Critic(nn.Module): def __init__(self, state_dim, hidden=(128,128)): super().__init__() self.net = MLP(state_dim, hidden, 1) def forward(self, x): return self.net(x).squeeze(-1) class MAPPO: def __init__( self, n_agents, local_dim, global_dim, act_dim, lr=3e-4, gamma=0.99, lam=0.95, clip_eps=0.2, k_epochs=10, batch_size=1024 ): self.n_agents = n_agents self.gamma = gamma self.lam = lam self.clip_eps = clip_eps self.k_epochs = k_epochs self.batch_size = batch_size self.actor = Actor(local_dim, act_dim).to(device) self.critic = Critic(global_dim).to(device) self.opt_a = torch.optim.Adam(self.actor.parameters(), lr=lr) self.opt_c = torch.optim.Adam(self.critic.parameters(), lr=lr) self.local_dim = local_dim self.global_dim = global_dim self.act_dim = act_dim self.clear_buffer() def clear_buffer(self): self.ls = [] # local observations self.gs = [] # global observations self.ac = [] # actions self.lp = [] # log-probs self.rw = [] # rewards self.done = [] # done flags self.next_gs = [] # next global observations @torch.no_grad() def select_action(self, local_obs, global_obs): l = torch.FloatTensor(local_obs).to(device) mean, std = self.actor(l) dist = Normal(mean, std) a = dist.sample() return a.cpu().numpy(), dist.log_prob(a).sum(-1).cpu().numpy() def store(self, local_obs, global_obs, action, logp, reward, done, next_global_obs): self.ls.append(local_obs) self.gs.append(global_obs) self.ac.append(action) self.lp.append(logp) self.rw.append(reward) self.done.append(done) self.next_gs.append(next_global_obs) def compute_gae(self, values): """ values: torch.Tensor shape [T] (one central V(s) per timestep) returns: adv_flat: torch.Tensor shape [T * n_agents] ret_flat: torch.Tensor shape [T * n_agents] """ # 1) get raw arrays vals_1d = values.cpu().numpy() # [T] T = len(vals_1d) N = self.n_agents # 2) broadcast to per-agent # vals_agent[t,i] = V(state_t) vals_agent = np.tile(vals_1d[:,None], (1, N)) # [T,N] # 3) build next_vals likewise next_vals = np.zeros_like(vals_agent) # [T,N] next_vals[:-1] = vals_agent[1:] # if episode didn’t end at final step, bootstrap last: if not self.done[-1]: with torch.no_grad(): v_last = self.critic( torch.FloatTensor(self.next_gs[-1]).to(device) ).cpu().item() next_vals[-1, :] = v_last # 4) GAE loop over (T,N) adv = np.zeros_like(vals_agent, dtype=np.float32) prev_adv = np.zeros(N, dtype=np.float32) for t in reversed(range(T)): mask = 1.0 - float(self.done[t]) # scalar 0/1 rew_t = np.array(self.rw[t], dtype=np.float32) # [N] delta = rew_t + self.gamma * next_vals[t] * mask - vals_agent[t] prev_adv = delta + self.gamma * self.lam * mask * prev_adv adv[t] = prev_adv # 5) compute returns & flatten ret = adv + vals_agent # [T,N] adv_flat = torch.from_numpy(adv.flatten()).to(device) ret_flat = torch.from_numpy(ret.flatten()).to(device) return adv_flat, ret_flat def update(self): # 1) Raw global states tensor [T, G] raw_gs = torch.FloatTensor(self.gs).to(device) # [T, G] # 2) Compute one value V(s_t) per timestep with torch.no_grad(): vals = self.critic(raw_gs).cpu() # [T] # 3) Compute advantages and returns using GAE (returns flattened [T*N]) adv_flat, ret_flat = self.compute_gae(vals) # both shape [T * N] # 4) Prepare per-agent flattened training tensors # Local states [T*N, local_dim] ls = torch.FloatTensor(self.ls).view(-1, self.local_dim).to(device) # Actions [T*N, act_dim] ac = torch.FloatTensor(self.ac).view(-1, self.act_dim).to(device) # Old log-probs [T*N] old_lp = torch.FloatTensor(self.lp).view(-1).to(device) # Broadcast global states to per-agent: [T, G] -> [T, N, G] -> [T*N, G] gs = raw_gs.unsqueeze(1).expand(-1, self.n_agents, -1) # [T, N, G] gs = gs.reshape(-1, self.global_dim).to(device) # [T*N, G] # Create dataset and loader dataset = torch.utils.data.TensorDataset( ls, gs, ac, old_lp, adv_flat, ret_flat ) gen = torch.Generator() gen.manual_seed(SEED) loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, generator=gen ) # 5) PPO update loop for _ in range(self.k_epochs): for b_ls, b_gs, b_ac, b_lp, b_adv, b_ret in loader: # Actor update mean, std = self.actor(b_ls) dist = Normal(mean, std) lp_new = dist.log_prob(b_ac).sum(-1) ratio = torch.exp(lp_new - b_lp) surr1 = ratio * b_adv surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * b_adv actor_loss = -torch.min(surr1, surr2).mean() self.opt_a.zero_grad() actor_loss.backward() self.opt_a.step() # Critic update val_pred = self.critic(b_gs) critic_loss = nn.MSELoss()(val_pred, b_ret) self.opt_c.zero_grad() critic_loss.backward() self.opt_c.step() # 6) Clear buffers for next rollout self.clear_buffer() def save(self, path): torch.save({'actor': self.actor.state_dict(), 'critic': self.critic.state_dict()}, path) def load(self, path): data = torch.load(path, map_location=device) self.actor.load_state_dict(data['actor']) self.critic.load_state_dict(data['critic'])