ErdemAtak commited on
Commit
7e5ec66
·
verified ·
1 Parent(s): ebd529b

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +74 -0
  2. model.py +24 -0
  3. requirements.txt +3 -0
  4. vit_cifar10_state_dict.pth +3 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### 1. Imports and class names setup ###
3
+
4
+ # Imports
5
+ import gradio as gr
6
+ import os
7
+ import torch
8
+
9
+ from model import create_vit_model
10
+ from timeit import default_timer as timer
11
+ from typing import Tuple, Dict
12
+
13
+ # Setup class names
14
+ class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
15
+
16
+ ### 2. Model and Transforms Preparation ###
17
+ vit_model, vit_transforms = create_vit_model(num_classes = 10)
18
+
19
+ # Load save weights
20
+ vit_model.load_state_dict(
21
+ torch.load(
22
+ f="vit_cifar10_state_dict.pth",
23
+ map_location=torch.device("cpu") # load the model to the cpu
24
+ )
25
+ )
26
+
27
+ ### 3. Predict function ###
28
+ def predict(img) -> Tuple[Dict, float]:
29
+ # Timer
30
+ start_time = timer()
31
+
32
+ # Transform the input image to work with ViT
33
+ img = vit_transforms(img).unsqueeze(0) # unsqueeze = add batch dimension on 0th index
34
+
35
+ # Eval mode and torch inference mode on
36
+ vit_model.eval()
37
+ with torch.inference_mode():
38
+ # Pass transformed image through the model and turn prediction logits into probabilities
39
+ pred_probs = torch.softmax(vit_model(img), dim = 1)
40
+
41
+ # Create prediction label and prediction probability dictionary
42
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
43
+
44
+ # Calculate prediction time
45
+ end_time = timer()
46
+ pred_time = round(end_time - start_time, 3)
47
+
48
+ # Return pred dict and pred time
49
+ return pred_labels_and_probs, pred_time
50
+
51
+ ### 4. Gradio app ###
52
+
53
+ # Create title for the gradio
54
+ title = "Object Classifier - Erdem Atak Version"
55
+ description = "ViT computer vision model to classify CIFAR-10 objects"
56
+ article = "PyTorch Model Deployment"
57
+
58
+ # Create example list
59
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
60
+
61
+ # Create the gradio demo
62
+ demo = gr.Interface(fn = predict, # it maps inputs to outputs
63
+ inputs = gr.Image(type = "pil"),
64
+ outputs = [gr.Label(num_top_classes = 3,
65
+ label = "Predictions"),
66
+ gr.Number(label = "Prediction Time (s)")],
67
+ examples = example_list, # example list above
68
+ title = title,
69
+ description = description,
70
+ article = article)
71
+
72
+ # launch the demo
73
+ demo.launch(debug = False,
74
+ share = True ) # public shareable URL
model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # import the needed dependencies for this py file
3
+ import torch
4
+ import torchvision
5
+
6
+ from torch import nn
7
+
8
+ #Turning train as a function as 1. setup, 2 getr transforms 3. setup model instance 4. Freeze base layers and adjust output layers
9
+ def create_vit_model(num_classes: int = 10,
10
+ seed: int = 24):
11
+ # 1, 2, 3, Create ViT weights, transform and model
12
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
13
+ transforms = weights.transforms()
14
+ model = torchvision.models.vit_b_16(weights=weights)
15
+
16
+ # 4. Freeze the base layers
17
+ for param in model.parameters()
18
+ param.requires_grad = False
19
+
20
+ # 5. Adjust the number of heads(output)
21
+ torch.manual_seed(seed)
22
+ model.heads = nn.Sequential(nn.Linear(in_features=768, out_features=num_classes, bias=True))
23
+
24
+ return model, transforms
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
vit_cifar10_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b92b218af10764547dbcb75b3bf32681970fc48bf5dc14339a3537811c5403c
3
+ size 343288529