shingguy1 commited on
Commit
6ede4a3
·
verified ·
1 Parent(s): 3034552

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +59 -77
src/streamlit_app.py CHANGED
@@ -10,69 +10,28 @@ from transformers import (
10
  T5ForConditionalGeneration
11
  )
12
 
13
- # Set page config with WHOOP branding
14
  st.set_page_config(
15
- page_title="WHOOP Nutrition Estimator",
16
- page_icon="https://www.whoop.com/wp-content/themes/whoop/library/images/whoop-logo-dark.svg",
17
  layout="centered"
18
  )
19
 
20
- # Inject global black-and-white styling
21
- st.markdown(
22
- """
23
- <style>
24
- /* Background and text monochrome */
25
- html, body, [class*="css"] {
26
- background-color: #ffffff !important;
27
- color: #000000 !important;
28
- }
29
- /* Sidebar and main container */
30
- .stSidebar, .stApp {
31
- background-color: #ffffff !important;
32
- color: #000000 !important;
33
- }
34
- /* Buttons styling */
35
- button, .stButton>button {
36
- background-color: #000000 !important;
37
- color: #ffffff !important;
38
- border: 1px solid #000000 !important;
39
- }
40
- /* Sidebar header accent */
41
- .stSidebar .css-1d391kg {
42
- color: #000000 !important;
43
- }
44
- </style>
45
- """, unsafe_allow_html=True
46
- )
47
 
48
- # Load WHOOP logo
49
- WHOOP_LOGO = "https://www.whoop.com/wp-content/themes/whoop/library/images/whoop-logo-dark.svg"
 
 
50
 
51
- # Main application
52
- def main():
53
- # Display WHOOP logo at top
54
- st.image(WHOOP_LOGO, width=200)
55
- st.title("WHOOP 🍽️ Food Nutrition Estimator")
56
- st.markdown(
57
- """
58
- **Powered by WHOOP Nutrition Science**
59
-
60
- Upload a food image to classify it and receive a paraphrased nutritional overview
61
- tailored to your WHOOP goals and recovery insights.
62
-
63
- ⚠️ This demo covers **10 food categories**:
64
- `pizza`, `hamburger`, `sushi`, `caesar_salad`, `spaghetti_bolognese`,
65
- `ice_cream`, `fried_rice`, `tacos`, `steak`, `chocolate_cake`.
66
- """
67
- )
68
-
69
- # Environment setup
70
  hf_token = os.getenv("HF_TOKEN", None)
71
  cache_dir = "/tmp/cache"
72
  os.makedirs(cache_dir, exist_ok=True)
73
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
74
 
75
- # Nutrition data
76
  nutritional_info = {
77
  "pizza": {"serving": "100 g (1 slice)", "calories": "270 kcal", "protein": "12 g", "carbs": "34 g", "fat": "10 g", "ingredients": "dough, tomato sauce, mozzarella cheese", "method": "baked", "substitute": "cauliflower crust"},
78
  "hamburger": {"serving": "150 g", "calories": "300 kcal", "protein": "20 g", "carbs": "30 g", "fat": "12 g", "ingredients": "ground beef patty, bun, lettuce, tomato", "method": "grilled or pan-fried", "substitute": "chicken patty"},
@@ -85,55 +44,63 @@ def main():
85
  "steak": {"serving": "113 g (4 oz)", "calories": "250 kcal", "protein": "25 g", "carbs": "0 g", "fat": "15 g", "ingredients": "beef sirloin, salt, pepper", "method": "grilled or pan-seared", "substitute": "leaner cut (filet mignon)"},
86
  "chocolate_cake": {"serving": "100 g (1 slice)", "calories": "350 kcal", "protein": "5 g", "carbs": "50 g", "fat": "15 g", "ingredients": "flour, sugar, cocoa, butter, eggs", "method": "baked", "substitute": "gluten-free flour"}
87
  }
88
- label_mapping = {"caesar_salad": "salad", "spaghetti_bolognese": "pasta"}
89
-
90
- # Sidebar
91
- st.sidebar.image(WHOOP_LOGO, width=150)
92
- st.sidebar.header("WHOOP Model Suite")
93
- st.sidebar.markdown(
94
- """
95
- - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
96
- - 💬 **Nutrition Paraphraser**: `google/flan-t5-small`
97
- """
98
- )
99
-
100
- # Image transforms
101
  transform = transforms.Compose([
102
  transforms.Resize(256),
103
  transforms.CenterCrop(224),
104
  transforms.Lambda(lambda img: img.convert("RGB")),
105
  transforms.ToTensor(),
106
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
107
  ])
