File size: 889 Bytes
fb62752 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | """Experience replay buffer for DQN training."""
import random
from collections import deque
import numpy as np
class ReplayBuffer:
def __init__(self, capacity: int = 50_000):
self.buffer: deque = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, float(done)))
def sample(self, batch_size: int):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states, dtype=np.float32),
np.array(actions, dtype=np.int64),
np.array(rewards, dtype=np.float32),
np.array(next_states, dtype=np.float32),
np.array(dones, dtype=np.float32),
)
def __len__(self) -> int:
return len(self.buffer)
|