shingguy1 commited on
Commit
df9e1b3
·
verified ·
1 Parent(s): dbca709

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +169 -76
src/streamlit_app.py CHANGED
@@ -5,8 +5,9 @@ st.set_page_config(
5
  layout="centered"
6
  )
7
 
8
- import torch
9
  import os
 
 
10
  from PIL import Image
11
  import torchvision.transforms as transforms
12
  from transformers import (
@@ -22,108 +23,200 @@ def main():
22
  os.makedirs(cache_dir, exist_ok=True)
23
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
24
 
25
- # 2. Image transform for ViT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  manual_transform = transforms.Compose([
27
  transforms.Resize(256),
28
  transforms.CenterCrop(224),
29
  transforms.Lambda(lambda img: img.convert("RGB")),
30
  transforms.ToTensor(),
31
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
32
- std=[0.229, 0.224, 0.225]),
33
- transforms.ConvertImageDtype(torch.float32)
34
  ])
35
 
36
- # 3. Sidebar info
37
  st.sidebar.header("Models Used")
38
  st.sidebar.markdown("""
39
- - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
40
- - 💬 **Text Generator**: `google/flan-t5-small`
41
  """)
42
 
43
- # 4. Load models (cached)
44
  @st.cache_resource
45
  def load_models():
46
- device = torch.device("cpu") # CPU-only environment
47
-
48
- # ViT classifier
49
- model_vit = ViTForImageClassification.from_pretrained(
50
  "shingguy1/fine_tuned_vit",
51
  cache_dir=cache_dir,
52
  use_auth_token=hf_token
53
  ).to(device)
54
-
55
- # FLAN-T5 Small for generation
56
- tokenizer_llm = AutoTokenizer.from_pretrained(
57
  "google/flan-t5-small",
58
  cache_dir=cache_dir,
59
  use_auth_token=hf_token
60
  )
61
- model_llm = T5ForConditionalGeneration.from_pretrained(
62
  "google/flan-t5-small",
63
  cache_dir=cache_dir,
64
  use_auth_token=hf_token
65
  ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- return model_vit, tokenizer_llm, model_llm, device
68
-
69
- model_vit, tokenizer_llm, model_llm, device = load_models()
70
-
71
- # 5. Image uploader
72
- uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
73
- if uploaded_file is not None:
74
- try:
75
- # Display image
76
- image = Image.open(uploaded_file)
77
- st.image(image, caption="Uploaded Image", use_column_width=True)
78
-
79
- # Classify with ViT
80
- inputs_vit = manual_transform(image).unsqueeze(0).to(device)
81
- with torch.no_grad():
82
- vit_outputs = model_vit(pixel_values=inputs_vit)
83
- pred_idx = vit_outputs.logits.argmax(-1).item()
84
- pred_label = model_vit.config.id2label[pred_idx]
85
- st.success(f"🍴 Predicted Food: **{pred_label}**")
86
-
87
- # Build FLAN-T5 prompt
88
- prompt = (
89
- "Provide a concise nutritional overview for a taco, including:\n"
90
- "- Serving size (with measurements & ingestion guidelines)\n"
91
- "- Calories\n"
92
- "- Protein, carbohydrates, and fat\n"
93
- "- Main ingredients\n"
94
- "- Cooking method\n"
95
- "- One healthy substitution\n"
96
- "Answer only the overview."
97
- )
98
- st.subheader("🧾 Nutrition Information")
99
- st.write(f"🤖 Prompt:\n\n{prompt}")
100
-
101
- # Tokenize & generate
102
- inputs = tokenizer_llm(
103
- prompt,
104
- return_tensors="pt",
105
- padding="longest",
106
- truncation=True,
107
- ).to(device)
108
-
109
- outputs = model_llm.generate(
110
- input_ids=inputs.input_ids,
111
- attention_mask=inputs.attention_mask,
112
- max_new_tokens=150,
113
- temperature=0.7,
114
- top_p=0.9,
115
- do_sample=True,
116
- no_repeat_ngram_size=2,
117
- early_stopping=True,
118
- pad_token_id=tokenizer_llm.pad_token_id,
119
- eos_token_id=tokenizer_llm.eos_token_id
120
- )
121
-
122
- summary = tokenizer_llm.decode(outputs[0], skip_special_tokens=True).strip()
123
- st.info(summary or "⚠️ The model did not generate any text.")
124
-
125
- except Exception as e:
126
- st.error(f"Something went wrong: {e}")
127
 
128
  if __name__ == "__main__":
129
  main()
 
5
  layout="centered"
6
  )
