Harsh72AI commited on
Commit
deb67e9
·
1 Parent(s): edde255

Uploaded Project Files

Browse files
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+
6
+ from PIL import Image
7
+ from typing import Tuple, Dict, List
8
+ from timeit import default_timer as timer
9
+
10
+ from model import create_vit_b_16_swag
11
+
12
+ class_names = ['Pizza', 'Steak', 'Sushi']
13
+
14
+ # Creating new instance of saved model's architecture and pre-trained model data transformation pipeline
15
+ vit_swag_model, vit_swag_transforms = create_vit_b_16_swag(num_classes=len(class_names))
16
+
17
+ # Load weights from trained and saved model
18
+ vit_swag_model.load_state_dict(torch.load('foodvision_mini_vit_swag_model.pt',
19
+ map_location=torch.device('cpu')))
20
+
21
+
22
+ # -------------- Model Predicting Function --------------
23
+
24
+ # Create Predicting Function
25
+ def predict(img) -> Tuple[Dict, float]:
26
+
27
+ # Start the timer
28
+ start_time = timer()
29
+
30
+ # Transform image
31
+ vit_swag_transformed_img = vit_swag_transforms(img)
32
+
33
+ # Making predictions with ViT SWAG model
34
+ vit_swag_model.eval()
35
+ with torch.inference_mode():
36
+ vit_swag_probs = torch.softmax(vit_swag_model(vit_swag_transformed_img.to("cpu").unsqueeze(dim=0)), dim=1)
37
+
38
+ pred_probs = {class_names[i]: float(vit_swag_probs[0][i]) for i in range(len(vit_swag_probs[0]))}
39
+
40
+ # Calculate the prediction time
41
+ pred_time = round(timer() - start_time, 5)
42
+
43
+ return pred_probs, pred_time
44
+
45
+
46
+ # -------------- Building Gradio App --------------
47
+
48
+ # Create title, description and article strings
49
+ title = "FoodVision Mini 🍕🥩🍣"
50
+ description = "A ViT (Vision Transformer) SWAG weighted feature extractor computer vision model to classify images of food as pizza, steak or sushi."
51
+ article = "Created by Harsh Singh [-Github-](https://github.com/HarshSingh2009/)"
52
+
53
+ example_list = example_list = ['example-pizza_img.jpeg', 'example-steak-img.jpeg', 'example-sushi-img.jpeg']
54
+
55
+ # Create the Gradio demo
56
+ demo = gr.Interface(fn=predict, # mapping function from input to output
57
+ inputs=gr.Image(type="pil"), # what are the inputs?
58
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
59
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
60
+ examples=example_list,
61
+ title=title,
62
+ description=description,
63
+ article=article)
64
+
65
+ # Launch the demo!
66
+ demo.launch()
example-pizza_img.jpeg ADDED
example-steak-img.jpeg ADDED
example-sushi-img.jpeg ADDED
foodvision_mini_vit_swag_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21120fb5ccf7e768de4b8b51629032b45f93d882da6df2384e5951ee6669afdd
3
+ size 344435830
model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Creates ViT pre-trained base model with SWAG weights
2
+
3
+ import torch
4
+ import torchvision
5
+
6
+ from torch import nn
7
+
8
+ def create_vit_b_16_swag(num_classes: int = 1000):
9
+ """
10
+ Creates ViT SWAG pre-trained base model from torchvision.models
11
+
12
+ Args:
13
+ num_clases: int = 1000 - Number of classes in data.
14
+
15
+ Returns:
16
+ model: torch.nn.Module - Pre-trained ViT SWAG base model.
17
+ transforms: torchvision.transforms._presets.ImageClassification - Data Transformation Pipeline required by pre-trained model.
18
+ """
19
+
20
+ # Get ViT weights and data transformation pipeline
21
+ model_weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
22
+ model_transforms = model_weights.transforms()
23
+
24
+ # Load in ViT Base model with patch size 16
25
+ model = torchvision.models.vit_b_16(weights=model_weights)
26
+
27
+ # Freezing all layer's parameters and then unfreezing only the classifier
28
+ for param_swag in model.parameters():
29
+ param_swag.requires_grad = False
30
+
31
+ for param_swag in model.heads.parameters():
32
+ param_swag.requires_grad = True
33
+
34
+ # custom classifier
35
+ model.heads = torch.nn.Sequential(
36
+ nn.Linear(in_features=768, out_features=num_classes, bias=True)
37
+ )
38
+
39
+ return model, model_transforms
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ gradio==4.7.1