| import torch | |
| class Memory: | |
| def __init__(self, num_nodes, memory_dim, device): | |
| self.memory = torch.zeros((num_nodes, memory_dim), device=device) | |
| def get(self, node_ids): | |
| return self.memory[node_ids].detach() | |
| def update(self, node_ids, values): | |
| for idx in range(len(node_ids)): | |
| self.memory[int(node_ids[idx].item())] = values[idx].detach() | |