import streamlit as st from PIL import Image import torch from torchvision import models import torch.nn as nn import torchvision.transforms as transforms import io import os import openai import json import timm # For ConvNeXt Large model # --- Page Configuration --- st.set_page_config(page_title="EatSmart Pro", page_icon="๐Ÿฝ๏ธ", layout="wide") # Mobile menu indicator st.markdown("""

๐Ÿ“ฑ Mobile Users: Click the > arrow (top-left) to open preferences menu

""", unsafe_allow_html=True) # ============================================================================== # The correct, full list of 101 class names from your training script. # ============================================================================== CLASS_NAMES = [ 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles' ] # ============================================================================== # ALL HELPER FUNCTIONS ARE NOW INSIDE A SINGLE APP.PY # ============================================================================== def get_convnext_model(num_classes): """ Creates the ConvNeXt Large model architecture, matching the training script for maximum accuracy. """ print(f"๐Ÿš€ Loading ConvNeXt Large model for inference...") print(f"๐Ÿ“Š Model: ConvNeXt Large (197M parameters)") # Create ConvNeXt Large model (same as training script) model = timm.create_model('convnext_large.fb_in22k_ft_in1k', pretrained=True, num_classes=num_classes) # Model statistics for user feedback total_params = sum(p.numel() for p in model.parameters()) print(f"โœ… Model architecture loaded:") print(f" Total parameters: {total_params:,}") print(f" Model size: ~{total_params * 4 / 1024**2:.1f} MB") return model def get_efficientnet_model(num_classes): """ Creates the old EfficientNet-V2-S model architecture for fallback. """ print(f"โš ๏ธ Loading EfficientNet-V2-S model (fallback)...") model = models.efficientnet_v2_s(weights=None) num_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(p=0.2), nn.Linear(num_features, 256), nn.ReLU(), nn.Dropout(p=0.1), nn.Linear(256, num_classes) ) return model def load_json_data(path): if not os.path.exists(path): return {} try: with open(path, 'r') as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError): return {} def get_health_info(food_name, health_data): """Enhanced health information display with comprehensive nutritional data""" food_name_key = food_name.replace('_', ' ').title() # Try to get from the nested nutrition_info structure first nutrition_info = health_data.get("nutrition_info", {}) info = nutrition_info.get(food_name_key) or nutrition_info.get(food_name.lower()) # Try direct lookup if nested lookup fails if not info: info = health_data.get(food_name_key) if not info: return "

No specific health information available for this dish.

" # Get health score health_scores = health_data.get("health_scores", {}) health_score = health_scores.get(food_name.lower(), {}) score = health_score.get("score", 75) score_explanation = health_score.get("explanation", "Good") # Determine score color if score >= 80: score_color = "#28a745" # Green score_bg = "#d4edda" elif score >= 60: score_color = "#ffc107" # Yellow score_bg = "#fff3cd" else: score_color = "#dc3545" # Red score_bg = "#f8d7da" # Build nutritional information nutrition_html = f"""

๐Ÿ“Š Nutritional Information

Calories
{info.get("calories", "N/A")}
kcal
Protein
{info.get("protein", "N/A")}
g
Carbs
{info.get("carbs", "N/A")}
g
Fat
{info.get("fat", "N/A")}
g
""" # Health Score score_html = f"""

๐Ÿฅ Health Assessment

