Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import torch | |
| import random | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from transformers import ( | |
| ViTForImageClassification, | |
| AutoTokenizer, | |
| T5ForConditionalGeneration | |
| ) | |
| # Set page config | |
| st.set_page_config( | |
| page_title="🍽️ Food Nutrition Estimator", | |
| page_icon="🥗", | |
| layout="centered" | |
| ) | |
| def main(): | |
| st.title("🍽️ Food Nutrition Estimator") | |
| st.markdown(""" | |
| Upload a food image to classify it and receive a paraphrased nutritional description. | |
| ⚠️ This demo is trained on **10 food categories** only: | |
| pizza, hamburger, sushi, caesar_salad, spaghetti_bolognese, | |
| ice_cream, fried_rice, tacos, steak, chocolate_cake. | |
| """) | |
| hf_token = os.getenv("HF_TOKEN", None) | |
| cache_dir = "/tmp/cache" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir | |
| nutritional_info = { | |
| "pizza": {"serving": "100 g (1 slice)", "calories": "270 kcal", "protein": "12 g", "carbs": "34 g", "fat": "10 g", "ingredients": "dough, tomato sauce, mozzarella cheese", "method": "baked", "substitute": "cauliflower crust"}, | |
| "hamburger": {"serving": "150 g", "calories": "300 kcal", "protein": "20 g", "carbs": "30 g", "fat": "12 g", "ingredients": "ground beef patty, bun, lettuce, tomato", "method": "grilled or pan-fried", "substitute": "chicken patty"}, | |
| "sushi": {"serving": "150 g (6 pieces)", "calories": "200 kcal", "protein": "7 g", "carbs": "30 g", "fat": "5 g", "ingredients": "sushi rice, nori, crab (or imitation), avocado, cucumber", "method": "assembled raw", "substitute": "brown rice"}, | |
| "salad": {"serving": "200 g", "calories": "50 kcal", "protein": "2 g", "carbs": "10 g", "fat": "0.5 g", "ingredients": "mixed greens, tomato, cucumber, carrots", "method": "raw", "substitute": "vinaigrette instead of ranch"}, | |
| "pasta": {"serving": "200 g (1 cup)", "calories": "220 kcal", "protein": "7 g", "carbs": "43 g", "fat": "2 g", "ingredients": "wheat pasta, marinara sauce, olive oil", "method": "boiled and simmered", "substitute": "whole-grain pasta"}, | |
| "ice_cream": {"serving": "100 g (½ cup)", "calories": "200 kcal", "protein": "4 g", "carbs": "20 g", "fat": "12 g", "ingredients": "cream, sugar, milk, vanilla", "method": "churned and frozen", "substitute": "frozen yogurt"}, | |
| "fried_rice": {"serving": "200 g (1 cup)", "calories": "250 kcal", "protein": "8 g", "carbs": "35 g", "fat": "9 g", "ingredients": "rice, egg, peas, carrots, soy sauce, oil", "method": "stir-fried", "substitute": "brown rice"}, | |
| "tacos": {"serving": "100 g (1 taco)", "calories": "200 kcal", "protein": "10 g", "carbs": "15 g", "fat": "10 g", "ingredients": "ground beef, corn tortilla, lettuce, cheese, salsa", "method": "beef pan-fried, tortilla warmed", "substitute": "fish filling"}, | |
| "steak": {"serving": "113 g (4 oz)", "calories": "250 kcal", "protein": "25 g", "carbs": "0 g", "fat": "15 g", "ingredients": "beef sirloin, salt, pepper", "method": "grilled or pan-seared", "substitute": "leaner cut (filet mignon)"}, | |
| "chocolate_cake": {"serving": "100 g (1 slice)", "calories": "350 kcal", "protein": "5 g", "carbs": "50 g", "fat": "15 g", "ingredients": "flour, sugar, cocoa, butter, eggs", "method": "baked", "substitute": "gluten-free flour"} | |
| } | |
| label_mapping = { | |
| "caesar_salad": "salad", | |
| "spaghetti_bolognese": "pasta" | |
| } | |
| st.sidebar.header("Models Used") | |
| st.sidebar.markdown(""" | |
| - 🖼️ **Image Classifier**: shingguy1/fine_tuned_vit | |
| - 💬 **Paraphraser**: google/flan-t5-small (sampling mode) | |
| """) | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.Lambda(lambda img: img.convert("RGB")), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def load_models(): | |
| device = torch.device("cpu") | |
| vit = ViTForImageClassification.from_pretrained( | |
| "shingguy1/fine_tuned_vit", | |
| cache_dir=cache_dir, | |
| use_auth_token=hf_token | |
| ).to(device) | |
| tok = AutoTokenizer.from_pretrained( | |
| "google/flan-t5-small", | |
| cache_dir=cache_dir, | |
| use_auth_token=hf_token | |
| ) | |
| t5 = T5ForConditionalGeneration.from_pretrained( | |
| "google/flan-t5-small", | |
| cache_dir=cache_dir, | |
| use_auth_token=hf_token | |
| ).to(device) | |
| return vit, tok, t5, device | |
| model_vit, tokenizer_t5, model_t5, device = load_models() | |
| uploaded = st.file_uploader("📷 Upload a food image...", type=["jpg", "png", "jpeg"]) | |
| if uploaded: | |
| img = Image.open(uploaded) | |
| st.image(img, caption="Your Food", use_column_width=True) | |
| inp = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| out = model_vit(pixel_values=inp) | |
| label = model_vit.config.id2label[out.logits.argmax(-1).item()] | |
| st.success(f"🍽️ Detected: **{label}**") | |
| true_label = label_mapping.get(label.lower(), label.lower()) | |
| data = nutritional_info.get(true_label) | |
| if data: | |
| base_description = ( | |
| f"A typical {true_label} serving ({data['serving']}) contains about {data['calories']}, " | |
| f"with {data['protein']} protein, {data['carbs']} carbs, and {data['fat']} fat. " | |
| f"Made from {data['ingredients']} and usually {data['method']}. " | |
| f"Try {data['substitute']} as a healthier swap." | |
| ) | |
| prompt = ( | |
| f"Paraphrase the following nutritional facts in a friendly, conversational tone. " | |
| f"Use varied sentence structures and synonyms, and feel free to generalize numeric details " | |
| f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n" | |
| f"{base_description}" | |
| ) | |
| else: | |
| prompt = ( | |
| f"Provide an approximate nutrition summary for {label}, including calories, " | |
| f"macronutrients, and a brief description." | |
| ) | |
| inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device) | |
| output_ids = model_t5.generate( | |
| inputs["input_ids"], | |
| max_new_tokens=100, | |
| do_sample=True, | |
| top_p=0.9, | |
| temperature=0.7, | |
| early_stopping=True | |
| ) | |
| response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True) | |
| # Fallback if the output seems too short or misses key phrases | |
| if "calories" not in response.lower() or len(response.split()) < 10: | |
| response = base_description | |
| st.subheader("🧾 Nutrition Overview") | |
| st.info(response) | |
| if __name__ == "__main__": | |
| main() |