Calorie_Estimator / src /streamlit_app.py
shingguy1's picture
Update src/streamlit_app.py
6ede4a3 verified
import streamlit as st
import os
import torch
import random
from PIL import Image
import torchvision.transforms as transforms
from transformers import (
ViTForImageClassification,
AutoTokenizer,
T5ForConditionalGeneration
)
# Set page config
st.set_page_config(
page_title="🍽️ Food Nutrition Estimator",
page_icon="🥗",
layout="centered"
)
def main():
st.title("🍽️ Food Nutrition Estimator")
st.markdown("""
Upload a food image to classify it and receive a paraphrased nutritional description.
⚠️ This demo is trained on **10 food categories** only:
pizza, hamburger, sushi, caesar_salad, spaghetti_bolognese,
ice_cream, fried_rice, tacos, steak, chocolate_cake.
""")
hf_token = os.getenv("HF_TOKEN", None)
cache_dir = "/tmp/cache"
os.makedirs(cache_dir, exist_ok=True)
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
nutritional_info = {
"pizza": {"serving": "100 g (1 slice)", "calories": "270 kcal", "protein": "12 g", "carbs": "34 g", "fat": "10 g", "ingredients": "dough, tomato sauce, mozzarella cheese", "method": "baked", "substitute": "cauliflower crust"},
"hamburger": {"serving": "150 g", "calories": "300 kcal", "protein": "20 g", "carbs": "30 g", "fat": "12 g", "ingredients": "ground beef patty, bun, lettuce, tomato", "method": "grilled or pan-fried", "substitute": "chicken patty"},
"sushi": {"serving": "150 g (6 pieces)", "calories": "200 kcal", "protein": "7 g", "carbs": "30 g", "fat": "5 g", "ingredients": "sushi rice, nori, crab (or imitation), avocado, cucumber", "method": "assembled raw", "substitute": "brown rice"},
"salad": {"serving": "200 g", "calories": "50 kcal", "protein": "2 g", "carbs": "10 g", "fat": "0.5 g", "ingredients": "mixed greens, tomato, cucumber, carrots", "method": "raw", "substitute": "vinaigrette instead of ranch"},
"pasta": {"serving": "200 g (1 cup)", "calories": "220 kcal", "protein": "7 g", "carbs": "43 g", "fat": "2 g", "ingredients": "wheat pasta, marinara sauce, olive oil", "method": "boiled and simmered", "substitute": "whole-grain pasta"},
"ice_cream": {"serving": "100 g (½ cup)", "calories": "200 kcal", "protein": "4 g", "carbs": "20 g", "fat": "12 g", "ingredients": "cream, sugar, milk, vanilla", "method": "churned and frozen", "substitute": "frozen yogurt"},
"fried_rice": {"serving": "200 g (1 cup)", "calories": "250 kcal", "protein": "8 g", "carbs": "35 g", "fat": "9 g", "ingredients": "rice, egg, peas, carrots, soy sauce, oil", "method": "stir-fried", "substitute": "brown rice"},
"tacos": {"serving": "100 g (1 taco)", "calories": "200 kcal", "protein": "10 g", "carbs": "15 g", "fat": "10 g", "ingredients": "ground beef, corn tortilla, lettuce, cheese, salsa", "method": "beef pan-fried, tortilla warmed", "substitute": "fish filling"},
"steak": {"serving": "113 g (4 oz)", "calories": "250 kcal", "protein": "25 g", "carbs": "0 g", "fat": "15 g", "ingredients": "beef sirloin, salt, pepper", "method": "grilled or pan-seared", "substitute": "leaner cut (filet mignon)"},
"chocolate_cake": {"serving": "100 g (1 slice)", "calories": "350 kcal", "protein": "5 g", "carbs": "50 g", "fat": "15 g", "ingredients": "flour, sugar, cocoa, butter, eggs", "method": "baked", "substitute": "gluten-free flour"}
}
label_mapping = {
"caesar_salad": "salad",
"spaghetti_bolognese": "pasta"
}
st.sidebar.header("Models Used")
st.sidebar.markdown("""
- 🖼️ **Image Classifier**: shingguy1/fine_tuned_vit
- 💬 **Paraphraser**: google/flan-t5-small (sampling mode)
""")
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Lambda(lambda img: img.convert("RGB")),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
@st.cache_resource
def load_models():
device = torch.device("cpu")
vit = ViTForImageClassification.from_pretrained(
"shingguy1/fine_tuned_vit",
cache_dir=cache_dir,
use_auth_token=hf_token
).to(device)
tok = AutoTokenizer.from_pretrained(
"google/flan-t5-small",
cache_dir=cache_dir,
use_auth_token=hf_token
)
t5 = T5ForConditionalGeneration.from_pretrained(
"google/flan-t5-small",
cache_dir=cache_dir,
use_auth_token=hf_token
).to(device)
return vit, tok, t5, device
model_vit, tokenizer_t5, model_t5, device = load_models()
uploaded = st.file_uploader("📷 Upload a food image...", type=["jpg", "png", "jpeg"])
if uploaded:
img = Image.open(uploaded)
st.image(img, caption="Your Food", use_column_width=True)
inp = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
out = model_vit(pixel_values=inp)
label = model_vit.config.id2label[out.logits.argmax(-1).item()]
st.success(f"🍽️ Detected: **{label}**")
true_label = label_mapping.get(label.lower(), label.lower())
data = nutritional_info.get(true_label)
if data:
base_description = (
f"A typical {true_label} serving ({data['serving']}) contains about {data['calories']}, "
f"with {data['protein']} protein, {data['carbs']} carbs, and {data['fat']} fat. "
f"Made from {data['ingredients']} and usually {data['method']}. "
f"Try {data['substitute']} as a healthier swap."
)
prompt = (
f"Paraphrase the following nutritional facts in a friendly, conversational tone. "
f"Use varied sentence structures and synonyms, and feel free to generalize numeric details "
f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n"
f"{base_description}"
)
else:
prompt = (
f"Provide an approximate nutrition summary for {label}, including calories, "
f"macronutrients, and a brief description."
)
inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
output_ids = model_t5.generate(
inputs["input_ids"],
max_new_tokens=100,
do_sample=True,
top_p=0.9,
temperature=0.7,
early_stopping=True
)
response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)
# Fallback if the output seems too short or misses key phrases
if "calories" not in response.lower() or len(response.split()) < 10:
response = base_description
st.subheader("🧾 Nutrition Overview")
st.info(response)
if __name__ == "__main__":
main()