amar6de2 commited on
Commit
01a5409
·
verified ·
1 Parent(s): 7a7cf02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -97
app.py CHANGED
@@ -1,114 +1,65 @@
1
- ### 1. Imports and class names setup ###
2
- import gradio as gr
3
- import os
4
  import torch
5
- import numpy as np
6
- from PIL import Image
7
- from model import create_vit_model # Make sure this function exists in model.py
8
- from timeit import default_timer as timer
9
- from typing import Tuple, Dict
10
-
11
- # Setup class names (or hardcode them if needed)
12
- class_names = ["apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad",
13
- "beignets", "bibimbap", "biryani", "bread_pudding", "breakfast_burrito", "bruschetta",
14
- "caesar_salad", "cannoli", "caprese_salad", "carrot_cake", "ceviche", "chai", "chapati",
15
- "cheese_plate", "cheesecake", "chicken_curry", "chicken_quesadilla", "chicken_wings",
16
- "chocolate_cake", "chocolate_mousse", "chole_bhature", "churros", "clam_chowder",
17
- "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "dabeli",
18
- "dal", "deviled_eggs", "dhokla", "donuts", "dosa", "dumplings", "edamame", "eggs_benedict",
19
- "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras", "french_fries",
20
- "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt",
21
- "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon",
22
- "guacamole", "gyoza", "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros",
23
- "hummus", "ice_cream", "idli", "jalebi", "kathi_rolls", "kofta", "kulfi", "lasagna",
24
- "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup",
25
- "momos", "mussels", "naan", "nachos", "omelette", "onion_rings", "oysters", "pad_thai",
26
- "paella", "pakoda", "pancakes", "pani_puri", "panna_cotta", "panner_butter_masala",
27
- "pav_bhaji", "peking_duck", "pho", "pizza", "pork_chop", "poutine", "prime_rib",
28
- "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto", "samosa",
29
- "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese",
30
- "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi",
31
- "tacos", "takoyaki", "tiramisu", "tuna_tartare", "vadapav", "waffles"]
32
-
33
- ### 2. Model and transforms setup ###
34
-
35
- # Create the model and transforms
36
- vit, vit_transforms = create_vit_model(num_classes=len(class_names))
37
-
38
- # Load saved model weights (assumes model is trained and .pth file is in the correct path)
39
- vit.load_state_dict(torch.load("vit_epoch_2.pth", map_location=torch.device("cpu")))
40
-
41
- ### 3. Prediction function ###
42
 
43
- def predict(img) -> Tuple[Dict[str, float], float]:
44
- """Transforms and performs a prediction on img and returns prediction and time taken."""
45
- from PIL import UnidentifiedImageError
46
-
47
- try:
48
- # Convert ndarray to PIL if needed
49
- if isinstance(img, np.ndarray):
50
- img = Image.fromarray(img)
51
-
52
- # Catch bad image input
53
- if img.mode != "RGB":
54
- img = img.convert("RGB")
55
 
56
- # Start timer
57
- start_time = timer()
 
58
 
59
- # Transform and add batch dimension
60
- img_tensor = vit_transforms(img).unsqueeze(0)
 
 
 
 
 
 
61
 
62
- # Inference
63
- vit.eval()
64
- with torch.inference_mode():
65
- pred_probs = torch.softmax(vit(img_tensor), dim=1)
66
 
67
- pred_labels_and_probs = {
68
- class_names[i]: float(pred_probs[0][i])
69
- for i in range(len(class_names))
70
- }
71
 
72
- pred_time = round(timer() - start_time, 5)
 
 
 
 
 
 
 
73
 
74
- return pred_labels_and_probs, pred_time
 
 
 
 
75
 
76
- except (UnidentifiedImageError, TypeError, ValueError) as e:
77
- return {"Error": f"Invalid image input: {str(e)}"}, 0.0
 
78
 
79
- ### 4. Gradio app setup ###
80
 
