Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from agent_nn import AgentNN | |
| from tensordict import TensorDict | |
| from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage | |
| class Agent: | |
| def __init__(self, | |
| input_dims, | |
| num_actions, | |
| lr=0.00025, | |
| gamma=0.9, | |
| epsilon=1.0, | |
| eps_decay=0.99999975, | |
| eps_min=0.1, | |
| replay_buffer_capacity=100_000, | |
| batch_size=32, | |
| sync_network_rate=10_000): | |
| self.num_actions = num_actions | |
| self.learn_step_counter = 0 | |
| # Hyperparameters | |
| self.lr = lr | |
| self.gamma = gamma | |
| self.epsilon = epsilon | |
| self.eps_decay = eps_decay | |
| self.eps_min = eps_min | |
| self.batch_size = batch_size | |
| self.sync_network_rate = sync_network_rate | |
| # Networks | |
| self.online_network = AgentNN(input_dims, num_actions) | |
| self.target_network = AgentNN(input_dims, num_actions, freeze=True) | |
| # Optimizer and loss | |
| self.optimizer = torch.optim.Adam(self.online_network.parameters(), lr=self.lr) | |
| self.loss = torch.nn.MSELoss() | |
| # self.loss = torch.nn.SmoothL1Loss() # Optional alternative | |
| # Replay buffer | |
| storage = LazyMemmapStorage(replay_buffer_capacity) | |
| self.replay_buffer = TensorDictReplayBuffer(storage=storage) | |
| def choose_action(self, observation): | |
| if np.random.random() < self.epsilon: | |
| return np.random.randint(self.num_actions) | |
| # Convert observation safely to float32 tensor | |
| observation = torch.tensor(np.array(observation), dtype=torch.float32) \ | |
| .unsqueeze(0).to(self.online_network.device) | |
| return self.online_network(observation).argmax().item() | |
| def decay_epsilon(self): | |
| self.epsilon = max(self.epsilon * self.eps_decay, self.eps_min) | |
| def store_in_memory(self, state, action, reward, next_state, done): | |
| # Ensure all tensors are float32 and consistent | |
| self.replay_buffer.add(TensorDict({ | |
| "state": torch.tensor(np.array(state), dtype=torch.float32), | |
| "action": torch.tensor(action, dtype=torch.long), | |
| "reward": torch.tensor(reward, dtype=torch.float32), | |
| "next_state": torch.tensor(np.array(next_state), dtype=torch.float32), | |
| "done": torch.tensor(done, dtype=torch.float32) | |
| }, batch_size=[])) | |
| def sync_networks(self): | |
| if self.learn_step_counter % self.sync_network_rate == 0 and self.learn_step_counter > 0: | |
| self.target_network.load_state_dict(self.online_network.state_dict()) | |
| def save_model(self, path): | |
| torch.save(self.online_network.state_dict(), path) | |
| def load_model(self, path): | |
| self.online_network.load_state_dict(torch.load(path, map_location=self.online_network.device)) | |
| self.target_network.load_state_dict(torch.load(path, map_location=self.online_network.device)) | |
| def learn(self): | |
| if len(self.replay_buffer) < self.batch_size: | |
| return | |
| self.sync_networks() | |
| self.optimizer.zero_grad() | |
| samples = self.replay_buffer.sample(self.batch_size).to(self.online_network.device) | |
| keys = ("state", "action", "reward", "next_state", "done") | |
| states, actions, rewards, next_states, dones = [samples[key] for key in keys] | |
| # Ensure correct types and device placement | |
| states = states.float().to(self.online_network.device) | |
| next_states = next_states.float().to(self.online_network.device) | |
| rewards = rewards.float().to(self.online_network.device) | |
| dones = dones.float().to(self.online_network.device) | |
| actions = actions.long().to(self.online_network.device) | |
| # Compute predicted Q-values | |
| predicted_q_values = self.online_network(states) | |
| predicted_q_values = predicted_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) | |
| # Compute target Q-values | |
| target_q_values = self.target_network(next_states).max(dim=1)[0] | |
| target_q_values = rewards + self.gamma * target_q_values * (1 - dones) | |
| # Loss and optimization | |
| loss = self.loss(predicted_q_values, target_q_values.detach()) | |
| loss.backward() | |
| self.optimizer.step() | |
| self.learn_step_counter += 1 | |
| self.decay_epsilon() | |