|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- pytorch |
|
|
- lightning |
|
|
- cifar10 |
|
|
--- |
|
|
|
|
|
# CIFAR10 Classifier trained with PyTorch Lightning |
|
|
|
|
|
## Introduction |
|
|
|
|
|
A ResNet18 model that achieves 94% prediction accuracy. Key features include: |
|
|
|
|
|
1. Data normalization and randomization (10% improvement) |
|
|
2. Dropout before FC classifier (1% improvement) |
|
|
3. Batch normalization in ResNetBlock (2% improvement) |
|
|
4. Cos learning rate schedule (1% improvement) |
|
|
5. ResNet18 is deeper than a simple CNN network. |
|
|
|
|
|
|
|
|
|
|
|
## Usage |
|
|
|
|
|
### Approach 1: use pytorch to predict |
|
|
```python |
|
|
|
|
|
## Approach 1: use pytorch to predict |
|
|
import torch |
|
|
from model import CIFARCNN |
|
|
|
|
|
# Evaluate model checkpoints |
|
|
model = CIFARCNN.load_from_checkpoint("model.ckpt") |
|
|
model.eval() |
|
|
x = torch.randn(4, 3, 32, 32).to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
predictions = model(x) # the lightning module should implement forward func |
|
|
print(predictions.shape) # should be [4, 10] |
|
|
``` |
|
|
|
|
|
|
|
|
### Approach 2: use Lightning to predict |
|
|
```py |
|
|
import torch |
|
|
from model import CIFARCNN |
|
|
from lightning import Trainer |
|
|
|
|
|
test_dataloader = DataLoader(...) |
|
|
model = CIFARCNN.load_from_checkpoint("model.ckpt") # lightning will move model to default device |
|
|
trainer = Trainer() |
|
|
|
|
|
trainer.test(model, test_dataloader) |
|
|
``` |
|
|
|
|
|
### Visualize results |
|
|
```py |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
cifar10_labels = { |
|
|
0: "airplane", |
|
|
1: "automobile", |
|
|
2: "bird", |
|
|
3: "cat", |
|
|
4: "deer", |
|
|
5: "dog", |
|
|
6: "frog", |
|
|
7: "horse", |
|
|
8: "ship", |
|
|
9: "truck", |
|
|
} |
|
|
|
|
|
samples, labels = next(iter(train_loader)) |
|
|
predicts = trainer.predict(model, samples) |
|
|
labels = predicts.argmax(dim=1) |
|
|
|
|
|
fig, axes = plt.subplots(2, 5, figsize=(10, 4)) |
|
|
for i, ax in enumerate(axes.flatten()): |
|
|
ax.imshow(samples[i].permute(1, 2, 0)) |
|
|
ax.set_title(f"{cifar10_labels[labels[i].item()]}") |
|
|
ax.axis("off") |
|
|
plt.show() |
|
|
``` |