| import random |
| from collections import namedtuple, deque |
| import torch |
|
|
| """A tuple representing a transition in the shape of a namedtuple""" |
| Transition = namedtuple('Transition', |
| ('state', 'action', 'value', 'reward', 'next_state', 'done')) |
|
|
|
|
| class ReplayMemory: |
| """This class implements the replay memory used for the agent training. |
| |
| It is used as a way to abstract the replay memory using the deque |
| object. |
| |
| Attributes: |
| memory: A deque of Transitions |
| """ |
| def __init__(self, capacity): |
| """The init method for the ReplayMemory class, intializing the deque""" |
| self.memory = deque([], maxlen=capacity) |
|
|
| def push(self, *args): |
| """Save a transition""" |
| self.memory.append(Transition(*args)) |
|
|
| def sample(self, batch_size): |
| """Sample a batch of transitions""" |
| return random.sample(self.memory, batch_size) |
|
|
| def __len__(self): |
| """Size of the current replay memory""" |
| return len(self.memory) |
|
|
|
|
| class RolloutBuffer: |
| """Simple on-policy rollout storage used by PPO-style updates.""" |
|
|
| def __init__(self): |
| self.clear() |
|
|
| def clear(self): |
| self.states = [] |
| self.coord_actions = [] |
| self.tile_actions = [] |
| self.rewards = [] |
| self.dones = [] |
| self.log_probs = [] |
| self.values = [] |
|
|
| def add(self, state, coord_action, tile_action, reward, done, log_prob, value): |
| self.states.append(state.detach()) |
| self.coord_actions.append(coord_action.detach()) |
| self.tile_actions.append(tile_action.detach()) |
| self.rewards.append(float(reward)) |
| self.dones.append(float(done)) |
| self.log_probs.append(log_prob.detach()) |
| self.values.append(value.detach()) |
|
|
| def __len__(self): |
| return len(self.rewards) |
|
|
| def as_tensors(self, device, dtype): |
| states = torch.cat(self.states, dim=0).to(device=device, dtype=dtype) |
| coord_actions = torch.cat(self.coord_actions, dim=0).to(device=device, dtype=torch.long) |
| tile_actions = torch.cat(self.tile_actions, dim=0).to(device=device, dtype=torch.long) |
| rewards = torch.tensor(self.rewards, device=device, dtype=dtype) |
| dones = torch.tensor(self.dones, device=device, dtype=dtype) |
| log_probs = torch.cat(self.log_probs, dim=0).to(device=device, dtype=dtype) |
| values = torch.cat(self.values, dim=0).to(device=device, dtype=dtype) |
| return states, coord_actions, tile_actions, rewards, dones, log_probs, values |
|
|