RL-Recommendation-System / src /data /replay_buffer.py
mnoorchenar's picture
Update 2026-03-23 09:06:26
fb62752
raw
history blame contribute delete
889 Bytes
"""Experience replay buffer for DQN training."""
import random
from collections import deque
import numpy as np
class ReplayBuffer:
def __init__(self, capacity: int = 50_000):
self.buffer: deque = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, float(done)))
def sample(self, batch_size: int):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states, dtype=np.float32),
np.array(actions, dtype=np.int64),
np.array(rewards, dtype=np.float32),
np.array(next_states, dtype=np.float32),
np.array(dones, dtype=np.float32),
)
def __len__(self) -> int:
return len(self.buffer)