turtlemb commited on
Commit
ca3a209
·
1 Parent(s): 55b2628

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 09_pretrained_vit_feature_extractor_pizza_steak_sushi_20_percent.pth filter=lfs diff=lfs merge=lfs -text
09_pretrained_vit_feature_extractor_pizza_steak_sushi_20_percent.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3363d8b4126848fe1cc4b6ac1eef130ec2de18da1221b330428e1d44901ec6b0
3
+ size 343271805
__pycache__/model.cpython-312.pyc ADDED
Binary file (1.17 kB). View file
 
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_vit_model
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # Setup class names
11
+ class_names = ["pizza", "steak", "sushi"]
12
+
13
+ ### 2. Model and transforms preparation ###
14
+ vit, vit_transforms = create_vit_model(num_classes = 3)
15
+
16
+ # Load saved weights
17
+ vit.load_state_dict(
18
+ torch.load(
19
+ f = "09_pretrained_vit_feature_extractor_pizza_steak_sushi_20_percent.pth",
20
+ map_location = torch.device("cpu") # load the model to the CPU
21
+ )
22
+ )
23
+
24
+ ### 3. Predict function ###
25
+
26
+ def predict(img) -> Tuple[Dict, float]:
27
+ # Start a timer
28
+ start_time = timer()
29
+
30
+ # Transform the input image for use with ViT
31
+ img = vit_transforms(img).unsqueeze(0) # unsqueeze = add batch dimension on 0th index
32
+
33
+ # Put model into eval mode, make prediction
34
+ vit.eval()
35
+ with torch.inference_mode():
36
+ # Pass transformed image through the model and turn the prediction logits into probabilities
37
+ pred_probs = torch.softmax(vit(img), dim = 1)
38
+
39
+ # Create a prediction label and prediction probability dictionary
40
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
41
+
42
+ # Calculate pred time
43
+ end_time = timer()
44
+ pred_time = round(end_time - start_time, 4)
45
+
46
+ # Return pred dict and pred time
47
+ return pred_labels_and_probs, pred_time
48
+
49
+ ### 4. Gradio app ###
50
+
51
+ # Create title, description, and article
52
+ title = "FoodVision Mini 🍕🥩🍣"
53
+ description = "A [ViT transformer feature extractor](https://docs.pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#vit-b-16) computer vision model to classify images as pizza, steak, or sushi."
54
+ article = "Created at [turtlemb's GitHub](https://github.com/turtlemb)."
55
+
56
+ # Create example list
57
+ example_list = [["examples/" + example] for example in os.listdir(examples)]
58
+
59
+ # Create the Gradio demo
60
+ demo = gr.Interface(fn = predict, # maps inputs to outputs
61
+ inputs = gr.Image(type = "pil"),
62
+ outputs = [gr.Label(num_top_classes = 3, label = "Predictions"),
63
+ gr.Number(label = "Prediction time (s)")],
64
+ examples = example_list,
65
+ title = title,
66
+ description = description,
67
+ article = article)
68
+
69
+ # Launch the demo
70
+ demo.launch(debug = False, # print errors locally?
71
+ share = True) # generate a publicly shareable URL
examples/.ipynb_checkpoints/592799-checkpoint.jpg ADDED
examples/2582289.jpg ADDED
examples/3622237.jpg ADDED
examples/592799.jpg ADDED
model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torchvision
4
+
5
+ from torch import nn
6
+
7
+ def create_vit_model(num_classes: int = 3,
8
+ seed: int = 42):
9
+ # Create ViT_B_16 pre-trained weights, transforms and model
10
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
11
+ transforms = weights.transforms()
12
+ model = torchvision.models.vit_b_16(weights = weights)
13
+
14
+ # Freeze all of the base layers
15
+ for param in model.parameters():
16
+ param.requires_grad = False
17
+
18
+ # Change classifier head to suit our needs
19
+ torch.manual_seed(seed)
20
+ model.heads = nn.Sequential(nn.Linear(in_features = 768,
21
+ out_features = num_classes))
22
+
23
+ return model, transforms
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ torch==2.7.1
3
+ torchvision==0.22.1
4
+ gradio==6.9.0