File size: 6,961 Bytes
fda8905
45a1fc8
df9e1b3
 
404eb80
45a1fc8
 
cdfccf9
9129b6f
dbca709
45a1fc8
fda8905
6ede4a3
a965712
6ede4a3
 
a965712
 
 
6ede4a3
 
 
 
4840f3d
6ede4a3
 
 
 
2548991
749ea77
 
 
 
 
df9e1b3
4840f3d
 
 
 
 
 
 
 
 
 
a965712
6ede4a3
 
 
 
 
 
 
 
 
 
 
 
4840f3d
 
 
 
 
6ede4a3
 
4840f3d
 
 
 
 
 
6ede4a3
 
 
4840f3d
 
6ede4a3
 
 
4840f3d
 
6ede4a3
 
 
4840f3d
 
 
 
 
6ede4a3
4840f3d
 
 
6ede4a3
4840f3d
6ede4a3
 
4840f3d
 
 
 
 
6ede4a3
4840f3d
 
 
 
 
 
 
 
 
 
6ede4a3
 
4840f3d
 
6ede4a3
 
 
 
4840f3d
 
6ede4a3
 
 
 
 
 
 
 
4840f3d
6ede4a3
 
 
 
4840f3d
 
 
749ea77
6ede4a3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import streamlit as st
import os
import torch
import random
from PIL import Image
import torchvision.transforms as transforms
from transformers import (
    ViTForImageClassification,
    AutoTokenizer,
    T5ForConditionalGeneration
)

# Set page config
st.set_page_config(
    page_title="🍽️ Food Nutrition Estimator",
    page_icon="🥗",
    layout="centered"
)

