| | import random |
| | import torch |
| |
|
| | |
| | |
| |
|
| |
|
| | class BufferArray: |
| | def __init__(self, array_size, batch_size): |
| | self.array_size = array_size |
| | self.batch_size = batch_size |
| | self.batch_n_vectors = None |
| |
|
| | def init_buffers(self): |
| | self.buffers = [ |
| | [[] for _ in range(self.batch_n_vectors)] for _ in range(self.array_size) |
| | ] |
| |
|
| | def get_batch(self): |
| | "return the idx of the first buffer reaching batch_size and a batch" |
| | assert hasattr(self, "buffers") |
| | for idx, buffer in enumerate(self.buffers): |
| | if len(buffer[0]) >= self.batch_size: |
| | vectors = [[] for _ in range(self.batch_n_vectors)] |
| | for _ in range(self.batch_size): |
| | pop_idx = random.randrange(len(buffer[0])) |
| | for v, b in zip(vectors, buffer): |
| | v.append(b.pop(pop_idx)) |
| | return idx, tuple([torch.stack(v, dim=0) for v in vectors]) |
| | return 0, None |
| |
|
| | def append(self, idx, batch: tuple): |
| | "append batch to the desired buffer. we suppose batch as a tuple such as (x,y)" |
| |
|
| | if idx >= self.array_size: |
| | return |
| | if self.batch_n_vectors is None: |
| | self.batch_n_vectors = len(batch) |
| | self.init_buffers() |
| | else: |
| | assert len(batch) == self.batch_n_vectors |
| | for i, element_vectors in enumerate(batch): |
| | self.buffers[idx][i] = self.buffers[idx][i] + [ |
| | vector for vector in element_vectors |
| | ] |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | class Buffer: |
| | def __init__(self, batch_size): |
| | self.batch_size = batch_size |
| | self.batch_n_vectors = None |
| |
|
| | def init_buffer(self): |
| | self.buffer = [[] for _ in range(self.batch_n_vectors)] |
| |
|
| | def get_batch(self): |
| | "return the idx of the first buffer reaching batch_size and a batch" |
| | if not hasattr(self, "buffer"): |
| | return None |
| | if len(self.buffer[0]) >= self.batch_size: |
| | vectors = [[] for _ in range(self.batch_n_vectors)] |
| | for _ in range(self.batch_size): |
| | pop_idx = random.randrange(len(self.buffer[0])) |
| | for v, b in zip(vectors, self.buffer): |
| | v.append(b.pop(pop_idx)) |
| | return tuple([torch.stack(v, dim=0) for v in vectors]) |
| | return None |
| |
|
| | def append(self, batch: tuple): |
| | "append batch to the desired buffer. we suppose batch as a tuple such as (x,y)" |
| |
|
| | if self.batch_n_vectors is None: |
| | self.batch_n_vectors = len(batch) |
| | self.init_buffer() |
| | else: |
| | assert len(batch) == self.batch_n_vectors |
| | for i, element_vectors in enumerate(batch): |
| | self.buffer[i] = self.buffer[i] + [ |
| | vector for vector in element_vectors |
| | ] |
| |
|