shingguy1 commited on
Commit
dbca709
·
verified ·
1 Parent(s): aaf3765

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +41 -49
src/streamlit_app.py CHANGED
@@ -12,17 +12,17 @@ import torchvision.transforms as transforms
12
  from transformers import (
13
  ViTForImageClassification,
14
  AutoTokenizer,
15
- AutoModelForCausalLM
16
  )
17
 
18
  def main():
19
- # Environment & cache
20
  hf_token = os.getenv("HF_TOKEN", None)
21
  cache_dir = "/tmp/cache"
22
  os.makedirs(cache_dir, exist_ok=True)
23
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
24
 
25
- # Image transform for ViT
26
  manual_transform = transforms.Compose([
27
  transforms.Resize(256),
28
  transforms.CenterCrop(224),
@@ -33,102 +33,94 @@ def main():
33
  transforms.ConvertImageDtype(torch.float32)
34
  ])
35
 
36
- # Sidebar info
37
  st.sidebar.header("Models Used")
38
  st.sidebar.markdown("""
39
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
40
- - 💬 **Text Generator**: `tiiuae/falcon-7b-instruct`
41
  """)
42
 
43
- # Load models (cached)
44
  @st.cache_resource
45
  def load_models():
46
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
 
48
- # ViT classifier → GPU/CPU
49
  model_vit = ViTForImageClassification.from_pretrained(
50
  "shingguy1/fine_tuned_vit",
51
  cache_dir=cache_dir,
52
  use_auth_token=hf_token
53
  ).to(device)
54
 
55
- # Falcon-7B Instruct 8-bit quant on GPU
56
  tokenizer_llm = AutoTokenizer.from_pretrained(
57
- "tiiuae/falcon-7b-instruct",
58
  cache_dir=cache_dir,
59
  use_auth_token=hf_token
60
  )
61
- model_llm = AutoModelForCausalLM.from_pretrained(
62
- "tiiuae/falcon-7b-instruct",
63
  cache_dir=cache_dir,
64
- load_in_8bit=True,
65
- device_map="auto",
66
- torch_dtype=torch.float16,
67
  use_auth_token=hf_token
68
- )
69
 
70
  return model_vit, tokenizer_llm, model_llm, device
71
 
72
  model_vit, tokenizer_llm, model_llm, device = load_models()
73
 
74
- # Image uploader
75
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
76
  if uploaded_file is not None:
77
  try:
 
78
  image = Image.open(uploaded_file)
79
  st.image(image, caption="Uploaded Image", use_column_width=True)
80
 
81
- # Classify
82
- inputs_v = manual_transform(image).unsqueeze(0).to(device)
83
  with torch.no_grad():
84
- out = model_vit(pixel_values=inputs_v)
85
- idx = out.logits.argmax(-1).item()
86
- pred_label = model_vit.config.id2label[idx]
87
  st.success(f"🍴 Predicted Food: **{pred_label}**")
88
 
89
- # Unified instruction prompt
90
  prompt = (
91
- "### Instruction\n"
92
- f"Provide a concise nutritional overview for a {pred_label}, including:\n"
93
- "- Serving size (measurements & ingestion guidelines)\n"
94
  "- Calories\n"
95
  "- Protein, carbohydrates, and fat\n"
96
  "- Main ingredients\n"
97
  "- Cooking method\n"
98
  "- One healthy substitution\n"
99
- "### Response"
100
  )
101
  st.subheader("🧾 Nutrition Information")
102
- st.write(f"🤖 Prompt to LLM:\n\n{prompt}")
103
 
104
  # Tokenize & generate
105
- inputs = tokenizer_llm(prompt, return_tensors="pt")
106
- inputs = {k: v.to(model_llm.device) for k, v in inputs.items()}
107
- inp_len = inputs["input_ids"].shape[1]
108
-
109
- out_ids = model_llm.generate(
110
- **inputs,
111
- max_length=inp_len + 150,
 
 
 
 
112
  temperature=0.7,
113
  top_p=0.9,
114
  do_sample=True,
115
  no_repeat_ngram_size=2,
116
  early_stopping=True,
117
- pad_token_id=tokenizer_llm.eos_token_id,
118
  eos_token_id=tokenizer_llm.eos_token_id
119
- )[0]
120
-
121
- # Decode & strip prompt
122
- decoded = tokenizer_llm.decode(out_ids, skip_special_tokens=True).strip()
123
- if "### Response" in decoded:
124
- caption = decoded.split("### Response", 1)[1].strip()
125
- else:
126
- caption = decoded[inp_len:].strip()
127
-
128
- if caption:
129
- st.info(caption)
130
- else:
131
- st.error("⚠️ The LLM did not generate any text.")
132
 
