bert_imdb_sentiment / data /imdb_data.py
ikkbor's picture
Upload folder using huggingface_hub
fba3401 verified
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
class IMDBDataset(Dataset):
def __init__(self, split, tokenizer, max_length=256):
print(f"Loading IMDB {split} dataset...")
self.dataset = load_dataset("imdb")[split]
print(f"IMDB {split} loaded.")
self.encodings = tokenizer(
self.dataset["text"],
truncation=True,
padding=True,
max_length=max_length
)
self.labels = self.dataset["label"]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return {
"input_ids": torch.tensor(self.encodings["input_ids"][idx], dtype=torch.long),
"attention_mask": torch.tensor(self.encodings["attention_mask"][idx], dtype=torch.long),
"labels": torch.tensor(self.labels[idx], dtype=torch.long)
}