File size: 1,098 Bytes
fd0332a 7f3b4e4 fd0332a 52a257c fd0332a 52a257c 1907275 52a257c 6a41ff7 52a257c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from datasets import load_dataset
import numpy as np
import torch
from util import Config, GetDevice
class Dataset:
def __init__(self, config: Config):
self.__dict__ = dict(config.__dict__)
self.dataset = load_dataset(self.remote_path)
self.text = ''.join(s for s in self.dataset['train']['text']).encode('ascii', 'ignore').decode('ascii')
#def __iadd__(self, args):
# name, value = args
# setattr(self, name, value)
# return self
def batch(self, ids):
if not isinstance(ids, np.ndarray):
ids = np.array(ids)
num_batches = len(ids) // (self.seq_length * self.batch_size)
total_elements = num_batches * self.seq_length * self.batch_size
trimmed_array = ids[:total_elements]
array_reshaped = trimmed_array.reshape((num_batches, self.batch_size, self.seq_length))
batches = []
for batch in array_reshaped:
tensor_batch = torch.tensor(batch, dtype=torch.long).to(GetDevice())
batches.append(tensor_batch)
return batches, num_batches
|