Health Score: {score}/100 ({score_explanation})
""" # Health Benefits - Display using Streamlit components instead of HTML benefits_section = "" if info.get("benefits"): # We'll handle benefits separately using st.markdown with proper formatting benefits_section = "BENEFITS_SECTION" # Placeholder to indicate benefits exist return nutrition_html + score_html + benefits_section def get_allergen_info(food_name, allergen_data): """Enhanced allergen information using sidebar preferences""" food_name_key = food_name.replace('_', ' ').title() allergens = allergen_data.get(food_name_key, []) # Display allergen information for the detected food if allergens: # Check for user-specified allergen matches user_allergen_matches = [a for a in allergens if a in st.session_state.user_allergens] if user_allergen_matches: # High priority alert for user's allergens st.error(f"๐Ÿšจ **CRITICAL ALLERGEN ALERT**: This dish contains **{', '.join(user_allergen_matches)}** which you've marked as allergens to avoid!") # Display allergens in a more user-friendly way st.markdown("### โš ๏ธ Allergens Detected in This Food") # Create columns for allergen badges cols = st.columns(min(len(allergens), 4)) for i, allergen in enumerate(allergens): with cols[i % len(cols)]: if allergen in st.session_state.user_allergens: # User's marked allergens - show as error st.error(f"๐Ÿšจ {allergen}") else: # Other allergens - show as info st.info(f"โ„น๏ธ {allergen}") # Add explanation if user_allergen_matches: st.markdown("---") st.markdown("๐Ÿ”ด **Red alerts** are for allergens you've marked in your preferences") else: st.markdown("---") st.markdown("๐Ÿ’™ **Blue badges** show allergens present in this food") else: st.success("โœ… No common allergens typically found in this dish.") def get_trans_fat_analysis(food_name, health_data): """Enhanced trans fat analysis with user preferences""" food_name_lower = food_name.lower().replace('_', ' ') # Get trans fat ingredients from health data trans_fat_ingredients = health_data.get("trans_fat_ingredients", [ "hydrogenated oil", "partially hydrogenated oil", "margarine", "shortening" ]) # Foods that commonly contain trans fats high_trans_fat_foods = [ "donuts", "french_fries", "fried", "margarine", "shortening", "processed", "packaged", "fast food", "baked goods", "cookies", "crackers" ] # Check if food likely contains trans fats likely_trans_fat = any(ingredient in food_name_lower for ingredient in high_trans_fat_foods) if likely_trans_fat: alert_level = "HIGH RISK" if st.session_state.avoid_trans_fat else "POTENTIAL RISK" if st.session_state.avoid_trans_fat: st.error(f"๐Ÿšจ **TRANS FAT ALERT**: You've enabled trans fat warnings and this food may contain trans fats!") trans_fat_html = f"""

๐Ÿงช Trans Fat Analysis - {alert_level}

โš ๏ธ This food may contain trans fats from:

{"".join([f'{ingredient.title()}' for ingredient in trans_fat_ingredients])}

๐Ÿ’ก Recommendation: Check ingredient labels and consider healthier alternatives

""" else: trans_fat_html = f"""

๐Ÿงช Trans Fat Analysis - LOW RISK

โœ… Great Choice!

This food typically contains minimal or no artificial trans fats.

Keep enjoying healthy foods like this! ๐ŸŒŸ

""" return trans_fat_html def generate_recipe(food_name): """Enhanced recipe generation based on detected ingredients and user preferences""" try: api_key = st.secrets.get("OPENAI_API_KEY") if not api_key: return "Error: OPENAI_API_KEY not found." openai.api_key = api_key # Build prompt based on user preferences dietary_restrictions = "" if st.session_state.dietary_preferences: dietary_restrictions = f"Make it {', '.join(st.session_state.dietary_preferences).lower()}. " allergen_restrictions = "" if st.session_state.user_allergens: allergen_restrictions = f"Avoid using {', '.join(st.session_state.user_allergens).lower()}. " trans_fat_note = "" if st.session_state.avoid_trans_fat: trans_fat_note = "Use healthy oils and avoid trans fats. " prompt = f"""Create a recipe for {food_name}. {dietary_restrictions}{allergen_restrictions}{trans_fat_note} Format the response with these exact headings: Ingredients: Instructions: Chef's Tips: Make it practical for home cooking and include nutritional benefits.""" response = openai.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a professional chef who creates healthy, personalized recipes based on dietary preferences and restrictions."}, {"role": "user", "content": prompt} ], temperature=0.7, max_tokens=600 ) recipe_text = response.choices[0].message.content return recipe_text.replace("**Ingredients:**", "Ingredients:").replace("**Instructions:**", "Instructions:").replace("**Chef's Tips:**", "Chef's Tips:") except Exception as e: return f"The AI Chef is busy right now. Error: {str(e)}" @st.cache_resource def load_model_resources(): try: num_classes = len(CLASS_NAMES) # Look for the new ConvNeXt Large model first, fallback to old models convnext_model_path = 'models/food_classifier_convnext_large_cpu_full.pth' efficientnet_model_path = 'models/food101_efficientnet_best.pth' old_model_path = 'models/final_model.pth' if os.path.exists(convnext_model_path): model_path = convnext_model_path model = get_convnext_model(num_classes=num_classes) print(f"๐ŸŽฏ Using NEW ConvNeXt Large model: {model_path}") elif os.path.exists(efficientnet_model_path): model_path = efficientnet_model_path model = get_efficientnet_model(num_classes=num_classes) print(f"โš ๏ธ Using EfficientNet model: {model_path}") elif os.path.exists(old_model_path): model_path = old_model_path model = get_efficientnet_model(num_classes=num_classes) print(f"โš ๏ธ Using old fallback model: {model_path}") else: st.error(f"FATAL: No model file found. Looking for:\n- {convnext_model_path}\n- {efficientnet_model_path}\n- {old_model_path}") return None, None, None, None device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() # Store model info for UI display model_info = { 'path': model_path, 'type': 'ConvNeXt Large' if 'convnext' in model_path else 'EfficientNet-V2-S', 'accuracy': checkpoint.get('accuracy', 'Unknown') } health_data = load_json_data('health_data.json') allergen_data = load_json_data('allergen_data.json') return model, health_data, allergen_data, model_info except Exception as e: st.error(f"A critical error occurred while loading the model: {e}") return None, None, None, None model, health_data, allergen_data, model_info = load_model_resources() def transform_image(image_bytes): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") return transform(image).unsqueeze(0) def get_prediction(image_tensor): if model is None: return "Error: Model not loaded", 0.0 with torch.no_grad(): outputs = model(image_tensor.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))) probabilities = torch.nn.functional.softmax(outputs, dim=1) confidence, predicted_idx = torch.max(probabilities, 1) predicted_idx_item = predicted_idx.item() if predicted_idx_item >= len(CLASS_NAMES): return "Prediction Error: Index out of bounds.", 0.0 predicted_class = CLASS_NAMES[predicted_idx_item].replace('_', ' ').title() return predicted_class, confidence.item() # --- UI Layout --- st.markdown("""

