SolarSys2025's picture
Upload 30 files
55da406 verified
import torch
import torch.nn as nn
import random
import numpy as np
from torch.distributions import Normal
if torch.cuda.is_available():
device = torch.device("cuda")
print("Using CUDA (NVIDIA GPU)")
else:
device = torch.device("cpu")
print("Using CPU")
def set_global_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
SEED = 42
set_global_seed(SEED)
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim):
super().__init__()
layers = []
last_dim = input_dim
for h in hidden_dims:
layers += [nn.Linear(last_dim, h), nn.ReLU()]
last_dim = h
layers.append(nn.Linear(last_dim, output_dim))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class Actor(nn.Module):
def __init__(self, obs_dim, act_dim, hidden=(64,64)):
super().__init__()
self.net = MLP(obs_dim, hidden, act_dim)
self.log_std = nn.Parameter(torch.zeros(act_dim))
def forward(self, x):
mean = self.net(x)
std = torch.exp(self.log_std)
return mean, std
class Critic(nn.Module):
def __init__(self, state_dim, hidden=(128,128)):
super().__init__()
self.net = MLP(state_dim, hidden, 1)
def forward(self, x):
return self.net(x).squeeze(-1)
class MAPPO:
def __init__(
self,
n_agents,
local_dim,
global_dim,
act_dim,
lr=3e-4,
gamma=0.99,
lam=0.95,
clip_eps=0.2,
k_epochs=10,
batch_size=1024,
episode_len=96
):
self.n_agents = n_agents
self.local_dim = local_dim
self.global_dim = global_dim
self.act_dim = act_dim
self.gamma = gamma
self.lam = lam
self.clip_eps = clip_eps
self.k_epochs = k_epochs
self.batch_size = batch_size
self.episode_len = episode_len
self.actor = Actor(local_dim, act_dim).to(device)
self.critic = Critic(global_dim).to(device)
self.opt_a = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.opt_c = torch.optim.Adam(self.critic.parameters(), lr=lr)
print("MAPPO CUDA AMP is disabled for stability.")
self.init_buffer()
def init_buffer(self):
self.ls_buf = np.zeros((self.episode_len, self.n_agents, self.local_dim), dtype=np.float16)
self.gs_buf = np.zeros((self.episode_len, self.global_dim), dtype=np.float16)
self.ac_buf = np.zeros((self.episode_len, self.n_agents, self.act_dim), dtype=np.float16)
self.lp_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float16)
self.rw_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float16)
self.done_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float16)
self.next_gs_buf = np.zeros((self.episode_len, self.global_dim), dtype=np.float16)
self.step_idx = 0
@torch.no_grad()
def select_action(self, local_obs, global_obs):
l = torch.from_numpy(local_obs).float().to(device)
mean, std = self.actor(l)
dist = Normal(mean, std)
a = dist.sample()
return a.cpu().numpy(), dist.log_prob(a).sum(-1).cpu().numpy()
def store(self, local_obs, global_obs, action, logp, reward, done, next_global_obs):
if self.step_idx < self.episode_len:
self.ls_buf[self.step_idx] = local_obs
self.gs_buf[self.step_idx] = global_obs
self.ac_buf[self.step_idx] = action
self.lp_buf[self.step_idx] = logp
self.rw_buf[self.step_idx] = reward
self.done_buf[self.step_idx] = done
self.next_gs_buf[self.step_idx] = next_global_obs
self.step_idx += 1
def compute_gae(self, T, vals):
N = self.n_agents
vals_agent = vals.unsqueeze(1).expand(-1, N).cpu().numpy()
next_vals_agent = np.zeros_like(vals_agent)
next_vals_agent[:-1] = vals_agent[1:]
if not self.done_buf[T-1].all():
with torch.no_grad():
v_last = self.critic(
torch.from_numpy(self.next_gs_buf[T-1]).float().to(device)
).cpu().item()
next_vals_agent[T-1, :] = v_last
masks = 1.0 - self.done_buf[:T]
rewards = self.rw_buf[:T]
adv = rewards + self.gamma * next_vals_agent * masks - vals_agent
ret = adv + vals_agent
adv_flat = torch.from_numpy(adv.flatten()).to(device)
ret_flat = torch.from_numpy(ret.flatten()).to(device)
return adv_flat, ret_flat
def update(self):
T = self.step_idx
if T == 0: return
gs_tensor = torch.from_numpy(self.gs_buf[:T]).float().to(device)
ls_tensor = torch.from_numpy(self.ls_buf[:T]).float().to(device).view(T * self.n_agents, -1)
ac_tensor = torch.from_numpy(self.ac_buf[:T]).float().to(device).view(T * self.n_agents, -1)
lp_tensor = torch.from_numpy(self.lp_buf[:T]).float().to(device).view(-1)
with torch.no_grad():
vals = self.critic(gs_tensor)
adv_flat, ret_flat = self.compute_gae(T, vals)
adv_flat = (adv_flat - adv_flat.mean()) / (adv_flat.std() + 1e-8)
gs_for_batch = gs_tensor.unsqueeze(1).expand(-1, self.n_agents, -1).reshape(T * self.n_agents, self.global_dim)
dataset = torch.utils.data.TensorDataset(ls_tensor, gs_for_batch, ac_tensor, lp_tensor, adv_flat, ret_flat)
gen = torch.Generator()
gen.manual_seed(SEED)
loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, generator=gen)
for _ in range(self.k_epochs):
for b_ls, b_gs, b_ac, b_lp, b_adv, b_ret in loader:
mean, std = self.actor(b_ls)
dist = Normal(mean, std)
entropy = dist.entropy().mean()
lp_new = dist.log_prob(b_ac).sum(-1)
ratio = torch.exp(lp_new - b_lp)
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * b_adv
actor_loss = -torch.min(surr1, surr2).mean() - 0.01 * entropy
self.opt_a.zero_grad()
actor_loss.backward()
self.opt_a.step()
val_pred = self.critic(b_gs)
critic_loss = nn.MSELoss()(val_pred, b_ret)
self.opt_c.zero_grad()
critic_loss.backward()
self.opt_c.step()
self.step_idx = 0
def save(self, path):
torch.save({'actor': self.actor.state_dict(),
'critic': self.critic.state_dict()}, path)
def load(self, path):
data = torch.load(path, map_location=device)
self.actor.load_state_dict(data['actor'])
self.critic.load_state_dict(data['critic'])