133
  except Exception as e:
134
  st.error(f"Something went wrong: {e}")
 
12
  from transformers import (
13
  ViTForImageClassification,
14
  AutoTokenizer,
15
+ T5ForConditionalGeneration
16
  )
17
 
18
  def main():
19
+ # 1. Environment & cache
20
  hf_token = os.getenv("HF_TOKEN", None)
21
  cache_dir = "/tmp/cache"
22
  os.makedirs(cache_dir, exist_ok=True)
23
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
24
 
25
+ # 2. Image transform for ViT
26
  manual_transform = transforms.Compose([
27
  transforms.Resize(256),
28
  transforms.CenterCrop(224),
 
33
  transforms.ConvertImageDtype(torch.float32)
34
  ])
35
 
36
+ # 3. Sidebar info
37
  st.sidebar.header("Models Used")
38
  st.sidebar.markdown("""
39
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
40
+ - 💬 **Text Generator**: `google/flan-t5-small`
41
  """)
42
 
43
+ # 4. Load models (cached)
44
  @st.cache_resource
45
  def load_models():
46
+ device = torch.device("cpu") # CPU-only environment
47
 
48
+ # ViT classifier
49
  model_vit = ViTForImageClassification.from_pretrained(
50
  "shingguy1/fine_tuned_vit",
51
  cache_dir=cache_dir,
52
  use_auth_token=hf_token
53
  ).to(device)
54
 
55
+ # FLAN-T5 Small for generation
56
  tokenizer_llm = AutoTokenizer.from_pretrained(
57
+ "google/flan-t5-small",
58
  cache_dir=cache_dir,
59
  use_auth_token=hf_token
60
  )
61
+ model_llm = T5ForConditionalGeneration.from_pretrained(
62
+ "google/flan-t5-small",
63
  cache_dir=cache_dir,
 
 
 
64
  use_auth_token=hf_token
65
+ ).to(device)
66
 
67
  return model_vit, tokenizer_llm, model_llm, device
68
 
69
  model_vit, tokenizer_llm, model_llm, device = load_models()
70
 
71
+ # 5. Image uploader
72
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
73
  if uploaded_file is not None:
74
  try:
75
+ # Display image
76
  image = Image.open(uploaded_file)
77
  st.image(image, caption="Uploaded Image", use_column_width=True)
78
 
79
+ # Classify with ViT
80
+ inputs_vit = manual_transform(image).unsqueeze(0).to(device)
81
  with torch.no_grad():
82
+ vit_outputs = model_vit(pixel_values=inputs_vit)
83
+ pred_idx = vit_outputs.logits.argmax(-1).item()
84
+ pred_label = model_vit.config.id2label[pred_idx]
85
  st.success(f"🍴 Predicted Food: **{pred_label}**")
86
 
87
+ # Build FLAN-T5 prompt
88
  prompt = (
89
+ "Provide a concise nutritional overview for a taco, including:\n"
90
+ "- Serving size (with measurements & ingestion guidelines)\n"
 
91
  "- Calories\n"
92
  "- Protein, carbohydrates, and fat\n"
93
  "- Main ingredients\n"
94
  "- Cooking method\n"
95
  "- One healthy substitution\n"
96
+ "Answer only the overview."
97
  )
98
  st.subheader("🧾 Nutrition Information")
99
+ st.write(f"🤖 Prompt:\n\n{prompt}")
100
 
101
  # Tokenize & generate
102
+ inputs = tokenizer_llm(
103
+ prompt,
104
+ return_tensors="pt",
105
+ padding="longest",
106
+ truncation=True,
107
+ ).to(device)
108
+
109
+ outputs = model_llm.generate(
110
+ input_ids=inputs.input_ids,
111
+ attention_mask=inputs.attention_mask,
112
+ max_new_tokens=150,
113
  temperature=0.7,
114
  top_p=0.9,
115
  do_sample=True,
116
  no_repeat_ngram_size=2,
117
  early_stopping=True,
118
+ pad_token_id=tokenizer_llm.pad_token_id,
119
  eos_token_id=tokenizer_llm.eos_token_id
120
+ )
121
+
122
+ summary = tokenizer_llm.decode(outputs[0], skip_special_tokens=True).strip()
123
+ st.info(summary or "⚠️ The model did not generate any text.")
 
 
 
 
 
 
 
 
 
124
 
125
  except Exception as e:
126
  st.error(f"Something went wrong: {e}")