108
 
109
  @st.cache_resource
110
  def load_models():
111
  device = torch.device("cpu")
112
  vit = ViTForImageClassification.from_pretrained(
113
- "shingguy1/fine_tuned_vit", cache_dir=cache_dir, use_auth_token=hf_token
 
 
114
  ).to(device)
115
  tok = AutoTokenizer.from_pretrained(
116
- "google/flan-t5-small", cache_dir=cache_dir, use_auth_token=hf_token
 
 
117
  )
118
  t5 = T5ForConditionalGeneration.from_pretrained(
119
- "google/flan-t5-small", cache_dir=cache_dir, use_auth_token=hf_token
 
 
120
  ).to(device)
121
  return vit, tok, t5, device
122
 
123
  model_vit, tokenizer_t5, model_t5, device = load_models()
124
 
125
- # File uploader and inference loop
126
- uploaded = st.file_uploader("📷 Upload a food image...", type=["jpg","png","jpeg"])
127
  if uploaded:
128
  img = Image.open(uploaded)
129
  st.image(img, caption="Your Food", use_column_width=True)
 
130
  inp = transform(img).unsqueeze(0).to(device)
131
- with torch.no_grad(): out = model_vit(pixel_values=inp)
 
132
  label = model_vit.config.id2label[out.logits.argmax(-1).item()]
133
  st.success(f"🍽️ Detected: **{label}**")
134
 
135
  true_label = label_mapping.get(label.lower(), label.lower())
136
  data = nutritional_info.get(true_label)
 
137
  if data:
138
  base_description = (
139
  f"A typical {true_label} serving ({data['serving']}) contains about {data['calories']}, "
@@ -144,17 +111,32 @@ def main():
144
  prompt = (
145
  f"Paraphrase the following nutritional facts in a friendly, conversational tone. "
146
  f"Use varied sentence structures and synonyms, and feel free to generalize numeric details "
147
- f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n" + base_description
 
148
  )
149
  else:
150
- prompt = f"Provide an approximate nutrition summary for {label}, including calories, macronutrients, and a brief description."
 
 
 
151
 
152
  inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
153
- output_ids = model_t5.generate(inputs["input_ids"], max_new_tokens=100, do_sample=True, top_p=0.9, temperature=0.7, early_stopping=True)
 
 
 
 
 
 
 
154
  response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)
155
- if "calories" not in response.lower() or len(response.split()) < 10: response = base_description
 
 
 
156
 
157
  st.subheader("🧾 Nutrition Overview")
158
  st.info(response)
159
 
160
- if __name__ == "__main__": main()
 
 
10
  T5ForConditionalGeneration
11
  )
12
 
13
+ # Set page config
14
  st.set_page_config(
15
+ page_title="🍽️ Food Nutrition Estimator",
16
+ page_icon="🥗",
17
  layout="centered"
18
  )
19
 
20
+ def main():
21
+ st.title("🍽️ Food Nutrition Estimator")
22
+ st.markdown("""
23
+ Upload a food image to classify it and receive a paraphrased nutritional description.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ ⚠️ This demo is trained on **10 food categories** only:
26
+ pizza, hamburger, sushi, caesar_salad, spaghetti_bolognese,
27
+ ice_cream, fried_rice, tacos, steak, chocolate_cake.
28
+ """)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  hf_token = os.getenv("HF_TOKEN", None)
31
  cache_dir = "/tmp/cache"
32
  os.makedirs(cache_dir, exist_ok=True)
