mvi-ai-engine / training /replay_buffer.py
Musombi's picture
Create training/replay_buffer.py
ef03f56
import random
import torch
class CognitiveReplayBuffer:
def __init__(self, max_size=5000):
self.max_size = max_size
self.buffer = []
def push(self, state, target):
if len(self.buffer) >= self.max_size:
self.buffer.pop(0)
self.buffer.append((state, target))
def sample(self, batch_size=16):
batch = random.sample(self.buffer, min(batch_size, len(self.buffer)))
states, targets = zip(*batch)
return (
torch.stack(states),
torch.stack(targets)
)