Spaces:
Runtime error
Runtime error
| import random | |
| import torch | |
| class CognitiveReplayBuffer: | |
| def __init__(self, max_size=5000): | |
| self.max_size = max_size | |
| self.buffer = [] | |
| def push(self, state, target): | |
| if len(self.buffer) >= self.max_size: | |
| self.buffer.pop(0) | |
| self.buffer.append((state, target)) | |
| def sample(self, batch_size=16): | |
| batch = random.sample(self.buffer, min(batch_size, len(self.buffer))) | |
| states, targets = zip(*batch) | |
| return ( | |
| torch.stack(states), | |
| torch.stack(targets) | |
| ) |