Spaces:
Running
Running
| import torch | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader | |
| def get_mnist_loaders(batch_size=64): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) | |
| test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| return train_loader, test_loader | |
| def get_fashion_mnist_loaders(batch_size=64): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform) | |
| test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| return train_loader, test_loader | |
| def get_imdb_loaders(batch_size=64, max_len=256, vocab_size=10000): | |
| from torchtext.datasets import IMDB | |
| from torchtext.data.utils import get_tokenizer | |
| from torchtext.vocab import build_vocab_from_iterator | |
| from torch.utils.data import DataLoader, Dataset | |
| import torch.nn.utils.rnn as rnn_utils | |
| tokenizer = get_tokenizer("basic_english") | |
| train_iter = IMDB(split='train') | |
| def yield_tokens(data_iter): | |
| for _, text in data_iter: | |
| yield tokenizer(text) | |
| vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"]) | |
| vocab.set_default_index(vocab["<unk>"]) | |
| def text_pipeline(text): | |
| return vocab(tokenizer(text)) | |
| class IMDBDataset(Dataset): | |
| def __init__(self, split): | |
| self.data = list(IMDB(split=split)) | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| label, text = self.data[idx] | |
| # Convert label: 1 (neg), 2 (pos) -> 0, 1 | |
| label = 0 if label == 1 else 1 | |
| tokens = text_pipeline(text)[:self.max_len] | |
| # Padding | |
| if len(tokens) < self.max_len: | |
| tokens += [vocab["<pad>"]] * (self.max_len - len(tokens)) | |
| return torch.tensor(tokens), torch.tensor(label) | |
| train_dataset = IMDBDataset('train') | |
| test_dataset = IMDBDataset('test') | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| return train_loader, test_loader, len(vocab) | |