import torch from torch.utils.data import Dataset class TextDataset(Dataset): def __init__(self, ids, block_size): self.ids = ids self.block = block_size def __len__(self): return max(1, len(self.ids) - self.block) def __getitem__(self, i): x = self.ids[i:i+self.block] y = self.ids[i+1:i+self.block+1] return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)