amar6de2 commited on
Commit
6778ab8
·
verified ·
1 Parent(s): 01a5409

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -48
app.py CHANGED
@@ -1,65 +1,114 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import torchvision.transforms as transforms
4
- from torchvision.models import vit_b_16
5
- from torchvision.transforms import v2
6
-
7
- from PIL import Image
8
  import gradio as gr
9
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Load pretrained model
12
- model = vit_b_16(weights='DEFAULT')
13
- model.eval()
14
 
15
- # Transformation for ViT
16
- vit_transforms = v2.Compose([
17
- v2.Resize((224, 224)),
18
- v2.ToImage(), # Ensure proper image type
19
- v2.ToDtype(torch.float32, scale=True),
20
- v2.Normalize(mean=[0.485, 0.456, 0.406],
21
- std=[0.229, 0.224, 0.225]),
22
- ])
23
 
24
- # Class labels (example)
25
- class_labels = [f"Class {i}" for i in range(1000)] # Replace with actual class names if you have them
 
26
 
 
 
 
 
27
 
28
- def predict(img):
29
- # Defensive: Ensure image is PIL
30
- if isinstance(img, torch.Tensor):
31
- raise ValueError("Expected PIL.Image, got torch.Tensor.")
32
- elif isinstance(img, np.ndarray):
33
- img = Image.fromarray(img)
34
- elif not isinstance(img, Image.Image):
35
- raise ValueError("Input is not a valid PIL image")
36
 
37
- # Transform and run through model
38
- img_tensor = vit_transforms(img).unsqueeze(0)
39
- with torch.no_grad():
40
- outputs = model(img_tensor)
41
- probs = F.softmax(outputs[0], dim=0)
42
 
43
- top5 = torch.topk(probs, 5)
44
- results = {class_labels[i]: float(probs[i]) for i in top5.indices}
45
- return results
46
 
 
 
 
 
47
 
