Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +26 -8
src/streamlit_app.py
CHANGED
|
@@ -53,7 +53,7 @@ def main():
|
|
| 53 |
st.sidebar.header("Models Used")
|
| 54 |
st.sidebar.markdown("""
|
| 55 |
- 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
|
| 56 |
-
- 💬 **Paraphraser**: `google/flan-t5-small` (
|
| 57 |
""")
|
| 58 |
|
| 59 |
transform = transforms.Compose([
|
|
@@ -68,9 +68,21 @@ def main():
|
|
| 68 |
@st.cache_resource
|
| 69 |
def load_models():
|
| 70 |
device = torch.device("cpu")
|
| 71 |
-
vit = ViTForImageClassification.from_pretrained(
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
return vit, tok, t5, device
|
| 75 |
|
| 76 |
model_vit, tokenizer_t5, model_t5, device = load_models()
|
|
@@ -97,18 +109,24 @@ def main():
|
|
| 97 |
f"Try {data['substitute']} as a healthier swap."
|
| 98 |
)
|
| 99 |
prompt = (
|
| 100 |
-
f"
|
| 101 |
-
f"
|
|
|
|
| 102 |
f"{base_description}"
|
| 103 |
)
|
| 104 |
else:
|
| 105 |
-
prompt =
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
|
| 108 |
output_ids = model_t5.generate(
|
| 109 |
inputs["input_ids"],
|
| 110 |
max_new_tokens=100,
|
| 111 |
-
|
|
|
|
|
|
|
| 112 |
early_stopping=True
|
| 113 |
)
|
| 114 |
response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
| 53 |
st.sidebar.header("Models Used")
|
| 54 |
st.sidebar.markdown("""
|
| 55 |
- 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
|
| 56 |
+
- 💬 **Paraphraser**: `google/flan-t5-small` (sampling mode)
|
| 57 |
""")
|
| 58 |
|
| 59 |
transform = transforms.Compose([
|
|
|
|
| 68 |
@st.cache_resource
|
| 69 |
def load_models():
|
| 70 |
device = torch.device("cpu")
|
| 71 |
+
vit = ViTForImageClassification.from_pretrained(
|
| 72 |
+
"shingguy1/fine_tuned_vit",
|
| 73 |
+
cache_dir=cache_dir,
|
| 74 |
+
use_auth_token=hf_token
|
| 75 |
+
).to(device)
|
| 76 |
+
tok = AutoTokenizer.from_pretrained(
|
| 77 |
+
"google/flan-t5-small",
|
| 78 |
+
cache_dir=cache_dir,
|
| 79 |
+
use_auth_token=hf_token
|
| 80 |
+
)
|
| 81 |
+
t5 = T5ForConditionalGeneration.from_pretrained(
|
| 82 |
+
"google/flan-t5-small",
|
| 83 |
+
cache_dir=cache_dir,
|
| 84 |
+
use_auth_token=hf_token
|
| 85 |
+
).to(device)
|
| 86 |
return vit, tok, t5, device
|
| 87 |
|
| 88 |
model_vit, tokenizer_t5, model_t5, device = load_models()
|
|
|
|
| 109 |
f"Try {data['substitute']} as a healthier swap."
|
| 110 |
)
|
| 111 |
prompt = (
|
| 112 |
+
f"Paraphrase the following nutritional facts in a friendly, conversational tone. "
|
| 113 |
+
f"Use varied sentence structures and synonyms, and feel free to generalize numeric details "
|
| 114 |
+
f"(e.g., ‘around 250 kcal’). Don’t add any new facts.\n\n"
|
| 115 |
f"{base_description}"
|
| 116 |
)
|
| 117 |
else:
|
| 118 |
+
prompt = (
|
| 119 |
+
f"Provide an approximate nutrition summary for {label}, including calories, "
|
| 120 |
+
f"macronutrients, and a brief description."
|
| 121 |
+
)
|
| 122 |
|
| 123 |
inputs = tokenizer_t5(prompt, return_tensors="pt", truncation=True).to(device)
|
| 124 |
output_ids = model_t5.generate(
|
| 125 |
inputs["input_ids"],
|
| 126 |
max_new_tokens=100,
|
| 127 |
+
do_sample=True,
|
| 128 |
+
top_p=0.9,
|
| 129 |
+
temperature=0.7,
|
| 130 |
early_stopping=True
|
| 131 |
)
|
| 132 |
response = tokenizer_t5.decode(output_ids[0], skip_special_tokens=True)
|