from datasets import load_dataset import numpy as np import torch from util import Config, GetDevice class Dataset: def __init__(self, config: Config): self.__dict__ = dict(config.__dict__) self.dataset = load_dataset(self.remote_path) self.text = ''.join(s for s in self.dataset['train']['text']).encode('ascii', 'ignore').decode('ascii') #def __iadd__(self, args): # name, value = args # setattr(self, name, value) # return self def batch(self, ids): if not isinstance(ids, np.ndarray): ids = np.array(ids) num_batches = len(ids) // (self.seq_length * self.batch_size) total_elements = num_batches * self.seq_length * self.batch_size trimmed_array = ids[:total_elements] array_reshaped = trimmed_array.reshape((num_batches, self.batch_size, self.seq_length)) batches = [] for batch in array_reshaped: tensor_batch = torch.tensor(batch, dtype=torch.long).to(GetDevice()) batches.append(tensor_batch) return batches, num_batches