noelsinghsr commited on
Commit
3c94f3b
·
1 Parent(s): 0a7c008

initial commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ cifar10_effnet_sm_uncompiled.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+
5
+ from model import load_model
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple, Dict
8
+
9
+ # class names
10
+ class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
11
+
12
+ model, transform = load_model()
13
+
14
+ # predict function
15
+ def predict(img):
16
+
17
+ start_time = timer()
18
+
19
+ img = transform(img).unsqueeze(0)
20
+
21
+ model.eval()
22
+ with torch.inference_mode():
23
+ pred_probs = torch.softmax(model(img), dim=1)
24
+
25
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
26
+
27
+ end_time = timer()
28
+
29
+ pred_time = round(end_time - start_time, 4)
30
+
31
+ return pred_labels_and_probs, pred_time
32
+
33
+
34
+
35
+ title = "Noel's Cifar10 - Efficinet Computer Vision Model (PyTorch)"
36
+
37
+ description = "An EfficientNetB0 feature extractor computer vision model to classify Cifar10 dataset"
38
+
39
+ article = "Created in SageMaker Studio"
40
+
41
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
42
+
43
+ # Gradio app
44
+ demo = gr.Interface(fn=predict,
45
+ inputs=gr.Image(type="pil"),
46
+ outputs=[gr.Label(num_top_classes=10, label="Predictions"),
47
+ gr.Number(label="Prediction time (s)")],
48
+ examples=example_list,
49
+ title=title,
50
+ description=description,
51
+ article=article)
cifar10_effnet_sm_uncompiled.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7dab88e7985cf24151df6d2ea73847ea63d1979484a503d40789fad33642990
3
+ size 31400415
examples/dog.jpg ADDED
examples/plane2.jpg ADDED
examples/truck.jpg ADDED
model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+
5
+ def load_model():
6
+
7
+ loaded_model = torch.load('demo/cifar10/cifar10_effnet_sm_uncompiled.pth', map_location=torch.device('cpu'))
8
+ model_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
9
+ transforms = model_weights.transforms()
10
+
11
+ return loaded_model, transforms
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio