import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader def get_mnist_loaders(batch_size=64): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) return train_loader, test_loader def get_fashion_mnist_loaders(batch_size=64): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) return train_loader, test_loader def get_imdb_loaders(batch_size=64, max_len=256, vocab_size=10000): from torchtext.datasets import IMDB from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator from torch.utils.data import DataLoader, Dataset import torch.nn.utils.rnn as rnn_utils tokenizer = get_tokenizer("basic_english") train_iter = IMDB(split='train') def yield_tokens(data_iter): for _, text in data_iter: yield tokenizer(text) vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["", ""]) vocab.set_default_index(vocab[""]) def text_pipeline(text): return vocab(tokenizer(text)) class IMDBDataset(Dataset): def __init__(self, split): self.data = list(IMDB(split=split)) self.max_len = max_len def __len__(self): return len(self.data) def __getitem__(self, idx): label, text = self.data[idx] # Convert label: 1 (neg), 2 (pos) -> 0, 1 label = 0 if label == 1 else 1 tokens = text_pipeline(text)[:self.max_len] # Padding if len(tokens) < self.max_len: tokens += [vocab[""]] * (self.max_len - len(tokens)) return torch.tensor(tokens), torch.tensor(label) train_dataset = IMDBDataset('train') test_dataset = IMDBDataset('test') train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) return train_loader, test_loader, len(vocab)