shingguy1 commited on
Commit
068cfde
·
verified ·
1 Parent(s): 868c7ab

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +162 -3
src/streamlit_app.py CHANGED
@@ -10,7 +10,11 @@ from transformers import (
10
  )
11
 
12
  # Set Streamlit UI
13
- st.set_page_config(page_title="🍽️ Food Nutrition Estimator", page_icon="🥗", layout="centered")
 
 
 
 
14
  st.title("🍽️ Food Nutrition Estimator")
15
  st.markdown("Upload a food image and get nutritional information generated by AI!")
16
 
@@ -25,7 +29,8 @@ manual_transform = transforms.Compose([
25
  transforms.Resize(224),
26
  transforms.CenterCrop(196),
27
  transforms.ToTensor(),
28
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
29
  transforms.ConvertImageDtype(torch.float32)
30
  ])
31
 
@@ -64,4 +69,158 @@ def load_models():
64
 
65
  model_convnext, tokenizer, model_llm, device = load_models()
66
 
67
- # ... rest of your code remains unchanged ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  )
11
 
12
  # Set Streamlit UI
13
+ st.set_page_config(
14
+ page_title="🍽️ Food Nutrition Estimator",
15
+ page_icon="🥗",
16
+ layout="centered"
17
+ )
18
  st.title("🍽️ Food Nutrition Estimator")
19
  st.markdown("Upload a food image and get nutritional information generated by AI!")
20
 
 
29
  transforms.Resize(224),
30
  transforms.CenterCrop(196),
31
  transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
33
+ std=[0.229, 0.224, 0.225]),
34
  transforms.ConvertImageDtype(torch.float32)
35
  ])
36
 
 
69
 
70
  model_convnext, tokenizer, model_llm, device = load_models()
71
 
