Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import json | |
| from pathlib import Path | |
| from PIL import Image | |
| # ========================= | |
| # 1. Load Model | |
| # ========================= | |
| MODEL_PATH = "model/best_food_model.keras" | |
| model = tf.keras.models.load_model(MODEL_PATH) | |
| # ========================= | |
| # 2. Labels | |
| # ========================= | |
| LABELS = [ | |
| 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', | |
| 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', | |
| 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', | |
| 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', | |
| 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', | |
| 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', | |
| 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', | |
| 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', | |
| 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', | |
| 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', | |
| 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', | |
| 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', | |
| 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', | |
| 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', | |
| 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', | |
| 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', | |
| 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', | |
| 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', | |
| 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', | |
| 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', | |
| 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', | |
| 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', | |
| 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles' | |
| ] | |
| # ========================= | |
| # 3. Load Nutrition JSON | |
| # ========================= | |
| NUTRITION_PATH = Path("nutrition_db.json") | |
| if NUTRITION_PATH.exists(): | |
| with open(NUTRITION_PATH, "r", encoding="utf-8") as f: | |
| NUTRITION_DB = json.load(f) | |
| else: | |
| NUTRITION_DB = {} | |
| # ========================= | |
| # 4. Prediction Function | |
| # ========================= | |
| def predict_nutrition(img): | |
| if img is None: | |
| return {}, "Upload an image." | |
| # Ensure PIL RGB | |
| if isinstance(img, np.ndarray): | |
| img = Image.fromarray(img) | |
| img = img.convert("RGB").resize((224, 224)) | |
| img_array = tf.keras.preprocessing.image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) / 255.0 | |
| preds = model.predict(img_array, verbose=0)[0] | |
| # Top 3 predictions | |
| top_indices = np.argsort(preds)[-3:][::-1] | |
| confidences = {LABELS[i]: float(preds[i]) for i in top_indices} | |
| # Top 1 nutrition | |
| top_idx = int(np.argmax(preds)) | |
| food_name = LABELS[top_idx] | |
| nutri = NUTRITION_DB.get( | |
| food_name, | |
| {"cal": 0, "protein": 0, "carbs": 0, "fat": 0} | |
| ) | |
| clean_name = food_name.replace("_", " ").title() | |
| nutrition_md = f""" | |
| ### π₯ Nutrition Facts β {clean_name} | |
| *(Estimated per 100g)* | |
| | Nutrient | Amount | | |
| |---|---| | |
| | Calories | {nutri['cal']} kcal | | |
| | Protein | {nutri['protein']} g | | |
| | Carbs | {nutri['carbs']} g | | |
| | Fat | {nutri['fat']} g | | |
| """ | |
| return confidences, nutrition_md | |
| # ========================= | |
| # 5. Gradio UI | |
| # ========================= | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Food-101 Classifier") | |
| gr.Markdown("Upload food β get prediction + macros.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image(type="numpy", label="Upload Food Image") | |
| submit_btn = gr.Button("Analyze Meal", variant="primary") | |
| with gr.Column(): | |
| output_chart = gr.Label(num_top_classes=3) | |
| output_nutri = gr.Markdown() | |
| submit_btn.click( | |
| fn=predict_nutrition, | |
| inputs=input_img, | |
| outputs=[output_chart, output_nutri] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("Educational demo. Not medical advice.") | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |