|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import random |
|
|
import numpy as np |
|
|
from torch.distributions import Normal |
|
|
|
|
|
|
|
|
def set_global_seed(seed: int): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
print("Using CUDA (NVIDIA GPU)") |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
print("Using CPU") |
|
|
|
|
|
|
|
|
SEED = 42 |
|
|
set_global_seed(SEED) |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, input_dim, hidden_dims, output_dim): |
|
|
super().__init__() |
|
|
layers = [] |
|
|
last_dim = input_dim |
|
|
for h in hidden_dims: |
|
|
layers += [nn.Linear(last_dim, h), nn.ReLU()] |
|
|
last_dim = h |
|
|
layers.append(nn.Linear(last_dim, output_dim)) |
|
|
self.net = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class Actor(nn.Module): |
|
|
def __init__(self, obs_dim, act_dim, hidden=(64,64)): |
|
|
super().__init__() |
|
|
self.net = MLP(obs_dim, hidden, act_dim) |
|
|
self.log_std = nn.Parameter(torch.zeros(act_dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
mean = self.net(x) |
|
|
std = torch.exp(self.log_std) |
|
|
return mean, std |
|
|
|
|
|
class Critic(nn.Module): |
|
|
def __init__(self, state_dim, hidden=(128,128)): |
|
|
super().__init__() |
|
|
self.net = MLP(state_dim, hidden, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x).squeeze(-1) |
|
|
|
|
|
class MAPPO: |
|
|
def __init__( |
|
|
self, |
|
|
n_agents, |
|
|
local_dim, |
|
|
global_dim, |
|
|
act_dim, |
|
|
lr=3e-4, |
|
|
gamma=0.99, |
|
|
lam=0.95, |
|
|
clip_eps=0.2, |
|
|
k_epochs=10, |
|
|
batch_size=1024 |
|
|
): |
|
|
self.n_agents = n_agents |
|
|
self.gamma = gamma |
|
|
self.lam = lam |
|
|
self.clip_eps = clip_eps |
|
|
self.k_epochs = k_epochs |
|
|
self.batch_size = batch_size |
|
|
|
|
|
self.actor = Actor(local_dim, act_dim).to(device) |
|
|
self.critic = Critic(global_dim).to(device) |
|
|
|
|
|
self.opt_a = torch.optim.Adam(self.actor.parameters(), lr=lr) |
|
|
self.opt_c = torch.optim.Adam(self.critic.parameters(), lr=lr) |
|
|
|
|
|
self.local_dim = local_dim |
|
|
self.global_dim = global_dim |
|
|
self.act_dim = act_dim |
|
|
|
|
|
self.clear_buffer() |
|
|
|
|
|
def clear_buffer(self): |
|
|
self.ls = [] |
|
|
self.gs = [] |
|
|
self.ac = [] |
|
|
self.lp = [] |
|
|
self.rw = [] |
|
|
self.done = [] |
|
|
self.next_gs = [] |
|
|
|
|
|
@torch.no_grad() |
|
|
def select_action(self, local_obs, global_obs): |
|
|
l = torch.FloatTensor(local_obs).to(device) |
|
|
mean, std = self.actor(l) |
|
|
dist = Normal(mean, std) |
|
|
a = dist.sample() |
|
|
return a.cpu().numpy(), dist.log_prob(a).sum(-1).cpu().numpy() |
|
|
|
|
|
def store(self, local_obs, global_obs, action, logp, reward, done, next_global_obs): |
|
|
self.ls.append(local_obs) |
|
|
self.gs.append(global_obs) |
|
|
self.ac.append(action) |
|
|
self.lp.append(logp) |
|
|
self.rw.append(reward) |
|
|
self.done.append(done) |
|
|
self.next_gs.append(next_global_obs) |
|
|
|
|
|
def compute_gae(self, values): |
|
|
""" |
|
|
values: torch.Tensor shape [T] (one central V(s) per timestep) |
|
|
returns: |
|
|
adv_flat: torch.Tensor shape [T * n_agents] |
|
|
ret_flat: torch.Tensor shape [T * n_agents] |
|
|
""" |
|
|
|
|
|
vals_1d = values.cpu().numpy() |
|
|
T = len(vals_1d) |
|
|
N = self.n_agents |
|
|
|
|
|
|
|
|
|
|
|
vals_agent = np.tile(vals_1d[:,None], (1, N)) |
|
|
|
|
|
|
|
|
next_vals = np.zeros_like(vals_agent) |
|
|
next_vals[:-1] = vals_agent[1:] |
|
|
|
|
|
if not self.done[-1]: |
|
|
with torch.no_grad(): |
|
|
v_last = self.critic( |
|
|
torch.FloatTensor(self.next_gs[-1]).to(device) |
|
|
).cpu().item() |
|
|
next_vals[-1, :] = v_last |
|
|
|
|
|
|
|
|
adv = np.zeros_like(vals_agent, dtype=np.float32) |
|
|
prev_adv = np.zeros(N, dtype=np.float32) |
|
|
for t in reversed(range(T)): |
|
|
mask = 1.0 - float(self.done[t]) |
|
|
rew_t = np.array(self.rw[t], dtype=np.float32) |
|
|
delta = rew_t + self.gamma * next_vals[t] * mask - vals_agent[t] |
|
|
prev_adv = delta + self.gamma * self.lam * mask * prev_adv |
|
|
adv[t] = prev_adv |
|
|
|
|
|
|
|
|
ret = adv + vals_agent |
|
|
adv_flat = torch.from_numpy(adv.flatten()).to(device) |
|
|
ret_flat = torch.from_numpy(ret.flatten()).to(device) |
|
|
return adv_flat, ret_flat |
|
|
|
|
|
|
|
|
def update(self): |
|
|
|
|
|
raw_gs = torch.FloatTensor(self.gs).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
vals = self.critic(raw_gs).cpu() |
|
|
|
|
|
|
|
|
adv_flat, ret_flat = self.compute_gae(vals) |
|
|
|
|
|
|
|
|
|
|
|
ls = torch.FloatTensor(self.ls).view(-1, self.local_dim).to(device) |
|
|
|
|
|
ac = torch.FloatTensor(self.ac).view(-1, self.act_dim).to(device) |
|
|
|
|
|
old_lp = torch.FloatTensor(self.lp).view(-1).to(device) |
|
|
|
|
|
|
|
|
gs = raw_gs.unsqueeze(1).expand(-1, self.n_agents, -1) |
|
|
gs = gs.reshape(-1, self.global_dim).to(device) |
|
|
|
|
|
|
|
|
dataset = torch.utils.data.TensorDataset( |
|
|
ls, gs, ac, old_lp, adv_flat, ret_flat |
|
|
) |
|
|
gen = torch.Generator() |
|
|
gen.manual_seed(SEED) |
|
|
loader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=0, |
|
|
generator=gen |
|
|
) |
|
|
|
|
|
for _ in range(self.k_epochs): |
|
|
for b_ls, b_gs, b_ac, b_lp, b_adv, b_ret in loader: |
|
|
|
|
|
mean, std = self.actor(b_ls) |
|
|
dist = Normal(mean, std) |
|
|
lp_new = dist.log_prob(b_ac).sum(-1) |
|
|
ratio = torch.exp(lp_new - b_lp) |
|
|
surr1 = ratio * b_adv |
|
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * b_adv |
|
|
actor_loss = -torch.min(surr1, surr2).mean() |
|
|
|
|
|
self.opt_a.zero_grad() |
|
|
actor_loss.backward() |
|
|
self.opt_a.step() |
|
|
|
|
|
|
|
|
val_pred = self.critic(b_gs) |
|
|
critic_loss = nn.MSELoss()(val_pred, b_ret) |
|
|
|
|
|
self.opt_c.zero_grad() |
|
|
critic_loss.backward() |
|
|
self.opt_c.step() |
|
|
|
|
|
|
|
|
self.clear_buffer() |
|
|
|
|
|
|
|
|
def save(self, path): |
|
|
torch.save({'actor': self.actor.state_dict(), |
|
|
'critic': self.critic.state_dict()}, path) |
|
|
|
|
|
def load(self, path): |
|
|
data = torch.load(path, map_location=device) |
|
|
self.actor.load_state_dict(data['actor']) |
|
|
self.critic.load_state_dict(data['critic']) |
|
|
|