import torch class EssayDataset(torch.utils.data.Dataset): def __init__(self, dataframe, tokenizer, max_length): self.data = dataframe self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __getitem__(self, idx): text = self.data.iloc[idx]['train_input'] labels = self.data.iloc[idx]['labels'] encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(labels, dtype=torch.float) }