Spaces:
Running
Running
Add the Prefill and the @prefill command
Browse files
app.py
CHANGED
|
@@ -33,9 +33,18 @@ else:
|
|
| 33 |
if "messages" not in st.session_state:
|
| 34 |
st.session_state.messages = []
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
def get_ai_response(messages):
|
| 37 |
st.session_state.is_streaming = True
|
| 38 |
st.session_state.response = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
with st.chat_message("assistant"):
|
| 41 |
penalty_kwargs = {
|
|
@@ -69,8 +78,6 @@ def get_ai_response(messages):
|
|
| 69 |
):
|
| 70 |
st.button("Stop generating")
|
| 71 |
|
| 72 |
-
shown_message = ""
|
| 73 |
-
|
| 74 |
for chunk in stream:
|
| 75 |
if chunk.choices and chunk.choices[0].delta.get("content"):
|
| 76 |
content = chunk.choices[0].delta["content"]
|
|
@@ -78,6 +85,11 @@ def get_ai_response(messages):
|
|
| 78 |
shown_message += content.replace("\n", " \n")
|
| 79 |
placeholder.markdown(shown_message)
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
st.session_state.is_streaming = False
|
| 82 |
return st.session_state.response
|
| 83 |
|
|
@@ -205,8 +217,8 @@ with st.sidebar:
|
|
| 205 |
"llama-2-70b-chat",
|
| 206 |
"llama-2-7b-chat",
|
| 207 |
"llama-2-70b-fast",
|
| 208 |
-
"
|
| 209 |
-
"
|
| 210 |
"mistral-7b-instruct-v0.1",
|
| 211 |
"mixtral-8x7b-instruct-v0.1",
|
| 212 |
"mixtral-8x7b-fast",
|
|
@@ -216,8 +228,6 @@ with st.sidebar:
|
|
| 216 |
"mistral-tiny",
|
| 217 |
"dolphin-mixtral-8x7b",
|
| 218 |
"mixtral-8x22b",
|
| 219 |
-
"pplx-7b-chat",
|
| 220 |
-
"pplx-7b-online",
|
| 221 |
"yi-34-chat",
|
| 222 |
"chronos-hermes-13b",
|
| 223 |
"mythomax-l2-13b",
|
|
@@ -229,6 +239,10 @@ with st.sidebar:
|
|
| 229 |
]
|
| 230 |
model = st.selectbox("Model", options=model_list, index=0)
|
| 231 |
system_prompt = st.text_area("System prompt", height=200)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.1)
|
| 233 |
top_p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
|
| 234 |
penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
|
|
@@ -290,14 +304,10 @@ display_messages()
|
|
| 290 |
# After Retry
|
| 291 |
if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
|
| 292 |
if len(st.session_state.messages) > 0:
|
| 293 |
-
|
|
|
|
| 294 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 295 |
|
| 296 |
-
if response:
|
| 297 |
-
st.session_state.is_error = False
|
| 298 |
-
else:
|
| 299 |
-
st.session_state.is_error = True
|
| 300 |
-
|
| 301 |
st.session_state.retry_flag = False
|
| 302 |
st.rerun()
|
| 303 |
|
|
@@ -305,6 +315,17 @@ if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
|
|
| 305 |
st.session_state.retry_flag = False
|
| 306 |
|
| 307 |
if prompt := st.chat_input("What is up?"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 309 |
messages = st.session_state.messages.copy()
|
| 310 |
|
|
@@ -313,11 +334,9 @@ if prompt := st.chat_input("What is up?"):
|
|
| 313 |
|
| 314 |
response = get_ai_response(messages)
|
| 315 |
|
| 316 |
-
if response:
|
| 317 |
-
st.session_state.is_error = False
|
| 318 |
-
else:
|
| 319 |
-
st.session_state.is_error = True
|
| 320 |
-
|
| 321 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
st.rerun()
|
|
|
|
| 33 |
if "messages" not in st.session_state:
|
| 34 |
st.session_state.messages = []
|
| 35 |
|
| 36 |
+
if "prefill" not in st.session_state:
|
| 37 |
+
st.session_state.prefill = ""
|
| 38 |
+
|
| 39 |
def get_ai_response(messages):
|
| 40 |
st.session_state.is_streaming = True
|
| 41 |
st.session_state.response = ""
|
| 42 |
+
shown_message = ""
|
| 43 |
+
|
| 44 |
+
if st.session_state.prefill:
|
| 45 |
+
messages.append({"role": "assistant", "content": st.session_state.prefill})
|
| 46 |
+
st.session_state.response += st.session_state.prefill
|
| 47 |
+
shown_message = st.session_state.prefill.replace("\n", " \n")
|
| 48 |
|
| 49 |
with st.chat_message("assistant"):
|
| 50 |
penalty_kwargs = {
|
|
|
|
| 78 |
):
|
| 79 |
st.button("Stop generating")
|
| 80 |
|
|
|
|
|
|
|
| 81 |
for chunk in stream:
|
| 82 |
if chunk.choices and chunk.choices[0].delta.get("content"):
|
| 83 |
content = chunk.choices[0].delta["content"]
|
|
|
|
| 85 |
shown_message += content.replace("\n", " \n")
|
| 86 |
placeholder.markdown(shown_message)
|
| 87 |
|
| 88 |
+
if st.session_state.prefill == st.session_state.response:
|
| 89 |
+
st.session_state.is_error = True
|
| 90 |
+
else:
|
| 91 |
+
st.session_state.is_error = False
|
| 92 |
+
|
| 93 |
st.session_state.is_streaming = False
|
| 94 |
return st.session_state.response
|
| 95 |
|
|
|
|
| 217 |
"llama-2-70b-chat",
|
| 218 |
"llama-2-7b-chat",
|
| 219 |
"llama-2-70b-fast",
|
| 220 |
+
"llama-3-70b-instruct",
|
| 221 |
+
"llama-3-8b-instruct",
|
| 222 |
"mistral-7b-instruct-v0.1",
|
| 223 |
"mixtral-8x7b-instruct-v0.1",
|
| 224 |
"mixtral-8x7b-fast",
|
|
|
|
| 228 |
"mistral-tiny",
|
| 229 |
"dolphin-mixtral-8x7b",
|
| 230 |
"mixtral-8x22b",
|
|
|
|
|
|
|
| 231 |
"yi-34-chat",
|
| 232 |
"chronos-hermes-13b",
|
| 233 |
"mythomax-l2-13b",
|
|
|
|
| 239 |
]
|
| 240 |
model = st.selectbox("Model", options=model_list, index=0)
|
| 241 |
system_prompt = st.text_area("System prompt", height=200)
|
| 242 |
+
|
| 243 |
+
st.session_state.prefill = st.text_area("Prefill", height=50, value=st.session_state.prefill, placeholder="It only works well with the Claude models.")
|
| 244 |
+
save_prefill = st.toggle("Save the @prefill command input in the sidebar", value=True)
|
| 245 |
+
|
| 246 |
temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.1)
|
| 247 |
top_p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
|
| 248 |
penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
|
|
|
|
| 304 |
# After Retry
|
| 305 |
if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
|
| 306 |
if len(st.session_state.messages) > 0:
|
| 307 |
+
messages = st.session_state.messages.copy()
|
| 308 |
+
response = get_ai_response(messages)
|
| 309 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
st.session_state.retry_flag = False
|
| 312 |
st.rerun()
|
| 313 |
|
|
|
|
| 315 |
st.session_state.retry_flag = False
|
| 316 |
|
| 317 |
if prompt := st.chat_input("What is up?"):
|
| 318 |
+
used_prefill = False
|
| 319 |
+
prefill_pattern = r"([@@](prefill|ぷれふぃる|プレフィル)\s?(.*))"
|
| 320 |
+
prefill_match = re.search(prefill_pattern, prompt)
|
| 321 |
+
|
| 322 |
+
if prefill_match:
|
| 323 |
+
used_prefill = True
|
| 324 |
+
if not save_prefill:
|
| 325 |
+
original_prefill = st.session_state.prefill
|
| 326 |
+
st.session_state.prefill = prefill_match.group(3)
|
| 327 |
+
prompt = prompt.replace(prefill_match.group(1), '')
|
| 328 |
+
|
| 329 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 330 |
messages = st.session_state.messages.copy()
|
| 331 |
|
|
|
|
| 334 |
|
| 335 |
response = get_ai_response(messages)
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 338 |
+
|
| 339 |
+
if used_prefill and not save_prefill:
|
| 340 |
+
st.session_state.prefill = original_prefill
|
| 341 |
|
| 342 |
st.rerun()
|