shingguy1 commited on
Commit
749ea77
·
verified ·
1 Parent(s): 50e8acf

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +124 -124
src/streamlit_app.py CHANGED
@@ -1,4 +1,10 @@
1
  import streamlit as st
 
 
 
 
 
 
2
  import torch
3
  import os
4
  from PIL import Image
@@ -9,130 +15,124 @@ from transformers import (
9
  AutoModelForCausalLM
10
  )
11
 
12
- # 1. Streamlit UI setup
13
- st.set_page_config(
14
- page_title="🍽️ Food Nutrition Estimator",
15
- page_icon="🥗",
16
- layout="centered"
17
- )
18
- st.title("🍽️ Food Nutrition Estimator")
19
- st.markdown("Upload a food image and get a nutritional overview generated by an instruction‐tuned LLM!")
20
-
21
- # 2. Environment & cache
22
- hf_token = os.getenv("HF_TOKEN", None)
23
- cache_dir = "/tmp/cache"
24
- os.makedirs(cache_dir, exist_ok=True)
25
- os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
26
-
27
- # 3. Image transform for ViT
28
- manual_transform = transforms.Compose([
29
- transforms.Resize(256),
30
- transforms.CenterCrop(224),
31
- transforms.Lambda(lambda img: img.convert("RGB")),
32
- transforms.ToTensor(),
33
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
34
- std=[0.229, 0.224, 0.225]),
35
- transforms.ConvertImageDtype(torch.float32)
36
- ])
37
-
38
- # 4. Sidebar info
39
- st.sidebar.header("Models Used")
40
- st.sidebar.markdown("""
41
- - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
42
- - 💬 **Text Generator**: `tiiuae/falcon-7b-instruct`
43
- """)
44
-
45
- # 5. Load models (cached)
46
- @st.cache_resource
47
- def load_models():
48
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
-
50
- # ViT classifier
51
- model_vit = ViTForImageClassification.from_pretrained(
52
- "shingguy1/fine_tuned_vit",
53
- cache_dir=cache_dir,
54
- use_auth_token=hf_token
55
- ).to(device)
56
-
57
- # Falcon‐7B Instruct LLM
58
- tokenizer_llm = AutoTokenizer.from_pretrained(
59
- "tiiuae/falcon-7b-instruct",
60
- cache_dir=cache_dir,
61
- use_auth_token=hf_token
62
- )
63
- model_llm = AutoModelForCausalLM.from_pretrained(
64
- "tiiuae/falcon-7b-instruct",
65
- cache_dir=cache_dir,
66
- use_auth_token=hf_token,
67
- torch_dtype=torch.float16,
68
- device_map="auto"
69
- )
70
-
71
- return model_vit, tokenizer_llm, model_llm, device
72
-
73
- model_vit, tokenizer_llm, model_llm, device = load_models()
74
-
75
- # 6. Image uploader
76
- uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
77
-
78
- if uploaded_file is not None:
79
- try:
80
- # Display image
81
- image = Image.open(uploaded_file)
82
- st.image(image, caption="Uploaded Image", use_column_width=True)
83
-
84
- # Classify with ViT
85
- input_tensor = manual_transform(image).unsqueeze(0).to(device)
86
- with torch.no_grad():
87
- outputs = model_vit(pixel_values=input_tensor)
88
- pred_idx = outputs.logits.argmax(-1).item()
89
- pred_label = model_vit.config.id2label[pred_idx]
90
- st.success(f"🍴 Predicted Food: **{pred_label}**")
91
-
92
- # Build a single, unified instruction prompt
93
- prompt = (
94
- "### Instruction\n"
95
- f"Provide a concise nutritional overview for a {pred_label}, including:\n"
96
- "- Serving size (exact measurements & ingestion guidelines)\n"
97
- "- Calories\n"
98
- "- Protein, carbohydrates, and fat\n"
99
- "- Main ingredients\n"
100
- "- Cooking method\n"
101
- "- One healthy substitution\n"
102
- "### Response"
103
  )
104
- st.subheader("🧾 Nutrition Information")
105
- st.write(f"🤖 Prompt sent to LLM:\n\n{prompt}")
106
-
107
- # Tokenize & generate
108
- inputs = tokenizer_llm(prompt, return_tensors="pt")
109
- inputs = {k: v.to(model_llm.device) for k, v in inputs.items()}
110
- input_len = inputs["input_ids"].shape[1]
111
-
112
- outputs = model_llm.generate(
113
- **inputs,
114
- max_length=input_len + 150,
115
- temperature=0.7,
116
- top_p=0.9,
117
- do_sample=True,
118
- no_repeat_ngram_size=2,
119
- early_stopping=True,
120
- pad_token_id=tokenizer_llm.eos_token_id,
121
- eos_token_id=tokenizer_llm.eos_token_id
122
  )
123
 
