File size: 1,812 Bytes
020580e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
---
license: apache-2.0
tags:
- pytorch
- lightning
- cifar10
---
# CIFAR10 Classifier trained with PyTorch Lightning
## Introduction
A ResNet18 model that achieves 94% prediction accuracy. Key features include:
1. Data normalization and randomization (10% improvement)
2. Dropout before FC classifier (1% improvement)
3. Batch normalization in ResNetBlock (2% improvement)
4. Cos learning rate schedule (1% improvement)
5. ResNet18 is deeper than a simple CNN network.
## Usage
### Approach 1: use pytorch to predict
```python
## Approach 1: use pytorch to predict
import torch
from model import CIFARCNN
# Evaluate model checkpoints
model = CIFARCNN.load_from_checkpoint("model.ckpt")
model.eval()
x = torch.randn(4, 3, 32, 32).to(model.device)
with torch.no_grad():
predictions = model(x) # the lightning module should implement forward func
print(predictions.shape) # should be [4, 10]
```
### Approach 2: use Lightning to predict
```py
import torch
from model import CIFARCNN
from lightning import Trainer
test_dataloader = DataLoader(...)
model = CIFARCNN.load_from_checkpoint("model.ckpt") # lightning will move model to default device
trainer = Trainer()
trainer.test(model, test_dataloader)
```
### Visualize results
```py
import matplotlib.pyplot as plt
cifar10_labels = {
0: "airplane",
1: "automobile",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck",
}
samples, labels = next(iter(train_loader))
predicts = trainer.predict(model, samples)
labels = predicts.argmax(dim=1)
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flatten()):
ax.imshow(samples[i].permute(1, 2, 0))
ax.set_title(f"{cifar10_labels[labels[i].item()]}")
ax.axis("off")
plt.show()
``` |