shingguy1 commited on
Commit
c5c8acf
·
verified ·
1 Parent(s): caf1197

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +35 -27
src/streamlit_app.py CHANGED
@@ -24,37 +24,40 @@ cache_dir = "/tmp/cache"
24
  os.makedirs(cache_dir, exist_ok=True)
25
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
26
 
27
- # 3. Image transform (ViT)
28
  manual_transform = transforms.Compose([
29
  transforms.Resize(256),
30
  transforms.CenterCrop(224),
31
- transforms.Lambda(lambda img: img.convert("RGB")),
32
  transforms.ToTensor(),
33
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
34
  std=[0.229, 0.224, 0.225]),
35
  transforms.ConvertImageDtype(torch.float32)
36
  ])
37
 
38
- # Sidebar
39
  st.sidebar.header("Models Used")
40
  st.sidebar.markdown("""
41
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
42
  - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
43
  """)
44
 
45
- # 4. Load models
46
  @st.cache_resource
47
  def load_models():
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
 
 
50
  model_vit = ViTForImageClassification.from_pretrained(
51
  "shingguy1/fine_tuned_vit",
52
  cache_dir=cache_dir,
53
  use_auth_token=hf_token
54
  ).to(device)
55
 
 
56
  tokenizer = AutoTokenizer.from_pretrained(
57
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0", cache_dir=cache_dir
 
58
  )
59
  model_llm = AutoModelForCausalLM.from_pretrained(
60
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
@@ -67,54 +70,59 @@ def load_models():
67
 
68
  model_vit, tokenizer, model_llm, device = load_models()
69
 
70
- # 5. Image uploader
71
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
72
 
73
- if uploaded_file:
74
  try:
 
75
  image = Image.open(uploaded_file)
76
  st.image(image, caption="Uploaded Image", use_column_width=True)
77
 
78
- # Predict
79
- batch = manual_transform(image).unsqueeze(0).to(device)
80
  with torch.no_grad():
81
- out = model_vit(pixel_values=batch)
82
- label = out.logits.argmax(-1).item()
83
- pred = model_vit.config.id2label[label]
84
- st.success(f"🍴 Predicted Food: **{pred}**")
85
 
86
- # Build prompt
87
  prompt = (
88
- f"Provide a concise nutritional overview for a {pred}. "
89
- "Include serving size, calories, protein, carbs, fat, main ingredients, cooking method, and one substitution. "
 
90
  "Answer only the overview—do not repeat this instruction."
91
  )
92
  st.subheader("🧾 Nutrition Information")
93
  st.write(f"🤖 Prompt to LLM:\n\n{prompt}")
94
 
95
- # Tokenize & move
96
  inputs = tokenizer(prompt, return_tensors="pt")
97
  inputs = {k: v.to(model_llm.device) for k, v in inputs.items()}
 
98
 
99
- # Generate
100
- max_len = inputs["input_ids"].shape[-1] + 150
101
  outputs = model_llm.generate(
102
  **inputs,
103
- max_length=max_len,
104
- temperature=0.7,
105
- top_p=0.9,
106
  do_sample=True,
 
 
107
  no_repeat_ngram_size=2,
 
108
  pad_token_id=tokenizer.eos_token_id,
109
  eos_token_id=tokenizer.eos_token_id
110
  )
111
 
112
- # Decode all, then strip prompt if echoed
113
- text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
114
- if text.lower().startswith(prompt.lower()):
115
- text = text[len(prompt):].strip()
116
 
117
- st.info(text or "⚠️ The LLM did not produce any text.")
 
 
 
118
 
119
  except Exception as e:
120
  st.error(f"Something went wrong: {e}")
 
24
  os.makedirs(cache_dir, exist_ok=True)
25
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
26
 
27
+ # 3. Image transform for ViT
28
  manual_transform = transforms.Compose([
29
  transforms.Resize(256),
30
  transforms.CenterCrop(224),
31
+ transforms.Lambda(lambda img: img.convert("RGB")), # ensure 3 channels
32
  transforms.ToTensor(),
33
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
34
  std=[0.229, 0.224, 0.225]),
35
  transforms.ConvertImageDtype(torch.float32)
36
  ])
37
 
38
+ # 4. Sidebar info
39
  st.sidebar.header("Models Used")
40
  st.sidebar.markdown("""
41
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
42
  - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
43
  """)
44
 
45
+ # 5. Load models (cached)
46
  @st.cache_resource
47
  def load_models():
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
 
50
+ # ViT for classification
51
  model_vit = ViTForImageClassification.from_pretrained(
52
  "shingguy1/fine_tuned_vit",
53
  cache_dir=cache_dir,
54
  use_auth_token=hf_token
55
  ).to(device)
56
 
57
+ # TinyLlama for nutrition text
58
  tokenizer = AutoTokenizer.from_pretrained(
59
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
60
+ cache_dir=cache_dir
61
  )
62
  model_llm = AutoModelForCausalLM.from_pretrained(
63
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
 
70
 
71
  model_vit, tokenizer, model_llm, device = load_models()
72
 
73
+ # 6. Image uploader
74
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
75
 
76
+ if uploaded_file is not None:
77
  try:
78
+ # Load & display image
79
  image = Image.open(uploaded_file)
80
  st.image(image, caption="Uploaded Image", use_column_width=True)
81
 
82
+ # Classify with ViT
83
+ input_tensor = manual_transform(image).unsqueeze(0).to(device)
84
  with torch.no_grad():
85
+ outputs = model_vit(pixel_values=input_tensor)
86
+ pred_idx = outputs.logits.argmax(-1).item()
87
+ pred_label = model_vit.config.id2label[pred_idx]
88
+ st.success(f"🍴 Predicted Food: **{pred_label}**")
89
 
90
+ # Prepare LLM prompt
91
  prompt = (
92
+ "Provide a concise nutritional overview for a tacos. "
93
+ "Include serving size, calories, protein, carbs, fat, "
94
+ "main ingredients, cooking method, and one substitution. "
95
  "Answer only the overview—do not repeat this instruction."
96
  )
97
  st.subheader("🧾 Nutrition Information")
98
  st.write(f"🤖 Prompt to LLM:\n\n{prompt}")
99
 
100
+ # Tokenize & move to device
101
  inputs = tokenizer(prompt, return_tensors="pt")
102
  inputs = {k: v.to(model_llm.device) for k, v in inputs.items()}
103
+ input_len = inputs["input_ids"].shape[1]
104
 
105
+ # Generate with constraints
 
106
  outputs = model_llm.generate(
107
  **inputs,
108
+ max_length=input_len + 150,
 
 
109
  do_sample=True,
110
+ temperature=0.8,
111
+ top_p=0.9,
112
  no_repeat_ngram_size=2,
113
+ early_stopping=True,
114
  pad_token_id=tokenizer.eos_token_id,
115
  eos_token_id=tokenizer.eos_token_id
116
  )
117
 
118
+ # Decode generated tokens only
119
+ gen_ids = outputs[0][input_len:]
120
+ caption = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
 
121
 
122
+ if caption:
123
+ st.info(caption)
124
+ else:
125
+ st.error("⚠️ The LLM failed to generate any text.")
126
 
127
  except Exception as e:
128
  st.error(f"Something went wrong: {e}")