| """Experience replay buffer for DQN training.""" |
| import random |
| from collections import deque |
|
|
| import numpy as np |
|
|
|
|
| class ReplayBuffer: |
| def __init__(self, capacity: int = 50_000): |
| self.buffer: deque = deque(maxlen=capacity) |
|
|
| def push(self, state, action, reward, next_state, done): |
| self.buffer.append((state, action, reward, next_state, float(done))) |
|
|
| def sample(self, batch_size: int): |
| batch = random.sample(self.buffer, batch_size) |
| states, actions, rewards, next_states, dones = zip(*batch) |
| return ( |
| np.array(states, dtype=np.float32), |
| np.array(actions, dtype=np.int64), |
| np.array(rewards, dtype=np.float32), |
| np.array(next_states, dtype=np.float32), |
| np.array(dones, dtype=np.float32), |
| ) |
|
|
| def __len__(self) -> int: |
| return len(self.buffer) |
|
|