Spaces:
Build error
Build error
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import AutoModelForImageClassification | |
| import torch | |
| import heapq | |
| model = AutoModelForImageClassification.from_pretrained('Mullerjo/food-101-finetuned-model') | |
| #h | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| label_to_class = { | |
| "apple_pie": 0, | |
| "baby_back_ribs": 1, | |
| "baklava": 2, | |
| "beef_carpaccio": 3, | |
| "beef_tartare": 4, | |
| "beet_salad": 5, | |
| "beignets": 6, | |
| "bibimbap": 7, | |
| "bread_pudding": 8, | |
| "breakfast_burrito": 9, | |
| "bruschetta": 10, | |
| "caesar_salad": 11, | |
| "cannoli": 12, | |
| "caprese_salad": 13, | |
| "carrot_cake": 14, | |
| "ceviche": 15, | |
| "cheesecake": 16, | |
| "cheese_plate": 17, | |
| "chicken_curry": 18, | |
| "chicken_quesadilla": 19, | |
| "chicken_wings": 20, | |
| "chocolate_cake": 21, | |
| "chocolate_mousse": 22, | |
| "churros": 23, | |
| "clam_chowder": 24, | |
| "club_sandwich": 25, | |
| "crab_cakes": 26, | |
| "creme_brulee": 27, | |
| "croque_madame": 28, | |
| "cup_cakes": 29, | |
| "deviled_eggs": 30, | |
| "donuts": 31, | |
| "dumplings": 32, | |
| "edamame": 33, | |
| "eggs_benedict": 34, | |
| "escargots": 35, | |
| "falafel": 36, | |
| "filet_mignon": 37, | |
| "fish_and_chips": 38, | |
| "foie_gras": 39, | |
| "french_fries": 40, | |
| "french_onion_soup": 41, | |
| "french_toast": 42, | |
| "fried_calamari": 43, | |
| "fried_rice": 44, | |
| "frozen_yogurt": 45, | |
| "garlic_bread": 46, | |
| "gnocchi": 47, | |
| "greek_salad": 48, | |
| "grilled_cheese_sandwich": 49, | |
| "grilled_salmon": 50, | |
| "guacamole": 51, | |
| "gyoza": 52, | |
| "hamburger": 53, | |
| "hot_and_sour_soup": 54, | |
| "hot_dog": 55, | |
| "huevos_rancheros": 56, | |
| "hummus": 57, | |
| "ice_cream": 58, | |
| "lasagna": 59, | |
| "lobster_bisque": 60, | |
| "lobster_roll_sandwich": 61, | |
| "macaroni_and_cheese": 62, | |
| "macarons": 63, | |
| "miso_soup": 64, | |
| "mussels": 65, | |
| "nachos": 66, | |
| "omelette": 67, | |
| "onion_rings": 68, | |
| "oysters": 69, | |
| "pad_thai": 70, | |
| "paella": 71, | |
| "pancakes": 72, | |
| "panna_cotta": 73, | |
| "peking_duck": 74, | |
| "pho": 75, | |
| "pizza": 76, | |
| "pork_chop": 77, | |
| "poutine": 78, | |
| "prime_rib": 79, | |
| "pulled_pork_sandwich": 80, | |
| "ramen": 81, | |
| "ravioli": 82, | |
| "red_velvet_cake": 83, | |
| "risotto": 84, | |
| "samosa": 85, | |
| "sashimi": 86, | |
| "scallops": 87, | |
| "seaweed_salad": 88, | |
| "shrimp_and_grits": 89, | |
| "spaghetti_bolognese": 90, | |
| "spaghetti_carbonara": 91, | |
| "spring_rolls": 92, | |
| "steak": 93, | |
| "strawberry_shortcake": 94, | |
| "sushi": 95, | |
| "tacos": 96, | |
| "takoyaki": 97, | |
| "tiramisu": 98, | |
| "tuna_tartare": 99, | |
| "waffles": 100 | |
| } | |
| class_to_label = {v: k for k, v in label_to_class.items()} | |
| class_labels = model.config.id2label | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| def classify_image(img): | |
| img = Image.fromarray(img) | |
| img = preprocess(img) | |
| img = img.unsqueeze(0) # Add batch dimension | |
| img = img.to(device) # Move to device | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| logits = outputs.logits.squeeze().tolist() # Squeeze and convert to list | |
| # Get indices of top 3 logits | |
| top3_indices = heapq.nlargest(3, range(len(logits)), key=logits.__getitem__) | |
| # Get corresponding class names | |
| top3_classes = [class_to_label[class_labels[idx]] for idx in top3_indices] | |
| words = top3_classes | |
| outp = f"The most likely food is {words[0]}! If that seems unlikely it could also be {words[1]} or {words[2]} :)" | |
| return outp | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=gr.Textbox(label="Output"), | |
| title="Food Image Classifier", | |
| description="Upload an image of food and get the predicted category." | |
| ) | |
| iface.launch() |