๐Ÿฝ๏ธ EatSmart Pro

๐ŸŒŸ Your AI-Powered Nutrition Assistant ๐ŸŒŸ

๐Ÿฅ— Healthy Analysis โš ๏ธ Allergen Alerts ๐Ÿณ Smart Recipes
""", unsafe_allow_html=True) # Model Status Indicator if model_info: model_type = model_info['type'] model_accuracy = model_info['accuracy'] if 'ConvNeXt' in model_type: status_color = "#28a745" # Green for new model status_bg = "#d4edda" status_icon = "๐Ÿš€" status_text = f"HIGH ACCURACY MODEL ACTIVE" else: status_color = "#ffc107" # Yellow for old model status_bg = "#fff3cd" status_icon = "โš ๏ธ" status_text = f"FALLBACK MODEL (Training in Progress)" st.markdown(f"""
{status_icon} {status_text}
Model: {model_type} | Validation Accuracy: {model_accuracy}
""", unsafe_allow_html=True) # User Preferences Sidebar with st.sidebar: st.markdown("""

โš™๏ธ Your Preferences

Set your dietary needs

""", unsafe_allow_html=True) # Initialize session state for preferences if 'user_allergens' not in st.session_state: st.session_state.user_allergens = [] if 'avoid_trans_fat' not in st.session_state: st.session_state.avoid_trans_fat = False if 'dietary_preferences' not in st.session_state: st.session_state.dietary_preferences = [] # Allergen Preferences st.markdown("### ๐Ÿšจ Allergen Alerts") st.markdown("*Select allergens you want to be warned about:*") common_allergens = ["Gluten", "Dairy", "Egg", "Fish", "Shellfish", "Nuts", "Peanuts", "Soy", "Sesame"] for allergen in common_allergens: if st.checkbox(f"๐Ÿ›ก๏ธ {allergen}", key=f"allergen_{allergen}", value=allergen in st.session_state.user_allergens): if allergen not in st.session_state.user_allergens: st.session_state.user_allergens.append(allergen) else: if allergen in st.session_state.user_allergens: st.session_state.user_allergens.remove(allergen) # Trans Fat Preference st.markdown("### ๐Ÿงช Trans Fat Settings") st.session_state.avoid_trans_fat = st.checkbox( "โš ๏ธ Alert me about trans fats", value=st.session_state.avoid_trans_fat, help="Get warnings about foods that may contain trans fats" ) # Dietary Preferences st.markdown("### ๐ŸŒฑ Dietary Preferences") dietary_options = ["Vegetarian", "Vegan", "Keto", "Low-Carb", "High-Protein", "Gluten-Free"] for diet in dietary_options: if st.checkbox(f"๐Ÿฅฌ {diet}", key=f"diet_{diet}", value=diet in st.session_state.dietary_preferences): if diet not in st.session_state.dietary_preferences: st.session_state.dietary_preferences.append(diet) else: if diet in st.session_state.dietary_preferences: st.session_state.dietary_preferences.remove(diet) # Display current preferences summary if st.session_state.user_allergens or st.session_state.dietary_preferences or st.session_state.avoid_trans_fat: st.markdown("---") st.markdown("### ๐Ÿ“‹ Active Preferences") if st.session_state.user_allergens: st.markdown(f"๐Ÿšจ **Allergen Alerts:** {', '.join(st.session_state.user_allergens)}") if st.session_state.dietary_preferences: st.markdown(f"๐ŸŒฑ **Diet:** {', '.join(st.session_state.dietary_preferences)}") if st.session_state.avoid_trans_fat: st.markdown("๐Ÿงช **Trans Fat Alerts:** Enabled") if 'image_buffer' not in st.session_state: st.session_state.image_buffer = None if 'prediction_result' not in st.session_state: st.session_state.prediction_result = None if 'last_image_buffer' not in st.session_state: st.session_state.last_image_buffer = None col1, col2 = st.columns([1, 1.2]) with col1: st.markdown("""