7
 
 
8
  import os
9
+ import torch
10
+ import random
11
  from PIL import Image
12
  import torchvision.transforms as transforms
13
  from transformers import (
 
23
  os.makedirs(cache_dir, exist_ok=True)
24
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
25
 
26
+ # 2. Nutritional lookup table
27
+ nutritional_info = {
28
+ "pizza": {
29
+ "serving": "100 g (1 slice)",
30
+ "calories": "270 kcal",
31
+ "protein": "12 g",
32
+ "carbs": "34 g",
33
+ "fat": "10 g",
34
+ "ingredients": "dough, tomato sauce, mozzarella cheese",
35
+ "method": "baked",
36
+ "substitute": "cauliflower crust"
37
+ },
38
+ "hamburger": {
39
+ "serving": "150 g",
40
+ "calories": "300 kcal",
41
+ "protein": "20 g",
42
+ "carbs": "30 g",
43
+ "fat": "12 g",
44
+ "ingredients": "ground beef patty (80/20), bun, lettuce, tomato",
45
+ "method": "grilled or pan-fried",
46
+ "substitute": "chicken patty"
47
+ },
48
+ "sushi": {
49
+ "serving": "150 g (6 pieces)",
50
+ "calories": "200 kcal",
51
+ "protein": "7 g",
52
+ "carbs": "30 g",
53
+ "fat": "5 g",
54
+ "ingredients": "sushi rice, nori, crab (or imitation), avocado, cucumber",
55
+ "method": "assembled raw",
56
+ "substitute": "brown rice"
57
+ },
58
+ "salad": {
59
+ "serving": "200 g",
60
+ "calories": "50 kcal",
61
+ "protein": "2 g",
62
+ "carbs": "10 g",
63
+ "fat": "0.5 g",
64
+ "ingredients": "mixed greens, tomato, cucumber, carrots",
65
+ "method": "raw",
66
+ "substitute": "vinaigrette instead of ranch"
67
+ },
68
+ "pasta": {
69
+ "serving": "200 g (1 cup)",
70
+ "calories": "220 kcal",
71
+ "protein": "7 g",
72
+ "carbs": "43 g",
73
+ "fat": "2 g",
74
+ "ingredients": "wheat pasta, marinara sauce, olive oil",
75
+ "method": "boiled and simmered",
76
+ "substitute": "whole-grain pasta"
77
+ },
78
+ "ice_cream": {
79
+ "serving": "100 g (½ cup)",
80
+ "calories": "200 kcal",
81
+ "protein": "4 g",
82
+ "carbs": "20 g",
83
+ "fat": "12 g",
84
+ "ingredients": "cream, sugar, milk, vanilla",
85
+ "method": "churned and frozen",
86
+ "substitute": "frozen yogurt"
87
+ },
88
+ "fried_rice": {
89
+ "serving": "200 g (1 cup)",
90
+ "calories": "250 kcal",
91
+ "protein": "8 g",
92
+ "carbs": "35 g",
93
+ "fat": "9 g",
94
+ "ingredients": "rice, egg, peas, carrots, soy sauce, oil",
95
+ "method": "stir-fried",
96
+ "substitute": "brown rice"
97
+ },
98
+ "tacos": {
99
+ "serving": "100 g (1 soft taco)",
100
+ "calories": "200 kcal",
101
+ "protein": "10 g",
102
+ "carbs": "15 g",
103
+ "fat": "10 g",
104
+ "ingredients": "ground beef, corn tortilla, lettuce, cheese, salsa",
105
+ "method": "beef pan-fried, tortilla warmed",
106
+ "substitute": "fish filling"
107
+ },
108
+ "steak": {
109
+ "serving": "113 g (4 oz)",
110
+ "calories": "250 kcal",
111
+ "protein": "25 g",
112
+ "carbs": "0 g",
113
+ "fat": "15 g",
114
+ "ingredients": "beef sirloin, salt, pepper",
115
+ "method": "grilled or pan-seared",
116
+ "substitute": "leaner cut (filet mignon)"
117
+ },
118
+ "chocolate_cake": {
119
+ "serving": "100 g (1 slice)",
120
+ "calories": "350 kcal",
121
+ "protein": "5 g",
122
+ "carbs": "50 g",
123
+ "fat": "15 g",
124
+ "ingredients": "flour, sugar, cocoa, butter, eggs",
125
+ "method": "baked",
126
+ "substitute": "gluten-free flour"
127
+ }
128
+ }
129
+
130
+ # 3. Image transform for ViT
131
  manual_transform = transforms.Compose([
132
  transforms.Resize(256),
133
  transforms.CenterCrop(224),
134
  transforms.Lambda(lambda img: img.convert("RGB")),
135
  transforms.ToTensor(),
136
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
137
+ std=[0.229, 0.224, 0.225])
 
138
  ])
