shingguy1 commited on
Commit
868c7ab
·
verified ·
1 Parent(s): b0637a8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +8 -164
src/streamlit_app.py CHANGED
@@ -32,7 +32,7 @@ manual_transform = transforms.Compose([
32
  # Sidebar info
33
  st.sidebar.header("Models Used")
34
  st.sidebar.markdown("""
35
- - 🖼️ **Image Classifier**: `shingguy1/food-calorie-convnext`
36
  - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
37
  """)
38
 
@@ -41,15 +41,18 @@ st.sidebar.markdown("""
41
  def load_models():
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
- # ConvNeXt for classification
45
  model_convnext = ConvNextForImageClassification.from_pretrained(
46
- "shingguy1/food-calorie-convnext",
47
  cache_dir=cache_dir,
48
  token=hf_token
49
  ).to(device)
50
 
51
  # TinyLlama for nutritional facts
52
- tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", cache_dir=cache_dir)
 
 
 
53
  model_llm = AutoModelForCausalLM.from_pretrained(
54
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
55
  cache_dir=cache_dir,
@@ -61,163 +64,4 @@ def load_models():
61
 
62
  model_convnext, tokenizer, model_llm, device = load_models()
63
 
64
- # Fallback nutritional data for all ten ConvNeXt food classes
65
- fallback_nutrition = {
66
- "hamburger": (
67
- "A standard hamburger (single beef patty, bun, lettuce, tomato, no condiments) weighs about 150g per serving. "
68
- "It has approximately 300 calories, 20g protein, 30g carbohydrates, and 12g fat. "
69
- "Main ingredients: ground beef patty (80/20 lean-to-fat ratio), white bun, lettuce, tomato. "
70
- "Cooking method: grilled or pan-fried. Nutritional facts: 500mg sodium, 10% daily value iron, minimal fiber. "
71
- "Add-ins like cheese (+100 calories, 7g fat) or mayo (+90 calories, 10g fat) increase calories and fat. "
72
- "Substitute chicken patty for lower fat (8g)."
73
- ),
74
- "pizza": (
75
- "A standard cheese pizza slice (1/8 of a 14-inch pizza) weighs about 100g per serving. "
76
- "It has approximately 270 calories, 12g protein, 34g carbohydrates, and 10g fat. "
77
- "Main ingredients: dough, tomato sauce, mozzarella cheese. "
78
- "Cooking method: baked in an oven. Nutritional facts: 600mg sodium, 15% daily value calcium, 2g fiber. "
79
- "Add-ins like pepperoni (+50 calories, 5g fat) or extra cheese (+80 calories, 6g fat) increase calories. "
80
- "Substitute cauliflower crust for lower carbs (20g)."
81
- ),
82
- "sushi": (
83
- "A standard sushi roll (6 pieces, e.g., California roll) weighs about 150g per serving. "
84
- "It has approximately 200 calories, 7g protein, 30g carbohydrates, and 5g fat. "
85
- "Main ingredients: sushi rice, nori (seaweed), crab or imitation crab, avocado, cucumber. "
86
- "Cooking method: raw or assembled without cooking. Nutritional facts: 400mg sodium, 10% daily value vitamin A. "
87
- "Add-ins like soy sauce (+200mg sodium) or spicy mayo (+100 calories, 10g fat) increase sodium or fat. "
88
- "Substitute brown rice for higher fiber (2g)."
89
- ),
90
- "salad": (
91
- "A standard garden salad (mixed greens, tomato, cucumber, no dressing) weighs about 200g per serving. "
92
- "It has approximately 50 calories, 2g protein, 10g carbohydrates, and 0.5g fat. "
93
- "Main ingredients: lettuce, spinach, tomato, cucumber, carrots. "
94
- "Cooking method: raw, no cooking. Nutritional facts: 50mg sodium, 50% daily value vitamin C, 4g fiber. "
95
- "Add-ins like ranch dressing (+150 calories, 15g fat) or grilled chicken (+120 calories, 20g protein) increase calories or protein. "
96
- "Substitute vinaigrette for lower fat (5g)."
97
- ),
98
- "pasta": (
99
- "A standard pasta dish (1 cup cooked spaghetti with marinara sauce) weighs about 200g per serving. "
100
- "It has approximately 220 calories, 7g protein, 43g carbohydrates, and 2g fat. "
101
- "Main ingredients: wheat pasta, tomato sauce, olive oil. "
102
- "Cooking method: boiled pasta, sauce simmered. Nutritional facts: 400mg sodium, 10% daily value vitamin A, 3g fiber. "
103
- "Add-ins like meatballs (+200 calories, 15g fat) or parmesan (+50 calories, 4g fat) increase calories. "
104
- "Substitute whole-grain pasta for higher fiber (5g)."
105
- ),
106
- "ice cream": (
107
- "A standard ice cream serving (1/2 cup vanilla) weighs about 100g. "
108
- "It has approximately 200 calories, 4g protein, 20g carbohydrates, and 12g fat. "
109
- "Main ingredients: cream, sugar, milk, vanilla extract. "
110
- "Cooking method: churned and frozen. Nutritional facts: 100mg sodium, 15% daily value calcium, 0g fiber. "
111
- "Add-ins like chocolate syrup (+100 calories, 2g fat) or sprinkles (+50 calories) increase calories. "
112
- "Substitute frozen yogurt for lower fat (7g)."
113
- ),
114
- "fried rice": (
115
- "A standard fried rice serving (1 cup with vegetables and egg) weighs about 200g. "
116
- "It has approximately 250 calories, 8g protein, 35g carbohydrates, and 9g fat. "
117
- "Main ingredients: white rice, egg, peas, carrots, soy sauce, vegetable oil. "
118
- "Cooking method: stir-fried. Nutritional facts: 700mg sodium, 10% daily value vitamin A, 2g fiber. "
119
- "Add-ins like chicken (+100 calories, 15g protein) or shrimp (+80 calories, 12g protein) increase protein. "
120
- "Substitute brown rice for higher fiber (3g)."
121
- ),
122
- "tacos": (
123
- "A standard taco (1 soft corn tortilla with beef, lettuce, cheese) weighs about 100g per serving. "
124
- "It has approximately 200 calories, 10g protein, 15g carbohydrates, and 10g fat. "
125
- "Main ingredients: ground beef, corn tortilla, lettuce, cheddar cheese, salsa. "
126
- "Cooking method: beef pan-fried, tortilla warmed. Nutritional facts: 400mg sodium, 10% daily value calcium, 2g fiber. "
127
- "Add-ins like sour cream (+50 calories, 5g fat) or guacamole (+80 calories, 7g fat) increase fat. "
128
- "Substitute fish for lower fat (6g)."
129
- ),
130
- "steak": (
131
- "A standard steak (4 oz grilled sirloin) weighs about 113g per serving. "
132
- "It has approximately 250 calories, 25g protein, 0g carbohydrates, and 15g fat. "
133
- "Main ingredients: beef sirloin, salt, pepper. "
134
- "Cooking method: grilled or pan-seared. Nutritional facts: 300mg sodium, 20% daily value iron, 0g fiber. "
135
- "Add-ins like butter sauce (+100 calories, 10g fat) or mashed potatoes (+150 calories, 5g fat) increase calories. "
136
- "Substitute leaner cut (e.g., filet mignon) for lower fat (10g)."
137
- ),
138
- "chocolate cake": (
139
- "A standard chocolate cake slice (1/12 of a 9-inch cake) weighs about 100g per serving. "
140
- "It has approximately 350 calories, 5g protein, 50g carbohydrates, and 15g fat. "
141
- "Main ingredients: flour, sugar, cocoa powder, butter, eggs. "
142
- "Cooking method: baked. Nutritional facts: 200mg sodium, 5% daily value iron, 2g fiber. "
143
- "Add-ins like frosting (+100 calories, 5g fat) or whipped cream (+50 calories, 5g fat) increase calories. "
144
- "Substitute gluten-free flour for dietary needs (same calories)."
145
- )
146
- }
147
-
148
- # Upload image
149
- uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
150
-
151
- if uploaded_file is not None:
152
- try:
153
- image = Image.open(uploaded_file).convert("RGB")
154
- st.image(image, caption="Uploaded Image", use_column_width=True)
155
-
156
- # Predict with ConvNeXt
157
- input_tensor = manual_transform(image).unsqueeze(0).to(device)
158
- with torch.no_grad():
159
- outputs = model_convnext(pixel_values=input_tensor)
160
- pred_idx = outputs.logits.argmax(-1).item()
161
- pred_label = model_convnext.config.id2label[pred_idx]
162
- st.success(f"🍴 Predicted Food: **{pred_label}**")
163
-
164
- # Generate nutrition caption using TinyLlama
165
- fallback_text = fallback_nutrition.get(pred_label.lower(), "Nutritional facts unavailable for this item.")
166
- # Simplified fallback data for the prompt
167
- simplified_fallback = (
168
- f"Food: {pred_label}\n"
169
- f"Serving size: {fallback_text.split('weighs about ')[1].split(' per serving')[0]}\n"
170
- f"Calories: {fallback_text.split('approximately ')[1].split(' calories')[0]} calories\n"
171
- f"Protein: {fallback_text.split('protein, ')[1].split('g ')[0]}g\n"
172
- f"Carbs: {fallback_text.split('carbohydrates, ')[1].split('g ')[0]}g\n"
173
- f"Fat: {fallback_text.split('and ')[1].split('g fat')[0]}g\n"
174
- f"Ingredients: {fallback_text.split('Main ingredients: ')[1].split('. ')[0]}\n"
175
- f"Cooking method: {fallback_text.split('Cooking method: ')[1].split('. ')[0]}\n"
176
- f"Extras: {fallback_text.split('Add-ins like ')[1].split('. Substitute')[0]}\n"
177
- f"Substitution: {fallback_text.split('Substitute ')[1].split('.')[0]}"
178
- )
179
- prompt = (
180
- f"Below is nutritional information for a {pred_label}:\n"
181
- f"{simplified_fallback}\n\n"
182
- f"Using this data, write a concise, natural description of the {pred_label}'s nutrition in your own words. "
183
- f"Do not copy the text above; create a new description with different wording. "
184
- f"Include calories, macronutrients (protein, carbs, fat), serving size, ingredients, cooking method, add-ins, and substitution. "
185
- f"Example for a taco: 'A classic taco, around 100g, offers 200 calories, 10g protein, 15g carbs, and 10g fat. "
186
- f"It’s prepared with ground beef, corn tortilla, lettuce, cheese, and salsa, with the beef fried and tortilla warmed. "
187
- f"Add sour cream for 50 extra calories or switch to fish for 6g fat.'"
188
- )
189
- st.subheader("🧾 Nutrition Information")
190
- st.write(f"🤖 Prompt: Describe {pred_label} nutrition in your own words based on provided data.")
191
-
192
- input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
193
- with torch.no_grad():
194
- output = model_llm.generate(
195
- **input_ids,
196
- max_new_tokens=300,
197
- temperature=0.8,
198
- top_p=0.9,
199
- do_sample=True
200
- )
201
- # Decode only the generated tokens
202
- input_length = input_ids.input_ids.shape[1]
203
- caption = tokenizer.decode(output[0][input_length:], skip_special_tokens=True).strip()
204
-
205
- # Debug output
206
- st.write(f"Debug: TinyLlama output: `{caption}`")
207
-
208
- # Check if output is too similar to fallback (simple similarity check)
209
- fallback_words = set(fallback_text.lower().split())
210
- caption_words = set(caption.lower().split())
211
- similarity = len(fallback_words.intersection(caption_words)) / len(fallback_words)
212
-
213
- # Use fallback if output is empty, too short, prompt-like, or too similar
214
- if not caption or len(caption) < 60 or any(kw in caption.lower() for kw in ["rephrase", "below is", "example for"]) or similarity > 0.8:
215
- caption = fallback_text
216
- st.info(caption)
217
-
218
- except Exception as e:
219
- st.error(f"Something went wrong: {e}")
220
-
221
- # Footer
222
- st.markdown("---")
223
- st.markdown("Built with ❤️ using Streamlit and Hugging Face by **shingguy1**")
 
32
  # Sidebar info
33
  st.sidebar.header("Models Used")
34
  st.sidebar.markdown("""
35
+ - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_convnext`
36
  - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
37
  """)
38
 
 
41
  def load_models():
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
+ # ConvNeXt for classification (updated repo)
45
  model_convnext = ConvNextForImageClassification.from_pretrained(
46
+ "shingguy1/fine_tuned_convnext",
47
  cache_dir=cache_dir,
48
  token=hf_token
49
  ).to(device)
50
 
51
  # TinyLlama for nutritional facts
52
+ tokenizer = AutoTokenizer.from_pretrained(
53
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
54
+ cache_dir=cache_dir
55
+ )
56
  model_llm = AutoModelForCausalLM.from_pretrained(
57
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
58
  cache_dir=cache_dir,
 
64
 
65
  model_convnext, tokenizer, model_llm, device = load_models()
66
 
67
+ # ... rest of your code remains unchanged ...