๐Ÿ“ธ Upload Food Image

Drag & drop or browse to analyze

""", unsafe_allow_html=True) uploaded_file = st.file_uploader("Choose your food image", type=["jpg", "jpeg", "png"], label_visibility="collapsed") # Set the image buffer based on the file uploader's state. if uploaded_file is not None: st.session_state.image_buffer = uploaded_file.getvalue() else: st.session_state.image_buffer = None # This code block displays the image after it is uploaded. if st.session_state.image_buffer is not None: st.image(st.session_state.image_buffer, caption='๐Ÿฝ๏ธ Your Food Image', use_column_width=True) with col2: st.markdown("""

๐Ÿ”ฌ Smart Analysis & Recipes

AI-powered nutrition insights

""", unsafe_allow_html=True) if model and st.session_state.image_buffer: if st.session_state.image_buffer != st.session_state.last_image_buffer: st.session_state.last_image_buffer = st.session_state.image_buffer with st.spinner('Analyzing image...'): image_tensor = transform_image(st.session_state.image_buffer) st.session_state.prediction_result = get_prediction(image_tensor) if 'recipe' in st.session_state: del st.session_state.recipe if st.session_state.prediction_result: food_name, confidence = st.session_state.prediction_result st.metric(label="Predicted Food", value=food_name) st.progress(confidence, text=f"Confidence: {confidence:.2%}") tab1, tab2, tab3, tab4 = st.tabs(["Health Info", "Allergen Alert", "Trans Fat Analysis", "AI Recipes"]) with tab1: health_info_html = get_health_info(food_name, health_data) if "BENEFITS_SECTION" in health_info_html: # Display HTML parts st.markdown(health_info_html.replace("BENEFITS_SECTION", ""), unsafe_allow_html=True) # Display benefits using Streamlit components for proper formatting nutrition_info = health_data.get("nutrition_info", {}) food_info = nutrition_info.get(food_name.replace('_', ' ').title()) or nutrition_info.get(food_name.lower()) or health_data.get(food_name.replace('_', ' ').title()) if food_info and food_info.get("benefits"): st.markdown("### โœจ Health Benefits") for benefit in food_info.get("benefits", []): st.markdown(f"โ€ข {benefit}") else: st.markdown(health_info_html, unsafe_allow_html=True) with tab2: get_allergen_info(food_name, allergen_data) with tab3: st.markdown(get_trans_fat_analysis(food_name, health_data), unsafe_allow_html=True) with tab4: st.subheader(f"AI-Generated Recipe for {food_name}") if 'recipe' not in st.session_state or st.session_state.get('recipe_food') != food_name: with st.spinner("Chef AI is thinking of a recipe..."): st.session_state.recipe = generate_recipe(food_name) st.session_state.recipe_food = food_name if 'recipe' in st.session_state: recipe_text = st.session_state.recipe sections = {"ingredients": [], "instructions": [], "tips": []} current_section_key = None for line in recipe_text.split('\n'): line_lower = line.strip().lower() if line_lower.startswith("ingredients"): current_section_key = "ingredients" elif line_lower.startswith("instructions"): current_section_key = "instructions" elif line_lower.startswith("tips") or line_lower.startswith("chef's tips"): current_section_key = "tips" elif line.strip() and current_section_key: sections[current_section_key].append(line.strip().lstrip('*- ')) with st.expander("Ingredients", expanded=True): st.markdown("\n".join(f"- {item}" for item in sections["ingredients"]) or "No ingredients listed.") with st.expander("Instructions", expanded=True): st.markdown("\n".join(f"{i+1}. {item}" for i, item in enumerate(sections["instructions"])) or "No instructions provided.") with st.expander("Chef's Tips"): st.markdown("\n".join(f"- {item}" for item in sections["tips"]) or "No special tips provided.") elif not model: st.error("Application has failed to start. Please check the logs for errors.") else: st.info("Upload an image to get started.")