72
+ # Fallback nutritional data for all ten ConvNeXt food classes
73
+ fallback_nutrition = {
74
+ "hamburger": (
75
+ "A standard hamburger (single beef patty, bun, lettuce, tomato, no condiments) weighs about 150g per serving. "
76
+ "It has approximately 300 calories, 20g protein, 30g carbohydrates, and 12g fat. "
77
+ "Main ingredients: ground beef patty (80/20 lean-to-fat ratio), white bun, lettuce, tomato. "
78
+ "Cooking method: grilled or pan-fried. Nutritional facts: 500mg sodium, 10% daily value iron, minimal fiber. "
79
+ "Add-ins like cheese (+100 calories, 7g fat) or mayo (+90 calories, 10g fat) increase calories and fat. "
80
+ "Substitute chicken patty for lower fat (8g)."
81
+ ),
82
+ "pizza": (
83
+ "A standard cheese pizza slice (1/8 of a 14-inch pizza) weighs about 100g per serving. "
84
+ "It has approximately 270 calories, 12g protein, 34g carbohydrates, and 10g fat. "
85
+ "Main ingredients: dough, tomato sauce, mozzarella cheese. "
86
+ "Cooking method: baked in an oven. Nutritional facts: 600mg sodium, 15% daily value calcium, 2g fiber. "
87
+ "Add-ins like pepperoni (+50 calories, 5g fat) or extra cheese (+80 calories, 6g fat) increase calories. "
88
+ "Substitute cauliflower crust for lower carbs (20g)."
89
+ ),
90
+ "sushi": (
91
+ "A standard sushi roll (6 pieces, e.g., California roll) weighs about 150g per serving. "
92
+ "It has approximately 200 calories, 7g protein, 30g carbohydrates, and 5g fat. "
93
+ "Main ingredients: sushi rice, nori (seaweed), crab or imitation crab, avocado, cucumber. "
94
+ "Cooking method: raw or assembled without cooking. Nutritional facts: 400mg sodium, 10% daily value vitamin A. "
95
+ "Add-ins like soy sauce (+200mg sodium) or spicy mayo (+100 calories, 10g fat) increase sodium or fat. "
96
+ "Substitute brown rice for higher fiber (2g)."
97
+ ),
98
+ "ceasar_salad": (
99
+ "A standard garden salad (mixed greens, tomato, cucumber, no dressing) weighs about 200g per serving. "
100
+ "It has approximately 50 calories, 2g protein, 10g carbohydrates, and 0.5g fat. "
101
+ "Main ingredients: lettuce, spinach, tomato, cucumber, carrots. "
102
+ "Cooking method: raw, no cooking. Nutritional facts: 50mg sodium, 50% daily value vitamin C, 4g fiber. "
103
+ "Add-ins like ranch dressing (+150 calories, 15g fat) or grilled chicken (+120 calories, 20g protein) increase calories or protein. "
104
+ "Substitute vinaigrette for lower fat (5g)."
105
+ ),
106
+ "pasta": (
107
+ "A standard pasta dish (1 cup cooked spaghetti with marinara sauce) weighs about 200g per serving. "
108
+ "It has approximately 220 calories, 7g protein, 43g carbohydrates, and 2g fat. "
109
+ "Main ingredients: wheat pasta, tomato sauce, olive oil. "
110
+ "Cooking method: boiled pasta, sauce simmered. Nutritional facts: 400mg sodium, 10% daily value vitamin A, 3g fiber. "
111
+ "Add-ins like meatballs (+200 calories, 15g fat) or parmesan (+50 calories, 4g fat) increase calories. "
112
+ "Substitute whole-grain pasta for higher fiber (5g)."
113
+ ),
114
+ "ice_cream": (
115
+ "A standard ice cream serving (1/2 cup vanilla) weighs about 100g. "
116
+ "It has approximately 200 calories, 4g protein, 20g carbohydrates, and 12g fat. "
117
+ "Main ingredients: cream, sugar, milk, vanilla extract. "
118
+ "Cooking method: churned and frozen. Nutritional facts: 100mg sodium, 15% daily value calcium, 0g fiber. "
119
+ "Add-ins like chocolate syrup (+100 calories, 2g fat) or sprinkles (+50 calories) increase calories. "
120
+ "Substitute frozen yogurt for lower fat (7g)."
121
+ ),
122
+ "fried_rice": (
123
+ "A standard fried rice serving (1 cup with vegetables and egg) weighs about 200g. "
124
+ "It has approximately 250 calories, 8g protein, 35g carbohydrates, and 9g fat. "
125
+ "Main ingredients: white rice, egg, peas, carrots, soy sauce, vegetable oil. "
126
+ "Cooking method: stir-fried. Nutritional facts: 700mg sodium, 10% daily value vitamin A, 2g fiber. "
127
+ "Add-ins like chicken (+100 calories, 15g protein) or shrimp (+80 calories, 12g protein) increase protein. "
128
+ "Substitute brown rice for higher fiber (3g)."
129
+ ),
130
+ "tacos": (
131
+ "A standard taco (1 soft corn tortilla with beef, lettuce, cheese) weighs about 100g per serving. "
132
+ "It has approximately 200 calories, 10g protein, 15g carbohydrates, and 10g fat. "
133
+ "Main ingredients: ground beef, corn tortilla, lettuce, cheddar cheese, salsa. "
134
+ "Cooking method: beef pan-fried, tortilla warmed. Nutritional facts: 400mg sodium, 10% daily value calcium, 2g fiber. "
135
+ "Add-ins like sour cream (+50 calories, 5g fat) or guacamole (+80 calories, 7g fat) increase fat. "
136
+ "Substitute fish for lower fat (6g)."
137
+ ),
138
+ "steak": (
139
+ "A standard steak (4 oz grilled sirloin) weighs about 113g per serving. "
140
+ "It has approximately 250 calories, 25g protein, 0g carbohydrates, and 15g fat. "
141
+ "Main ingredients: beef sirloin, salt, pepper. "
142
+ "Cooking method: grilled or pan-seared. Nutritional facts: 300mg sodium, 20% daily value iron, 0g fiber. "
143
+ "Add-ins like butter sauce (+100 calories, 10g fat) or mashed potatoes (+150 calories, 5g fat) increase calories. "
144
+ "Substitute leaner cut (e.g., filet mignon) for lower fat (10g)."
145
+ ),
146
+ "chocolate_cake": (
147
+ "A standard chocolate cake slice (1/12 of a 9-inch cake) weighs about 100g per serving. "
148
+ "It has approximately 350 calories, 5g protein, 50g carbohydrates, and 15g fat. "
149
+ "Main ingredients: flour, sugar, cocoa powder, butter, eggs. "
150
+ "Cooking method: baked. Nutritional facts: 200mg sodium, 5% daily value iron, 2g fiber. "
151
+ "Add-ins like frosting (+100 calories, 5g fat) or whipped cream (+50 calories, 5g fat) increase calories. "
152
+ "Substitute gluten-free flour for dietary needs (same calories)."
153
+ )
154
+ }
155
+
156
+ # Upload image
157
+ uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
158
+
159
+ if uploaded_file is not None:
160
+ try:
161
+ image = Image.open(uploaded_file).convert("RGB")
162
+ st.image(image, caption="Uploaded Image", use_column_width=True)
163
+
164
+ # Predict with ConvNeXt
165
+ input_tensor = manual_transform(image).unsqueeze(0).to(device)
166
+ with torch.no_grad():
167
+ outputs = model_convnext(pixel_values=input_tensor)
168
+ pred_idx = outputs.logits.argmax(-1).item()
169
+ pred_label = model_convnext.config.id2label[pred_idx]
170
+ st.success(f"🍴 Predicted Food: **{pred_label}**")
171
+
172
+ # Generate nutrition caption using TinyLlama
173
+ fallback_text = fallback_nutrition.get(pred_label.lower(), "Nutritional facts unavailable for this item.")
174
+ simplified_fallback = (
175
+ f"Food: {pred_label}\n"
176
+ f"Serving size: {fallback_text.split('weighs about ')[1].split(' per serving')[0]}\n"
177
+ f"Calories: {fallback_text.split('approximately ')[1].split(' calories')[0]} calories\n"
178
+ f"Protein: {fallback_text.split('protein, ')[1].split('g ')[0]}g\n"
179
+ f"Carbs: {fallback_text.split('carbohydrates, ')[1].split('g ')[0]}g\n"
180
+ f"Fat: {fallback_text.split('and ')[1].split('g fat')[0]}g\n"
181
+ f"Ingredients: {fallback_text.split('Main ingredients: ')[1].split('. ')[0]}\n"
182
+ f"Cooking method: {fallback_text.split('Cooking method: ')[1].split('. ')[0]}\n"
183
+ f"Extras: {fallback_text.split('Add-ins like ')[1].split('. Substitute')[0]}\n"
184
+ f"Substitution: {fallback_text.split('Substitute ')[1].split('.')[0]}"
185
+ )
186
+ prompt = (
187
+ f"Below is nutritional information for a {pred_label}:\n"
188
+ f"{simplified_fallback}\n\n"
189
+ f"Using this data, write a concise, natural description of the {pred_label}'s nutrition in your own words. "
190
+ f"Do not copy the text above; create a new description with different wording. "
191
+ f"Include calories, macronutrients (protein, carbs, fat), serving size, ingredients, cooking method, add-ins, and substitution. "
192
+ f"Example for a taco: 'A classic taco, around 100g, offers 200 calories, 10g protein, 15g carbs, and 10g fat. "
193
+ f"It’s prepared with ground beef, corn tortilla, lettuce, cheese, and salsa, with the beef fried and tortilla warmed. "
194
+ f"Add sour cream for 50 extra calories or switch to fish for 6g fat.'"
195
+ )
196
+ st.subheader("🧾 Nutrition Information")
197
+ st.write(f"🤖 Prompt: Describe {pred_label} nutrition in your own words based on provided data.")
198
+
199
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
200
+ with torch.no_grad():
201
+ output = model_llm.generate(
202
+ **input_ids,
203
+ max_new_tokens=300,
204
+ temperature=0.8,
205
+ top_p=0.9,
206
+ do_sample=True
207
+ )
208
+ input_len = input_ids.input_ids.shape[1]
209
+ caption = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip()
210
+
211
+ # Check similarity to fallback
212
+ fallback_words = set(fallback_text.lower().split())
213
+ caption_words = set(caption.lower().split())
214
+ similarity = len(fallback_words & caption_words) / max(len(fallback_words), 1)
215
+
216
+ if not caption or len(caption) < 60 or any(kw in caption.lower() for kw in ["rephrase", "below is", "example for"]) or similarity > 0.8:
217
+ caption = fallback_text
218
+
219
+ st.info(caption)
220
+
221
+ except Exception as e:
222
+ st.error(f"Something went wrong: {e}")
223
+
224
+ # Footer
225
+ st.markdown("---")
226
+ st.markdown("Built with ❤️ using Streamlit and Hugging Face by **shingguy1**")