Physics-Tutor-Model / train /data_utils.py
adityashisharma's picture
Upload 6 files
04e4b39 verified
import torch
from torch.utils.data import Dataset
class TextDataset(Dataset):
def __init__(self, ids, block_size):
self.ids = ids
self.block = block_size
def __len__(self):
return max(1, len(self.ids) - self.block)
def __getitem__(self, i):
x = self.ids[i:i+self.block]
y = self.ids[i+1:i+self.block+1]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)