gdm-ppo / memory.py
trixyL's picture
add: dump first try
d2231f2
import random
from collections import namedtuple, deque
import torch
"""A tuple representing a transition in the shape of a namedtuple"""
Transition = namedtuple('Transition',
('state', 'action', 'value', 'reward', 'next_state', 'done'))
class ReplayMemory:
"""This class implements the replay memory used for the agent training.
It is used as a way to abstract the replay memory using the deque
object.
Attributes:
memory: A deque of Transitions
"""
def __init__(self, capacity):
"""The init method for the ReplayMemory class, intializing the deque"""
self.memory = deque([], maxlen=capacity)
def push(self, *args):
"""Save a transition"""
self.memory.append(Transition(*args))
def sample(self, batch_size):
"""Sample a batch of transitions"""
return random.sample(self.memory, batch_size)
def __len__(self):
"""Size of the current replay memory"""
return len(self.memory)
class RolloutBuffer:
"""Simple on-policy rollout storage used by PPO-style updates."""
def __init__(self):
self.clear()
def clear(self):
self.states = []
self.coord_actions = []
self.tile_actions = []
self.rewards = []
self.dones = []
self.log_probs = []
self.values = []
def add(self, state, coord_action, tile_action, reward, done, log_prob, value):
self.states.append(state.detach())
self.coord_actions.append(coord_action.detach())
self.tile_actions.append(tile_action.detach())
self.rewards.append(float(reward))
self.dones.append(float(done))
self.log_probs.append(log_prob.detach())
self.values.append(value.detach())
def __len__(self):
return len(self.rewards)
def as_tensors(self, device, dtype):
states = torch.cat(self.states, dim=0).to(device=device, dtype=dtype)
coord_actions = torch.cat(self.coord_actions, dim=0).to(device=device, dtype=torch.long)
tile_actions = torch.cat(self.tile_actions, dim=0).to(device=device, dtype=torch.long)
rewards = torch.tensor(self.rewards, device=device, dtype=dtype)
dones = torch.tensor(self.dones, device=device, dtype=dtype)
log_probs = torch.cat(self.log_probs, dim=0).to(device=device, dtype=dtype)
values = torch.cat(self.values, dim=0).to(device=device, dtype=dtype)
return states, coord_actions, tile_actions, rewards, dones, log_probs, values