Nyanfa commited on
Commit
eab17f3
·
verified ·
1 Parent(s): e0a8eab

Added the Stop generation button

Browse files
Files changed (1) hide show
  1. app.py +53 -28
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from premai import Prem
2
  import streamlit as st
3
  from streamlit.components.v1 import html
 
4
  import re
5
  import urllib.parse
6
 
@@ -28,6 +29,22 @@ if "messages" not in st.session_state:
28
  st.session_state.messages = []
29
 
30
  def get_ai_response(messages):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  with st.chat_message("assistant"):
32
  penalty_kwargs = {
33
  "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
@@ -44,16 +61,15 @@ def get_ai_response(messages):
44
  **penalty_kwargs,
45
  )
46
 
47
- response = ""
48
-
49
  placeholder = st.empty()
50
  for chunk in stream:
51
  if chunk.choices[0].delta["content"]:
52
  content = chunk.choices[0].delta["content"]
53
- response += content
54
- placeholder.markdown(response)
55
-
56
- return response
 
57
 
58
  def display_messages():
59
  for i, message in enumerate(st.session_state.messages):
@@ -97,7 +113,7 @@ def display_messages():
97
 
98
  if "edit_index" in st.session_state and st.session_state.edit_index == i:
99
  with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
100
- new_content = st.text_area("Edit message", value=st.session_state.messages[i]["content"])
101
  col1, col2 = st.columns([1, 1])
102
  with col1:
103
  if st.form_submit_button("Save"):
@@ -109,24 +125,8 @@ def display_messages():
109
  del st.session_state.edit_index
110
  st.rerun()
111
 
112
- if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
113
- if len(st.session_state.messages) > 0: # メッセージリストが空でないことを確認
114
- response = get_ai_response(st.session_state.messages)
115
- st.session_state.messages.append({"role": "assistant", "content": response})
116
-
117
- if response:
118
- st.session_state.is_error = False
119
- else:
120
- st.session_state.is_error = True
121
-
122
- st.session_state.retry_flag = False
123
- st.rerun()
124
-
125
- else:
126
- st.session_state.retry_flag = False # retry_flagをFalseに設定して処理を続行
127
-
128
  if "is_error" in st.session_state and st.session_state.is_error:
129
- st.warning("""
130
  Something went wrong. To resolve this error:
131
  1. Use the Retry button.
132
  2. Update your API key or Project ID correctly.
@@ -148,7 +148,7 @@ with st.sidebar:
148
  log_text += message["content"] + "\n\n"
149
  log_text = log_text.rstrip("\n")
150
 
151
- # PythonでURLエンコード
152
  log_text_escaped = urllib.parse.quote(log_text)
153
 
154
  copy_log_button_html = f"""
@@ -208,7 +208,7 @@ with st.sidebar:
208
  "remm-slerp-l2-13b",
209
  ]
210
  model = st.selectbox("Model", options=model_list, index=0)
211
- system_prompt = st.text_area("System prompt", height=100)
212
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.1)
213
  top_p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
214
  penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
@@ -256,11 +256,37 @@ with st.sidebar:
256
  else:
257
  st.warning("Please enter a valid Project ID.")
258
 
 
 
 
 
 
 
 
 
 
259
  display_messages()
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  if prompt := st.chat_input("What is up?"):
 
262
  messages = st.session_state.messages.copy()
263
- messages.append({"role": "user", "content": prompt})
264
 
265
  with st.chat_message("user"):
266
  st.write(prompt)
@@ -272,7 +298,6 @@ if prompt := st.chat_input("What is up?"):
272
  else:
273
  st.session_state.is_error = True
274
 
275
- st.session_state.messages.append({"role": "user", "content": prompt})
276
  st.session_state.messages.append({"role": "assistant", "content": response})
277
 
278
  st.rerun()
 
1
  from premai import Prem
2
  import streamlit as st
3
  from streamlit.components.v1 import html