33
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
34
 
 
35
  nutritional_info = {
36
  "pizza": {"serving": "100 g (1 slice)", "calories": "270 kcal", "protein": "12 g", "carbs": "34 g", "fat": "10 g", "ingredients": "dough, tomato sauce, mozzarella cheese", "method": "baked", "substitute": "cauliflower crust"},
37
  "hamburger": {"serving": "150 g", "calories": "300 kcal", "protein": "20 g", "carbs": "30 g", "fat": "12 g", "ingredients": "ground beef patty, bun, lettuce, tomato", "method": "grilled or pan-fried", "substitute": "chicken patty"},
 
44
  "steak": {"serving": "113 g (4 oz)", "calories": "250 kcal", "protein": "25 g", "carbs": "0 g", "fat": "15 g", "ingredients": "beef sirloin, salt, pepper", "method": "grilled or pan-seared", "substitute": "leaner cut (filet mignon)"},
45
  "chocolate_cake": {"serving": "100 g (1 slice)", "calories": "350 kcal", "protein": "5 g", "carbs": "50 g", "fat": "15 g", "ingredients": "flour, sugar, cocoa, butter, eggs", "method": "baked", "substitute": "gluten-free flour"}
46
  }
47
+
48
+ label_mapping = {
49
+ "caesar_salad": "salad",
50
+ "spaghetti_bolognese": "pasta"
51
+ }
52
+
53
+ st.sidebar.header("Models Used")
54
+ st.sidebar.markdown("""
55
+ - 🖼️ **Image Classifier**: shingguy1/fine_tuned_vit
56
+ - 💬 **Paraphraser**: google/flan-t5-small (sampling mode)
57
+ """)
58
+
 
59
  transform = transforms.Compose([
60
  transforms.Resize(256),
61
  transforms.CenterCrop(224),
62
  transforms.Lambda(lambda img: img.convert("RGB")),
63
  transforms.ToTensor(),
64
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
65
+ std=[0.229, 0.224, 0.225])
66
  ])
67
 
68
  @st.cache_resource
69
  def load_models():
70
  device = torch.device("cpu")
71
  vit = ViTForImageClassification.from_pretrained(
72
+ "shingguy1/fine_tuned_vit",
73
+ cache_dir=cache_dir,
74
+ use_auth_token=hf_token
75
  ).to(device)
76
  tok = AutoTokenizer.from_pretrained(
77
+ "google/flan-t5-small",
78
+ cache_dir=cache_dir,
79
+ use_auth_token=hf_token
80
  )
81
  t5 = T5ForConditionalGeneration.from_pretrained(
82
+ "google/flan-t5-small",
83
+ cache_dir=cache_dir,
84
+ use_auth_token=hf_token
85
  ).to(device)
86
  return vit, tok, t5, device
87
 
88
  model_vit, tokenizer_t5, model_t5, device = load_models()
89
 
90
+ uploaded = st.file_uploader("📷 Upload a food image...", type=["jpg", "png", "jpeg"])
 
91
  if uploaded:
92
  img = Image.open(uploaded)
93
  st.image(img, caption="Your Food", use_column_width=True)
94
+
95
  inp = transform(img).unsqueeze(0).to(device)
96
+ with torch.no_grad():
97
+ out = model_vit(pixel_values=inp)
98
  label = model_vit.config.id2label[out.logits.argmax(-1).item()]
99
  st.success(f"🍽️ Detected: **{label}**")
100
 
101
  true_label = label_mapping.get(label.lower(), label.lower())
102
  data = nutritional_info.get(true_label)
103
+
104
  if data:
105
  base_description = (
106
  f"A typical {true_label} serving ({data['serving']}) contains about {data['calories']}, "
 
111
  prompt = (
112
  f"Paraphrase the following nutritional facts in a friendly, conversational tone. "
113
  f"Use varied sentence structures and synonyms, and feel free to generalize numeric details "
114
+ f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n"
115
+ f"{base_description}"
116
  )
117
  else:
118
+ prompt = (
119
+ f"Provide an approximate nutrition summary for {label}, including calories, "
120
+ f"macronutrients, and a brief description."
121
+ )
122
 
123
  inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
124
+ output_ids = model_t5.generate(
125
+ inputs["input_ids"],
126
+ max_new_tokens=100,
127
+ do_sample=True,
128
+ top_p=0.9,
129
+ temperature=0.7,
130
+ early_stopping=True
131
+ )
132
  response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)
133
+
134
+ # Fallback if the output seems too short or misses key phrases
135
+ if "calories" not in response.lower() or len(response.split()) < 10:
136
+ response = base_description
137
 
138
  st.subheader("🧾 Nutrition Overview")
139
  st.info(response)
140
 
141
+ if __name__ == "__main__":
142
+ main()