|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import random |
|
|
from torch.distributions import Normal |
|
|
from torch.amp import autocast |
|
|
from torch.cuda.amp import GradScaler |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
print("Using CUDA (NVIDIA GPU)") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
print("Using CPU") |
|
|
|
|
|
def set_global_seed(seed: int): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = False |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
SEED = 42 |
|
|
set_global_seed(SEED) |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, input_dim, hidden_dims, output_dim): |
|
|
super().__init__() |
|
|
layers = [] |
|
|
last_dim = input_dim |
|
|
for h in hidden_dims: |
|
|
layers += [nn.Linear(last_dim, h), nn.ReLU()] |
|
|
last_dim = h |
|
|
layers.append(nn.Linear(last_dim, output_dim)) |
|
|
self.net = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class Actor(nn.Module): |
|
|
def __init__(self, obs_dim, mean_field_dim, act_dim, hidden=(64, 64)): |
|
|
super().__init__() |
|
|
input_dim = obs_dim + mean_field_dim |
|
|
self.net = MLP(input_dim, hidden, act_dim) |
|
|
self.log_std = nn.Parameter(torch.zeros(act_dim)) |
|
|
|
|
|
def forward(self, local_obs, mean_field): |
|
|
x = torch.cat([local_obs, mean_field], dim=-1) |
|
|
mean = self.net(x) |
|
|
LOG_STD_MIN = -5 |
|
|
LOG_STD_MAX = 2 |
|
|
clamped_log_std = torch.clamp(self.log_std, LOG_STD_MIN, LOG_STD_MAX) |
|
|
std = torch.exp(clamped_log_std) |
|
|
|
|
|
return Normal(mean, std) |
|
|
|
|
|
class Critic(nn.Module): |
|
|
def __init__(self, obs_dim, mean_field_dim, hidden=(128, 128)): |
|
|
super().__init__() |
|
|
input_dim = obs_dim + mean_field_dim |
|
|
self.net = MLP(input_dim, hidden, 1) |
|
|
|
|
|
def forward(self, local_obs, mean_field): |
|
|
x = torch.cat([local_obs, mean_field], dim=-1) |
|
|
return self.net(x).squeeze(-1) |
|
|
|
|
|
class MFAC: |
|
|
def __init__( |
|
|
self, |
|
|
n_agents, |
|
|
local_dim, |
|
|
act_dim, |
|
|
lr=3e-4, |
|
|
gamma=0.99, |
|
|
lam=0.95, |
|
|
clip_eps=0.2, |
|
|
k_epochs=10, |
|
|
batch_size=1024, |
|
|
entropy_coeff=0.01, |
|
|
episode_len=96 |
|
|
): |
|
|
self.n_agents = n_agents |
|
|
self.local_dim = local_dim |
|
|
self.mean_field_dim = local_dim |
|
|
self.act_dim = act_dim |
|
|
self.gamma = gamma |
|
|
self.lam = lam |
|
|
self.clip_eps = clip_eps |
|
|
self.k_epochs = k_epochs |
|
|
self.batch_size = batch_size |
|
|
self.entropy_coeff = entropy_coeff |
|
|
self.episode_len = episode_len |
|
|
|
|
|
self.actor = Actor(self.local_dim, self.mean_field_dim, self.act_dim).to(device) |
|
|
self.critic = Critic(self.local_dim, self.mean_field_dim).to(device) |
|
|
|
|
|
self.opt_a = torch.optim.Adam(self.actor.parameters(), lr=lr) |
|
|
self.opt_c = torch.optim.Adam(self.critic.parameters(), lr=lr) |
|
|
|
|
|
self.use_cuda_amp = (device.type == 'cuda') |
|
|
self.scaler = GradScaler(enabled=self.use_cuda_amp) |
|
|
print(f"MFAC CUDA AMP Enabled: {self.use_cuda_amp}") |
|
|
|
|
|
self.init_buffer() |
|
|
|
|
|
def init_buffer(self): |
|
|
self.ls_buf = np.zeros((self.episode_len, self.n_agents, self.local_dim), dtype=np.float32) |
|
|
self.ac_buf = np.zeros((self.episode_len, self.n_agents, self.act_dim), dtype=np.float32) |
|
|
self.lp_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float32) |
|
|
self.rw_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float32) |
|
|
self.done_buf = np.zeros((self.episode_len, self.n_agents), dtype=np.float32) |
|
|
self.next_ls_buf = np.zeros((self.episode_len, self.n_agents, self.local_dim), dtype=np.float32) |
|
|
self.step_idx = 0 |
|
|
|
|
|
def clear_buffer(self): |
|
|
pass |
|
|
|
|
|
def _get_mean_field(self, obs_batch): |
|
|
if self.n_agents <= 1: |
|
|
return torch.zeros(*obs_batch.shape[:-1], self.mean_field_dim, device=obs_batch.device) |
|
|
total_obs = torch.sum(obs_batch, dim=-2, keepdim=True) |
|
|
mean_field = (total_obs - obs_batch) / (self.n_agents - 1) |
|
|
return mean_field |
|
|
|
|
|
@torch.no_grad() |
|
|
def select_action(self, local_obs, evaluate=False): |
|
|
obs_tensor = torch.from_numpy(local_obs).float().to(device) |
|
|
with autocast(device_type=device.type, dtype=torch.float16, enabled=self.use_cuda_amp): |
|
|
mean_field = self._get_mean_field(obs_tensor) |
|
|
dist = self.actor(obs_tensor, mean_field) |
|
|
if evaluate: |
|
|
action = dist.mean |
|
|
else: |
|
|
action = dist.sample() |
|
|
|
|
|
log_prob = dist.log_prob(action).sum(-1) |
|
|
return action.cpu().numpy(), log_prob.cpu().numpy() |
|
|
|
|
|
def store(self, local_obs, action, logp, reward, done, next_local_obs): |
|
|
if self.step_idx < self.episode_len: |
|
|
self.ls_buf[self.step_idx] = local_obs |
|
|
self.ac_buf[self.step_idx] = action |
|
|
self.lp_buf[self.step_idx] = logp |
|
|
self.rw_buf[self.step_idx] = np.array(reward, dtype=np.float32) |
|
|
self.done_buf[self.step_idx] = np.array(done, dtype=np.float32) |
|
|
self.next_ls_buf[self.step_idx] = next_local_obs |
|
|
self.step_idx += 1 |
|
|
|
|
|
def update(self): |
|
|
T = self.step_idx |
|
|
if T == 0: return |
|
|
|
|
|
ls_tensor = torch.from_numpy(self.ls_buf[:T]).float().to(device) |
|
|
ac_tensor = torch.from_numpy(self.ac_buf[:T]).float().to(device) |
|
|
lp_tensor = torch.from_numpy(self.lp_buf[:T]).float().to(device) |
|
|
rw_tensor = torch.from_numpy(self.rw_buf[:T]).float().to(device) |
|
|
done_tensor = torch.from_numpy(self.done_buf[:T]).float().to(device) |
|
|
next_ls_tensor = torch.from_numpy(self.next_ls_buf[:T]).float().to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
with autocast(device_type=device.type, dtype=torch.float16, enabled=self.use_cuda_amp): |
|
|
mf_all = self._get_mean_field(ls_tensor) |
|
|
vals = self.critic(ls_tensor, mf_all) |
|
|
next_mf_all = self._get_mean_field(next_ls_tensor) |
|
|
next_vals = self.critic(next_ls_tensor, next_mf_all) |
|
|
adv = torch.zeros_like(rw_tensor) |
|
|
gae = 0 |
|
|
masks = 1.0 - done_tensor |
|
|
for t in reversed(range(T)): |
|
|
delta = rw_tensor[t] + self.gamma * next_vals[t] * masks[t] - vals[t] |
|
|
gae = delta + self.gamma * self.lam * masks[t] * gae |
|
|
adv[t] = gae |
|
|
ret = adv + vals |
|
|
|
|
|
N, D_l = self.n_agents, self.local_dim |
|
|
|
|
|
ls_flat = ls_tensor.view(T * N, D_l) |
|
|
mf_flat = mf_all.view(T * N, self.mean_field_dim) |
|
|
ac_flat = ac_tensor.view(T * N, self.act_dim) |
|
|
lp_flat = lp_tensor.view(-1) |
|
|
adv_flat = adv.view(-1) |
|
|
ret_flat = ret.view(-1) |
|
|
|
|
|
adv_flat = (adv_flat - adv_flat.mean()) / (adv_flat.std() + 1e-8) |
|
|
ret_flat = (ret_flat - ret_flat.mean()) / (ret_flat.std() + 1e-8) |
|
|
|
|
|
dataset = torch.utils.data.TensorDataset(ls_flat, mf_flat, ac_flat, lp_flat, adv_flat, ret_flat) |
|
|
gen = torch.Generator() |
|
|
gen.manual_seed(SEED) |
|
|
loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, generator=gen) |
|
|
|
|
|
for _ in range(self.k_epochs): |
|
|
for b_ls, b_mf, b_ac, b_lp, b_adv, b_ret in loader: |
|
|
|
|
|
self.opt_a.zero_grad(set_to_none=True) |
|
|
with autocast(device_type=device.type, dtype=torch.float16, enabled=self.use_cuda_amp): |
|
|
dist_new = self.actor(b_ls, b_mf) |
|
|
lp_new = dist_new.log_prob(b_ac).sum(-1) |
|
|
entropy = dist_new.entropy().sum(-1).mean() |
|
|
log_ratio = torch.clamp(lp_new - b_lp, -20.0, 20.0) |
|
|
ratio = torch.exp(log_ratio) |
|
|
surr1 = ratio * b_adv |
|
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * b_adv |
|
|
actor_loss = -torch.min(surr1, surr2).mean() - self.entropy_coeff * entropy |
|
|
|
|
|
self.scaler.scale(actor_loss).backward() |
|
|
self.scaler.unscale_(self.opt_a) |
|
|
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=0.5) |
|
|
self.scaler.step(self.opt_a) |
|
|
|
|
|
self.opt_c.zero_grad(set_to_none=True) |
|
|
with autocast(device_type=device.type, dtype=torch.float16, enabled=self.use_cuda_amp): |
|
|
val_pred = self.critic(b_ls, b_mf) |
|
|
critic_loss = nn.MSELoss()(val_pred, b_ret) |
|
|
|
|
|
self.scaler.scale(critic_loss).backward() |
|
|
self.scaler.unscale_(self.opt_c) |
|
|
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=0.5) |
|
|
self.scaler.step(self.opt_c) |
|
|
|
|
|
self.scaler.update() |
|
|
|
|
|
self.step_idx = 0 |
|
|
|
|
|
def save(self, path): |
|
|
torch.save({ |
|
|
'actor': self.actor.state_dict(), |
|
|
'critic': self.critic.state_dict() |
|
|
}, path) |
|
|
|
|
|
def load(self, path): |
|
|
data = torch.load(path, map_location=device) |
|
|
self.actor.load_state_dict(data['actor']) |
|
|
self.critic.load_state_dict(data['critic']) |