81
- # Title, description, and article text
82
- title = "VisionBite 🍕🥩🍣"
83
- description = (
84
- "A Vision Transformer (ViT-Base-16) model trained to classify images of food "
85
- "into 121 distinct categories. The model uses a transformer-based architecture "
86
- "to extract visual features and achieve accurate classification across diverse food items."
87
- )
88
- article = (
89
- "Model trained on the [Food121 dataset](https://huggingface.co/datasets/ItsNotRohit/Food121) "
90
- "with 95% top-5 prediction accuracy."
91
- )
92
 
93
- # Setup example images (if available)
94
- if os.path.exists("examples"):
95
- example_list = [["examples/" + f] for f in os.listdir("examples") if f.endswith((".jpg", ".jpeg", ".png"))]
96
- else:
97
- example_list = []
98
 
99
- # Create Gradio interface
100
  demo = gr.Interface(
101
  fn=predict,
102
- inputs=gr.Image(type="pil"),
103
- outputs=[
104
- gr.Label(num_top_classes=5, label="Top Predictions"),
105
- gr.Number(label="Prediction time (s)")
106
- ],
107
- examples=example_list,
108
- title=title,
109
- description=description,
110
- article=article
111
  )
112
 
113
- # Launch app
114
- demo.launch()
 
 
 
 
1
  import torch
2
+ import torch.nn.functional as F
3
+ import torchvision.transforms as transforms
4
+ from torchvision.models import vit_b_16
5
+ from torchvision.transforms import v2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ from PIL import Image
8
+ import gradio as gr
9
+ import os
 
 
 
 
 
 
 
 
 
10
 
11
+ # Load pretrained model
12
+ model = vit_b_16(weights='DEFAULT')
13
+ model.eval()
14
 
15
+ # Transformation for ViT
16
+ vit_transforms = v2.Compose([
17
+ v2.Resize((224, 224)),
18
+ v2.ToImage(), # Ensure proper image type
19
+ v2.ToDtype(torch.float32, scale=True),
20
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
21
+ std=[0.229, 0.224, 0.225]),
22
+ ])
23
 
24
+ # Class labels (example)
25
+ class_labels = [f"Class {i}" for i in range(1000)] # Replace with actual class names if you have them
 
 
26
 
 
 
 
 
27
 
28
+ def predict(img):
29
+ # Defensive: Ensure image is PIL
30
+ if isinstance(img, torch.Tensor):
31
+ raise ValueError("Expected PIL.Image, got torch.Tensor.")
32
+ elif isinstance(img, np.ndarray):
33
+ img = Image.fromarray(img)
34
+ elif not isinstance(img, Image.Image):
35
+ raise ValueError("Input is not a valid PIL image")
36
 
37
+ # Transform and run through model
38
+ img_tensor = vit_transforms(img).unsqueeze(0)
39
+ with torch.no_grad():
40
+ outputs = model(img_tensor)
41
+ probs = F.softmax(outputs[0], dim=0)
42
 
43
+ top5 = torch.topk(probs, 5)
44
+ results = {class_labels[i]: float(probs[i]) for i in top5.indices}
45
+ return results
46
 
 
47
 
48
+ # Set up Gradio interface
49
+ image_input = gr.Image(type="pil", label="Upload JPEG Image")
50
+ label_output = gr.Label(num_top_classes=5)
 
 
 
 
 
 
 
 
51
 
52
+ example_images = ["images/sample1.jpg", "images/sample2.jpg"]
53
+ example_images = [img for img in example_images if os.path.exists(img)] # filter missing files
 
 
 
54
 
 
55
  demo = gr.Interface(
56
  fn=predict,
57
+ inputs=image_input,
58
+ outputs=label_output,
59
+ examples=example_images,
60
+ title="ViT Image Classifier",
61
+ description="Upload a JPEG image to classify it using Vision Transformer (ViT-B16)."
 
 
 
 
62
  )
63
 
64
+ if __name__ == "__main__":
65
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)