buffaX commited on
Commit
020580e
·
verified ·
1 Parent(s): a5ba72b

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +83 -0
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - pytorch
5
+ - lightning
6
+ - cifar10
7
+ ---
8
+
9
+ # CIFAR10 Classifier trained with PyTorch Lightning
10
+
11
+ ## Introduction
12
+
13
+ A ResNet18 model that achieves 94% prediction accuracy. Key features include:
14
+
15
+ 1. Data normalization and randomization (10% improvement)
16
+ 2. Dropout before FC classifier (1% improvement)
17
+ 3. Batch normalization in ResNetBlock (2% improvement)
18
+ 4. Cos learning rate schedule (1% improvement)
19
+ 5. ResNet18 is deeper than a simple CNN network.
20
+
21
+
22
+
23
+ ## Usage
24
+
25
+ ### Approach 1: use pytorch to predict
26
+ ```python
27
+
28
+ ## Approach 1: use pytorch to predict
29
+ import torch
30
+ from model import CIFARCNN
31
+
32
+ # Evaluate model checkpoints
33
+ model = CIFARCNN.load_from_checkpoint("model.ckpt")
34
+ model.eval()
35
+ x = torch.randn(4, 3, 32, 32).to(model.device)
36
+
37
+ with torch.no_grad():
38
+ predictions = model(x) # the lightning module should implement forward func
39
+ print(predictions.shape) # should be [4, 10]
40
+ ```
41
+
42
+
43
+ ### Approach 2: use Lightning to predict
44
+ ```py
45
+ import torch
46
+ from model import CIFARCNN
47
+ from lightning import Trainer
48
+
49
+ test_dataloader = DataLoader(...)
50
+ model = CIFARCNN.load_from_checkpoint("model.ckpt") # lightning will move model to default device
51
+ trainer = Trainer()
52
+
53
+ trainer.test(model, test_dataloader)
54
+ ```
55
+
56
+ ### Visualize results
57
+ ```py
58
+ import matplotlib.pyplot as plt
59
+
60
+ cifar10_labels = {
61
+ 0: "airplane",
62
+ 1: "automobile",
63
+ 2: "bird",
64
+ 3: "cat",
65
+ 4: "deer",
66
+ 5: "dog",
67
+ 6: "frog",
68
+ 7: "horse",
69
+ 8: "ship",
70
+ 9: "truck",
71
+ }
72
+
73
+ samples, labels = next(iter(train_loader))
74
+ predicts = trainer.predict(model, samples)
75
+ labels = predicts.argmax(dim=1)
76
+
77
+ fig, axes = plt.subplots(2, 5, figsize=(10, 4))
78
+ for i, ax in enumerate(axes.flatten()):
79
+ ax.imshow(samples[i].permute(1, 2, 0))
80
+ ax.set_title(f"{cifar10_labels[labels[i].item()]}")
81
+ ax.axis("off")
82
+ plt.show()
83
+ ```