SolarSys2025's picture
Upload 30 files
55da406 verified
# mappo.py
import torch
import torch.nn as nn
import random
import numpy as np
from torch.distributions import Normal
def set_global_seed(seed: int):
random.seed(seed) # Python
np.random.seed(seed) # NumPy
torch.manual_seed(seed) # PyTorch CPU
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) # PyTorch GPU
# make CuDNN deterministic (may slow you down a bit):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Universal device selection
if torch.cuda.is_available():
device = torch.device("cuda")
print("Using CUDA (NVIDIA GPU)")
# elif torch.backends.mps.is_available():
# device = torch.device("mps")
# print("Using MPS (Apple Silicon GPU)")
else:
device = torch.device("cpu")
print("Using CPU")
# fix EVERYTHING
SEED = 42
set_global_seed(SEED)
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim):
super().__init__()
layers = []
last_dim = input_dim
for h in hidden_dims:
layers += [nn.Linear(last_dim, h), nn.ReLU()]
last_dim = h
layers.append(nn.Linear(last_dim, output_dim))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class Actor(nn.Module):
def __init__(self, obs_dim, act_dim, hidden=(64,64)):
super().__init__()
self.net = MLP(obs_dim, hidden, act_dim)
self.log_std = nn.Parameter(torch.zeros(act_dim))
def forward(self, x):
mean = self.net(x)
std = torch.exp(self.log_std)
return mean, std
class Critic(nn.Module):
def __init__(self, state_dim, hidden=(128,128)):
super().__init__()
self.net = MLP(state_dim, hidden, 1)
def forward(self, x):
return self.net(x).squeeze(-1)
class MAPPO:
def __init__(
self,
n_agents,
local_dim,
global_dim,
act_dim,
lr=3e-4,
gamma=0.99,
lam=0.95,
clip_eps=0.2,
k_epochs=10,
batch_size=1024
):
self.n_agents = n_agents
self.gamma = gamma
self.lam = lam
self.clip_eps = clip_eps
self.k_epochs = k_epochs
self.batch_size = batch_size
self.actor = Actor(local_dim, act_dim).to(device)
self.critic = Critic(global_dim).to(device)
self.opt_a = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.opt_c = torch.optim.Adam(self.critic.parameters(), lr=lr)
self.local_dim = local_dim
self.global_dim = global_dim
self.act_dim = act_dim
self.clear_buffer()
def clear_buffer(self):
self.ls = [] # local observations
self.gs = [] # global observations
self.ac = [] # actions
self.lp = [] # log-probs
self.rw = [] # rewards
self.done = [] # done flags
self.next_gs = [] # next global observations
@torch.no_grad()
def select_action(self, local_obs, global_obs):
l = torch.FloatTensor(local_obs).to(device)
mean, std = self.actor(l)
dist = Normal(mean, std)
a = dist.sample()
return a.cpu().numpy(), dist.log_prob(a).sum(-1).cpu().numpy()
def store(self, local_obs, global_obs, action, logp, reward, done, next_global_obs):
self.ls.append(local_obs)
self.gs.append(global_obs)
self.ac.append(action)
self.lp.append(logp)
self.rw.append(reward)
self.done.append(done)
self.next_gs.append(next_global_obs)
def compute_gae(self, values):
"""
values: torch.Tensor shape [T] (one central V(s) per timestep)
returns:
adv_flat: torch.Tensor shape [T * n_agents]
ret_flat: torch.Tensor shape [T * n_agents]
"""
# 1) get raw arrays
vals_1d = values.cpu().numpy() # [T]
T = len(vals_1d)
N = self.n_agents
# 2) broadcast to per-agent
# vals_agent[t,i] = V(state_t)
vals_agent = np.tile(vals_1d[:,None], (1, N)) # [T,N]
# 3) build next_vals likewise
next_vals = np.zeros_like(vals_agent) # [T,N]
next_vals[:-1] = vals_agent[1:]
# if episode didn’t end at final step, bootstrap last:
if not self.done[-1]:
with torch.no_grad():
v_last = self.critic(
torch.FloatTensor(self.next_gs[-1]).to(device)
).cpu().item()
next_vals[-1, :] = v_last
# 4) GAE loop over (T,N)
adv = np.zeros_like(vals_agent, dtype=np.float32)
prev_adv = np.zeros(N, dtype=np.float32)
for t in reversed(range(T)):
mask = 1.0 - float(self.done[t]) # scalar 0/1
rew_t = np.array(self.rw[t], dtype=np.float32) # [N]
delta = rew_t + self.gamma * next_vals[t] * mask - vals_agent[t]
prev_adv = delta + self.gamma * self.lam * mask * prev_adv
adv[t] = prev_adv
# 5) compute returns & flatten
ret = adv + vals_agent # [T,N]
adv_flat = torch.from_numpy(adv.flatten()).to(device)
ret_flat = torch.from_numpy(ret.flatten()).to(device)
return adv_flat, ret_flat
def update(self):
# 1) Raw global states tensor [T, G]
raw_gs = torch.FloatTensor(self.gs).to(device) # [T, G]
# 2) Compute one value V(s_t) per timestep
with torch.no_grad():
vals = self.critic(raw_gs).cpu() # [T]
# 3) Compute advantages and returns using GAE (returns flattened [T*N])
adv_flat, ret_flat = self.compute_gae(vals) # both shape [T * N]
# 4) Prepare per-agent flattened training tensors
# Local states [T*N, local_dim]
ls = torch.FloatTensor(self.ls).view(-1, self.local_dim).to(device)
# Actions [T*N, act_dim]
ac = torch.FloatTensor(self.ac).view(-1, self.act_dim).to(device)
# Old log-probs [T*N]
old_lp = torch.FloatTensor(self.lp).view(-1).to(device)
# Broadcast global states to per-agent: [T, G] -> [T, N, G] -> [T*N, G]
gs = raw_gs.unsqueeze(1).expand(-1, self.n_agents, -1) # [T, N, G]
gs = gs.reshape(-1, self.global_dim).to(device) # [T*N, G]
# Create dataset and loader
dataset = torch.utils.data.TensorDataset(
ls, gs, ac, old_lp, adv_flat, ret_flat
)
gen = torch.Generator()
gen.manual_seed(SEED)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=0,
generator=gen
)
# 5) PPO update loop
for _ in range(self.k_epochs):
for b_ls, b_gs, b_ac, b_lp, b_adv, b_ret in loader:
# Actor update
mean, std = self.actor(b_ls)
dist = Normal(mean, std)
lp_new = dist.log_prob(b_ac).sum(-1)
ratio = torch.exp(lp_new - b_lp)
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * b_adv
actor_loss = -torch.min(surr1, surr2).mean()
self.opt_a.zero_grad()
actor_loss.backward()
self.opt_a.step()
# Critic update
val_pred = self.critic(b_gs)
critic_loss = nn.MSELoss()(val_pred, b_ret)
self.opt_c.zero_grad()
critic_loss.backward()
self.opt_c.step()
# 6) Clear buffers for next rollout
self.clear_buffer()
def save(self, path):
torch.save({'actor': self.actor.state_dict(),
'critic': self.critic.state_dict()}, path)
def load(self, path):
data = torch.load(path, map_location=device)
self.actor.load_state_dict(data['actor'])
self.critic.load_state_dict(data['critic'])