mamba / dataset.py
flpelerin's picture
Update 3 files
6a41ff7
from datasets import load_dataset
import numpy as np
import torch
from util import Config, GetDevice
class Dataset:
def __init__(self, config: Config):
self.__dict__ = dict(config.__dict__)
self.dataset = load_dataset(self.remote_path)
self.text = ''.join(s for s in self.dataset['train']['text']).encode('ascii', 'ignore').decode('ascii')
#def __iadd__(self, args):
# name, value = args
# setattr(self, name, value)
# return self
def batch(self, ids):
if not isinstance(ids, np.ndarray):
ids = np.array(ids)
num_batches = len(ids) // (self.seq_length * self.batch_size)
total_elements = num_batches * self.seq_length * self.batch_size
trimmed_array = ids[:total_elements]
array_reshaped = trimmed_array.reshape((num_batches, self.batch_size, self.seq_length))
batches = []
for batch in array_reshaped:
tensor_batch = torch.tensor(batch, dtype=torch.long).to(GetDevice())
batches.append(tensor_batch)
return batches, num_batches