--- license: apache-2.0 tags: - pytorch - lightning - cifar10 --- # CIFAR10 Classifier trained with PyTorch Lightning ## Introduction A ResNet18 model that achieves 94% prediction accuracy. Key features include: 1. Data normalization and randomization (10% improvement) 2. Dropout before FC classifier (1% improvement) 3. Batch normalization in ResNetBlock (2% improvement) 4. Cos learning rate schedule (1% improvement) 5. ResNet18 is deeper than a simple CNN network. ## Usage ### Approach 1: use pytorch to predict ```python ## Approach 1: use pytorch to predict import torch from model import CIFARCNN # Evaluate model checkpoints model = CIFARCNN.load_from_checkpoint("model.ckpt") model.eval() x = torch.randn(4, 3, 32, 32).to(model.device) with torch.no_grad(): predictions = model(x) # the lightning module should implement forward func print(predictions.shape) # should be [4, 10] ``` ### Approach 2: use Lightning to predict ```py import torch from model import CIFARCNN from lightning import Trainer test_dataloader = DataLoader(...) model = CIFARCNN.load_from_checkpoint("model.ckpt") # lightning will move model to default device trainer = Trainer() trainer.test(model, test_dataloader) ``` ### Visualize results ```py import matplotlib.pyplot as plt cifar10_labels = { 0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck", } samples, labels = next(iter(train_loader)) predicts = trainer.predict(model, samples) labels = predicts.argmax(dim=1) fig, axes = plt.subplots(2, 5, figsize=(10, 4)) for i, ax in enumerate(axes.flatten()): ax.imshow(samples[i].permute(1, 2, 0)) ax.set_title(f"{cifar10_labels[labels[i].item()]}") ax.axis("off") plt.show() ```