amar6de2 commited on
Commit
5f0b74a
Β·
1 Parent(s): b289894
Files changed (1) hide show
  1. app.py +96 -40
app.py CHANGED
@@ -1,49 +1,105 @@
 
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
- from model import create_vit_model
6
-
7
- # Load model and transforms
8
- vit, vit_transforms = create_vit_model()
9
- vit.eval()
10
-
11
- # Define class names (replace with your actual class names)
12
- class_names = [f"class_{i}" for i in range(121)] # Replace with real labels if available
13
-
14
- # Prediction function
15
- def predict(img):
16
- try:
17
- # πŸ›‘οΈ Ensure valid PIL Image
18
- if isinstance(img, np.ndarray):
19
- img = Image.fromarray(img)
20
-
21
- if img is None or not isinstance(img, Image.Image):
22
- raise ValueError("Invalid image input or format.")
23
-
24
- # βœ… Convert to RGB to avoid dtype errors
25
- if img.mode != "RGB":
26
- img = img.convert("RGB")
27
-
28
- # πŸ“¦ Transform and predict
29
- img_tensor = vit_transforms(img).unsqueeze(0)
30
- with torch.no_grad():
31
- preds = vit(img_tensor)
32
- probs = torch.softmax(preds, dim=1)[0]
33
- top5 = torch.topk(probs, k=5)
34
- results = {class_names[idx]: float(probs[idx]) for idx in top5.indices}
35
- return results
36
- except Exception as e:
37
- raise RuntimeError(f"Prediction failed: {str(e)}") from e
38
-
39
- # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  demo = gr.Interface(
41
  fn=predict,
42
  inputs=gr.Image(type="pil"),
43
- outputs=gr.Label(num_top_classes=5),
44
- title="ViT Image Classifier",
45
- description="Upload an image to classify using Vision Transformer (ViT)."
 
 
 
 
 
46
  )
47
 
48
- if __name__ == "__main__":
49
- demo.launch()
 
1
+ ### 1. Imports and class names setup ###
2
  import gradio as gr
3
+ import os
4
  import torch
5
  import numpy as np
6
  from PIL import Image
7
+ from model import create_vit_model # Make sure this function exists in model.py
8
+ from timeit import default_timer as timer
9
+ from typing import Tuple, Dict
10
+
11
+ # Setup class names (or hardcode them if needed)
12
+ class_names = ["apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad",
13
+ "beignets", "bibimbap", "biryani", "bread_pudding", "breakfast_burrito", "bruschetta",
14
+ "caesar_salad", "cannoli", "caprese_salad", "carrot_cake", "ceviche", "chai", "chapati",
15
+ "cheese_plate", "cheesecake", "chicken_curry", "chicken_quesadilla", "chicken_wings",
16
+ "chocolate_cake", "chocolate_mousse", "chole_bhature", "churros", "clam_chowder",
17
+ "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "dabeli",
18
+ "dal", "deviled_eggs", "dhokla", "donuts", "dosa", "dumplings", "edamame", "eggs_benedict",
19
+ "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras", "french_fries",
20
+ "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt",
21
+ "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon",
22
+ "guacamole", "gyoza", "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros",
23
+ "hummus", "ice_cream", "idli", "jalebi", "kathi_rolls", "kofta", "kulfi", "lasagna",
24
+ "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup",
25
+ "momos", "mussels", "naan", "nachos", "omelette", "onion_rings", "oysters", "pad_thai",
26
+ "paella", "pakoda", "pancakes", "pani_puri", "panna_cotta", "panner_butter_masala",
27
+ "pav_bhaji", "peking_duck", "pho", "pizza", "pork_chop", "poutine", "prime_rib",
28
+ "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto", "samosa",
29
+ "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese",
30
+ "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi",
31
+ "tacos", "takoyaki", "tiramisu", "tuna_tartare", "vadapav", "waffles"]
32
+
33
+ ### 2. Model and transforms setup ###
34
+
35
+ # Create the model and transforms
36
+ vit, vit_transforms = create_vit_model(num_classes=len(class_names))
37
+
38
+ # Load saved model weights (assumes model is trained and .pth file is in the correct path)
39
+ vit.load_state_dict(torch.load("vit_epoch_2.pth", map_location=torch.device("cpu")))
40
+
41
+ ### 3. Prediction function ###
42
+
43
+ def predict(img) -> Tuple[Dict[str, float], float]:
44
+ """Transforms and performs a prediction on img and returns prediction and time taken."""
45
+ # Ensure the image is a PIL image
46
+ if isinstance(img, np.ndarray):
47
+ img = Image.fromarray(img)
48
+
49
+ # Start the timer
50
+ start_time = timer()
51
+
52
+ # Transform the image and add batch dimension
53
+ img = vit_transforms(img).unsqueeze(0)
54
+
55
+ # Run inference
56
+ vit.eval()
57
+ with torch.inference_mode():
58
+ pred_probs = torch.softmax(vit(img), dim=1)
59
+
60
+ # Create label and probability dict
61
+ pred_labels_and_probs = {
62
+ class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
63
+ }
64
+
65
+ # Calculate prediction time
66
+ pred_time = round(timer() - start_time, 5)
67
+
68
+ return pred_labels_and_probs, pred_time
69
+
70
+ ### 4. Gradio app setup ###
71
+
72
+ # Title, description, and article text
73
+ title = "VisionBite πŸ•πŸ₯©πŸ£"
74
+ description = (
75
+ "A Vision Transformer (ViT-Base-16) model trained to classify images of food "
76
+ "into 121 distinct categories. The model uses a transformer-based architecture "
77
+ "to extract visual features and achieve accurate classification across diverse food items."
78
+ )
79
+ article = (
80
+ "Model trained on the [Food121 dataset](https://huggingface.co/datasets/ItsNotRohit/Food121) "
81
+ "with 95% top-5 prediction accuracy."
82
+ )
83
+
84
+ # Setup example images (if available)
85
+ if os.path.exists("examples"):
86
+ example_list = [["examples/" + f] for f in os.listdir("examples") if f.endswith((".jpg", ".jpeg", ".png"))]
87
+ else:
88
+ example_list = []
89
+
90
+ # Create Gradio interface
91
  demo = gr.Interface(
92
  fn=predict,
93
  inputs=gr.Image(type="pil"),
94
+ outputs=[
95
+ gr.Label(num_top_classes=5, label="Top Predictions"),
96
+ gr.Number(label="Prediction time (s)")
97
+ ],
98
+ examples=example_list,
99
+ title=title,
100
+ description=description,
101
+ article=article
102
  )
103
 
104
+ # Launch app
105
+ demo.launch()