SolarSys2025's picture
Upload 30 files
55da406 verified
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
from torch.utils.data import Dataset, DataLoader
class ReplayBufferDataset(Dataset):
def __init__(self, max_size=100000):
self.buffer = deque(maxlen=max_size)
def add(self, states, actions, rewards, next_states, done):
data = (
states,
actions,
np.array(rewards, dtype=np.float32),
next_states,
np.float32(done)
)
self.buffer.append(data)
def __len__(self):
return len(self.buffer)
def __getitem__(self, idx):
states, actions, rewards, next_states, done = self.buffer[idx]
return (
torch.from_numpy(states),
torch.from_numpy(actions),
torch.from_numpy(rewards),
torch.from_numpy(next_states),
torch.tensor(done, dtype=torch.float32)
)
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super(Actor, self).__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Sigmoid()
)
def forward(self, state):
return self.net(state)
class SharedCritic(nn.Module):
def __init__(self, global_state_dim, global_action_dim, hidden_dim=128, num_agents=1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(global_state_dim + global_action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_agents)
)
def forward(self, global_state, global_action):
x = torch.cat([global_state, global_action], dim=1)
return self.net(x)
class Agent:
def __init__(self, local_state_dim, action_dim, lr_actor=1e-3, device=torch.device('cpu')):
self.device = device
self.actor = Actor(local_state_dim, action_dim).to(device)
self.target_actor = Actor(local_state_dim, action_dim).to(device)
self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr_actor)
self.target_actor.load_state_dict(self.actor.state_dict())
def sync_target(self, tau):
for tp, p in zip(self.target_actor.parameters(), self.actor.parameters()):
tp.data.copy_(tau * p.data + (1.0 - tau) * tp.data)
class MADDPG:
def __init__(self, num_agents, local_state_dim, action_dim,
gamma=0.95, tau=0.01, lr_actor=1e-4, lr_critic=1e-3,
buffer_size=100000, noise_episodes=100, init_sigma=0.2, final_sigma=0.01,
batch_size=128, num_workers=0):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.num_agents = num_agents
self.gamma = gamma
self.tau = tau
self.init_sigma = init_sigma
self.final_sigma = final_sigma
self.noise_episodes = noise_episodes
self.current_episode = 0
self.actor = Actor(local_state_dim, action_dim).to(self.device)
self.target_actor = Actor(local_state_dim, action_dim).to(self.device)
self.target_actor.load_state_dict(self.actor.state_dict())
self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr_actor)
global_state_dim = num_agents * local_state_dim
global_action_dim = num_agents * action_dim
self.critic = SharedCritic(global_state_dim, global_action_dim, num_agents=num_agents).to(self.device)
self.target_critic = SharedCritic(global_state_dim, global_action_dim, num_agents=num_agents).to(self.device)
self.target_critic.load_state_dict(self.critic.state_dict())
self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr_critic)
self.batch_size = batch_size
self.num_workers = num_workers
self.memory = ReplayBufferDataset(max_size=buffer_size)
self.dataloader = None
self.loader_iter = None
def select_actions(self, states, evaluate=False):
states_t = torch.as_tensor(states, dtype=torch.float32, device=self.device)
with torch.no_grad():
actions_t = torch.stack([
self.actor(states_t[i]) for i in range(self.num_agents)
], dim=0)
actions = actions_t.cpu().numpy()
if not evaluate:
frac = min(float(self.current_episode) / self.noise_episodes, 1.0)
current_sigma = self.init_sigma - frac * (self.init_sigma - self.final_sigma)
noise = np.random.normal(0, current_sigma, size=actions.shape)
actions += noise
return np.clip(actions, 0.0, 1.0)
def store_transition(self, states, actions, rewards, next_states, done):
self.memory.add(states, actions, rewards, next_states, done)
def train(self):
if len(self.memory) < self.batch_size:
return
should_pin_memory = self.device.type == 'cuda'
if self.dataloader is None:
self.dataloader = DataLoader(self.memory, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=should_pin_memory, drop_last=True)
self.loader_iter = iter(self.dataloader)
try:
s, a, r, s2, d = next(self.loader_iter)
except StopIteration:
self.dataloader = DataLoader(self.memory, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=should_pin_memory, drop_last=True)
self.loader_iter = iter(self.dataloader)
s, a, r, s2, d = next(self.loader_iter)
s_t, a_t, r_t, s2_t, d_t = s.to(self.device), a.to(self.device), r.to(self.device), s2.to(self.device), d.to(self.device).unsqueeze(-1)
r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-7)
batch_len = s_t.shape[0]
gs, ga, ns = s_t.reshape(batch_len, -1), a_t.reshape(batch_len, -1), s2_t.reshape(batch_len, -1)
with torch.no_grad():
targ_actions = torch.cat([self.target_actor(s2_t[:, i, :]) for i in range(self.num_agents)], dim=1)
Q_prime = self.target_critic(ns, targ_actions)
targets = r_t + self.gamma * (1 - d_t) * Q_prime
Q = self.critic(gs, ga)
critic_loss = nn.MSELoss()(Q, targets)
self.critic_optim.zero_grad()
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0)
self.critic_optim.step()
all_actions = torch.cat([self.actor(s_t[:, i, :]) for i in range(self.num_agents)], dim=1)
actor_loss = -self.critic(gs, all_actions).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
self.actor_optim.step()
for tp, p in zip(self.target_actor.parameters(), self.actor.parameters()):
tp.data.copy_(self.tau * p.data + (1.0 - self.tau) * tp.data)
for tp, p in zip(self.target_critic.parameters(), self.critic.parameters()):
tp.data.copy_(self.tau * p.data + (1.0 - self.tau) * tp.data)
def on_episode_end(self):
self.current_episode += 1
def save(self, path: str):
payload = {
"critic": self.critic.state_dict(),
"target_critic": self.target_critic.state_dict(),
"critic_optim": self.critic_optim.state_dict(),
"actor": self.actor.state_dict(),
"target_actor": self.target_actor.state_dict(),
"actor_optim": self.actor_optim.state_dict(),
"current_episode": self.current_episode,
}
torch.save(payload, path)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device)
self.critic.load_state_dict(checkpoint["critic"])
self.target_critic.load_state_dict(checkpoint["target_critic"])
self.critic_optim.load_state_dict(checkpoint["critic_optim"])
self.actor.load_state_dict(checkpoint["actor"])
self.target_actor.load_state_dict(checkpoint["target_actor"])
self.actor_optim.load_state_dict(checkpoint["actor_optim"])
self.current_episode = checkpoint.get("current_episode", 0)