Spaces:
Sleeping
Sleeping
File size: 6,961 Bytes
fda8905 45a1fc8 df9e1b3 404eb80 45a1fc8 cdfccf9 9129b6f dbca709 45a1fc8 fda8905 6ede4a3 a965712 6ede4a3 a965712 6ede4a3 4840f3d 6ede4a3 2548991 749ea77 df9e1b3 4840f3d a965712 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 6ede4a3 4840f3d 749ea77 6ede4a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import streamlit as st
import os
import torch
import random
from PIL import Image
import torchvision.transforms as transforms
from transformers import (
ViTForImageClassification,
AutoTokenizer,
T5ForConditionalGeneration
)
# Set page config
st.set_page_config(
page_title="🍽️ Food Nutrition Estimator",
page_icon="🥗",
layout="centered"
)
def main():
st.title("🍽️ Food Nutrition Estimator")
st.markdown("""
Upload a food image to classify it and receive a paraphrased nutritional description.
⚠️ This demo is trained on **10 food categories** only:
pizza, hamburger, sushi, caesar_salad, spaghetti_bolognese,
ice_cream, fried_rice, tacos, steak, chocolate_cake.
""")
hf_token = os.getenv("HF_TOKEN", None)
cache_dir = "/tmp/cache"
os.makedirs(cache_dir, exist_ok=True)
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
nutritional_info = {
"pizza": {"serving": "100 g (1 slice)", "calories": "270 kcal", "protein": "12 g", "carbs": "34 g", "fat": "10 g", "ingredients": "dough, tomato sauce, mozzarella cheese", "method": "baked", "substitute": "cauliflower crust"},
"hamburger": {"serving": "150 g", "calories": "300 kcal", "protein": "20 g", "carbs": "30 g", "fat": "12 g", "ingredients": "ground beef patty, bun, lettuce, tomato", "method": "grilled or pan-fried", "substitute": "chicken patty"},
"sushi": {"serving": "150 g (6 pieces)", "calories": "200 kcal", "protein": "7 g", "carbs": "30 g", "fat": "5 g", "ingredients": "sushi rice, nori, crab (or imitation), avocado, cucumber", "method": "assembled raw", "substitute": "brown rice"},
"salad": {"serving": "200 g", "calories": "50 kcal", "protein": "2 g", "carbs": "10 g", "fat": "0.5 g", "ingredients": "mixed greens, tomato, cucumber, carrots", "method": "raw", "substitute": "vinaigrette instead of ranch"},
"pasta": {"serving": "200 g (1 cup)", "calories": "220 kcal", "protein": "7 g", "carbs": "43 g", "fat": "2 g", "ingredients": "wheat pasta, marinara sauce, olive oil", "method": "boiled and simmered", "substitute": "whole-grain pasta"},
"ice_cream": {"serving": "100 g (½ cup)", "calories": "200 kcal", "protein": "4 g", "carbs": "20 g", "fat": "12 g", "ingredients": "cream, sugar, milk, vanilla", "method": "churned and frozen", "substitute": "frozen yogurt"},
"fried_rice": {"serving": "200 g (1 cup)", "calories": "250 kcal", "protein": "8 g", "carbs": "35 g", "fat": "9 g", "ingredients": "rice, egg, peas, carrots, soy sauce, oil", "method": "stir-fried", "substitute": "brown rice"},
"tacos": {"serving": "100 g (1 taco)", "calories": "200 kcal", "protein": "10 g", "carbs": "15 g", "fat": "10 g", "ingredients": "ground beef, corn tortilla, lettuce, cheese, salsa", "method": "beef pan-fried, tortilla warmed", "substitute": "fish filling"},
"steak": {"serving": "113 g (4 oz)", "calories": "250 kcal", "protein": "25 g", "carbs": "0 g", "fat": "15 g", "ingredients": "beef sirloin, salt, pepper", "method": "grilled or pan-seared", "substitute": "leaner cut (filet mignon)"},
"chocolate_cake": {"serving": "100 g (1 slice)", "calories": "350 kcal", "protein": "5 g", "carbs": "50 g", "fat": "15 g", "ingredients": "flour, sugar, cocoa, butter, eggs", "method": "baked", "substitute": "gluten-free flour"}
}
label_mapping = {
"caesar_salad": "salad",
"spaghetti_bolognese": "pasta"
}
st.sidebar.header("Models Used")
st.sidebar.markdown("""
- 🖼️ **Image Classifier**: shingguy1/fine_tuned_vit
- 💬 **Paraphraser**: google/flan-t5-small (sampling mode)
""")
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Lambda(lambda img: img.convert("RGB")),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
@st.cache_resource
def load_models():
device = torch.device("cpu")
vit = ViTForImageClassification.from_pretrained(
"shingguy1/fine_tuned_vit",
cache_dir=cache_dir,
use_auth_token=hf_token
).to(device)
tok = AutoTokenizer.from_pretrained(
"google/flan-t5-small",
cache_dir=cache_dir,
use_auth_token=hf_token
)
t5 = T5ForConditionalGeneration.from_pretrained(
"google/flan-t5-small",
cache_dir=cache_dir,
use_auth_token=hf_token
).to(device)
return vit, tok, t5, device
model_vit, tokenizer_t5, model_t5, device = load_models()
uploaded = st.file_uploader("📷 Upload a food image...", type=["jpg", "png", "jpeg"])
if uploaded:
img = Image.open(uploaded)
st.image(img, caption="Your Food", use_column_width=True)
inp = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
out = model_vit(pixel_values=inp)
label = model_vit.config.id2label[out.logits.argmax(-1).item()]
st.success(f"🍽️ Detected: **{label}**")
true_label = label_mapping.get(label.lower(), label.lower())
data = nutritional_info.get(true_label)
if data:
base_description = (
f"A typical {true_label} serving ({data['serving']}) contains about {data['calories']}, "
f"with {data['protein']} protein, {data['carbs']} carbs, and {data['fat']} fat. "
f"Made from {data['ingredients']} and usually {data['method']}. "
f"Try {data['substitute']} as a healthier swap."
)
prompt = (
f"Paraphrase the following nutritional facts in a friendly, conversational tone. "
f"Use varied sentence structures and synonyms, and feel free to generalize numeric details "
f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n"
f"{base_description}"
)
else:
prompt = (
f"Provide an approximate nutrition summary for {label}, including calories, "
f"macronutrients, and a brief description."
)
inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
output_ids = model_t5.generate(
inputs["input_ids"],
max_new_tokens=100,
do_sample=True,
top_p=0.9,
temperature=0.7,
early_stopping=True
)
response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)
# Fallback if the output seems too short or misses key phrases
if "calories" not in response.lower() or len(response.split()) < 10:
response = base_description
st.subheader("🧾 Nutrition Overview")
st.info(response)
if __name__ == "__main__":
main() |