48
- # Set up Gradio interface
49
- image_input = gr.Image(type="pil", label="Upload JPEG Image")
50
- label_output = gr.Label(num_top_classes=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- example_images = ["images/sample1.jpg", "images/sample2.jpg"]
53
- example_images = [img for img in example_images if os.path.exists(img)] # filter missing files
 
 
 
54
 
 
55
  demo = gr.Interface(
56
  fn=predict,
57
- inputs=image_input,
58
- outputs=label_output,
59
- examples=example_images,
60
- title="ViT Image Classifier",
61
- description="Upload a JPEG image to classify it using Vision Transformer (ViT-B16)."
 
 
 
 
62
  )
63
 
64
- if __name__ == "__main__":
65
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
+ ### 1. Imports and class names setup ###
 
 
 
 
 
 
2
  import gradio as gr
3
  import os
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ from model import create_vit_model # Make sure this function exists in model.py
8
+ from timeit import default_timer as timer
9
+ from typing import Tuple, Dict
10
+
11
+ # Setup class names (or hardcode them if needed)
12
+ class_names = ["apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad",
13
+ "beignets", "bibimbap", "biryani", "bread_pudding", "breakfast_burrito", "bruschetta",
14
+ "caesar_salad", "cannoli", "caprese_salad", "carrot_cake", "ceviche", "chai", "chapati",
15
+ "cheese_plate", "cheesecake", "chicken_curry", "chicken_quesadilla", "chicken_wings",
16
+ "chocolate_cake", "chocolate_mousse", "chole_bhature", "churros", "clam_chowder",
17
+ "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "dabeli",
18
+ "dal", "deviled_eggs", "dhokla", "donuts", "dosa", "dumplings", "edamame", "eggs_benedict",
19
+ "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras", "french_fries",
20
+ "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt",
21
+ "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon",
22
+ "guacamole", "gyoza", "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros",
23
+ "hummus", "ice_cream", "idli", "jalebi", "kathi_rolls", "kofta", "kulfi", "lasagna",
24
+ "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup",
25
+ "momos", "mussels", "naan", "nachos", "omelette", "onion_rings", "oysters", "pad_thai",
26
+ "paella", "pakoda", "pancakes", "pani_puri", "panna_cotta", "panner_butter_masala",
27
+ "pav_bhaji", "peking_duck", "pho", "pizza", "pork_chop", "poutine", "prime_rib",
28
+ "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto", "samosa",
29
+ "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese",
30
+ "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi",
31
+ "tacos", "takoyaki", "tiramisu", "tuna_tartare", "vadapav", "waffles"]
32
+
33
+ ### 2. Model and transforms setup ###
34
+
35
+ # Create the model and transforms
36
+ vit, vit_transforms = create_vit_model(num_classes=len(class_names))
37
 
38
+ # Load saved model weights (assumes model is trained and .pth file is in the correct path)
39
+ vit.load_state_dict(torch.load("vit_epoch_2.pth", map_location=torch.device("cpu")))
 
40
 
41
+ ### 3. Prediction function ###
 
 
 
 
 
 
 
42
 
43
+ def predict(img) -> Tuple[Dict[str, float], float]:
44
+ """Transforms and performs a prediction on img and returns prediction and time taken."""
45
+ from PIL import UnidentifiedImageError
46
 
47
+ try:
48
+ # Convert ndarray to PIL if needed
49
+ if isinstance(img, np.ndarray):
50
+ img = Image.fromarray(img)
51
 
52
+ # Catch bad image input
53
+ if img.mode != "RGB":
54
+ img = img.convert("RGB")
 
 
 
 
 
55
 
56
+ # Start timer
57
+ start_time = timer()
 
 
 
58
 
59
+ # Transform and add batch dimension
60
+ img_tensor = vit_transforms(img).unsqueeze(0)
 
61
 
62
+ # Inference
63
+ vit.eval()
64
+ with torch.inference_mode():
65
+ pred_probs = torch.softmax(vit(img_tensor), dim=1)
66
 
67
+ pred_labels_and_probs = {
68
+ class_names[i]: float(pred_probs[0][i])
69
+ for i in range(len(class_names))
70
+ }
71
+
72
+ pred_time = round(timer() - start_time, 5)
73
+
74
+ return pred_labels_and_probs, pred_time
75
+
76
+ except (UnidentifiedImageError, TypeError, ValueError) as e:
77
+ return {"Error": f"Invalid image input: {str(e)}"}, 0.0
78
+
79
+ ### 4. Gradio app setup ###
80
+
81
+ # Title, description, and article text
82
+ title = "VisionBite 🍕🥩🍣"
83
+ description = (
84
+ "A Vision Transformer (ViT-Base-16) model trained to classify images of food "
85
+ "into 121 distinct categories. The model uses a transformer-based architecture "
86
+ "to extract visual features and achieve accurate classification across diverse food items."
87
+ )
88
+ article = (
89
+ "Model trained on the [Food121 dataset](https://huggingface.co/datasets/ItsNotRohit/Food121) "
90
+ "with 95% top-5 prediction accuracy."
91
+ )
92
 
93
+ # Setup example images (if available)
94
+ if os.path.exists("examples"):
95
+ example_list = [["examples/" + f] for f in os.listdir("examples") if f.endswith((".jpg", ".jpeg", ".png"))]
96
+ else:
97
+ example_list = []
98
 
99
+ # Create Gradio interface
100
  demo = gr.Interface(
101
  fn=predict,
102
+ inputs=gr.Image(type="pil"),
103
+ outputs=[
104
+ gr.Label(num_top_classes=5, label="Top Predictions"),
105
+ gr.Number(label="Prediction time (s)")
106
+ ],
107
+ examples=example_list,
108
+ title=title,
109
+ description=description,
110
+ article=article
111
  )
112
 
113
+ # Launch app
114
+ demo.launch()