import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import os CALORIE_DATA = { "apple_pie": 237, "baby_back_ribs": 290, "baklava": 334, "beef_carpaccio": 121, "beef_tartare": 200, "beet_salad": 70, "beignets": 350, "bibimbap": 490, "bread_pudding": 260, "breakfast_burrito": 305, "bruschetta": 120, "caesar_salad": 180, "cannoli": 220, "caprese_salad": 250, "carrot_cake": 300, "ceviche": 130, "cheese_plate": 350, "cheesecake": 320, "chicken_curry": 240, "chicken_quesadilla": 330, "chicken_wings": 290, "chocolate_cake": 370, "chocolate_mousse": 210, "churros": 230, "clam_chowder": 170, "club_sandwich": 350, "crab_cakes": 220, "creme_brulee": 260, "croque_madame": 450, "cup_cakes": 305, "deviled_eggs": 130, "donuts": 250, "dumplings": 210, "edamame": 120, "eggs_benedict": 290, "escargots": 170, "falafel": 330, "filet_mignon": 280, "fish_and_chips": 590, "foie_gras": 460, "french_fries": 365, "french_onion_soup": 210, "french_toast": 260, "fried_calamari": 310, "fried_rice": 230, "frozen_yogurt": 160, "garlic_bread": 200, "gnocchi": 250, "greek_salad": 130, "grilled_cheese_sandwich": 370, "grilled_salmon": 350, "guacamole": 150, "gyoza": 200, "hamburger": 354, "hot_and_sour_soup": 90, "hot_dog": 290, "huevos_rancheros": 360, "hummus": 170, "ice_cream": 210, "lasagna": 290, "lobster_bisque": 240, "lobster_roll_sandwich": 290, "macaroni_and_cheese": 310, "macarons": 100, "miso_soup": 40, "mussels": 170, "nachos": 340, "omelette": 150, "onion_rings": 330, "oysters": 60, "pad_thai": 360, "paella": 310, "pancakes": 230, "panna_cotta": 340, "peking_duck": 330, "pho": 350, "pizza": 270, "pork_chop": 230, "poutine": 510, "prime_rib": 350, "pulled_pork_sandwich": 390, "ramen": 380, "ravioli": 220, "red_velvet_cake": 360, "risotto": 340, "samosa": 260, "sashimi": 130, "scallops": 110, "seaweed_salad": 70, "shrimp_and_grits": 280, "spaghetti_bolognese": 370, "spaghetti_carbonara": 390, "spring_rolls": 150, "steak": 270, "strawberry_shortcake": 280, "sushi": 200, "tacos": 210, "takoyaki": 170, "tiramisu": 290, "tuna_tartare": 180, "waffles": 290 } CLASS_NAMES = sorted(CALORIE_DATA.keys()) def load_model(): model = models.resnet50(weights=None) num_classes = len(CLASS_NAMES) model.fc = nn.Sequential( nn.Dropout(0.4), nn.Linear(model.fc.in_features, num_classes) ) model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'food_classifier_resnet50.pth') state_dict = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(state_dict) model.eval() return model model = load_model() transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def classify_food(image): if image is None: return {}, "" img = Image.fromarray(image).convert("RGB") input_tensor = transform(img).unsqueeze(0) with torch.no_grad(): outputs = model(input_tensor) probs = torch.nn.functional.softmax(outputs, dim=1) top5_probs, top5_indices = torch.topk(probs, 5) confidences = {} for i in range(5): class_name = CLASS_NAMES[top5_indices[0][i].item()] food_name = class_name.replace("_", " ").title() confidences[food_name] = float(top5_probs[0][i].item()) top_class = CLASS_NAMES[top5_indices[0][0].item()] top_food = top_class.replace("_", " ").title() top_cal = CALORIE_DATA.get(top_class, "N/A") top_conf = top5_probs[0][0].item() * 100 calorie_text = f"## {top_food}\n" calorie_text += f"**Confidence:** {top_conf:.1f}%\n\n" calorie_text += f"**Estimated Calories:** ~{top_cal} kcal per serving\n\n" calorie_text += "---\n\n" calorie_text += "**Other possibilities:**\n\n" for i in range(1, 5): cls = CLASS_NAMES[top5_indices[0][i].item()] name = cls.replace("_", " ").title() cal = CALORIE_DATA.get(cls, "N/A") conf = top5_probs[0][i].item() * 100 calorie_text += f"| {name} | {conf:.1f}% | ~{cal} kcal |\n" return confidences, calorie_text custom_css = """ .gradio-container { max-width: 900px !important; margin: auto !important; } h1 { text-align: center; margin-bottom: 0.2em; } .description { text-align: center; } """ theme = gr.themes.Soft( primary_hue="orange", secondary_hue="amber", neutral_hue="gray", font=gr.themes.GoogleFont("Inter"), ) with gr.Blocks(theme=theme, css=custom_css, title="Food Image Classifier") as demo: gr.Markdown("# Food Image Classifier") gr.Markdown( "Upload a photo of any food — the model identifies it from **101 categories** and estimates calories.", elem_classes="description" ) with gr.Row(equal_height=True): with gr.Column(scale=1): image_input = gr.Image( label="Upload Food Photo", type="numpy", height=350 ) classify_btn = gr.Button("Classify", variant="primary", size="lg") with gr.Column(scale=1): label_output = gr.Label(num_top_classes=5, label="Top 5 Predictions") calorie_output = gr.Markdown(label="Details") classify_btn.click( fn=classify_food, inputs=image_input, outputs=[label_output, calorie_output] ) image_input.change( fn=classify_food, inputs=image_input, outputs=[label_output, calorie_output] ) gr.Markdown("---") gr.Markdown( "
Trained on Food-101 dataset (101K images) — " "GitHub
", sanitize_html=False ) demo.launch()