Spaces:
Runtime error
Runtime error
Commit
·
008b175
1
Parent(s):
5db8bb0
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import gradio as gr
|
|
| 2 |
import os
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
from model import
|
| 6 |
from timeit import default_timer as timer
|
| 7 |
from typing import Tuple, Dict
|
| 8 |
|
|
@@ -12,14 +12,14 @@ with open("class_names.txt", "r") as f:
|
|
| 12 |
|
| 13 |
|
| 14 |
# Create model
|
| 15 |
-
|
| 16 |
-
num_classes=
|
| 17 |
)
|
| 18 |
|
| 19 |
# Load saved weights
|
| 20 |
-
|
| 21 |
torch.load(
|
| 22 |
-
f="
|
| 23 |
map_location=torch.device("cpu"),
|
| 24 |
)
|
| 25 |
)
|
|
@@ -31,13 +31,13 @@ def predict(img) -> Tuple[Dict, float]:
|
|
| 31 |
start_time = timer()
|
| 32 |
|
| 33 |
# Transform the target image and add a batch dimension
|
| 34 |
-
img =
|
| 35 |
|
| 36 |
# Put model into evaluation mode and turn on inference mode
|
| 37 |
-
|
| 38 |
with torch.inference_mode():
|
| 39 |
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
|
| 40 |
-
pred_probs = torch.softmax(
|
| 41 |
|
| 42 |
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
|
| 43 |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
|
@@ -52,7 +52,7 @@ def predict(img) -> Tuple[Dict, float]:
|
|
| 52 |
##GRADIO APP
|
| 53 |
# Create title, description and article strings
|
| 54 |
title = "FoodVision🍔🍟🍦"
|
| 55 |
-
description = "
|
| 56 |
article = "Created by [Rohit](https://github.com/ItsNotRohit02)."
|
| 57 |
|
| 58 |
# Create examples list from "examples/" directory
|
|
|
|
| 2 |
import os
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
from model import create_ViT
|
| 6 |
from timeit import default_timer as timer
|
| 7 |
from typing import Tuple, Dict
|
| 8 |
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
# Create model
|
| 15 |
+
ViT_model, ViT_transforms = create_ViT(
|
| 16 |
+
num_classes=126,
|
| 17 |
)
|
| 18 |
|
| 19 |
# Load saved weights
|
| 20 |
+
ViT_model.load_state_dict(
|
| 21 |
torch.load(
|
| 22 |
+
f="ViT.pth",
|
| 23 |
map_location=torch.device("cpu"),
|
| 24 |
)
|
| 25 |
)
|
|
|
|
| 31 |
start_time = timer()
|
| 32 |
|
| 33 |
# Transform the target image and add a batch dimension
|
| 34 |
+
img = ViT_transforms(img).unsqueeze(0)
|
| 35 |
|
| 36 |
# Put model into evaluation mode and turn on inference mode
|
| 37 |
+
ViT_model.eval()
|
| 38 |
with torch.inference_mode():
|
| 39 |
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
|
| 40 |
+
pred_probs = torch.softmax(ViT_model(img), dim=1)
|
| 41 |
|
| 42 |
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
|
| 43 |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
|
|
|
| 52 |
##GRADIO APP
|
| 53 |
# Create title, description and article strings
|
| 54 |
title = "FoodVision🍔🍟🍦"
|
| 55 |
+
description = "A Vision Transformer feature extractor computer vision model to classify images of food into 126 different classes."
|
| 56 |
article = "Created by [Rohit](https://github.com/ItsNotRohit02)."
|
| 57 |
|
| 58 |
# Create examples list from "examples/" directory
|