ID56 commited on
Commit
8b0eda0
·
1 Parent(s): dfa0ec6

Updated card to also include usage example.

Browse files
Files changed (1) hide show
  1. README.md +52 -1
README.md CHANGED
@@ -7,6 +7,7 @@ datasets:
7
  - cifar10
8
  metrics:
9
  - accuracy
 
10
  ---
11
 
12
  # CIFAR-10 Upside Down Classifier
@@ -15,4 +16,54 @@ For the Fatima Fellowship 2022 Coding Challenge, DL for Vision track.
15
 
16
  <a href="https://wandb.ai/dealer56/cifar-updown-classifier/reports/CIFAR-10-Upside-Down-Classifier-Fatima-Fellowship-2022-Coding-Challenge-Vision---VmlldzoxODA2MDE4" target="_parent"><img src="https://img.shields.io/badge/weights-%26biases-ffcf40" alt="W&B Report"/></a>
17
 
18
- <img src="https://huggingface.co/ID56/FF-Vision-CIFAR/resolve/main/assets/cover_image.png" alt="Cover Image" width="800"/>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  - cifar10
8
  metrics:
9
  - accuracy
10
+ inference: false
11
  ---
12
 
13
  # CIFAR-10 Upside Down Classifier
 
16
 
17
  <a href="https://wandb.ai/dealer56/cifar-updown-classifier/reports/CIFAR-10-Upside-Down-Classifier-Fatima-Fellowship-2022-Coding-Challenge-Vision---VmlldzoxODA2MDE4" target="_parent"><img src="https://img.shields.io/badge/weights-%26biases-ffcf40" alt="W&B Report"/></a>
18
 
19
+ <img src="https://huggingface.co/ID56/FF-Vision-CIFAR/resolve/main/assets/cover_image.png" alt="Cover Image" width="800"/>
20
+
21
+ ## Usage
22
+
23
+ ### Model Definition
24
+
25
+ ```python
26
+ from torch import nn
27
+ import timm
28
+ from huggingface_hub import PyTorchModelHubMixin
29
+
30
+
31
+ class UpDownEfficientNetB0(nn.Module, PyTorchModelHubMixin):
32
+ """A simple Hub Mixin wrapper for timm EfficientNet-B0. Used to classify whether an image is upright or flipped down, on CIFAR-10."""
33
+
34
+ def __init__(self, **kwargs):
35
+ super().__init__()
36
+ self.base_model = timm.create_model('efficientnet_b0', num_classes=1, drop_rate=0.2, drop_path_rate=0.2)
37
+ self.config = kwargs.pop("config", None)
38
+
39
+ def forward(self, input):
40
+ return self.base_model(input)
41
+ ```
42
+ ### Loading the Model from Hub
43
+
44
+ ```python
45
+ net = UpDownEfficientNetB0.from_pretrained("ID56/FF-Vision-CIFAR")
46
+ ```
47
+
48
+ ### Running Inference
49
+
50
+ ```python
51
+ from torchvision import transforms
52
+
53
+ CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
54
+ CIFAR_STD = (0.247, 0.243, 0.261)
55
+
56
+ transform = transforms.Compose([
57
+ transforms.Resize(40, 40),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
60
+ ])
61
+
62
+ image = load_some_image() # Load some PIL Image or uint8 HWC image array
63
+ image = transform(image) # Convert to CHW image tensor
64
+ image = image.unsqueeze(0) # Add batch dimension
65
+
66
+ net.eval()
67
+
68
+ pred = net(image)
69
+ ```