SCREAMIE's picture
add new model, better acc
5b6fe27
import gradio as gr
import os
import torch
from model import pretrained_vit
from timeit import default_timer as timer
from consts import class_names
from huggingface_hub import hf_hub_download
# Model and transforms
model, transforms = pretrained_vit()
# Load saved weights
checkpoint_path = hf_hub_download(
repo_id="SCREAMIE/ViT_Food101",
filename="ViT_Food101_89.pth",
local_dir="models"
)
state_dict = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(state_dict)
# prediction function
def predict(img) -> tuple:
start_time = timer()
img = transforms(img).unsqueeze(0)
model.eval()
with torch.inference_mode():
pred_probs = torch.softmax(model(img), dim=1)
pred_labesl_and_probs = {
class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
}
end_time = timer()
pred_time = round(end_time - start_time, 4)
return pred_labesl_and_probs, pred_time
# gradio app
title = "Fine Tuned ViT on Food101 ๐ŸŒฎ๐Ÿฃ๐Ÿ•๐Ÿฃ๐Ÿ"
description = "ViT feature extractor computer vision model to classify images of classes Food101 dataset."
article = """
## Training Details
This model was fine-tuned on the **Food-101** dataset using a **pretrained Vision Transformer (ViT)** in PyTorch with.
### Final Result
- **Top-1 Accuracy:** **89%**
- **Total Training Time:** **3:26:16**
- **Test loss:** **test_loss=1.16616**
- **Train loss:** **test_loss=1.83015**
- **Batch size:** **128**
- **Num epochs:** **40**
- **Hardware:** **NVIDIA DGX Spark**
"""
# create example list
foodvision_min_examples_path = "examples"
example_list = [
[os.path.join(foodvision_min_examples_path, file)]
for file in os.listdir(foodvision_min_examples_path)
if file.lower().endswith((".jpg", ".jpeg", ".png"))
]
demo = gr.Interface(
fn=predict,
inputs=gr.Image(
type="pil",
sources=["upload"],
streaming=False
),
outputs=[gr.Label(num_top_classes=5, label="Predictions"), gr.Number(label="Prediction time (s)")],
title=title,
description=description,
article=article,
examples=example_list
)
demo.launch(share=False, server_name="0.0.0.0", debug=False)