Nyanfa commited on
Commit
b581f30
·
verified ·
1 Parent(s): b31d009

Add the Prefill and the @prefill command

Browse files
Files changed (1) hide show
  1. app.py +36 -17
app.py CHANGED
@@ -33,9 +33,18 @@ else:
33
  if "messages" not in st.session_state:
34
  st.session_state.messages = []
35
 
 
 
 
36
  def get_ai_response(messages):
37
  st.session_state.is_streaming = True
38
  st.session_state.response = ""
 
 
 
 
 
 
39
 
40
  with st.chat_message("assistant"):
41
  penalty_kwargs = {
@@ -69,8 +78,6 @@ def get_ai_response(messages):
69
  ):
70
  st.button("Stop generating")
71
 
72
- shown_message = ""
73
-
74
  for chunk in stream:
75
  if chunk.choices and chunk.choices[0].delta.get("content"):
76
  content = chunk.choices[0].delta["content"]
@@ -78,6 +85,11 @@ def get_ai_response(messages):
78
  shown_message += content.replace("\n", " \n")
79
  placeholder.markdown(shown_message)
80
 
 
 
 
 
 
81
  st.session_state.is_streaming = False
82
  return st.session_state.response
83
 
@@ -205,8 +217,8 @@ with st.sidebar:
205
  "llama-2-70b-chat",
206
  "llama-2-7b-chat",
207
  "llama-2-70b-fast",
208
- "pplx-70b-chat",
209
- "pplx-70b-online",
210
  "mistral-7b-instruct-v0.1",
211
  "mixtral-8x7b-instruct-v0.1",
212
  "mixtral-8x7b-fast",
@@ -216,8 +228,6 @@ with st.sidebar:
216
  "mistral-tiny",
217
  "dolphin-mixtral-8x7b",
218
  "mixtral-8x22b",
219
- "pplx-7b-chat",
220
- "pplx-7b-online",
221
  "yi-34-chat",
222
  "chronos-hermes-13b",
223
  "mythomax-l2-13b",
@@ -229,6 +239,10 @@ with st.sidebar:
229
  ]
230
  model = st.selectbox("Model", options=model_list, index=0)
231
  system_prompt = st.text_area("System prompt", height=200)
 
 
 
 
232
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.1)
233
  top_p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
234
  penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
@@ -290,14 +304,10 @@ display_messages()
290
  # After Retry
291
  if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
292
  if len(st.session_state.messages) > 0:
293
- response = get_ai_response(st.session_state.messages)
 
294
  st.session_state.messages.append({"role": "assistant", "content": response})
295
 
296
- if response:
297
- st.session_state.is_error = False
298
- else:
299
- st.session_state.is_error = True
300
-
301
  st.session_state.retry_flag = False
302
  st.rerun()
303
 
@@ -305,6 +315,17 @@ if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
305
  st.session_state.retry_flag = False
306
 
307
  if prompt := st.chat_input("What is up?"):
 
 
 
 
 
 
 
 
 
 
 
308
  st.session_state.messages.append({"role": "user", "content": prompt})
309
  messages = st.session_state.messages.copy()
310
 
@@ -313,11 +334,9 @@ if prompt := st.chat_input("What is up?"):
313
 
314
  response = get_ai_response(messages)
315
 
316
- if response:
317
- st.session_state.is_error = False
318
- else:
319
- st.session_state.is_error = True
320
-
321
  st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
322
 
323
  st.rerun()
 
33
  if "messages" not in st.session_state:
34
  st.session_state.messages = []
35
 
36
+ if "prefill" not in st.session_state:
37
+ st.session_state.prefill = ""
38
+
39
  def get_ai_response(messages):
40
  st.session_state.is_streaming = True
41
  st.session_state.response = ""
42
+ shown_message = ""
43
+
44
+ if st.session_state.prefill:
45
+ messages.append({"role": "assistant", "content": st.session_state.prefill})
46
+ st.session_state.response += st.session_state.prefill
47
+ shown_message = st.session_state.prefill.replace("\n", " \n")
48
 
49
  with st.chat_message("assistant"):
50
  penalty_kwargs = {
 
78
  ):
79
  st.button("Stop generating")
80
 
 
 
81
  for chunk in stream:
82
  if chunk.choices and chunk.choices[0].delta.get("content"):
83
  content = chunk.choices[0].delta["content"]
 
85
  shown_message += content.replace("\n", " \n")
86
  placeholder.markdown(shown_message)
87
 
88
+ if st.session_state.prefill == st.session_state.response:
89
+ st.session_state.is_error = True
90
+ else:
91
+ st.session_state.is_error = False
92
+
93
  st.session_state.is_streaming = False
94
  return st.session_state.response
95
 
 
217
  "llama-2-70b-chat",
218
  "llama-2-7b-chat",
219
  "llama-2-70b-fast",
220
+ "llama-3-70b-instruct",
221
+ "llama-3-8b-instruct",
222
  "mistral-7b-instruct-v0.1",
223
  "mixtral-8x7b-instruct-v0.1",
224
  "mixtral-8x7b-fast",
 
228
  "mistral-tiny",
229
  "dolphin-mixtral-8x7b",
230
  "mixtral-8x22b",
 
 
231
  "yi-34-chat",
232
  "chronos-hermes-13b",
233
  "mythomax-l2-13b",
 
239
  ]
240
  model = st.selectbox("Model", options=model_list, index=0)
241
  system_prompt = st.text_area("System prompt", height=200)
242
+
243
+ st.session_state.prefill = st.text_area("Prefill", height=50, value=st.session_state.prefill, placeholder="It only works well with the Claude models.")
244
+ save_prefill = st.toggle("Save the @prefill command input in the sidebar", value=True)
245
+
246
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.1)
247
  top_p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
248
  penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
 
304
  # After Retry
305
  if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
306
  if len(st.session_state.messages) > 0:
307
+ messages = st.session_state.messages.copy()
308
+ response = get_ai_response(messages)
309
  st.session_state.messages.append({"role": "assistant", "content": response})
310
 
 
 
 
 
 
311
  st.session_state.retry_flag = False
312
  st.rerun()
313
 
 
315
  st.session_state.retry_flag = False
316
 
317
  if prompt := st.chat_input("What is up?"):
318
+ used_prefill = False
319
+ prefill_pattern = r"([@@](prefill|ぷれふぃる|プレフィル)\s?(.*))"
320
+ prefill_match = re.search(prefill_pattern, prompt)
321
+
322
+ if prefill_match:
323
+ used_prefill = True
324
+ if not save_prefill:
325
+ original_prefill = st.session_state.prefill
326
+ st.session_state.prefill = prefill_match.group(3)
327
+ prompt = prompt.replace(prefill_match.group(1), '')
328
+
329
  st.session_state.messages.append({"role": "user", "content": prompt})
330
  messages = st.session_state.messages.copy()
331
 
 
334
 
335
  response = get_ai_response(messages)
336
 
 
 
 
 
 
337
  st.session_state.messages.append({"role": "assistant", "content": response})
338
+
339
+ if used_prefill and not save_prefill:
340
+ st.session_state.prefill = original_prefill
341
 
342
  st.rerun()