File size: 440 Bytes
04e4b39
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, ids, block_size):
        self.ids = ids
        self.block = block_size
    def __len__(self):
        return max(1, len(self.ids) - self.block)
    def __getitem__(self, i):
        x = self.ids[i:i+self.block]
        y = self.ids[i+1:i+self.block+1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)