| | |
| | |
| |
|
| | import numpy as np |
| | from torch.utils.data import Dataset |
| |
|
| |
|
| | class RandomDataset(Dataset): |
| | """ |
| | RandomDataset for generating random dataset. |
| | |
| | Args: |
| | num_samples (int): The number of samples to generate. |
| | max_len (int): The maximum length of each sample. |
| | |
| | """ |
| |
|
| | def __init__(self, num_samples=10000, max_len=1024) -> None: |
| | super().__init__() |
| | rng = np.random.RandomState(1999) |
| | max_num = rng.randint(1, 30, size=(num_samples,)) |
| | rep_num = rng.randint(10, 200, size=(num_samples,)) |
| | data = [] |
| | lengths = [] |
| | for n, r in zip(max_num, rep_num): |
| | d = list(range(n)) * r |
| | d = [n, r] + d |
| | d = d[:max_len] |
| | data.append(d) |
| | lengths.append(len(d)) |
| | self.data = data |
| | self.max_len = max_len |
| | self.lengths = np.array(lengths, dtype=int) |
| |
|
| | def __getitem__(self, index): |
| | d = self.data[index] |
| | input_ids = np.array(d, dtype=int) |
| | return {"tokens": list(input_ids), "type_id": 0} |
| |
|
| | def get_dataset_name(self): |
| | return "dummy_path/dummy_lang/dummy_ds/train.bin" |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|