Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import os | |
| CALORIE_DATA = { | |
| "apple_pie": 237, "baby_back_ribs": 290, "baklava": 334, | |
| "beef_carpaccio": 121, "beef_tartare": 200, "beet_salad": 70, | |
| "beignets": 350, "bibimbap": 490, "bread_pudding": 260, | |
| "breakfast_burrito": 305, "bruschetta": 120, "caesar_salad": 180, | |
| "cannoli": 220, "caprese_salad": 250, "carrot_cake": 300, | |
| "ceviche": 130, "cheese_plate": 350, "cheesecake": 320, | |
| "chicken_curry": 240, "chicken_quesadilla": 330, | |
| "chicken_wings": 290, "chocolate_cake": 370, "chocolate_mousse": 210, | |
| "churros": 230, "clam_chowder": 170, "club_sandwich": 350, | |
| "crab_cakes": 220, "creme_brulee": 260, "croque_madame": 450, | |
| "cup_cakes": 305, "deviled_eggs": 130, "donuts": 250, | |
| "dumplings": 210, "edamame": 120, "eggs_benedict": 290, | |
| "escargots": 170, "falafel": 330, "filet_mignon": 280, | |
| "fish_and_chips": 590, "foie_gras": 460, "french_fries": 365, | |
| "french_onion_soup": 210, "french_toast": 260, "fried_calamari": 310, | |
| "fried_rice": 230, "frozen_yogurt": 160, "garlic_bread": 200, | |
| "gnocchi": 250, "greek_salad": 130, "grilled_cheese_sandwich": 370, | |
| "grilled_salmon": 350, "guacamole": 150, "gyoza": 200, | |
| "hamburger": 354, "hot_and_sour_soup": 90, "hot_dog": 290, | |
| "huevos_rancheros": 360, "hummus": 170, "ice_cream": 210, | |
| "lasagna": 290, "lobster_bisque": 240, "lobster_roll_sandwich": 290, | |
| "macaroni_and_cheese": 310, "macarons": 100, "miso_soup": 40, | |
| "mussels": 170, "nachos": 340, "omelette": 150, | |
| "onion_rings": 330, "oysters": 60, "pad_thai": 360, | |
| "paella": 310, "pancakes": 230, "panna_cotta": 340, | |
| "peking_duck": 330, "pho": 350, "pizza": 270, | |
| "pork_chop": 230, "poutine": 510, "prime_rib": 350, | |
| "pulled_pork_sandwich": 390, "ramen": 380, "ravioli": 220, | |
| "red_velvet_cake": 360, "risotto": 340, "samosa": 260, | |
| "sashimi": 130, "scallops": 110, "seaweed_salad": 70, | |
| "shrimp_and_grits": 280, "spaghetti_bolognese": 370, | |
| "spaghetti_carbonara": 390, "spring_rolls": 150, | |
| "steak": 270, "strawberry_shortcake": 280, "sushi": 200, | |
| "tacos": 210, "takoyaki": 170, "tiramisu": 290, | |
| "tuna_tartare": 180, "waffles": 290 | |
| } | |
| CLASS_NAMES = sorted(CALORIE_DATA.keys()) | |
| def load_model(): | |
| model = models.resnet50(weights=None) | |
| num_classes = len(CLASS_NAMES) | |
| model.fc = nn.Sequential( | |
| nn.Dropout(0.4), | |
| nn.Linear(model.fc.in_features, num_classes) | |
| ) | |
| model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'food_classifier_resnet50.pth') | |
| state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def classify_food(image): | |
| if image is None: | |
| return {}, "" | |
| img = Image.fromarray(image).convert("RGB") | |
| input_tensor = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probs = torch.nn.functional.softmax(outputs, dim=1) | |
| top5_probs, top5_indices = torch.topk(probs, 5) | |
| confidences = {} | |
| for i in range(5): | |
| class_name = CLASS_NAMES[top5_indices[0][i].item()] | |
| food_name = class_name.replace("_", " ").title() | |
| confidences[food_name] = float(top5_probs[0][i].item()) | |
| top_class = CLASS_NAMES[top5_indices[0][0].item()] | |
| top_food = top_class.replace("_", " ").title() | |
| top_cal = CALORIE_DATA.get(top_class, "N/A") | |
| top_conf = top5_probs[0][0].item() * 100 | |
| calorie_text = f"## {top_food}\n" | |
| calorie_text += f"**Confidence:** {top_conf:.1f}%\n\n" | |
| calorie_text += f"**Estimated Calories:** ~{top_cal} kcal per serving\n\n" | |
| calorie_text += "---\n\n" | |
| calorie_text += "**Other possibilities:**\n\n" | |
| for i in range(1, 5): | |
| cls = CLASS_NAMES[top5_indices[0][i].item()] | |
| name = cls.replace("_", " ").title() | |
| cal = CALORIE_DATA.get(cls, "N/A") | |
| conf = top5_probs[0][i].item() * 100 | |
| calorie_text += f"| {name} | {conf:.1f}% | ~{cal} kcal |\n" | |
| return confidences, calorie_text | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 900px !important; | |
| margin: auto !important; | |
| } | |
| h1 { | |
| text-align: center; | |
| margin-bottom: 0.2em; | |
| } | |
| .description { | |
| text-align: center; | |
| } | |
| """ | |
| theme = gr.themes.Soft( | |
| primary_hue="orange", | |
| secondary_hue="amber", | |
| neutral_hue="gray", | |
| font=gr.themes.GoogleFont("Inter"), | |
| ) | |
| with gr.Blocks(theme=theme, css=custom_css, title="Food Image Classifier") as demo: | |
| gr.Markdown("# Food Image Classifier") | |
| gr.Markdown( | |
| "Upload a photo of any food — the model identifies it from **101 categories** and estimates calories.", | |
| elem_classes="description" | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="Upload Food Photo", | |
| type="numpy", | |
| height=350 | |
| ) | |
| classify_btn = gr.Button("Classify", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| label_output = gr.Label(num_top_classes=5, label="Top 5 Predictions") | |
| calorie_output = gr.Markdown(label="Details") | |
| classify_btn.click( | |
| fn=classify_food, | |
| inputs=image_input, | |
| outputs=[label_output, calorie_output] | |
| ) | |
| image_input.change( | |
| fn=classify_food, | |
| inputs=image_input, | |
| outputs=[label_output, calorie_output] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "<center><small>Trained on Food-101 dataset (101K images) — " | |
| "<a href='https://github.com/ahmedamr022/Food-Classification'>GitHub</a></small></center>", | |
| sanitize_html=False | |
| ) | |
| demo.launch() | |