shingguy1 commited on
Commit
898542f
·
verified ·
1 Parent(s): 068cfde

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +10 -121
src/streamlit_app.py CHANGED
@@ -46,7 +46,7 @@ st.sidebar.markdown("""
46
  def load_models():
47
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
 
49
- # ConvNeXt for classification (updated repo)
50
  model_convnext = ConvNextForImageClassification.from_pretrained(
51
  "shingguy1/fine_tuned_convnext",
52
  cache_dir=cache_dir,
@@ -69,99 +69,16 @@ def load_models():
69
 
70
  model_convnext, tokenizer, model_llm, device = load_models()
71
 
72
- # Fallback nutritional data for all ten ConvNeXt food classes
73
- fallback_nutrition = {
74
- "hamburger": (
75
- "A standard hamburger (single beef patty, bun, lettuce, tomato, no condiments) weighs about 150g per serving. "
76
- "It has approximately 300 calories, 20g protein, 30g carbohydrates, and 12g fat. "
77
- "Main ingredients: ground beef patty (80/20 lean-to-fat ratio), white bun, lettuce, tomato. "
78
- "Cooking method: grilled or pan-fried. Nutritional facts: 500mg sodium, 10% daily value iron, minimal fiber. "
79
- "Add-ins like cheese (+100 calories, 7g fat) or mayo (+90 calories, 10g fat) increase calories and fat. "
80
- "Substitute chicken patty for lower fat (8g)."
81
- ),
82
- "pizza": (
83
- "A standard cheese pizza slice (1/8 of a 14-inch pizza) weighs about 100g per serving. "
84
- "It has approximately 270 calories, 12g protein, 34g carbohydrates, and 10g fat. "
85
- "Main ingredients: dough, tomato sauce, mozzarella cheese. "
86
- "Cooking method: baked in an oven. Nutritional facts: 600mg sodium, 15% daily value calcium, 2g fiber. "
87
- "Add-ins like pepperoni (+50 calories, 5g fat) or extra cheese (+80 calories, 6g fat) increase calories. "
88
- "Substitute cauliflower crust for lower carbs (20g)."
89
- ),
90
- "sushi": (
91
- "A standard sushi roll (6 pieces, e.g., California roll) weighs about 150g per serving. "
92
- "It has approximately 200 calories, 7g protein, 30g carbohydrates, and 5g fat. "
93
- "Main ingredients: sushi rice, nori (seaweed), crab or imitation crab, avocado, cucumber. "
94
- "Cooking method: raw or assembled without cooking. Nutritional facts: 400mg sodium, 10% daily value vitamin A. "
95
- "Add-ins like soy sauce (+200mg sodium) or spicy mayo (+100 calories, 10g fat) increase sodium or fat. "
96
- "Substitute brown rice for higher fiber (2g)."
97
- ),
98
- "ceasar_salad": (
99
- "A standard garden salad (mixed greens, tomato, cucumber, no dressing) weighs about 200g per serving. "
100
- "It has approximately 50 calories, 2g protein, 10g carbohydrates, and 0.5g fat. "
101
- "Main ingredients: lettuce, spinach, tomato, cucumber, carrots. "
102
- "Cooking method: raw, no cooking. Nutritional facts: 50mg sodium, 50% daily value vitamin C, 4g fiber. "
103
- "Add-ins like ranch dressing (+150 calories, 15g fat) or grilled chicken (+120 calories, 20g protein) increase calories or protein. "
104
- "Substitute vinaigrette for lower fat (5g)."
105
- ),
106
- "pasta": (
107
- "A standard pasta dish (1 cup cooked spaghetti with marinara sauce) weighs about 200g per serving. "
108
- "It has approximately 220 calories, 7g protein, 43g carbohydrates, and 2g fat. "
109
- "Main ingredients: wheat pasta, tomato sauce, olive oil. "
110
- "Cooking method: boiled pasta, sauce simmered. Nutritional facts: 400mg sodium, 10% daily value vitamin A, 3g fiber. "
111
- "Add-ins like meatballs (+200 calories, 15g fat) or parmesan (+50 calories, 4g fat) increase calories. "
112
- "Substitute whole-grain pasta for higher fiber (5g)."
113
- ),
114
- "ice_cream": (
115
- "A standard ice cream serving (1/2 cup vanilla) weighs about 100g. "
116
- "It has approximately 200 calories, 4g protein, 20g carbohydrates, and 12g fat. "
117
- "Main ingredients: cream, sugar, milk, vanilla extract. "
118
- "Cooking method: churned and frozen. Nutritional facts: 100mg sodium, 15% daily value calcium, 0g fiber. "
119
- "Add-ins like chocolate syrup (+100 calories, 2g fat) or sprinkles (+50 calories) increase calories. "
120
- "Substitute frozen yogurt for lower fat (7g)."
121
- ),
122
- "fried_rice": (
123
- "A standard fried rice serving (1 cup with vegetables and egg) weighs about 200g. "
124
- "It has approximately 250 calories, 8g protein, 35g carbohydrates, and 9g fat. "
125
- "Main ingredients: white rice, egg, peas, carrots, soy sauce, vegetable oil. "
126
- "Cooking method: stir-fried. Nutritional facts: 700mg sodium, 10% daily value vitamin A, 2g fiber. "
127
- "Add-ins like chicken (+100 calories, 15g protein) or shrimp (+80 calories, 12g protein) increase protein. "
128
- "Substitute brown rice for higher fiber (3g)."
129
- ),
130
- "tacos": (
131
- "A standard taco (1 soft corn tortilla with beef, lettuce, cheese) weighs about 100g per serving. "
132
- "It has approximately 200 calories, 10g protein, 15g carbohydrates, and 10g fat. "
133
- "Main ingredients: ground beef, corn tortilla, lettuce, cheddar cheese, salsa. "
134
- "Cooking method: beef pan-fried, tortilla warmed. Nutritional facts: 400mg sodium, 10% daily value calcium, 2g fiber. "
135
- "Add-ins like sour cream (+50 calories, 5g fat) or guacamole (+80 calories, 7g fat) increase fat. "
136
- "Substitute fish for lower fat (6g)."
137
- ),
138
- "steak": (
139
- "A standard steak (4 oz grilled sirloin) weighs about 113g per serving. "
140
- "It has approximately 250 calories, 25g protein, 0g carbohydrates, and 15g fat. "
141
- "Main ingredients: beef sirloin, salt, pepper. "
142
- "Cooking method: grilled or pan-seared. Nutritional facts: 300mg sodium, 20% daily value iron, 0g fiber. "
143
- "Add-ins like butter sauce (+100 calories, 10g fat) or mashed potatoes (+150 calories, 5g fat) increase calories. "
144
- "Substitute leaner cut (e.g., filet mignon) for lower fat (10g)."
145
- ),
146
- "chocolate_cake": (
147
- "A standard chocolate cake slice (1/12 of a 9-inch cake) weighs about 100g per serving. "
148
- "It has approximately 350 calories, 5g protein, 50g carbohydrates, and 15g fat. "
149
- "Main ingredients: flour, sugar, cocoa powder, butter, eggs. "
150
- "Cooking method: baked. Nutritional facts: 200mg sodium, 5% daily value iron, 2g fiber. "
151
- "Add-ins like frosting (+100 calories, 5g fat) or whipped cream (+50 calories, 5g fat) increase calories. "
152
- "Substitute gluten-free flour for dietary needs (same calories)."
153
- )
154
- }
155
-
156
- # Upload image
157
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
158
 
159
  if uploaded_file is not None:
160
  try:
 
161
  image = Image.open(uploaded_file).convert("RGB")
162
  st.image(image, caption="Uploaded Image", use_column_width=True)
163
 
164
- # Predict with ConvNeXt
165
  input_tensor = manual_transform(image).unsqueeze(0).to(device)
166
  with torch.no_grad():
167
  outputs = model_convnext(pixel_values=input_tensor)
@@ -169,32 +86,14 @@ if uploaded_file is not None:
169
  pred_label = model_convnext.config.id2label[pred_idx]
170
  st.success(f"🍴 Predicted Food: **{pred_label}**")
171
 
172
- # Generate nutrition caption using TinyLlama
173
- fallback_text = fallback_nutrition.get(pred_label.lower(), "Nutritional facts unavailable for this item.")
174
- simplified_fallback = (
175
- f"Food: {pred_label}\n"
176
- f"Serving size: {fallback_text.split('weighs about ')[1].split(' per serving')[0]}\n"
177
- f"Calories: {fallback_text.split('approximately ')[1].split(' calories')[0]} calories\n"
178
- f"Protein: {fallback_text.split('protein, ')[1].split('g ')[0]}g\n"
179
- f"Carbs: {fallback_text.split('carbohydrates, ')[1].split('g ')[0]}g\n"
180
- f"Fat: {fallback_text.split('and ')[1].split('g fat')[0]}g\n"
181
- f"Ingredients: {fallback_text.split('Main ingredients: ')[1].split('. ')[0]}\n"
182
- f"Cooking method: {fallback_text.split('Cooking method: ')[1].split('. ')[0]}\n"
183
- f"Extras: {fallback_text.split('Add-ins like ')[1].split('. Substitute')[0]}\n"
184
- f"Substitution: {fallback_text.split('Substitute ')[1].split('.')[0]}"
185
- )
186
  prompt = (
187
- f"Below is nutritional information for a {pred_label}:\n"
188
- f"{simplified_fallback}\n\n"
189
- f"Using this data, write a concise, natural description of the {pred_label}'s nutrition in your own words. "
190
- f"Do not copy the text above; create a new description with different wording. "
191
- f"Include calories, macronutrients (protein, carbs, fat), serving size, ingredients, cooking method, add-ins, and substitution. "
192
- f"Example for a taco: 'A classic taco, around 100g, offers 200 calories, 10g protein, 15g carbs, and 10g fat. "
193
- f"It’s prepared with ground beef, corn tortilla, lettuce, cheese, and salsa, with the beef fried and tortilla warmed. "
194
- f"Add sour cream for 50 extra calories or switch to fish for 6g fat.'"
195
  )
196
  st.subheader("🧾 Nutrition Information")
197
- st.write(f"🤖 Prompt: Describe {pred_label} nutrition in your own words based on provided data.")
198
 
199
  input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
200
  with torch.no_grad():
@@ -205,17 +104,7 @@ if uploaded_file is not None:
205
  top_p=0.9,
206
  do_sample=True
207
  )
208
- input_len = input_ids.input_ids.shape[1]
209
- caption = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip()
210
-
211
- # Check similarity to fallback
212
- fallback_words = set(fallback_text.lower().split())
213
- caption_words = set(caption.lower().split())
214
- similarity = len(fallback_words & caption_words) / max(len(fallback_words), 1)
215
-
216
- if not caption or len(caption) < 60 or any(kw in caption.lower() for kw in ["rephrase", "below is", "example for"]) or similarity > 0.8:
217
- caption = fallback_text
218
-
219
  st.info(caption)
220
 
221
  except Exception as e:
 
46
  def load_models():
47
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
 
49
+ # ConvNeXt for classification
50
  model_convnext = ConvNextForImageClassification.from_pretrained(
51
  "shingguy1/fine_tuned_convnext",
52
  cache_dir=cache_dir,
 
69
 
70
  model_convnext, tokenizer, model_llm, device = load_models()
71
 
72
+ # Image uploader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
74
 
75
  if uploaded_file is not None:
76
  try:
77
+ # Load and display
78
  image = Image.open(uploaded_file).convert("RGB")
79
  st.image(image, caption="Uploaded Image", use_column_width=True)
80
 
81
+ # Predict food label
82
  input_tensor = manual_transform(image).unsqueeze(0).to(device)
83
  with torch.no_grad():
84
  outputs = model_convnext(pixel_values=input_tensor)
 
86
  pred_label = model_convnext.config.id2label[pred_idx]
87
  st.success(f"🍴 Predicted Food: **{pred_label}**")
88
 
89
+ # Generate nutrition description with LLM
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  prompt = (
91
+ f"Please provide a concise nutritional overview for a {pred_label}. "
92
+ "Include typical serving size, approximate calories, macronutrient breakdown "
93
+ "(protein, carbs, fat), main ingredients, common cooking method, and one substitution suggestion."
 
 
 
 
 
94
  )
95
  st.subheader("🧾 Nutrition Information")
96
+ st.write(f"🤖 Prompt to LLM:\n\n{prompt}")
97
 
98
  input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
99
  with torch.no_grad():
 
104
  top_p=0.9,
105
  do_sample=True
106
  )
107
+ caption = tokenizer.decode(output[0], skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
108
  st.info(caption)
109
 
110
  except Exception as e: