File size: 4,816 Bytes
6778ab8
01a5409
 
6778ab8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f0b74a
6778ab8
 
5f0b74a
6778ab8
5f0b74a
6778ab8
 
5f0b74a
6778ab8
dd1e270
6778ab8
dd1e270
7a7cf02
dd1e270
6778ab8
 
7a7cf02
6778ab8
7a7cf02
dd1e270
6778ab8
5f0b74a
6778ab8
 
 
5f0b74a
6778ab8
 
 
 
 
 
 
 
 
 
 
 
dd1e270
6778ab8
 
 
 
 
 
 
 
 
 
 
 
 
5f0b74a
6778ab8
 
 
 
 
5f0b74a
6778ab8
8a7643b
 
6778ab8
 
 
 
 
1a55275
6778ab8
 
 
8a7643b
4d37213
6778ab8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
import numpy as np
from PIL import Image
from model import create_vit_model  # Make sure this function exists in model.py
from timeit import default_timer as timer
from typing import Tuple, Dict

# Setup class names (or hardcode them if needed)
class_names = ["apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad",
               "beignets", "bibimbap", "biryani", "bread_pudding", "breakfast_burrito", "bruschetta", 
               "caesar_salad", "cannoli", "caprese_salad", "carrot_cake", "ceviche", "chai", "chapati", 
               "cheese_plate", "cheesecake", "chicken_curry", "chicken_quesadilla", "chicken_wings", 
               "chocolate_cake", "chocolate_mousse", "chole_bhature", "churros", "clam_chowder", 
               "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "dabeli", 
               "dal", "deviled_eggs", "dhokla", "donuts", "dosa", "dumplings", "edamame", "eggs_benedict", 
               "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras", "french_fries", 
               "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt", 
               "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon", 
               "guacamole", "gyoza", "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros", 
               "hummus", "ice_cream", "idli", "jalebi", "kathi_rolls", "kofta", "kulfi", "lasagna", 
               "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup", 
               "momos", "mussels", "naan", "nachos", "omelette", "onion_rings", "oysters", "pad_thai", 
               "paella", "pakoda", "pancakes", "pani_puri", "panna_cotta", "panner_butter_masala", 
               "pav_bhaji", "peking_duck", "pho", "pizza", "pork_chop", "poutine", "prime_rib", 
               "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto", "samosa", 
               "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese", 
               "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi", 
               "tacos", "takoyaki", "tiramisu", "tuna_tartare", "vadapav", "waffles"]

### 2. Model and transforms setup ###

# Create the model and transforms
vit, vit_transforms = create_vit_model(num_classes=len(class_names))

# Load saved model weights (assumes model is trained and .pth file is in the correct path)
vit.load_state_dict(torch.load("vit_epoch_2.pth", map_location=torch.device("cpu")))

### 3. Prediction function ###

def predict(img) -> Tuple[Dict[str, float], float]:
    from PIL import UnidentifiedImageError

    try:
        # Convert ndarray to PIL.Image if needed
        if isinstance(img, np.ndarray):
            img = Image.fromarray(img.astype("uint8"))  # Ensure correct dtype

        # Ensure image is in RGB mode
        if img.mode != "RGB":
            img = img.convert("RGB")

        start_time = timer()

        # Apply transforms (expects a PIL image)
        img_tensor = vit_transforms(img).unsqueeze(0)

        vit.eval()
        with torch.inference_mode():
            pred_probs = torch.softmax(vit(img_tensor), dim=1)

        pred_labels_and_probs = {
            class_names[i]: float(pred_probs[0][i])
            for i in range(len(class_names))
        }

        pred_time = round(timer() - start_time, 5)

        return pred_labels_and_probs, pred_time

    except (UnidentifiedImageError, TypeError, ValueError) as e:
        return {"Error": f"Invalid image input: {str(e)}"}, 0.0


### 4. Gradio app setup ###

# Title, description, and article text
title = "VisionBite πŸ•πŸ₯©πŸ£"
description = (
    "A Vision Transformer (ViT-Base-16) model trained to classify images of food "
    "into 121 distinct categories. The model uses a transformer-based architecture "
    "to extract visual features and achieve accurate classification across diverse food items."
)
article = (
    "Model trained on the [Food121 dataset](https://huggingface.co/datasets/ItsNotRohit/Food121) "
    "with 95% top-5 prediction accuracy."
)

# Setup example images (if available)
if os.path.exists("examples"):
    example_list = [["examples/" + f] for f in os.listdir("examples") if f.endswith((".jpg", ".jpeg", ".png"))]
else:
    example_list = []

# Create Gradio interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Label(num_top_classes=5, label="Top Predictions"),
        gr.Number(label="Prediction time (s)")
    ],
    # examples=example_list
    title=title,
    description=description,
    article=article
)

# Launch app
demo.launch()