4
+ from streamlit_extras.stylable_container import stylable_container
5
  import re
6
  import urllib.parse
7
 
 
29
  st.session_state.messages = []
30
 
31
  def get_ai_response(messages):
32
+ st.session_state.is_streaming = True
33
+ st.session_state.response = ""
34
+
35
+ with stylable_container(
36
+ key="stop_generating",
37
+ css_styles="""
38
+ button {
39
+ position: fixed;
40
+ bottom: 100px;
41
+ left: 50%;
42
+ transform: translateX(-50%);
43
+ }
44
+ """,
45
+ ):
46
+ st.button("Stop generating")
47
+
48
  with st.chat_message("assistant"):
49
  penalty_kwargs = {
50
  "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
 
61
  **penalty_kwargs,
62
  )
63
 
 
 
64
  placeholder = st.empty()
65
  for chunk in stream:
66
  if chunk.choices[0].delta["content"]:
67
  content = chunk.choices[0].delta["content"]
68
+ st.session_state.response += content
69
+ placeholder.markdown(st.session_state.response)
70
+
71
+ st.session_state.is_streaming = False
72
+ return st.session_state.response
73
 
74
  def display_messages():
75
  for i, message in enumerate(st.session_state.messages):
 
113
 
114
  if "edit_index" in st.session_state and st.session_state.edit_index == i:
115
  with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
116
+ new_content = st.text_area("Edit message", height=200, value=st.session_state.messages[i]["content"])
117
  col1, col2 = st.columns([1, 1])
118
  with col1:
119
  if st.form_submit_button("Save"):
 
125
  del st.session_state.edit_index
126
  st.rerun()
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  if "is_error" in st.session_state and st.session_state.is_error:
129
+ st.error("""
130
  Something went wrong. To resolve this error:
131
  1. Use the Retry button.
132
  2. Update your API key or Project ID correctly.
 
148
  log_text += message["content"] + "\n\n"
149
  log_text = log_text.rstrip("\n")
150
 
151
+ # Encode the string to escape
152
  log_text_escaped = urllib.parse.quote(log_text)
153
 
154
  copy_log_button_html = f"""
 
208
  "remm-slerp-l2-13b",
209
  ]
210
  model = st.selectbox("Model", options=model_list, index=0)
211
+ system_prompt = st.text_area("System prompt", height=200)
212
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.1)
213
  top_p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
214
  penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
 
256
  else:
257
  st.warning("Please enter a valid Project ID.")
258
 
259
+ # After Stop generating
260
+ if "is_streaming" in st.session_state and st.session_state.is_streaming:
261
+ st.session_state.messages.append({"role": "assistant", "content": st.session_state.response})
262
+ st.session_state.is_error = False
263
+ st.session_state.is_streaming = False
264
+ if "retry_flag" in st.session_state and st.session_state.retry_flag:
265
+ st.session_state.retry_flag = False
266
+ st.rerun()
267
+
268
  display_messages()
269
 
270
+ # After Retry
271
+ if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
272
+ if len(st.session_state.messages) > 0:
273
+ response = get_ai_response(st.session_state.messages)
274
+ st.session_state.messages.append({"role": "assistant", "content": response})
275
+
276
+ if response:
277
+ st.session_state.is_error = False
278
+ else:
279
+ st.session_state.is_error = True
280
+
281
+ st.session_state.retry_flag = False
282
+ st.rerun()
283
+
284
+ else:
285
+ st.session_state.retry_flag = False
286
+
287
  if prompt := st.chat_input("What is up?"):
288
+ st.session_state.messages.append({"role": "user", "content": prompt})
289
  messages = st.session_state.messages.copy()
 
290
 
291
  with st.chat_message("user"):
292
  st.write(prompt)
 
298
  else:
299
  st.session_state.is_error = True
300
 
 
301
  st.session_state.messages.append({"role": "assistant", "content": response})
302
 
303
  st.rerun()