import torch import numpy as np from agent_nn import AgentNN from tensordict import TensorDict from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage class Agent: def __init__(self, input_dims, num_actions, lr=0.00025, gamma=0.9, epsilon=1.0, eps_decay=0.99999975, eps_min=0.1, replay_buffer_capacity=100_000, batch_size=32, sync_network_rate=10_000): self.num_actions = num_actions self.learn_step_counter = 0 # Hyperparameters self.lr = lr self.gamma = gamma self.epsilon = epsilon self.eps_decay = eps_decay self.eps_min = eps_min self.batch_size = batch_size self.sync_network_rate = sync_network_rate # Networks self.online_network = AgentNN(input_dims, num_actions) self.target_network = AgentNN(input_dims, num_actions, freeze=True) # Optimizer and loss self.optimizer = torch.optim.Adam(self.online_network.parameters(), lr=self.lr) self.loss = torch.nn.MSELoss() # self.loss = torch.nn.SmoothL1Loss() # Optional alternative # Replay buffer storage = LazyMemmapStorage(replay_buffer_capacity) self.replay_buffer = TensorDictReplayBuffer(storage=storage) def choose_action(self, observation): if np.random.random() < self.epsilon: return np.random.randint(self.num_actions) # Convert observation safely to float32 tensor observation = torch.tensor(np.array(observation), dtype=torch.float32) \ .unsqueeze(0).to(self.online_network.device) return self.online_network(observation).argmax().item() def decay_epsilon(self): self.epsilon = max(self.epsilon * self.eps_decay, self.eps_min) def store_in_memory(self, state, action, reward, next_state, done): # Ensure all tensors are float32 and consistent self.replay_buffer.add(TensorDict({ "state": torch.tensor(np.array(state), dtype=torch.float32), "action": torch.tensor(action, dtype=torch.long), "reward": torch.tensor(reward, dtype=torch.float32), "next_state": torch.tensor(np.array(next_state), dtype=torch.float32), "done": torch.tensor(done, dtype=torch.float32) }, batch_size=[])) def sync_networks(self): if self.learn_step_counter % self.sync_network_rate == 0 and self.learn_step_counter > 0: self.target_network.load_state_dict(self.online_network.state_dict()) def save_model(self, path): torch.save(self.online_network.state_dict(), path) def load_model(self, path): self.online_network.load_state_dict(torch.load(path, map_location=self.online_network.device)) self.target_network.load_state_dict(torch.load(path, map_location=self.online_network.device)) def learn(self): if len(self.replay_buffer) < self.batch_size: return self.sync_networks() self.optimizer.zero_grad() samples = self.replay_buffer.sample(self.batch_size).to(self.online_network.device) keys = ("state", "action", "reward", "next_state", "done") states, actions, rewards, next_states, dones = [samples[key] for key in keys] # Ensure correct types and device placement states = states.float().to(self.online_network.device) next_states = next_states.float().to(self.online_network.device) rewards = rewards.float().to(self.online_network.device) dones = dones.float().to(self.online_network.device) actions = actions.long().to(self.online_network.device) # Compute predicted Q-values predicted_q_values = self.online_network(states) predicted_q_values = predicted_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # Compute target Q-values target_q_values = self.target_network(next_states).max(dim=1)[0] target_q_values = rewards + self.gamma * target_q_values * (1 - dones) # Loss and optimization loss = self.loss(predicted_q_values, target_q_values.detach()) loss.backward() self.optimizer.step() self.learn_step_counter += 1 self.decay_epsilon()