| | |
| | import torch |
| | import torch.nn as nn |
| | import random |
| | import numpy as np |
| | from torch.distributions import Normal |
| | from torch.amp import autocast |
| | from torch.cuda.amp import GradScaler |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | print("Using CUDA (NVIDIA GPU)") |
| | else: |
| | device = torch.device("cpu") |
| | print("Using CPU") |
| |
|
| | def set_global_seed(seed: int): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| |
|
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.deterministic = False |
| | torch.backends.cudnn.benchmark = True |
| |
|
| | SEED = 42 |
| | set_global_seed(SEED) |
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, input_dim, hidden_dims, output_dim): |
| | super().__init__() |
| | layers = [] |
| | last_dim = input_dim |
| | for h in hidden_dims: |
| | layers += [nn.Linear(last_dim, h), nn.ReLU()] |
| | last_dim = h |
| | layers.append(nn.Linear(last_dim, output_dim)) |
| | self.net = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| | class Actor(nn.Module): |
| | def __init__(self, obs_dim, act_dim, hidden=(64,64)): |
| | super().__init__() |
| | self.net = MLP(obs_dim, hidden, act_dim) |
| | self.log_std = nn.Parameter(torch.zeros(act_dim)) |
| |
|
| | def forward(self, x): |
| | mean = self.net(x) |
| | std = torch.exp(self.log_std) |
| | return mean, std |
| |
|
| | class Critic(nn.Module): |
| | def __init__(self, state_dim, hidden=(128,128)): |
| | super().__init__() |
| | self.net = MLP(state_dim, hidden, 1) |
| |
|
| | def forward(self, x): |
| | return self.net(x).squeeze(-1) |
| |
|
| | class MeanField: |
| | def __init__( |
| | self, |
| | n_agents, |
| | local_dim, |
| | global_dim, |
| | act_dim, |
| | lr=3e-4, |
| | gamma=0.99, |
| | lam=0.95, |
| | clip_eps=0.2, |
| | k_epochs=10, |
| | batch_size=1024, |
| | episode_len=96 |
| | ): |
| | self.n_agents = n_agents |
| | self.local_dim = local_dim |
| | self.global_dim = global_dim |
| | self.act_dim = act_dim |
| | self.gamma = gamma |
| | self.lam = lam |
| | self.clip_eps = clip_eps |
| | self.k_epochs = k_epochs |
| | self.batch_size = batch_size |
| | self.episode_len = episode_len |
| |
|
| | self.actor = Actor(local_dim + global_dim, act_dim).to(device) |
| | self.critic = Critic(global_dim).to(device) |
| |
|
| | self.opt_a = torch.optim.Adam(self.actor.parameters(), lr=lr) |
| | self.opt_c = torch.optim.Adam(self.critic.parameters(), lr=lr) |
| |
|
| | print("MeanField CUDA AMP is disabled for stability.") |
| | |
| | self.init_buffer() |
| |
|
| | def init_buffer(self): |
| | self.ls_buf = np.zeros((self.episode_len, self.n_agents, self.local_dim), dtype=np.float32) |
| | self.gs_buf = np.zeros((self.episode_len, self.global_dim), dtype=np.float32) |
| | self.ac_buf = np.zeros((self.episode_len, self.n_agents, self.act_dim), dtype=np.float32) |
| | self.lp_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float32) |
| | self.rw_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float32) |
| | self.done_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float32) |
| | self.next_gs_buf = np.zeros((self.episode_len, self.global_dim), dtype=np.float32) |
| | self.step_idx = 0 |
| |
|
| | @torch.no_grad() |
| | def select_action(self, local_obs, global_obs): |
| | l = torch.from_numpy(local_obs).float().to(device) |
| | g = torch.from_numpy(global_obs).float().to(device).unsqueeze(0).expand(self.n_agents, -1) |
| | input_x = torch.cat([l, g], dim=-1) |
| | mean, std = self.actor(input_x) |
| | dist = Normal(mean, std) |
| | a = dist.sample() |
| | return a.cpu().numpy(), dist.log_prob(a).sum(-1).cpu().numpy() |
| |
|
| | def store(self, local_obs, global_obs, action, logp, reward, done, next_global_obs): |
| | if self.step_idx < self.episode_len: |
| | self.ls_buf[self.step_idx] = local_obs |
| | self.gs_buf[self.step_idx] = global_obs |
| | self.ac_buf[self.step_idx] = action |
| | self.lp_buf[self.step_idx] = logp |
| | self.rw_buf[self.step_idx] = reward |
| | self.done_buf[self.step_idx] = done |
| | self.next_gs_buf[self.step_idx] = next_global_obs |
| | self.step_idx += 1 |
| |
|
| | def compute_gae(self, T, vals): |
| | """ |
| | Computes Generalized Advantage Estimation (GAE). |
| | """ |
| | N = self.n_agents |
| | adv_buf = np.zeros_like(self.rw_buf[:T]) |
| |
|
| |
|
| | if not self.done_buf[T-1].all(): |
| | with torch.no_grad(): |
| | v_last = self.critic( |
| | torch.from_numpy(self.next_gs_buf[T-1]).float().to(device) |
| | ).cpu().numpy() |
| | else: |
| | v_last = 0.0 |
| | vals_agent = vals.unsqueeze(1).expand(-1, N).cpu().numpy() |
| | rewards = self.rw_buf[:T] |
| | masks = 1.0 - self.done_buf[:T] |
| | gae = 0 |
| | for t in reversed(range(T)): |
| | v_next = vals_agent[t+1] if t < T - 1 else v_last |
| | delta = rewards[t] + self.gamma * v_next * masks[t] - vals_agent[t] |
| | adv_buf[t] = gae = delta + self.gamma * self.lam * masks[t] * gae |
| | ret_buf = adv_buf + vals_agent |
| | adv_flat = torch.from_numpy(adv_buf.flatten()).float().to(device) |
| | ret_flat = torch.from_numpy(ret_buf.flatten()).float().to(device) |
| | return adv_flat, ret_flat |
| |
|
| | def update(self): |
| | T = self.step_idx |
| | if T == 0: return |
| |
|
| | gs_tensor = torch.from_numpy(self.gs_buf[:T]).float().to(device) |
| | ls_tensor = torch.from_numpy(self.ls_buf[:T]).float().to(device).view(T * self.n_agents, -1) |
| | ac_tensor = torch.from_numpy(self.ac_buf[:T]).float().to(device).view(T * self.n_agents, -1) |
| | lp_tensor = torch.from_numpy(self.lp_buf[:T]).float().to(device).view(-1) |
| | |
| | with torch.no_grad(): |
| | vals = self.critic(gs_tensor) |
| |
|
| | adv_flat, ret_flat = self.compute_gae(T, vals) |
| | adv_flat = (adv_flat - adv_flat.mean()) / (adv_flat.std() + 1e-8) |
| |
|
| | gs_for_batch = gs_tensor.unsqueeze(1).expand(-1, self.n_agents, -1).reshape(T * self.n_agents, self.global_dim) |
| |
|
| | dataset = torch.utils.data.TensorDataset(ls_tensor, gs_for_batch, ac_tensor, lp_tensor, adv_flat, ret_flat) |
| | gen = torch.Generator() |
| | gen.manual_seed(SEED) |
| | loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, generator=gen) |
| |
|
| | for _ in range(self.k_epochs): |
| | for b_ls, b_gs, b_ac, b_lp, b_adv, b_ret in loader: |
| | input_a = torch.cat([b_ls, b_gs], dim=-1) |
| | mean, std = self.actor(input_a) |
| | dist = Normal(mean, std) |
| |
|
| | entropy = dist.entropy().mean() |
| |
|
| | lp_new = dist.log_prob(b_ac).sum(-1) |
| | ratio = torch.exp(lp_new - b_lp) |
| | surr1 = ratio * b_adv |
| | surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * b_adv |
| |
|
| | actor_loss = -torch.min(surr1, surr2).mean() - 0.01 * entropy |
| |
|
| | self.opt_a.zero_grad() |
| | actor_loss.backward() |
| | nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=0.5) |
| | self.opt_a.step() |
| |
|
| |
|
| | val_pred = self.critic(b_gs) |
| | critic_loss = nn.MSELoss()(val_pred, b_ret) |
| |
|
| | self.opt_c.zero_grad() |
| | critic_loss.backward() |
| | nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=0.5) |
| | self.opt_c.step() |
| |
|
| | self.step_idx = 0 |
| | |
| | def save(self, path): |
| | torch.save({'actor': self.actor.state_dict(), |
| | 'critic': self.critic.state_dict()}, path) |
| |
|
| | def load(self, path): |
| | data = torch.load(path, map_location=device) |
| | self.actor.load_state_dict(data['actor']) |
| | self.critic.load_state_dict(data['critic']) |