slm-tiny-stories / dataloader.py
Eric Houzelle
Initial commit
c64cf6f
from torch.utils.data import Dataset
import torch
# class TinyLLMDataset(Dataset):
# def __init__(self, texts, block_size, encode):
# self.data = [torch.tensor(encode(t[:block_size]), dtype=torch.long) # Pour chaque ligne dans texts alors on encode avec le tokenizer la ligne si le texte est de plus de 10 caractères
# for t in texts if len(t) > 10]
# def __len__(self):
# return len(self.data)
# def __getitem__(self, idx):
# x = self.data[idx][:-1] # si idx = 0 on récupère la premiere phrase encodée sans le dernier paramètre
# y = self.data[idx][1:] # si idx = 0 on récupère la premiere phrase encodée sans le premier paramètre
# return x, y
class TinyLLMDataset(Dataset):
def __init__(self, texts, block_size, encode):
self.data = []
for t in texts:
if len(t) > 10:
words = t.split()
for i in range(0, len(words) - block_size):
segment = " ".join(words[i : i + block_size])
encoded = torch.tensor(encode(segment), dtype=torch.long)
self.data.append(encoded)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx][:-1]
y = self.data[idx][1:]
return x, y