shingguy1 commited on
Commit
3034552
·
verified ·
1 Parent(s): 4840f3d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +21 -41
src/streamlit_app.py CHANGED
@@ -48,6 +48,7 @@ st.markdown(
48
  # Load WHOOP logo
49
  WHOOP_LOGO = "https://www.whoop.com/wp-content/themes/whoop/library/images/whoop-logo-dark.svg"
50
 
 
51
  def main():
52
  # Display WHOOP logo at top
53
  st.image(WHOOP_LOGO, width=200)
@@ -65,11 +66,13 @@ def main():
65
  """
66
  )
67
 
 
68
  hf_token = os.getenv("HF_TOKEN", None)
69
  cache_dir = "/tmp/cache"
70
  os.makedirs(cache_dir, exist_ok=True)
71
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
72
 
 
73
  nutritional_info = {
74
  "pizza": {"serving": "100 g (1 slice)", "calories": "270 kcal", "protein": "12 g", "carbs": "34 g", "fat": "10 g", "ingredients": "dough, tomato sauce, mozzarella cheese", "method": "baked", "substitute": "cauliflower crust"},
75
  "hamburger": {"serving": "150 g", "calories": "300 kcal", "protein": "20 g", "carbs": "30 g", "fat": "12 g", "ingredients": "ground beef patty, bun, lettuce, tomato", "method": "grilled or pan-fried", "substitute": "chicken patty"},
@@ -82,19 +85,19 @@ def main():
82
  "steak": {"serving": "113 g (4 oz)", "calories": "250 kcal", "protein": "25 g", "carbs": "0 g", "fat": "15 g", "ingredients": "beef sirloin, salt, pepper", "method": "grilled or pan-seared", "substitute": "leaner cut (filet mignon)"},
83
  "chocolate_cake": {"serving": "100 g (1 slice)", "calories": "350 kcal", "protein": "5 g", "carbs": "50 g", "fat": "15 g", "ingredients": "flour, sugar, cocoa, butter, eggs", "method": "baked", "substitute": "gluten-free flour"}
84
  }
 
85
 
86
- label_mapping = {
87
- "caesar_salad": "salad",
88
- "spaghetti_bolognese": "pasta"
89
- }
90
-
91
  st.sidebar.image(WHOOP_LOGO, width=150)
92
  st.sidebar.header("WHOOP Model Suite")
93
  st.sidebar.markdown(
94
- "- 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
95
- - 💬 **Nutrition Paraphraser**: `google/flan-t5-small`"
 
 
96
  )
97
 
 
98
  transform = transforms.Compose([
99
  transforms.Resize(256),
100
  transforms.CenterCrop(224),
@@ -107,38 +110,30 @@ def main():
107
  def load_models():
108
  device = torch.device("cpu")
109
  vit = ViTForImageClassification.from_pretrained(
110
- "shingguy1/fine_tuned_vit",
111
- cache_dir=cache_dir,
112
- use_auth_token=hf_token
113
  ).to(device)
114
  tok = AutoTokenizer.from_pretrained(
115
- "google/flan-t5-small",
116
- cache_dir=cache_dir,
117
- use_auth_token=hf_token
118
  )
119
  t5 = T5ForConditionalGeneration.from_pretrained(
120
- "google/flan-t5-small",
121
- cache_dir=cache_dir,
122
- use_auth_token=hf_token
123
  ).to(device)
124
  return vit, tok, t5, device
125
 
126
  model_vit, tokenizer_t5, model_t5, device = load_models()
127
 
128
- uploaded = st.file_uploader("📷 Upload a food image...", type=["jpg", "png", "jpeg"])
 
129
  if uploaded:
130
  img = Image.open(uploaded)
131
  st.image(img, caption="Your Food", use_column_width=True)
132
-
133
  inp = transform(img).unsqueeze(0).to(device)
134
- with torch.no_grad():
135
- out = model_vit(pixel_values=inp)
136
  label = model_vit.config.id2label[out.logits.argmax(-1).item()]
137
  st.success(f"🍽️ Detected: **{label}**")
138
 
139
  true_label = label_mapping.get(label.lower(), label.lower())
140
  data = nutritional_info.get(true_label)
141
-
142
  if data:
143
  base_description = (
144
  f"A typical {true_label} serving ({data['serving']}) contains about {data['calories']}, "
@@ -149,32 +144,17 @@ def main():
149
  prompt = (
150
  f"Paraphrase the following nutritional facts in a friendly, conversational tone. "
151
  f"Use varied sentence structures and synonyms, and feel free to generalize numeric details "
152
- f"(e.g., ‘around 250 kcal’). Don’t add any new facts.
153
-
154
- " + base_description
155
  )
156
  else:
157
- prompt = (
158
- f"Provide an approximate nutrition summary for {label}, including calories, "
159
- f"macronutrients, and a brief description."
160
- )
161
 
162
  inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
163
- output_ids = model_t5.generate(
164
- inputs["input_ids"],
165
- max_new_tokens=100,
166
- do_sample=True,
167
- top_p=0.9,
168
- temperature=0.7,
169
- early_stopping=True
170
- )
171
  response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)
172
-
173
- if "calories" not in response.lower() or len(response.split()) < 10:
174
- response = base_description
175
 
176
  st.subheader("🧾 Nutrition Overview")
177
  st.info(response)
178
 
179
- if __name__ == "__main__":
180
- main()
 
48
  # Load WHOOP logo
49
  WHOOP_LOGO = "https://www.whoop.com/wp-content/themes/whoop/library/images/whoop-logo-dark.svg"
50
 
51
+ # Main application
52
  def main():
53
  # Display WHOOP logo at top
54
  st.image(WHOOP_LOGO, width=200)
 
66
  """
67
  )
