Added trigger examples and updated form to implement them, Also reset the session_state.trigger_example to False
Browse files
app.py
CHANGED
|
@@ -44,7 +44,7 @@ import torch
|
|
| 44 |
# So first run will load and save resources to global cache, and as user interact and causes rerun of load_model_and_tokenizer(), instead of loading again it will directly use cached resources from memory
|
| 45 |
def load_model_and_tokenizer():
|
| 46 |
model_name = "google/gemma-2b" # using gemma-2b for prototype for my GSOC Proposal. Wish me luck.
|
| 47 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 48 |
# Responsible for automatically downloading and loading the tokenizer configuration and vocabulary associated with the specified pre-trained model.
|
| 49 |
# Downloads and loads the tokenizer config and vocab for the given model
|
| 50 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
@@ -64,7 +64,7 @@ def generate_text(prompt, tone, max_length, temperature=0.7, top_p=0.9, repetiti
|
|
| 64 |
}
|
| 65 |
input_text = tone_prompts.get(tone, prompt)
|
| 66 |
|
| 67 |
-
inputs = tokenizer(input_text, return_tensors="pt")
|
| 68 |
outputs = model.generate(
|
| 69 |
inputs["input_ids"],
|
| 70 |
max_length=max_length + len(input_text.split()),
|
|
@@ -110,25 +110,28 @@ with st.expander("\U0001F9E0 How does this work? Click to peek inside."):
|
|
| 110 |
""")
|
| 111 |
|
| 112 |
# One-click examples
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
col1, col2 = st.columns(2)
|
| 115 |
with col1:
|
| 116 |
if st.button("Try Funny Cat Story"):
|
| 117 |
-
prompt = "The cat hacked my WiFi"
|
| 118 |
-
tone = "Funny"
|
| 119 |
-
|
| 120 |
with col2:
|
| 121 |
if st.button("Try Poetic Goodbye"):
|
| 122 |
-
prompt = "As the sun set on our final day"
|
| 123 |
-
tone = "Poetic"
|
| 124 |
-
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
|
| 128 |
# User input section
|
| 129 |
with st.form(key="input_form"):
|
| 130 |
-
prompt = st.text_input("Enter a prompt", placeholder="e.g., 'The future of AI is'")
|
| 131 |
-
tone = st.selectbox("Tone", ["Funny", "Serious", "Poetic"])
|
| 132 |
temperature = st.slider("Temperature (Creativity)", 0.2, 1.5, 0.7)
|
| 133 |
top_p = st.slider("Top-p (Nucleus Sampling)", 0.1, 1.0, 0.9)
|
| 134 |
repetition_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.0)
|
|
@@ -137,7 +140,8 @@ with st.form(key="input_form"):
|
|
| 137 |
submit_button = st.form_submit_button(label="Generate")
|
| 138 |
|
| 139 |
# Generate and display output
|
| 140 |
-
if submit_button:
|
|
|
|
| 141 |
if not prompt:
|
| 142 |
st.error("Please enter a prompt!")
|
| 143 |
else:
|
|
|
|
| 44 |
# So first run will load and save resources to global cache, and as user interact and causes rerun of load_model_and_tokenizer(), instead of loading again it will directly use cached resources from memory
|
| 45 |
def load_model_and_tokenizer():
|
| 46 |
model_name = "google/gemma-2b" # using gemma-2b for prototype for my GSOC Proposal. Wish me luck.
|
| 47 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name).to("cpu") # As free tier account got CPU only
|
| 48 |
# Responsible for automatically downloading and loading the tokenizer configuration and vocabulary associated with the specified pre-trained model.
|
| 49 |
# Downloads and loads the tokenizer config and vocab for the given model
|
| 50 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|
| 64 |
}
|
| 65 |
input_text = tone_prompts.get(tone, prompt)
|
| 66 |
|
| 67 |
+
inputs = tokenizer(input_text, return_tensors="pt").to("cpu") # Don't make me say again, Free tier.
|
| 68 |
outputs = model.generate(
|
| 69 |
inputs["input_ids"],
|
| 70 |
max_length=max_length + len(input_text.split()),
|
|
|
|
| 110 |
""")
|
| 111 |
|
| 112 |
# One-click examples
|
| 113 |
+
if "trigger_example" not in st.session_state:
|
| 114 |
+
st.session_state.trigger_example = False
|
| 115 |
+
|
| 116 |
col1, col2 = st.columns(2)
|
| 117 |
with col1:
|
| 118 |
if st.button("Try Funny Cat Story"):
|
| 119 |
+
st.session_state.prompt = "The cat hacked my WiFi"
|
| 120 |
+
st.session_state.tone = "Funny"
|
| 121 |
+
st.session_state.trigger_example = True
|
| 122 |
with col2:
|
| 123 |
if st.button("Try Poetic Goodbye"):
|
| 124 |
+
st.session_state.prompt = "As the sun set on our final day"
|
| 125 |
+
st.session_state.tone = "Poetic"
|
| 126 |
+
st.session_state.trigger_example = True
|
| 127 |
+
|
| 128 |
|
| 129 |
|
| 130 |
|
| 131 |
# User input section
|
| 132 |
with st.form(key="input_form"):
|
| 133 |
+
prompt = st.text_input("Enter a prompt", placeholder="e.g., 'The future of AI is'", value=st.session_state.get("prompt", ""))
|
| 134 |
+
tone = st.selectbox("Tone", ["Funny", "Serious", "Poetic"], index=["Funny", "Serious", "Poetic"].index(st.session_state.get("tone", "Funny")))
|
| 135 |
temperature = st.slider("Temperature (Creativity)", 0.2, 1.5, 0.7)
|
| 136 |
top_p = st.slider("Top-p (Nucleus Sampling)", 0.1, 1.0, 0.9)
|
| 137 |
repetition_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.0)
|
|
|
|
| 140 |
submit_button = st.form_submit_button(label="Generate")
|
| 141 |
|
| 142 |
# Generate and display output
|
| 143 |
+
if submit_button or st.session_state.trigger_example:
|
| 144 |
+
st.session_state.trigger_example = False # Reset after use
|
| 145 |
if not prompt:
|
| 146 |
st.error("Please enter a prompt!")
|
| 147 |
else:
|