124
- # Decode and strip prompt
125
- full = tokenizer_llm.decode(outputs[0], skip_special_tokens=True).strip()
126
- if full.startswith("### Response"):
127
- caption = full.split("### Response", 1)[1].strip()
128
- else:
129
- caption = full[input_len:].strip()
130
-
131
- st.info(caption or "⚠️ The LLM did not generate any text.")
132
-
133
- except Exception as e:
134
- st.error(f"Something went wrong: {e}")
135
-
136
- # Footer
137
- st.markdown("---")
138
- st.markdown("Built with ❤️ using Streamlit and Hugging Face by **shingguy1**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ st.set_page_config(
3
+ page_title="🍽️ Food Nutrition Estimator",
4
+ page_icon="🥗",
5
+ layout="centered"
6
+ )
7
+
8
  import torch
9
  import os
10
  from PIL import Image
 
15
  AutoModelForCausalLM
16
  )
17
 
18
+ def main():
19
+ # 2. Environment & cache
20
+ hf_token = os.getenv("HF_TOKEN", None)
21
+ cache_dir = "/tmp/cache"
22
+ os.makedirs(cache_dir, exist_ok=True)
23
+ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
24
+
25
+ # 3. Image transform for ViT
26
+ manual_transform = transforms.Compose([
27
+ transforms.Resize(256),
28
+ transforms.CenterCrop(224),
29
+ transforms.Lambda(lambda img: img.convert("RGB")),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
32
+ std=[0.229, 0.224, 0.225]),
33
+ transforms.ConvertImageDtype(torch.float32)
34
+ ])
35
+
36
+ # 4. Sidebar info
37
+ st.sidebar.header("Models Used")
38
+ st.sidebar.markdown("""
39
+ - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
40
+ - 💬 **Text Generator**: `tiiuae/falcon-7b-instruct`
41
+ """)
42
+
43
+ # 5. Load models
44
+ @st.cache_resource
45
+ def load_models():
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
+ # ViT classifier
49
+ model_vit = ViTForImageClassification.from_pretrained(
50
+ "shingguy1/fine_tuned_vit",
51
+ cache_dir=cache_dir,
52
+ use_auth_token=hf_token
53
+ ).to(device)
54
+
55
+ # Falcon-7B Instruct LLM
56
+ tokenizer_llm = AutoTokenizer.from_pretrained(
57
+ "tiiuae/falcon-7b-instruct",
58
+ cache_dir=cache_dir,
59
+ use_auth_token=hf_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
+ model_llm = AutoModelForCausalLM.from_pretrained(
62
+ "tiiuae/falcon-7b-instruct",
63
+ cache_dir=cache_dir,
64
+ use_auth_token=hf_token,
65
+ torch_dtype=torch.float16,
66
+ device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
 
69
+ return model_vit, tokenizer_llm, model_llm, device
70
+
71
+ model_vit, tokenizer_llm, model_llm, device = load_models()
72
+
73
+ # 6. Image uploader
74
+ uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
75
+
76
+ if uploaded_file is not None:
77
+ try:
78
+ # Display image
79
+ image = Image.open(uploaded_file)
80
+ st.image(image, caption="Uploaded Image", use_column_width=True)
81
+
82
+ # Classify with ViT
83
+ input_tensor = manual_transform(image).unsqueeze(0).to(device)
84
+ with torch.no_grad():
85
+ outputs = model_vit(pixel_values=input_tensor)
86
+ pred_idx = outputs.logits.argmax(-1).item()
87
+ pred_label = model_vit.config.id2label[pred_idx]
88
+ st.success(f"🍴 Predicted Food: **{pred_label}**")
89
+
90
+ # Build prompt
91
+ prompt = (
92
+ "### Instruction\n"
93
+ f"Provide a concise nutritional overview for a {pred_label}, including:\n"
94
+ "- Serving size (measurements & ingestion guidelines)\n"
95
+ "- Calories\n"
96
+ "- Protein, carbohydrates, and fat\n"
97
+ "- Main ingredients\n"
98
+ "- Cooking method\n"
99
+ "- One healthy substitution\n"
100
+ "### Response"
101
+ )
102
+ st.subheader("🧾 Nutrition Information")
103
+ st.write(f"🤖 Prompt to LLM:\n\n{prompt}")
104
+
105
+ # Tokenize & generate
106
+ inputs = tokenizer_llm(prompt, return_tensors="pt")
107
+ inputs = {k: v.to(device) for k, v in inputs.items()}
108
+ input_len = inputs["input_ids"].shape[1]
109
+
110
+ outputs = model_llm.generate(
111
+ **inputs,
112
+ max_length=input_len + 150,
113
+ temperature=0.7,
114
+ top_p=0.9,
115
+ do_sample=True,
116
+ no_repeat_ngram_size=2,
117
+ early_stopping=True,
118
+ pad_token_id=tokenizer_llm.eos_token_id,
119
+ eos_token_id=tokenizer_llm.eos_token_id
120
+ )
121
+
122
+ # Decode and strip prompt
123
+ full = tokenizer_llm.decode(outputs[0], skip_special_tokens=True).strip()
124
+ if "### Response" in full:
125
+ caption = full.split("### Response", 1)[1].strip()
126
+ else:
127
+ caption = full[input_len:].strip()
128
+
129
+ if caption:
130
+ st.info(caption)
131
+ else:
132
+ st.error("⚠️ The LLM did not generate any text.")
133
+
134
+ except Exception as e:
135
+ st.error(f"Something went wrong: {e}")
136
+
137
+ if __name__ == "__main__":
138
+ main()