import streamlit as st
from PIL import Image
import torch
from torchvision import models
import torch.nn as nn
import torchvision.transforms as transforms
import io
import os
import openai
import json
import timm # For ConvNeXt Large model
# --- Page Configuration ---
st.set_page_config(page_title="EatSmart Pro", page_icon="๐ฝ๏ธ", layout="wide")
# Mobile menu indicator
st.markdown("""
๐ฑ Mobile Users: Click the > arrow (top-left) to open preferences menu
""", unsafe_allow_html=True)
# ==============================================================================
# The correct, full list of 101 class names from your training script.
# ==============================================================================
CLASS_NAMES = [
'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla',
'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder',
'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes',
'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict',
'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras',
'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice',
'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich',
'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup',
'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna',
'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup',
'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters',
'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck',
'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib',
'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto',
'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits',
'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake',
'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare',
'waffles'
]
# ==============================================================================
# ALL HELPER FUNCTIONS ARE NOW INSIDE A SINGLE APP.PY
# ==============================================================================
def get_convnext_model(num_classes):
"""
Creates the ConvNeXt Large model architecture,
matching the training script for maximum accuracy.
"""
print(f"๐ Loading ConvNeXt Large model for inference...")
print(f"๐ Model: ConvNeXt Large (197M parameters)")
# Create ConvNeXt Large model (same as training script)
model = timm.create_model('convnext_large.fb_in22k_ft_in1k', pretrained=True, num_classes=num_classes)
# Model statistics for user feedback
total_params = sum(p.numel() for p in model.parameters())
print(f"โ
Model architecture loaded:")
print(f" Total parameters: {total_params:,}")
print(f" Model size: ~{total_params * 4 / 1024**2:.1f} MB")
return model
def get_efficientnet_model(num_classes):
"""
Creates the old EfficientNet-V2-S model architecture for fallback.
"""
print(f"โ ๏ธ Loading EfficientNet-V2-S model (fallback)...")
model = models.efficientnet_v2_s(weights=None)
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(num_features, 256),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Linear(256, num_classes)
)
return model
def load_json_data(path):
if not os.path.exists(path): return {}
try:
with open(path, 'r') as f: return json.load(f)
except (FileNotFoundError, json.JSONDecodeError): return {}
def get_health_info(food_name, health_data):
"""Enhanced health information display with comprehensive nutritional data"""
food_name_key = food_name.replace('_', ' ').title()
# Try to get from the nested nutrition_info structure first
nutrition_info = health_data.get("nutrition_info", {})
info = nutrition_info.get(food_name_key) or nutrition_info.get(food_name.lower())
# Try direct lookup if nested lookup fails
if not info:
info = health_data.get(food_name_key)
if not info:
return "No specific health information available for this dish.
"
# Get health score
health_scores = health_data.get("health_scores", {})
health_score = health_scores.get(food_name.lower(), {})
score = health_score.get("score", 75)
score_explanation = health_score.get("explanation", "Good")
# Determine score color
if score >= 80:
score_color = "#28a745" # Green
score_bg = "#d4edda"
elif score >= 60:
score_color = "#ffc107" # Yellow
score_bg = "#fff3cd"
else:
score_color = "#dc3545" # Red
score_bg = "#f8d7da"
# Build nutritional information
nutrition_html = f"""
๐ Nutritional Information
Calories
{info.get("calories", "N/A")}
kcal
Protein
{info.get("protein", "N/A")}
g
Carbs
{info.get("carbs", "N/A")}
g
Fat
{info.get("fat", "N/A")}
g
"""
# Health Score
score_html = f"""
๐ฅ Health Assessment
Health Score: {score}/100 ({score_explanation})
"""
# Health Benefits - Display using Streamlit components instead of HTML
benefits_section = ""
if info.get("benefits"):
# We'll handle benefits separately using st.markdown with proper formatting
benefits_section = "BENEFITS_SECTION" # Placeholder to indicate benefits exist
return nutrition_html + score_html + benefits_section
def get_allergen_info(food_name, allergen_data):
"""Enhanced allergen information using sidebar preferences"""
food_name_key = food_name.replace('_', ' ').title()
allergens = allergen_data.get(food_name_key, [])
# Display allergen information for the detected food
if allergens:
# Check for user-specified allergen matches
user_allergen_matches = [a for a in allergens if a in st.session_state.user_allergens]
if user_allergen_matches:
# High priority alert for user's allergens
st.error(f"๐จ **CRITICAL ALLERGEN ALERT**: This dish contains **{', '.join(user_allergen_matches)}** which you've marked as allergens to avoid!")
# Display allergens in a more user-friendly way
st.markdown("### โ ๏ธ Allergens Detected in This Food")
# Create columns for allergen badges
cols = st.columns(min(len(allergens), 4))
for i, allergen in enumerate(allergens):
with cols[i % len(cols)]:
if allergen in st.session_state.user_allergens:
# User's marked allergens - show as error
st.error(f"๐จ {allergen}")
else:
# Other allergens - show as info
st.info(f"โน๏ธ {allergen}")
# Add explanation
if user_allergen_matches:
st.markdown("---")
st.markdown("๐ด **Red alerts** are for allergens you've marked in your preferences")
else:
st.markdown("---")
st.markdown("๐ **Blue badges** show allergens present in this food")
else:
st.success("โ
No common allergens typically found in this dish.")
def get_trans_fat_analysis(food_name, health_data):
"""Enhanced trans fat analysis with user preferences"""
food_name_lower = food_name.lower().replace('_', ' ')
# Get trans fat ingredients from health data
trans_fat_ingredients = health_data.get("trans_fat_ingredients", [
"hydrogenated oil", "partially hydrogenated oil", "margarine", "shortening"
])
# Foods that commonly contain trans fats
high_trans_fat_foods = [
"donuts", "french_fries", "fried", "margarine", "shortening",
"processed", "packaged", "fast food", "baked goods", "cookies", "crackers"
]
# Check if food likely contains trans fats
likely_trans_fat = any(ingredient in food_name_lower for ingredient in high_trans_fat_foods)
if likely_trans_fat:
alert_level = "HIGH RISK" if st.session_state.avoid_trans_fat else "POTENTIAL RISK"
if st.session_state.avoid_trans_fat:
st.error(f"๐จ **TRANS FAT ALERT**: You've enabled trans fat warnings and this food may contain trans fats!")
trans_fat_html = f"""
๐งช Trans Fat Analysis - {alert_level}
โ ๏ธ This food may contain trans fats from:
{"".join([f'{ingredient.title()}' for ingredient in trans_fat_ingredients])}
๐ก Recommendation: Check ingredient labels and consider healthier alternatives
"""
else:
trans_fat_html = f"""
๐งช Trans Fat Analysis - LOW RISK
โ
Great Choice!
This food typically contains minimal or no artificial trans fats.
Keep enjoying healthy foods like this! ๐
"""
return trans_fat_html
def generate_recipe(food_name):
"""Enhanced recipe generation based on detected ingredients and user preferences"""
try:
api_key = st.secrets.get("OPENAI_API_KEY")
if not api_key:
return "Error: OPENAI_API_KEY not found."
openai.api_key = api_key
# Build prompt based on user preferences
dietary_restrictions = ""
if st.session_state.dietary_preferences:
dietary_restrictions = f"Make it {', '.join(st.session_state.dietary_preferences).lower()}. "
allergen_restrictions = ""
if st.session_state.user_allergens:
allergen_restrictions = f"Avoid using {', '.join(st.session_state.user_allergens).lower()}. "
trans_fat_note = ""
if st.session_state.avoid_trans_fat:
trans_fat_note = "Use healthy oils and avoid trans fats. "
prompt = f"""Create a recipe for {food_name}. {dietary_restrictions}{allergen_restrictions}{trans_fat_note}
Format the response with these exact headings:
Ingredients:
Instructions:
Chef's Tips:
Make it practical for home cooking and include nutritional benefits."""
response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a professional chef who creates healthy, personalized recipes based on dietary preferences and restrictions."},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=600
)
recipe_text = response.choices[0].message.content
return recipe_text.replace("**Ingredients:**", "Ingredients:").replace("**Instructions:**", "Instructions:").replace("**Chef's Tips:**", "Chef's Tips:")
except Exception as e:
return f"The AI Chef is busy right now. Error: {str(e)}"
@st.cache_resource
def load_model_resources():
try:
num_classes = len(CLASS_NAMES)
# Look for the new ConvNeXt Large model first, fallback to old models
convnext_model_path = 'models/food_classifier_convnext_large_cpu_full.pth'
efficientnet_model_path = 'models/food101_efficientnet_best.pth'
old_model_path = 'models/final_model.pth'
if os.path.exists(convnext_model_path):
model_path = convnext_model_path
model = get_convnext_model(num_classes=num_classes)
print(f"๐ฏ Using NEW ConvNeXt Large model: {model_path}")
elif os.path.exists(efficientnet_model_path):
model_path = efficientnet_model_path
model = get_efficientnet_model(num_classes=num_classes)
print(f"โ ๏ธ Using EfficientNet model: {model_path}")
elif os.path.exists(old_model_path):
model_path = old_model_path
model = get_efficientnet_model(num_classes=num_classes)
print(f"โ ๏ธ Using old fallback model: {model_path}")
else:
st.error(f"FATAL: No model file found. Looking for:\n- {convnext_model_path}\n- {efficientnet_model_path}\n- {old_model_path}")
return None, None, None, None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
# Store model info for UI display
model_info = {
'path': model_path,
'type': 'ConvNeXt Large' if 'convnext' in model_path else 'EfficientNet-V2-S',
'accuracy': checkpoint.get('accuracy', 'Unknown')
}
health_data = load_json_data('health_data.json')
allergen_data = load_json_data('allergen_data.json')
return model, health_data, allergen_data, model_info
except Exception as e:
st.error(f"A critical error occurred while loading the model: {e}")
return None, None, None, None
model, health_data, allergen_data, model_info = load_model_resources()
def transform_image(image_bytes):
transform = transforms.Compose([
transforms.Resize((224, 224)), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return transform(image).unsqueeze(0)
def get_prediction(image_tensor):
if model is None: return "Error: Model not loaded", 0.0
with torch.no_grad():
outputs = model(image_tensor.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
probabilities = torch.nn.functional.softmax(outputs, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
predicted_idx_item = predicted_idx.item()
if predicted_idx_item >= len(CLASS_NAMES):
return "Prediction Error: Index out of bounds.", 0.0
predicted_class = CLASS_NAMES[predicted_idx_item].replace('_', ' ').title()
return predicted_class, confidence.item()
# --- UI Layout ---
st.markdown("""
๐ฝ๏ธ EatSmart
Pro
๐ Your AI-Powered
Nutrition Assistant ๐
๐ฅ Healthy Analysis
โ ๏ธ Allergen Alerts
๐ณ Smart Recipes
""", unsafe_allow_html=True)
# Model Status Indicator
if model_info:
model_type = model_info['type']
model_accuracy = model_info['accuracy']
if 'ConvNeXt' in model_type:
status_color = "#28a745" # Green for new model
status_bg = "#d4edda"
status_icon = "๐"
status_text = f"HIGH ACCURACY MODEL ACTIVE"
else:
status_color = "#ffc107" # Yellow for old model
status_bg = "#fff3cd"
status_icon = "โ ๏ธ"
status_text = f"FALLBACK MODEL (Training in Progress)"
st.markdown(f"""
{status_icon} {status_text}
Model: {model_type} | Validation Accuracy: {model_accuracy}
""", unsafe_allow_html=True)
# User Preferences Sidebar
with st.sidebar:
st.markdown("""
โ๏ธ Your Preferences
Set your dietary needs
""", unsafe_allow_html=True)
# Initialize session state for preferences
if 'user_allergens' not in st.session_state:
st.session_state.user_allergens = []
if 'avoid_trans_fat' not in st.session_state:
st.session_state.avoid_trans_fat = False
if 'dietary_preferences' not in st.session_state:
st.session_state.dietary_preferences = []
# Allergen Preferences
st.markdown("### ๐จ Allergen Alerts")
st.markdown("*Select allergens you want to be warned about:*")
common_allergens = ["Gluten", "Dairy", "Egg", "Fish", "Shellfish", "Nuts", "Peanuts", "Soy", "Sesame"]
for allergen in common_allergens:
if st.checkbox(f"๐ก๏ธ {allergen}", key=f"allergen_{allergen}",
value=allergen in st.session_state.user_allergens):
if allergen not in st.session_state.user_allergens:
st.session_state.user_allergens.append(allergen)
else:
if allergen in st.session_state.user_allergens:
st.session_state.user_allergens.remove(allergen)
# Trans Fat Preference
st.markdown("### ๐งช Trans Fat Settings")
st.session_state.avoid_trans_fat = st.checkbox(
"โ ๏ธ Alert me about trans fats",
value=st.session_state.avoid_trans_fat,
help="Get warnings about foods that may contain trans fats"
)
# Dietary Preferences
st.markdown("### ๐ฑ Dietary Preferences")
dietary_options = ["Vegetarian", "Vegan", "Keto", "Low-Carb", "High-Protein", "Gluten-Free"]
for diet in dietary_options:
if st.checkbox(f"๐ฅฌ {diet}", key=f"diet_{diet}",
value=diet in st.session_state.dietary_preferences):
if diet not in st.session_state.dietary_preferences:
st.session_state.dietary_preferences.append(diet)
else:
if diet in st.session_state.dietary_preferences:
st.session_state.dietary_preferences.remove(diet)
# Display current preferences summary
if st.session_state.user_allergens or st.session_state.dietary_preferences or st.session_state.avoid_trans_fat:
st.markdown("---")
st.markdown("### ๐ Active Preferences")
if st.session_state.user_allergens:
st.markdown(f"๐จ **Allergen Alerts:** {', '.join(st.session_state.user_allergens)}")
if st.session_state.dietary_preferences:
st.markdown(f"๐ฑ **Diet:** {', '.join(st.session_state.dietary_preferences)}")
if st.session_state.avoid_trans_fat:
st.markdown("๐งช **Trans Fat Alerts:** Enabled")
if 'image_buffer' not in st.session_state: st.session_state.image_buffer = None
if 'prediction_result' not in st.session_state: st.session_state.prediction_result = None
if 'last_image_buffer' not in st.session_state: st.session_state.last_image_buffer = None
col1, col2 = st.columns([1, 1.2])
with col1:
st.markdown("""
๐ธ Upload Food Image
Drag & drop or browse to analyze
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader("Choose your food image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
# Set the image buffer based on the file uploader's state.
if uploaded_file is not None:
st.session_state.image_buffer = uploaded_file.getvalue()
else:
st.session_state.image_buffer = None
# This code block displays the image after it is uploaded.
if st.session_state.image_buffer is not None:
st.image(st.session_state.image_buffer, caption='๐ฝ๏ธ Your Food Image', use_column_width=True)
with col2:
st.markdown("""
๐ฌ Smart Analysis & Recipes
AI-powered nutrition insights
""", unsafe_allow_html=True)
if model and st.session_state.image_buffer:
if st.session_state.image_buffer != st.session_state.last_image_buffer:
st.session_state.last_image_buffer = st.session_state.image_buffer
with st.spinner('Analyzing image...'):
image_tensor = transform_image(st.session_state.image_buffer)
st.session_state.prediction_result = get_prediction(image_tensor)
if 'recipe' in st.session_state: del st.session_state.recipe
if st.session_state.prediction_result:
food_name, confidence = st.session_state.prediction_result
st.metric(label="Predicted Food", value=food_name)
st.progress(confidence, text=f"Confidence: {confidence:.2%}")
tab1, tab2, tab3, tab4 = st.tabs(["Health Info", "Allergen Alert", "Trans Fat Analysis", "AI Recipes"])
with tab1:
health_info_html = get_health_info(food_name, health_data)
if "BENEFITS_SECTION" in health_info_html:
# Display HTML parts
st.markdown(health_info_html.replace("BENEFITS_SECTION", ""), unsafe_allow_html=True)
# Display benefits using Streamlit components for proper formatting
nutrition_info = health_data.get("nutrition_info", {})
food_info = nutrition_info.get(food_name.replace('_', ' ').title()) or nutrition_info.get(food_name.lower()) or health_data.get(food_name.replace('_', ' ').title())
if food_info and food_info.get("benefits"):
st.markdown("### โจ Health Benefits")
for benefit in food_info.get("benefits", []):
st.markdown(f"โข {benefit}")
else:
st.markdown(health_info_html, unsafe_allow_html=True)
with tab2:
get_allergen_info(food_name, allergen_data)
with tab3:
st.markdown(get_trans_fat_analysis(food_name, health_data), unsafe_allow_html=True)
with tab4:
st.subheader(f"AI-Generated Recipe for {food_name}")
if 'recipe' not in st.session_state or st.session_state.get('recipe_food') != food_name:
with st.spinner("Chef AI is thinking of a recipe..."):
st.session_state.recipe = generate_recipe(food_name)
st.session_state.recipe_food = food_name
if 'recipe' in st.session_state:
recipe_text = st.session_state.recipe
sections = {"ingredients": [], "instructions": [], "tips": []}
current_section_key = None
for line in recipe_text.split('\n'):
line_lower = line.strip().lower()
if line_lower.startswith("ingredients"): current_section_key = "ingredients"
elif line_lower.startswith("instructions"): current_section_key = "instructions"
elif line_lower.startswith("tips") or line_lower.startswith("chef's tips"): current_section_key = "tips"
elif line.strip() and current_section_key: sections[current_section_key].append(line.strip().lstrip('*- '))
with st.expander("Ingredients", expanded=True): st.markdown("\n".join(f"- {item}" for item in sections["ingredients"]) or "No ingredients listed.")
with st.expander("Instructions", expanded=True): st.markdown("\n".join(f"{i+1}. {item}" for i, item in enumerate(sections["instructions"])) or "No instructions provided.")
with st.expander("Chef's Tips"): st.markdown("\n".join(f"- {item}" for item in sections["tips"]) or "No special tips provided.")
elif not model:
st.error("Application has failed to start. Please check the logs for errors.")
else:
st.info("Upload an image to get started.")