File size: 1,812 Bytes
020580e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
---
license: apache-2.0
tags:
- pytorch
- lightning
- cifar10
---

# CIFAR10 Classifier trained with PyTorch Lightning

## Introduction

A ResNet18 model that achieves 94% prediction accuracy. Key features include:

1. Data normalization and randomization (10% improvement)
2. Dropout before FC classifier (1% improvement)
3. Batch normalization in ResNetBlock (2% improvement)
4. Cos learning rate schedule (1% improvement)
5. ResNet18 is deeper than a simple CNN network. 



## Usage

### Approach 1: use pytorch to predict
```python

## Approach 1: use pytorch to predict
import torch
from model import CIFARCNN

# Evaluate model checkpoints
model = CIFARCNN.load_from_checkpoint("model.ckpt")
model.eval()
x = torch.randn(4, 3, 32, 32).to(model.device)

with torch.no_grad():
    predictions = model(x)     # the lightning module should implement forward func
print(predictions.shape)  # should be [4, 10]
```


### Approach 2: use Lightning to predict
```py
import torch
from model import CIFARCNN
from lightning import Trainer

test_dataloader = DataLoader(...)
model = CIFARCNN.load_from_checkpoint("model.ckpt") # lightning will move model to default device
trainer = Trainer()

trainer.test(model, test_dataloader)
```

### Visualize results
```py
import matplotlib.pyplot as plt

cifar10_labels = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

samples, labels = next(iter(train_loader))
predicts = trainer.predict(model, samples)
labels = predicts.argmax(dim=1)

fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(samples[i].permute(1, 2, 0))
    ax.set_title(f"{cifar10_labels[labels[i].item()]}")
    ax.axis("off")
plt.show()
```