| import torch | |
| from torch.utils.data import Dataset | |
| class TextDataset(Dataset): | |
| def __init__(self, ids, block_size): | |
| self.ids = ids | |
| self.block = block_size | |
| def __len__(self): | |
| return max(1, len(self.ids) - self.block) | |
| def __getitem__(self, i): | |
| x = self.ids[i:i+self.block] | |
| y = self.ids[i+1:i+self.block+1] | |
| return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) | |