File size: 440 Bytes
04e4b39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
from torch.utils.data import Dataset
class TextDataset(Dataset):
def __init__(self, ids, block_size):
self.ids = ids
self.block = block_size
def __len__(self):
return max(1, len(self.ids) - self.block)
def __getitem__(self, i):
x = self.ids[i:i+self.block]
y = self.ids[i+1:i+self.block+1]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
|