68
 
69
+ # Environment setup
70
  hf_token = os.getenv("HF_TOKEN", None)
71
  cache_dir = "/tmp/cache"
72
  os.makedirs(cache_dir, exist_ok=True)
73
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
74
 
75
+ # Nutrition data
76
  nutritional_info = {
77
  "pizza": {"serving": "100 g (1 slice)", "calories": "270 kcal", "protein": "12 g", "carbs": "34 g", "fat": "10 g", "ingredients": "dough, tomato sauce, mozzarella cheese", "method": "baked", "substitute": "cauliflower crust"},
78
  "hamburger": {"serving": "150 g", "calories": "300 kcal", "protein": "20 g", "carbs": "30 g", "fat": "12 g", "ingredients": "ground beef patty, bun, lettuce, tomato", "method": "grilled or pan-fried", "substitute": "chicken patty"},
 
85
  "steak": {"serving": "113 g (4 oz)", "calories": "250 kcal", "protein": "25 g", "carbs": "0 g", "fat": "15 g", "ingredients": "beef sirloin, salt, pepper", "method": "grilled or pan-seared", "substitute": "leaner cut (filet mignon)"},
86
  "chocolate_cake": {"serving": "100 g (1 slice)", "calories": "350 kcal", "protein": "5 g", "carbs": "50 g", "fat": "15 g", "ingredients": "flour, sugar, cocoa, butter, eggs", "method": "baked", "substitute": "gluten-free flour"}
87
  }
88
+ label_mapping = {"caesar_salad": "salad", "spaghetti_bolognese": "pasta"}
89
 
90
+ # Sidebar
 
 
 
 
91
  st.sidebar.image(WHOOP_LOGO, width=150)
92
  st.sidebar.header("WHOOP Model Suite")
93
  st.sidebar.markdown(
94
+ """
95
+ - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
96
+ - 💬 **Nutrition Paraphraser**: `google/flan-t5-small`
97
+ """
98
  )
99
 
100
+ # Image transforms
101
  transform = transforms.Compose([
102
  transforms.Resize(256),
103
  transforms.CenterCrop(224),
 
110
  def load_models():
111
  device = torch.device("cpu")
112
  vit = ViTForImageClassification.from_pretrained(
113
+ "shingguy1/fine_tuned_vit", cache_dir=cache_dir, use_auth_token=hf_token
 
 
114
  ).to(device)
115
  tok = AutoTokenizer.from_pretrained(
116
+ "google/flan-t5-small", cache_dir=cache_dir, use_auth_token=hf_token
 
 
117
  )
118
  t5 = T5ForConditionalGeneration.from_pretrained(
119
+ "google/flan-t5-small", cache_dir=cache_dir, use_auth_token=hf_token
 
 
120
  ).to(device)
121
  return vit, tok, t5, device
122
 
123
  model_vit, tokenizer_t5, model_t5, device = load_models()
124
 
125
+ # File uploader and inference loop
126
+ uploaded = st.file_uploader("📷 Upload a food image...", type=["jpg","png","jpeg"])
127
  if uploaded:
128
  img = Image.open(uploaded)
129
  st.image(img, caption="Your Food", use_column_width=True)
 
130
  inp = transform(img).unsqueeze(0).to(device)
131
+ with torch.no_grad(): out = model_vit(pixel_values=inp)
 
132
  label = model_vit.config.id2label[out.logits.argmax(-1).item()]
133
  st.success(f"🍽️ Detected: **{label}**")
134
 
135
  true_label = label_mapping.get(label.lower(), label.lower())
136
  data = nutritional_info.get(true_label)
 
137
  if data:
138
  base_description = (
139
  f"A typical {true_label} serving ({data['serving']}) contains about {data['calories']}, "
 
144
  prompt = (
145
  f"Paraphrase the following nutritional facts in a friendly, conversational tone. "
146
  f"Use varied sentence structures and synonyms, and feel free to generalize numeric details "
147
+ f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n" + base_description
 
 
148
  )
149
  else:
150
+ prompt = f"Provide an approximate nutrition summary for {label}, including calories, macronutrients, and a brief description."
 
 
 
151
 
152
  inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
153
+ output_ids = model_t5.generate(inputs["input_ids"], max_new_tokens=100, do_sample=True, top_p=0.9, temperature=0.7, early_stopping=True)
 
 
 
 
 
 
 
154
  response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)
155
+ if "calories" not in response.lower() or len(response.split()) < 10: response = base_description
 
 
156
 
157
  st.subheader("🧾 Nutrition Overview")
158
  st.info(response)
159
 
160
+ if __name__ == "__main__": main()