shingguy1 commited on
Commit
6926852
·
verified ·
1 Parent(s): 1eac402

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +26 -8
src/streamlit_app.py CHANGED
@@ -53,7 +53,7 @@ def main():
53
  st.sidebar.header("Models Used")
54
  st.sidebar.markdown("""
55
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
56
- - 💬 **Paraphraser**: `google/flan-t5-small` (beam search mode)
57
  """)
58
 
59
  transform = transforms.Compose([
@@ -68,9 +68,21 @@ def main():
68
  @st.cache_resource
69
  def load_models():
70
  device = torch.device("cpu")
71
- vit = ViTForImageClassification.from_pretrained("shingguy1/fine_tuned_vit", cache_dir=cache_dir, use_auth_token=hf_token).to(device)
72
- tok = AutoTokenizer.from_pretrained("google/flan-t5-small", cache_dir=cache_dir, use_auth_token=hf_token)
73
- t5 = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", cache_dir=cache_dir, use_auth_token=hf_token).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
74
  return vit, tok, t5, device
75
 
76
  model_vit, tokenizer_t5, model_t5, device = load_models()
@@ -97,18 +109,24 @@ def main():
97
  f"Try {data['substitute']} as a healthier swap."
98
  )
99
  prompt = (
100
- f"Rewrite the following nutritional information in a clear and friendly tone. "
101
- f"Do not add or change any facts. Keep the sentence structure simple:\n\n"
 
102
  f"{base_description}"
103
  )
104
  else:
105
- prompt = f"Give the typical calories, macros, and nutrition facts for {label}. Provide realistic values even if estimated."
 
 
 
106
 
107
  inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
108
  output_ids = model_t5.generate(
109
  inputs["input_ids"],
110
  max_new_tokens=100,
111
- num_beams=4,
 
 
112
  early_stopping=True
113
  )
114
  response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)
 
53
  st.sidebar.header("Models Used")
54
  st.sidebar.markdown("""
55
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
56
+ - 💬 **Paraphraser**: `google/flan-t5-small` (sampling mode)
57
  """)
58
 
59
  transform = transforms.Compose([
 
68
  @st.cache_resource
69
  def load_models():
70
  device = torch.device("cpu")
71
+ vit = ViTForImageClassification.from_pretrained(
72
+ "shingguy1/fine_tuned_vit",
73
+ cache_dir=cache_dir,
74
+ use_auth_token=hf_token
75
+ ).to(device)
76
+ tok = AutoTokenizer.from_pretrained(
77
+ "google/flan-t5-small",
78
+ cache_dir=cache_dir,
79
+ use_auth_token=hf_token
80
+ )
81
+ t5 = T5ForConditionalGeneration.from_pretrained(
82
+ "google/flan-t5-small",
83
+ cache_dir=cache_dir,
84
+ use_auth_token=hf_token
85
+ ).to(device)
86
  return vit, tok, t5, device
87
 
88
  model_vit, tokenizer_t5, model_t5, device = load_models()
 
109
  f"Try {data['substitute']} as a healthier swap."
110
  )
111
  prompt = (
112
+ f"Paraphrase the following nutritional facts in a friendly, conversational tone. "
113
+ f"Use varied sentence structures and synonyms, and feel free to generalize numeric details "
114
+ f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n"
115
  f"{base_description}"
116
  )
117
  else:
118
+ prompt = (
119
+ f"Provide an approximate nutrition summary for {label}, including calories, "
120
+ f"macronutrients, and a brief description."
121
+ )
122
 
123
  inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
124
  output_ids = model_t5.generate(
125
  inputs["input_ids"],
126
  max_new_tokens=100,
127
+ do_sample=True,
128
+ top_p=0.9,
129
+ temperature=0.7,
130
  early_stopping=True
131
  )
132
  response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)