Spaces:
Runtime error
Update app.py
Browse filesPrevious
import gradio as gr
import os
import torch
from model import create_ViT
from timeit import default_timer as timer
from typing import Tuple, Dict
# Setup class names
with open("class_names.txt", "r") as f:
class_names = [food_name.strip() for food_name in f.readlines()]
# Create model
ViT_model, ViT_transforms = create_ViT(
num_classes=126,
)
# Load saved weights
ViT_model.load_state_dict(
torch.load(
f="ViT.pth",
map_location=torch.device("cpu"),
)
)
# Create predict function
def predict(img) -> Tuple[Dict, float]:
start_time = timer()
# Transform the target image and add a batch dimension
img = ViT_transforms(img).unsqueeze(0)
# Put model into evaluation mode and turn on inference mode
ViT_model.eval()
with torch.inference_mode():
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
pred_probs = torch.softmax(ViT_model(img), dim=1)
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred_labels_and_probs, pred_time
##GRADIO APP
# Create title, description and article strings
title = "FoodVision🍔🍟🍦"
description = "A Vision Transformer feature extractor computer vision model to classify images of food into 126 different classes."
article = "Created by [Rohit](https://github.com/ItsNotRohit02)."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5, label="Predictions"),
gr.Number(label="Prediction time (s)"),
],
examples=example_list,
title=title,
description=description,
article=article,
)
# Launch the app!
demo.launch()
|
@@ -12,40 +12,37 @@ with open("class_names.txt", "r") as f:
|
|
| 12 |
|
| 13 |
|
| 14 |
# Create model
|
| 15 |
-
|
| 16 |
-
num_classes=126,
|
| 17 |
-
)
|
| 18 |
|
| 19 |
# Load saved weights
|
| 20 |
-
|
| 21 |
torch.load(
|
| 22 |
-
f="
|
| 23 |
map_location=torch.device("cpu"),
|
| 24 |
)
|
| 25 |
)
|
| 26 |
|
| 27 |
|
| 28 |
-
# Create predict function
|
| 29 |
def predict(img) -> Tuple[Dict, float]:
|
| 30 |
|
| 31 |
start_time = timer()
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
ViT_model.eval()
|
| 38 |
with torch.inference_mode():
|
| 39 |
-
|
| 40 |
-
pred_probs = torch.softmax(ViT_model(img), dim=1)
|
| 41 |
|
| 42 |
-
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
|
| 43 |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
| 44 |
|
| 45 |
-
# Calculate the prediction time
|
| 46 |
pred_time = round(timer() - start_time, 5)
|
| 47 |
|
| 48 |
-
# Return the prediction dictionary and prediction time
|
| 49 |
return pred_labels_and_probs, pred_time
|
| 50 |
|
| 51 |
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
# Create model
|
| 15 |
+
model = create_ViT()
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Load saved weights
|
| 18 |
+
model.load_state_dict(
|
| 19 |
torch.load(
|
| 20 |
+
f="ViTHg.pth",
|
| 21 |
map_location=torch.device("cpu"),
|
| 22 |
)
|
| 23 |
)
|
| 24 |
|
| 25 |
|
|
|
|
| 26 |
def predict(img) -> Tuple[Dict, float]:
|
| 27 |
|
| 28 |
start_time = timer()
|
| 29 |
|
| 30 |
+
preprocess = transforms.Compose([
|
| 31 |
+
transforms.Resize((224, 224)),
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
img = preprocess(img).unsqueeze(0) # Add batch dimension
|
| 37 |
|
| 38 |
+
model.eval()
|
|
|
|
| 39 |
with torch.inference_mode():
|
| 40 |
+
pred_probs = torch.softmax(model(img), dim=1)
|
|
|
|
| 41 |
|
|
|
|
| 42 |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
| 43 |
|
|
|
|
| 44 |
pred_time = round(timer() - start_time, 5)
|
| 45 |
|
|
|
|
| 46 |
return pred_labels_and_probs, pred_time
|
| 47 |
|
| 48 |
|