| import pickle |
| from copy import deepcopy |
|
|
| import numpy as np |
|
|
|
|
| class IndexedDataset: |
| def __init__(self, path, num_cache=1): |
| super().__init__() |
| self.path = path |
| self.data_file = None |
| self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets'] |
| self.data_file = open(f"{path}.data", 'rb', buffering=-1) |
| self.cache = [] |
| self.num_cache = num_cache |
|
|
| def check_index(self, i): |
| if i < 0 or i >= len(self.data_offsets) - 1: |
| raise IndexError('index out of range') |
|
|
| def __del__(self): |
| if self.data_file: |
| self.data_file.close() |
|
|
| def __getitem__(self, i): |
| self.check_index(i) |
| if self.num_cache > 0: |
| for c in self.cache: |
| if c[0] == i: |
| return c[1] |
| self.data_file.seek(self.data_offsets[i]) |
| b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i]) |
| item = pickle.loads(b) |
| if self.num_cache > 0: |
| self.cache = [(i, deepcopy(item))] + self.cache[:-1] |
| return item |
|
|
| def __len__(self): |
| return len(self.data_offsets) - 1 |
|
|
| class IndexedDatasetBuilder: |
| def __init__(self, path): |
| self.path = path |
| self.out_file = open(f"{path}.data", 'wb') |
| self.byte_offsets = [0] |
|
|
| def add_item(self, item): |
| s = pickle.dumps(item) |
| bytes = self.out_file.write(s) |
| self.byte_offsets.append(self.byte_offsets[-1] + bytes) |
|
|
| def finalize(self): |
| self.out_file.close() |
| np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets}) |
|
|
|
|
| if __name__ == "__main__": |
| import random |
| from tqdm import tqdm |
| ds_path = '/tmp/indexed_ds_example' |
| size = 100 |
| items = [{"a": np.random.normal(size=[10000, 10]), |
| "b": np.random.normal(size=[10000, 10])} for i in range(size)] |
| builder = IndexedDatasetBuilder(ds_path) |
| for i in tqdm(range(size)): |
| builder.add_item(items[i]) |
| builder.finalize() |
| ds = IndexedDataset(ds_path) |
| for i in tqdm(range(10000)): |
| idx = random.randint(0, size - 1) |
| assert (ds[idx]['a'] == items[idx]['a']).all() |
|
|