File size: 1,771 Bytes
ab0f6ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: penhe@microsoft.com
# Date: 05/15/2019
#

import pdb
from torch.utils.data import Dataset
import random
import mmap
import numpy as np
from bisect import bisect
from ..utils import get_logger
logger=get_logger()

__all__ = ['DynamicDataset']

class DynamicDataset(Dataset):
  def __init__(self, corpus, feature_fn, dataset_size=None, shuffle=False, **kwargs):
    self.corpus = corpus
    self.ds_len = len(self.corpus)
    logger.info(f'Total corpus examples: {self.ds_len}')
    self.feature_fn = feature_fn

    if not dataset_size:
      self.dataset_size = self.ds_len
    else:
      self.dataset_size = int(dataset_size)

    self.shuffle = shuffle
    index_buf = mmap.mmap(-1, self.dataset_size*8)
    shuffle_idx = np.ndarray(shape=(self.dataset_size, ), buffer=index_buf, dtype=int)
    shuffle_idx[:] = np.arange(self.dataset_size)[:]
    if self.shuffle:
      #rng = np.random.RandomState(0)
      rng = random.Random(0)
      rng.shuffle(shuffle_idx)
    self.shuffle_idx = shuffle_idx
    self.index_offset = 0
    if 'index_offset' in kwargs:
      self.index_offset = kwargs['index_offset']

  def __len__(self):
    return self.dataset_size

  def __getitem__(self, idx):
    if isinstance(idx, tuple) or isinstance(idx, list):
      idx, ext_params = idx
    else:
      ext_params = None
    idx += self.index_offset
    seed = idx
    rng = random.Random(seed)
    # get seq length
    example_idx = self.shuffle_idx[idx%self.dataset_size]%self.ds_len
    example = self.corpus[example_idx, rng, ext_params]
    return self.feature_fn(example, rng, ext_params = ext_params)