shingguy1 commited on
Commit
50e8acf
·
verified ·
1 Parent(s): c5c8acf

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +39 -33
src/streamlit_app.py CHANGED
@@ -16,7 +16,7 @@ st.set_page_config(
16
  layout="centered"
17
  )
18
  st.title("🍽️ Food Nutrition Estimator")
19
- st.markdown("Upload a food image and get nutritional information generated by AI!")
20
 
21
  # 2. Environment & cache
22
  hf_token = os.getenv("HF_TOKEN", None)
@@ -28,7 +28,7 @@ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
28
  manual_transform = transforms.Compose([
29
  transforms.Resize(256),
30
  transforms.CenterCrop(224),
31
- transforms.Lambda(lambda img: img.convert("RGB")), # ensure 3 channels
32
  transforms.ToTensor(),
33
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
34
  std=[0.229, 0.224, 0.225]),
@@ -39,7 +39,7 @@ manual_transform = transforms.Compose([
39
  st.sidebar.header("Models Used")
40
  st.sidebar.markdown("""
41
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
42
- - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
43
  """)
44
 
45
  # 5. Load models (cached)
@@ -47,35 +47,37 @@ st.sidebar.markdown("""
47
  def load_models():
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
 
50
- # ViT for classification
51
  model_vit = ViTForImageClassification.from_pretrained(
52
  "shingguy1/fine_tuned_vit",
53
  cache_dir=cache_dir,
54
  use_auth_token=hf_token
55
  ).to(device)
56
 
57
- # TinyLlama for nutrition text
58
- tokenizer = AutoTokenizer.from_pretrained(
59
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
60
- cache_dir=cache_dir
 
61
  )
62
  model_llm = AutoModelForCausalLM.from_pretrained(
63
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
64
  cache_dir=cache_dir,
65
- torch_dtype=torch.float32,
 
66
  device_map="auto"
67
  )
68
 
69
- return model_vit, tokenizer, model_llm, device
70
 
71
- model_vit, tokenizer, model_llm, device = load_models()
72
 
73
  # 6. Image uploader
74
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
75
 
76
  if uploaded_file is not None:
77
  try:
78
- # Load & display image
79
  image = Image.open(uploaded_file)
80
  st.image(image, caption="Uploaded Image", use_column_width=True)
81
 
@@ -87,42 +89,46 @@ if uploaded_file is not None:
87
  pred_label = model_vit.config.id2label[pred_idx]
88
  st.success(f"🍴 Predicted Food: **{pred_label}**")
89
 
90
- # Prepare LLM prompt
91
  prompt = (
92
- "Provide a concise nutritional overview for a tacos. "
93
- "Include serving size, calories, protein, carbs, fat, "
94
- "main ingredients, cooking method, and one substitution. "
95
- "Answer only the overview—do not repeat this instruction."
 
 
 
 
 
96
  )
97
  st.subheader("🧾 Nutrition Information")
98
- st.write(f"🤖 Prompt to LLM:\n\n{prompt}")
99
 
100
- # Tokenize & move to device
101
- inputs = tokenizer(prompt, return_tensors="pt")
102
  inputs = {k: v.to(model_llm.device) for k, v in inputs.items()}
103
  input_len = inputs["input_ids"].shape[1]
104
 
105
- # Generate with constraints
106
  outputs = model_llm.generate(
107
  **inputs,
108
  max_length=input_len + 150,
109
- do_sample=True,
110
- temperature=0.8,
111
  top_p=0.9,
 
112
  no_repeat_ngram_size=2,
113
  early_stopping=True,
114
- pad_token_id=tokenizer.eos_token_id,
115
- eos_token_id=tokenizer.eos_token_id
116
  )
117
 
118
- # Decode generated tokens only
119
- gen_ids = outputs[0][input_len:]
120
- caption = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
121
-
122
- if caption:
123
- st.info(caption)
124
  else:
125
- st.error("⚠️ The LLM failed to generate any text.")
 
 
126
 
127
  except Exception as e:
128
  st.error(f"Something went wrong: {e}")
 
16
  layout="centered"
17
  )
18
  st.title("🍽️ Food Nutrition Estimator")
19
+ st.markdown("Upload a food image and get a nutritional overview generated by an instruction‐tuned LLM!")
20
 
21
  # 2. Environment & cache
22
  hf_token = os.getenv("HF_TOKEN", None)
 
28
  manual_transform = transforms.Compose([
29
  transforms.Resize(256),
30
  transforms.CenterCrop(224),
31
+ transforms.Lambda(lambda img: img.convert("RGB")),
32
  transforms.ToTensor(),
33
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
34
  std=[0.229, 0.224, 0.225]),
 
39
  st.sidebar.header("Models Used")
40
  st.sidebar.markdown("""
41
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
42
+ - 💬 **Text Generator**: `tiiuae/falcon-7b-instruct`
43
  """)
44
 
45
  # 5. Load models (cached)
 
47
  def load_models():
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
 
50
+ # ViT classifier
51
  model_vit = ViTForImageClassification.from_pretrained(
52
  "shingguy1/fine_tuned_vit",
53
  cache_dir=cache_dir,
54
  use_auth_token=hf_token
55
  ).to(device)
56
 
57
+ # Falcon‐7B Instruct LLM
58
+ tokenizer_llm = AutoTokenizer.from_pretrained(
59
+ "tiiuae/falcon-7b-instruct",
60
+ cache_dir=cache_dir,
61
+ use_auth_token=hf_token
62
  )
63
  model_llm = AutoModelForCausalLM.from_pretrained(
64
+ "tiiuae/falcon-7b-instruct",
65
  cache_dir=cache_dir,
66
+ use_auth_token=hf_token,
67
+ torch_dtype=torch.float16,
68
  device_map="auto"
69
  )
70
 
71
+ return model_vit, tokenizer_llm, model_llm, device
72
 
73
+ model_vit, tokenizer_llm, model_llm, device = load_models()
74
 
75
  # 6. Image uploader
76
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
77
 
78
  if uploaded_file is not None:
79
  try:
80
+ # Display image
81
  image = Image.open(uploaded_file)
82
  st.image(image, caption="Uploaded Image", use_column_width=True)
83
 
 
89
  pred_label = model_vit.config.id2label[pred_idx]
90
  st.success(f"🍴 Predicted Food: **{pred_label}**")
91
 
92
+ # Build a single, unified instruction prompt
93
  prompt = (
94
+ "### Instruction\n"
95
+ f"Provide a concise nutritional overview for a {pred_label}, including:\n"
96
+ "- Serving size (exact measurements & ingestion guidelines)\n"
97
+ "- Calories\n"
98
+ "- Protein, carbohydrates, and fat\n"
99
+ "- Main ingredients\n"
100
+ "- Cooking method\n"
101
+ "- One healthy substitution\n"
102
+ "### Response"
103
  )
104
  st.subheader("🧾 Nutrition Information")
105
+ st.write(f"🤖 Prompt sent to LLM:\n\n{prompt}")
106
 
107
+ # Tokenize & generate
108
+ inputs = tokenizer_llm(prompt, return_tensors="pt")
109
  inputs = {k: v.to(model_llm.device) for k, v in inputs.items()}
110
  input_len = inputs["input_ids"].shape[1]
111
 
 
112
  outputs = model_llm.generate(
113
  **inputs,
114
  max_length=input_len + 150,
115
+ temperature=0.7,
 
116
  top_p=0.9,
117
+ do_sample=True,
118
  no_repeat_ngram_size=2,
119
  early_stopping=True,
120
+ pad_token_id=tokenizer_llm.eos_token_id,
121
+ eos_token_id=tokenizer_llm.eos_token_id
122
  )
123
 
124
+ # Decode and strip prompt
125
+ full = tokenizer_llm.decode(outputs[0], skip_special_tokens=True).strip()
126
+ if full.startswith("### Response"):
127
+ caption = full.split("### Response", 1)[1].strip()
 
 
128
  else:
129
+ caption = full[input_len:].strip()
130
+
131
+ st.info(caption or "⚠️ The LLM did not generate any text.")
132
 
133
  except Exception as e:
134
  st.error(f"Something went wrong: {e}")