Spaces:
Sleeping
Sleeping
File size: 4,816 Bytes
6778ab8 01a5409 6778ab8 5f0b74a 6778ab8 5f0b74a 6778ab8 5f0b74a 6778ab8 5f0b74a 6778ab8 dd1e270 6778ab8 dd1e270 7a7cf02 dd1e270 6778ab8 7a7cf02 6778ab8 7a7cf02 dd1e270 6778ab8 5f0b74a 6778ab8 5f0b74a 6778ab8 dd1e270 6778ab8 5f0b74a 6778ab8 5f0b74a 6778ab8 8a7643b 6778ab8 1a55275 6778ab8 8a7643b 4d37213 6778ab8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
import numpy as np
from PIL import Image
from model import create_vit_model # Make sure this function exists in model.py
from timeit import default_timer as timer
from typing import Tuple, Dict
# Setup class names (or hardcode them if needed)
class_names = ["apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad",
"beignets", "bibimbap", "biryani", "bread_pudding", "breakfast_burrito", "bruschetta",
"caesar_salad", "cannoli", "caprese_salad", "carrot_cake", "ceviche", "chai", "chapati",
"cheese_plate", "cheesecake", "chicken_curry", "chicken_quesadilla", "chicken_wings",
"chocolate_cake", "chocolate_mousse", "chole_bhature", "churros", "clam_chowder",
"club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "dabeli",
"dal", "deviled_eggs", "dhokla", "donuts", "dosa", "dumplings", "edamame", "eggs_benedict",
"escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras", "french_fries",
"french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt",
"garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon",
"guacamole", "gyoza", "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros",
"hummus", "ice_cream", "idli", "jalebi", "kathi_rolls", "kofta", "kulfi", "lasagna",
"lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup",
"momos", "mussels", "naan", "nachos", "omelette", "onion_rings", "oysters", "pad_thai",
"paella", "pakoda", "pancakes", "pani_puri", "panna_cotta", "panner_butter_masala",
"pav_bhaji", "peking_duck", "pho", "pizza", "pork_chop", "poutine", "prime_rib",
"pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto", "samosa",
"sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese",
"spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi",
"tacos", "takoyaki", "tiramisu", "tuna_tartare", "vadapav", "waffles"]
### 2. Model and transforms setup ###
# Create the model and transforms
vit, vit_transforms = create_vit_model(num_classes=len(class_names))
# Load saved model weights (assumes model is trained and .pth file is in the correct path)
vit.load_state_dict(torch.load("vit_epoch_2.pth", map_location=torch.device("cpu")))
### 3. Prediction function ###
def predict(img) -> Tuple[Dict[str, float], float]:
from PIL import UnidentifiedImageError
try:
# Convert ndarray to PIL.Image if needed
if isinstance(img, np.ndarray):
img = Image.fromarray(img.astype("uint8")) # Ensure correct dtype
# Ensure image is in RGB mode
if img.mode != "RGB":
img = img.convert("RGB")
start_time = timer()
# Apply transforms (expects a PIL image)
img_tensor = vit_transforms(img).unsqueeze(0)
vit.eval()
with torch.inference_mode():
pred_probs = torch.softmax(vit(img_tensor), dim=1)
pred_labels_and_probs = {
class_names[i]: float(pred_probs[0][i])
for i in range(len(class_names))
}
pred_time = round(timer() - start_time, 5)
return pred_labels_and_probs, pred_time
except (UnidentifiedImageError, TypeError, ValueError) as e:
return {"Error": f"Invalid image input: {str(e)}"}, 0.0
### 4. Gradio app setup ###
# Title, description, and article text
title = "VisionBite ππ₯©π£"
description = (
"A Vision Transformer (ViT-Base-16) model trained to classify images of food "
"into 121 distinct categories. The model uses a transformer-based architecture "
"to extract visual features and achieve accurate classification across diverse food items."
)
article = (
"Model trained on the [Food121 dataset](https://huggingface.co/datasets/ItsNotRohit/Food121) "
"with 95% top-5 prediction accuracy."
)
# Setup example images (if available)
if os.path.exists("examples"):
example_list = [["examples/" + f] for f in os.listdir("examples") if f.endswith((".jpg", ".jpeg", ".png"))]
else:
example_list = []
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5, label="Top Predictions"),
gr.Number(label="Prediction time (s)")
],
# examples=example_list
title=title,
description=description,
article=article
)
# Launch app
demo.launch()
|