import torch from torch.utils.data import Dataset from datasets import load_dataset class IMDBDataset(Dataset): def __init__(self, split, tokenizer, max_length=256): print(f"Loading IMDB {split} dataset...") self.dataset = load_dataset("imdb")[split] print(f"IMDB {split} loaded.") self.encodings = tokenizer( self.dataset["text"], truncation=True, padding=True, max_length=max_length ) self.labels = self.dataset["label"] def __len__(self): return len(self.labels) def __getitem__(self, idx): return { "input_ids": torch.tensor(self.encodings["input_ids"][idx], dtype=torch.long), "attention_mask": torch.tensor(self.encodings["attention_mask"][idx], dtype=torch.long), "labels": torch.tensor(self.labels[idx], dtype=torch.long) }