|
|
from datasets import load_dataset |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from util import Config, GetDevice |
|
|
|
|
|
|
|
|
class Dataset: |
|
|
def __init__(self, config: Config): |
|
|
self.__dict__ = dict(config.__dict__) |
|
|
|
|
|
self.dataset = load_dataset(self.remote_path) |
|
|
self.text = ''.join(s for s in self.dataset['train']['text']).encode('ascii', 'ignore').decode('ascii') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch(self, ids): |
|
|
if not isinstance(ids, np.ndarray): |
|
|
ids = np.array(ids) |
|
|
|
|
|
num_batches = len(ids) // (self.seq_length * self.batch_size) |
|
|
total_elements = num_batches * self.seq_length * self.batch_size |
|
|
|
|
|
trimmed_array = ids[:total_elements] |
|
|
array_reshaped = trimmed_array.reshape((num_batches, self.batch_size, self.seq_length)) |
|
|
|
|
|
batches = [] |
|
|
for batch in array_reshaped: |
|
|
tensor_batch = torch.tensor(batch, dtype=torch.long).to(GetDevice()) |
|
|
batches.append(tensor_batch) |
|
|
|
|
|
return batches, num_batches |
|
|
|