import torch import torch.nn as nn import random import numpy as np from torch.distributions import Normal if torch.cuda.is_available(): device = torch.device("cuda") print("Using CUDA (NVIDIA GPU)") else: device = torch.device("cpu") print("Using CPU") def set_global_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True SEED = 42 set_global_seed(SEED) class MLP(nn.Module): def __init__(self, input_dim, hidden_dims, output_dim): super().__init__() layers = [] last_dim = input_dim for h in hidden_dims: layers += [nn.Linear(last_dim, h), nn.ReLU()] last_dim = h layers.append(nn.Linear(last_dim, output_dim)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class Actor(nn.Module): def __init__(self, obs_dim, act_dim, hidden=(64,64)): super().__init__() self.net = MLP(obs_dim, hidden, act_dim) self.log_std = nn.Parameter(torch.zeros(act_dim)) def forward(self, x): mean = self.net(x) std = torch.exp(self.log_std) return mean, std class Critic(nn.Module): def __init__(self, state_dim, hidden=(128,128)): super().__init__() self.net = MLP(state_dim, hidden, 1) def forward(self, x): return self.net(x).squeeze(-1) class MAPPO: def __init__( self, n_agents, local_dim, global_dim, act_dim, lr=3e-4, gamma=0.99, lam=0.95, clip_eps=0.2, k_epochs=10, batch_size=1024, episode_len=96 ): self.n_agents = n_agents self.local_dim = local_dim self.global_dim = global_dim self.act_dim = act_dim self.gamma = gamma self.lam = lam self.clip_eps = clip_eps self.k_epochs = k_epochs self.batch_size = batch_size self.episode_len = episode_len self.actor = Actor(local_dim, act_dim).to(device) self.critic = Critic(global_dim).to(device) self.opt_a = torch.optim.Adam(self.actor.parameters(), lr=lr) self.opt_c = torch.optim.Adam(self.critic.parameters(), lr=lr) print("MAPPO CUDA AMP is disabled for stability.") self.init_buffer() def init_buffer(self): self.ls_buf = np.zeros((self.episode_len, self.n_agents, self.local_dim), dtype=np.float16) self.gs_buf = np.zeros((self.episode_len, self.global_dim), dtype=np.float16) self.ac_buf = np.zeros((self.episode_len, self.n_agents, self.act_dim), dtype=np.float16) self.lp_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float16) self.rw_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float16) self.done_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float16) self.next_gs_buf = np.zeros((self.episode_len, self.global_dim), dtype=np.float16) self.step_idx = 0 @torch.no_grad() def select_action(self, local_obs, global_obs): l = torch.from_numpy(local_obs).float().to(device) mean, std = self.actor(l) dist = Normal(mean, std) a = dist.sample() return a.cpu().numpy(), dist.log_prob(a).sum(-1).cpu().numpy() def store(self, local_obs, global_obs, action, logp, reward, done, next_global_obs): if self.step_idx < self.episode_len: self.ls_buf[self.step_idx] = local_obs self.gs_buf[self.step_idx] = global_obs self.ac_buf[self.step_idx] = action self.lp_buf[self.step_idx] = logp self.rw_buf[self.step_idx] = reward self.done_buf[self.step_idx] = done self.next_gs_buf[self.step_idx] = next_global_obs self.step_idx += 1 def compute_gae(self, T, vals): N = self.n_agents vals_agent = vals.unsqueeze(1).expand(-1, N).cpu().numpy() next_vals_agent = np.zeros_like(vals_agent) next_vals_agent[:-1] = vals_agent[1:] if not self.done_buf[T-1].all(): with torch.no_grad(): v_last = self.critic( torch.from_numpy(self.next_gs_buf[T-1]).float().to(device) ).cpu().item() next_vals_agent[T-1, :] = v_last masks = 1.0 - self.done_buf[:T] rewards = self.rw_buf[:T] adv = rewards + self.gamma * next_vals_agent * masks - vals_agent ret = adv + vals_agent adv_flat = torch.from_numpy(adv.flatten()).to(device) ret_flat = torch.from_numpy(ret.flatten()).to(device) return adv_flat, ret_flat def update(self): T = self.step_idx if T == 0: return gs_tensor = torch.from_numpy(self.gs_buf[:T]).float().to(device) ls_tensor = torch.from_numpy(self.ls_buf[:T]).float().to(device).view(T * self.n_agents, -1) ac_tensor = torch.from_numpy(self.ac_buf[:T]).float().to(device).view(T * self.n_agents, -1) lp_tensor = torch.from_numpy(self.lp_buf[:T]).float().to(device).view(-1) with torch.no_grad(): vals = self.critic(gs_tensor) adv_flat, ret_flat = self.compute_gae(T, vals) adv_flat = (adv_flat - adv_flat.mean()) / (adv_flat.std() + 1e-8) gs_for_batch = gs_tensor.unsqueeze(1).expand(-1, self.n_agents, -1).reshape(T * self.n_agents, self.global_dim) dataset = torch.utils.data.TensorDataset(ls_tensor, gs_for_batch, ac_tensor, lp_tensor, adv_flat, ret_flat) gen = torch.Generator() gen.manual_seed(SEED) loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, generator=gen) for _ in range(self.k_epochs): for b_ls, b_gs, b_ac, b_lp, b_adv, b_ret in loader: mean, std = self.actor(b_ls) dist = Normal(mean, std) entropy = dist.entropy().mean() lp_new = dist.log_prob(b_ac).sum(-1) ratio = torch.exp(lp_new - b_lp) surr1 = ratio * b_adv surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * b_adv actor_loss = -torch.min(surr1, surr2).mean() - 0.01 * entropy self.opt_a.zero_grad() actor_loss.backward() self.opt_a.step() val_pred = self.critic(b_gs) critic_loss = nn.MSELoss()(val_pred, b_ret) self.opt_c.zero_grad() critic_loss.backward() self.opt_c.step() self.step_idx = 0 def save(self, path): torch.save({'actor': self.actor.state_dict(), 'critic': self.critic.state_dict()}, path) def load(self, path): data = torch.load(path, map_location=device) self.actor.load_state_dict(data['actor']) self.critic.load_state_dict(data['critic'])