def main():
    st.title("🍽️ Food Nutrition Estimator")
    st.markdown("""
    Upload a food image to classify it and receive a paraphrased nutritional description.

    ⚠️ This demo is trained on **10 food categories** only:
    pizza, hamburger, sushi, caesar_salad, spaghetti_bolognese,
    ice_cream, fried_rice, tacos, steak, chocolate_cake.
    """)

    hf_token = os.getenv("HF_TOKEN", None)
    cache_dir = "/tmp/cache"
    os.makedirs(cache_dir, exist_ok=True)
    os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir

    nutritional_info = {
        "pizza": {"serving": "100 g (1 slice)", "calories": "270 kcal", "protein": "12 g", "carbs": "34 g", "fat": "10 g", "ingredients": "dough, tomato sauce, mozzarella cheese", "method": "baked", "substitute": "cauliflower crust"},
        "hamburger": {"serving": "150 g", "calories": "300 kcal", "protein": "20 g", "carbs": "30 g", "fat": "12 g", "ingredients": "ground beef patty, bun, lettuce, tomato", "method": "grilled or pan-fried", "substitute": "chicken patty"},
        "sushi": {"serving": "150 g (6 pieces)", "calories": "200 kcal", "protein": "7 g", "carbs": "30 g", "fat": "5 g", "ingredients": "sushi rice, nori, crab (or imitation), avocado, cucumber", "method": "assembled raw", "substitute": "brown rice"},
        "salad": {"serving": "200 g", "calories": "50 kcal", "protein": "2 g", "carbs": "10 g", "fat": "0.5 g", "ingredients": "mixed greens, tomato, cucumber, carrots", "method": "raw", "substitute": "vinaigrette instead of ranch"},
        "pasta": {"serving": "200 g (1 cup)", "calories": "220 kcal", "protein": "7 g", "carbs": "43 g", "fat": "2 g", "ingredients": "wheat pasta, marinara sauce, olive oil", "method": "boiled and simmered", "substitute": "whole-grain pasta"},
        "ice_cream": {"serving": "100 g (½ cup)", "calories": "200 kcal", "protein": "4 g", "carbs": "20 g", "fat": "12 g", "ingredients": "cream, sugar, milk, vanilla", "method": "churned and frozen", "substitute": "frozen yogurt"},
        "fried_rice": {"serving": "200 g (1 cup)", "calories": "250 kcal", "protein": "8 g", "carbs": "35 g", "fat": "9 g", "ingredients": "rice, egg, peas, carrots, soy sauce, oil", "method": "stir-fried", "substitute": "brown rice"},
        "tacos": {"serving": "100 g (1 taco)", "calories": "200 kcal", "protein": "10 g", "carbs": "15 g", "fat": "10 g", "ingredients": "ground beef, corn tortilla, lettuce, cheese, salsa", "method": "beef pan-fried, tortilla warmed", "substitute": "fish filling"},
        "steak": {"serving": "113 g (4 oz)", "calories": "250 kcal", "protein": "25 g", "carbs": "0 g", "fat": "15 g", "ingredients": "beef sirloin, salt, pepper", "method": "grilled or pan-seared", "substitute": "leaner cut (filet mignon)"},
        "chocolate_cake": {"serving": "100 g (1 slice)", "calories": "350 kcal", "protein": "5 g", "carbs": "50 g", "fat": "15 g", "ingredients": "flour, sugar, cocoa, butter, eggs", "method": "baked", "substitute": "gluten-free flour"}
    }

    label_mapping = {
        "caesar_salad": "salad",
        "spaghetti_bolognese": "pasta"
    }

    st.sidebar.header("Models Used")
    st.sidebar.markdown("""
    - 🖼️ **Image Classifier**: shingguy1/fine_tuned_vit
    - 💬 **Paraphraser**: google/flan-t5-small (sampling mode)
    """)

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Lambda(lambda img: img.convert("RGB")),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    @st.cache_resource
    def load_models():
        device = torch.device("cpu")
        vit = ViTForImageClassification.from_pretrained(
            "shingguy1/fine_tuned_vit",
            cache_dir=cache_dir,
            use_auth_token=hf_token
        ).to(device)
        tok = AutoTokenizer.from_pretrained(
            "google/flan-t5-small",
            cache_dir=cache_dir,
            use_auth_token=hf_token
        )
        t5 = T5ForConditionalGeneration.from_pretrained(
            "google/flan-t5-small",
            cache_dir=cache_dir,
            use_auth_token=hf_token
        ).to(device)
        return vit, tok, t5, device

    model_vit, tokenizer_t5, model_t5, device = load_models()

    uploaded = st.file_uploader("📷 Upload a food image...", type=["jpg", "png", "jpeg"])
    if uploaded:
        img = Image.open(uploaded)
        st.image(img, caption="Your Food", use_column_width=True)

        inp = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            out = model_vit(pixel_values=inp)
        label = model_vit.config.id2label[out.logits.argmax(-1).item()]
        st.success(f"🍽️ Detected: **{label}**")

        true_label = label_mapping.get(label.lower(), label.lower())
        data = nutritional_info.get(true_label)

        if data:
            base_description = (
                f"A typical {true_label} serving ({data['serving']}) contains about {data['calories']}, "
                f"with {data['protein']} protein, {data['carbs']} carbs, and {data['fat']} fat. "
                f"Made from {data['ingredients']} and usually {data['method']}. "
                f"Try {data['substitute']} as a healthier swap."
            )
            prompt = (
                f"Paraphrase the following nutritional facts in a friendly, conversational tone. "
                f"Use varied sentence structures and synonyms, and feel free to generalize numeric details "
                f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n"
                f"{base_description}"
            )
        else:
            prompt = (
                f"Provide an approximate nutrition summary for {label}, including calories, "
                f"macronutrients, and a brief description."
            )

        inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
        output_ids = model_t5.generate(
            inputs["input_ids"],
            max_new_tokens=100,
            do_sample=True,
            top_p=0.9,
            temperature=0.7,
            early_stopping=True
        )
        response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)

        # Fallback if the output seems too short or misses key phrases
        if "calories" not in response.lower() or len(response.split()) < 10:
            response = base_description

        st.subheader("🧾 Nutrition Overview")
        st.info(response)

if __name__ == "__main__":
    main()