139
 
140
+ # 4. Sidebar info
141
  st.sidebar.header("Models Used")
142
  st.sidebar.markdown("""
143
+ - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
144
+ - 💬 **Paraphraser**: `google/flan-t5-small`
145
  """)
146
 
147
+ # 5. Load models (cached)
148
  @st.cache_resource
149
  def load_models():
150
+ device = torch.device("cpu")
151
+ vit = ViTForImageClassification.from_pretrained(
 
 
152
  "shingguy1/fine_tuned_vit",
153
  cache_dir=cache_dir,
154
  use_auth_token=hf_token
155
  ).to(device)
156
+ tok = AutoTokenizer.from_pretrained(
 
 
157
  "google/flan-t5-small",
158
  cache_dir=cache_dir,
159
  use_auth_token=hf_token
160
  )
161
+ paraphraser = T5ForConditionalGeneration.from_pretrained(
162
  "google/flan-t5-small",
163
  cache_dir=cache_dir,
164
  use_auth_token=hf_token
165
  ).to(device)
166
+ return vit, tok, paraphraser, device
167
+
168
+ model_vit, tokenizer_t5, model_t5, device = load_models()
169
+
170
+ # 6. Uploader
171
+ uploaded = st.file_uploader("Upload a food image...", type=["jpg","png","jpeg"])
172
+ if uploaded:
173
+ img = Image.open(uploaded)
174
+ st.image(img, caption="Your Food", use_column_width=True)
175
+
176
+ # classify
177
+ inp = manual_transform(img).unsqueeze(0).to(device)
178
+ with torch.no_grad():
179
+ out = model_vit(pixel_values=inp)
180
+ label = model_vit.config.id2label[out.logits.argmax(-1).item()]
181
+ st.success(f"🍽️ Detected: **{label}**")
182
+
183
+ # lookup
184
+ data = nutritional_info.get(label.lower())
185
+ if not data:
186
+ st.error("No nutrition data for this item.")
187
+ return
188
+
189
+ # slot-fill template
190
+ templates = [
191
+ "A typical {label} serving ({serving}) contains about {calories}, with {protein} protein, {carbs} carbs, and {fat} fat. "
192
+ "Made from {ingredients} and usually {method}. Try {substitute} as a healthier swap.",
193
+ "For {label}, one {serving} provides {calories}. It offers {protein} protein, {carbs} carbohydrates, and {fat} fat. "
194
+ "Ingredients include {ingredients}, and it's {method}. You can substitute {substitute}."
195
+ ]
196
+ raw = random.choice(templates).format(label=label,
197
+ serving=data["serving"],
198
+ calories=data["calories"],
199
+ protein=data["protein"],
200
+ carbs=data["carbs"],
201
+ fat=data["fat"],
202
+ ingredients=data["ingredients"],
203
+ method=data["method"],
204
+ substitute=data["substitute"])
205
+
206
+ # paraphrase
207
+ prompt = f"Paraphrase this nutritional info without changing facts:\n\n{raw}"
208
+ inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
209
+ out_ids = model_t5.generate(
210
+ **inputs,
211
+ max_new_tokens=100,
212
+ do_sample=True,
213
+ temperature=0.8,
214
+ top_p=0.9
215
+ )
216
+ paraphrased = tokenizer_t5.decode(out_ids[0], skip_special_tokens=True)
217
 
218
+ st.subheader("🧾 Nutrition Overview")
219
+ st.info(paraphrased or raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  if __name__ == "__main__":
222
  main()