import streamlit as st import os import torch import random from PIL import Image import torchvision.transforms as transforms from transformers import ( ViTForImageClassification, AutoTokenizer, T5ForConditionalGeneration ) # Set page config st.set_page_config( page_title="🍽️ Food Nutrition Estimator", page_icon="🥗", layout="centered" ) def main(): st.title("🍽️ Food Nutrition Estimator") st.markdown(""" Upload a food image to classify it and receive a paraphrased nutritional description. ⚠️ This demo is trained on **10 food categories** only: pizza, hamburger, sushi, caesar_salad, spaghetti_bolognese, ice_cream, fried_rice, tacos, steak, chocolate_cake. """) hf_token = os.getenv("HF_TOKEN", None) cache_dir = "/tmp/cache" os.makedirs(cache_dir, exist_ok=True) os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir nutritional_info = { "pizza": {"serving": "100 g (1 slice)", "calories": "270 kcal", "protein": "12 g", "carbs": "34 g", "fat": "10 g", "ingredients": "dough, tomato sauce, mozzarella cheese", "method": "baked", "substitute": "cauliflower crust"}, "hamburger": {"serving": "150 g", "calories": "300 kcal", "protein": "20 g", "carbs": "30 g", "fat": "12 g", "ingredients": "ground beef patty, bun, lettuce, tomato", "method": "grilled or pan-fried", "substitute": "chicken patty"}, "sushi": {"serving": "150 g (6 pieces)", "calories": "200 kcal", "protein": "7 g", "carbs": "30 g", "fat": "5 g", "ingredients": "sushi rice, nori, crab (or imitation), avocado, cucumber", "method": "assembled raw", "substitute": "brown rice"}, "salad": {"serving": "200 g", "calories": "50 kcal", "protein": "2 g", "carbs": "10 g", "fat": "0.5 g", "ingredients": "mixed greens, tomato, cucumber, carrots", "method": "raw", "substitute": "vinaigrette instead of ranch"}, "pasta": {"serving": "200 g (1 cup)", "calories": "220 kcal", "protein": "7 g", "carbs": "43 g", "fat": "2 g", "ingredients": "wheat pasta, marinara sauce, olive oil", "method": "boiled and simmered", "substitute": "whole-grain pasta"}, "ice_cream": {"serving": "100 g (½ cup)", "calories": "200 kcal", "protein": "4 g", "carbs": "20 g", "fat": "12 g", "ingredients": "cream, sugar, milk, vanilla", "method": "churned and frozen", "substitute": "frozen yogurt"}, "fried_rice": {"serving": "200 g (1 cup)", "calories": "250 kcal", "protein": "8 g", "carbs": "35 g", "fat": "9 g", "ingredients": "rice, egg, peas, carrots, soy sauce, oil", "method": "stir-fried", "substitute": "brown rice"}, "tacos": {"serving": "100 g (1 taco)", "calories": "200 kcal", "protein": "10 g", "carbs": "15 g", "fat": "10 g", "ingredients": "ground beef, corn tortilla, lettuce, cheese, salsa", "method": "beef pan-fried, tortilla warmed", "substitute": "fish filling"}, "steak": {"serving": "113 g (4 oz)", "calories": "250 kcal", "protein": "25 g", "carbs": "0 g", "fat": "15 g", "ingredients": "beef sirloin, salt, pepper", "method": "grilled or pan-seared", "substitute": "leaner cut (filet mignon)"}, "chocolate_cake": {"serving": "100 g (1 slice)", "calories": "350 kcal", "protein": "5 g", "carbs": "50 g", "fat": "15 g", "ingredients": "flour, sugar, cocoa, butter, eggs", "method": "baked", "substitute": "gluten-free flour"} } label_mapping = { "caesar_salad": "salad", "spaghetti_bolognese": "pasta" } st.sidebar.header("Models Used") st.sidebar.markdown(""" - 🖼️ **Image Classifier**: shingguy1/fine_tuned_vit - 💬 **Paraphraser**: google/flan-t5-small (sampling mode) """) transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.Lambda(lambda img: img.convert("RGB")), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) @st.cache_resource def load_models(): device = torch.device("cpu") vit = ViTForImageClassification.from_pretrained( "shingguy1/fine_tuned_vit", cache_dir=cache_dir, use_auth_token=hf_token ).to(device) tok = AutoTokenizer.from_pretrained( "google/flan-t5-small", cache_dir=cache_dir, use_auth_token=hf_token ) t5 = T5ForConditionalGeneration.from_pretrained( "google/flan-t5-small", cache_dir=cache_dir, use_auth_token=hf_token ).to(device) return vit, tok, t5, device model_vit, tokenizer_t5, model_t5, device = load_models() uploaded = st.file_uploader("📷 Upload a food image...", type=["jpg", "png", "jpeg"]) if uploaded: img = Image.open(uploaded) st.image(img, caption="Your Food", use_column_width=True) inp = transform(img).unsqueeze(0).to(device) with torch.no_grad(): out = model_vit(pixel_values=inp) label = model_vit.config.id2label[out.logits.argmax(-1).item()] st.success(f"🍽️ Detected: **{label}**") true_label = label_mapping.get(label.lower(), label.lower()) data = nutritional_info.get(true_label) if data: base_description = ( f"A typical {true_label} serving ({data['serving']}) contains about {data['calories']}, " f"with {data['protein']} protein, {data['carbs']} carbs, and {data['fat']} fat. " f"Made from {data['ingredients']} and usually {data['method']}. " f"Try {data['substitute']} as a healthier swap." ) prompt = ( f"Paraphrase the following nutritional facts in a friendly, conversational tone. " f"Use varied sentence structures and synonyms, and feel free to generalize numeric details " f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n" f"{base_description}" ) else: prompt = ( f"Provide an approximate nutrition summary for {label}, including calories, " f"macronutrients, and a brief description." ) inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device) output_ids = model_t5.generate( inputs["input_ids"], max_new_tokens=100, do_sample=True, top_p=0.9, temperature=0.7, early_stopping=True ) response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True) # Fallback if the output seems too short or misses key phrases if "calories" not in response.lower() or len(response.split()) < 10: response = base_description st.subheader("🧾 Nutrition Overview") st.info(response) if __name__ == "__main__": main()