File size: 931 Bytes
fba3401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch.utils.data import Dataset
from datasets import load_dataset


class IMDBDataset(Dataset):
    def __init__(self, split, tokenizer, max_length=256):
        print(f"Loading IMDB {split} dataset...")
        self.dataset = load_dataset("imdb")[split]
        print(f"IMDB {split} loaded.")
        self.encodings = tokenizer(
            self.dataset["text"],
            truncation=True,
            padding=True,
            max_length=max_length
        )
        self.labels = self.dataset["label"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.encodings["input_ids"][idx], dtype=torch.long),
            "attention_mask": torch.tensor(self.encodings["attention_mask"][idx], dtype=torch.long),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }