Nyanfa commited on
Commit
8b30aa9
·
verified ·
1 Parent(s): 15fb4e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -38
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import cohere
2
  import streamlit as st
3
  from streamlit.components.v1 import html
4
  import re
@@ -6,43 +6,45 @@ import urllib.parse
6
 
7
  st.title("Prem Chat UI")
8
 
9
- if "api_key" not in st.session_state:
10
  api_key = st.text_input("Enter your API Key", type="password")
11
- if api_key:
 
12
  st.session_state.api_key = api_key
13
- client = cohere.Client(api_key=api_key)
 
14
  st.rerun()
15
  else:
16
- st.warning("Please enter your API key to use the app. You can obtain your API key from here: https://dashboard.cohere.com/api-keys")
17
  st.stop()
18
  else:
19
- client = cohere.Client(api_key=st.session_state.api_key)
20
 
21
  if "messages" not in st.session_state:
22
  st.session_state.messages = []
23
 
24
- def get_ai_response(prompt, chat_history):
25
- with st.chat_message("ai"):
26
  penalty_kwargs = {
27
  "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
28
  }
29
 
30
- stream = client.chat_stream(
31
- message=prompt,
 
 
32
  model=model,
33
- preamble=preamble,
34
- chat_history=chat_history,
35
  temperature=temperature,
36
- k=k,
37
- p=p,
38
- **penalty_kwargs
39
  )
40
 
41
  response = ""
42
  placeholder = st.empty()
43
- for event in stream:
44
- if event.event_type == "text-generation":
45
- content = event.text
46
  response += content
47
  placeholder.markdown(response)
48
 
@@ -50,9 +52,8 @@ def get_ai_response(prompt, chat_history):
50
 
51
  def display_messages():
52
  for i, message in enumerate(st.session_state.messages):
53
- name = "user" if message["role"] == "USER" else "ai"
54
- with st.chat_message(name):
55
- st.markdown(message["text"])
56
  col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
57
  with col1:
58
  if st.button("Edit", key=f"edit_{i}_{len(st.session_state.messages)}"):
@@ -63,7 +64,7 @@ def display_messages():
63
  del st.session_state.messages[i]
64
  st.rerun()
65
  with col3:
66
- text_to_copy = message["text"]
67
  # PythonでURLエンコード
68
  text_to_copy_escaped = urllib.parse.quote(text_to_copy)
69
 
@@ -81,7 +82,7 @@ def display_messages():
81
  """
82
  html(copy_button_html, height=50)
83
 
84
- if i == len(st.session_state.messages) - 1 and message["role"] == "CHATBOT":
85
  with col4:
86
  if st.button("Retry", key=f"retry_{i}_{len(st.session_state.messages)}"):
87
  if len(st.session_state.messages) >= 2:
@@ -95,7 +96,7 @@ def display_messages():
95
  col1, col2 = st.columns([1, 1])
96
  with col1:
97
  if st.form_submit_button("Save"):
98
- st.session_state.messages[i]["text"] = new_content
99
  del st.session_state.edit_index
100
  st.rerun()
101
  with col2:
@@ -105,9 +106,9 @@ def display_messages():
105
 
106
  if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
107
  if len(st.session_state.messages) > 0: # メッセージリストが空でないことを確認
108
- prompt = st.session_state.messages[-1]["text"]
109
  response = get_ai_response(prompt, st.session_state.messages[:-1])
110
- st.session_state.messages.append({"role": "CHATBOT", "text": response})
111
  st.session_state.retry_flag = False
112
  st.rerun()
113
  else:
@@ -118,12 +119,12 @@ with st.sidebar:
118
  # Copy Conversation History button
119
  log_text = ""
120
  for message in st.session_state.messages:
121
- if message["role"] == "USER":
122
  log_text += "<USER>\n"
123
- log_text += message["text"] + "\n\n"
124
  else:
125
  log_text += "<ASSISTANT>\n"
126
- log_text += message["text"] + "\n\n"
127
  log_text = log_text.rstrip("\n")
128
 
129
  # PythonでURLエンコード
@@ -144,11 +145,10 @@ with st.sidebar:
144
  html(copy_log_button_html, height=50)
145
 
146
  st.header("Advanced Settings")
147
- model = st.selectbox("Model", options=["command-r-plus", "command-r"], index=0)
148
- preamble = st.text_area("Preamble", height=100)
149
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
150
- k = st.slider("Top-K", min_value=0, max_value=500, value=0, step=1)
151
- p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
152
  penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
153
  penalty_value = st.slider("Penalty Value", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
154
 
@@ -162,9 +162,9 @@ with st.sidebar:
162
  for message in messages:
163
  if message.strip() in ["<USER>", "<ASSISTANT>"]:
164
  if role and text:
165
- st.session_state.messages.append({"role": role, "text": text.strip()})
166
  text = ""
167
- role = "USER" if message.strip() == "<USER>" else "CHATBOT"
168
  else:
169
  text += message
170
  if role and text:
@@ -181,7 +181,7 @@ with st.sidebar:
181
  if st.button("Update API Key"):
182
  if new_api_key:
183
  st.session_state.api_key = new_api_key
184
- client = cohere.Client(api_key=new_api_key)
185
  st.success("API Key updated successfully!")
186
  else:
187
  st.warning("Please enter a valid API Key.")
@@ -196,6 +196,6 @@ if prompt := st.chat_input("What is up?"):
196
 
197
  response = get_ai_response(prompt, chat_history)
198
 
199
- st.session_state.messages.append({"role": "USER", "text": prompt})
200
- st.session_state.messages.append({"role": "CHATBOT", "text": response})
201
  st.rerun()
 
1
+ from premai import Prem
2
  import streamlit as st
3
  from streamlit.components.v1 import html
4
  import re
 
6
 
7
  st.title("Prem Chat UI")
8
 
9
+ if "api_key" not in st.session_state and "project_id" not in st.session_state:
10
  api_key = st.text_input("Enter your API Key", type="password")
11
+ project_id = st.text_input("Enter your project ID")
12
+ if api_key and project_id:
13
  st.session_state.api_key = api_key
14
+ st.session_state.project_id = project_id
15
+ client = Prem(api_key=api_key)
16
  st.rerun()
17
  else:
18
+ st.warning("Please enter your API key and Project ID to use the app.")
19
  st.stop()
20
  else:
21
+ client = Prem(api_key=st.session_state.api_key)
22
 
23
  if "messages" not in st.session_state:
24
  st.session_state.messages = []
25
 
26
+ def get_ai_response(prompt, messages):
27
+ with st.chat_message("assistant"):
28
  penalty_kwargs = {
29
  "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
30
  }
31
 
32
+ response = client.chat.completions.create(
33
+ project_id=st.session_state.project_id,
34
+ messages=messages,
35
+ stream=True,
36
  model=model,
37
+ system_prompt=system_prompt,
 
38
  temperature=temperature,
39
+ top_p=top_p,
40
+ **penalty_kwargs,
 
41
  )
42
 
43
  response = ""
44
  placeholder = st.empty()
45
+ for chunk in response:
46
+ if chunk.choices[0].delta["content"]:
47
+ content = chunk.choices[0].delta["content"]
48
  response += content
49
  placeholder.markdown(response)
50
 
 
52
 
53
  def display_messages():
54
  for i, message in enumerate(st.session_state.messages):
55
+ with st.chat_message(message["role"]):
56
+ st.markdown(message["content"])
 
57
  col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
58
  with col1:
59
  if st.button("Edit", key=f"edit_{i}_{len(st.session_state.messages)}"):
 
64
  del st.session_state.messages[i]
65
  st.rerun()
66
  with col3:
67
+ text_to_copy = message["content"]
68
  # PythonでURLエンコード
69
  text_to_copy_escaped = urllib.parse.quote(text_to_copy)
70
 
 
82
  """
83
  html(copy_button_html, height=50)
84
 
85
+ if i == len(st.session_state.messages) - 1 and message["role"] == "assistant":
86
  with col4:
87
  if st.button("Retry", key=f"retry_{i}_{len(st.session_state.messages)}"):
88
  if len(st.session_state.messages) >= 2:
 
96
  col1, col2 = st.columns([1, 1])
97
  with col1:
98
  if st.form_submit_button("Save"):
99
+ st.session_state.messages[i]["content"] = new_content
100
  del st.session_state.edit_index
101
  st.rerun()
102
  with col2:
 
106
 
107
  if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
108
  if len(st.session_state.messages) > 0: # メッセージリストが空でないことを確認
109
+ prompt = st.session_state.messages[-1]["content"]
110
  response = get_ai_response(prompt, st.session_state.messages[:-1])
111
+ st.session_state.messages.append({"role": "assistant", "content": response})
112
  st.session_state.retry_flag = False
113
  st.rerun()
114
  else:
 
119
  # Copy Conversation History button
120
  log_text = ""
121
  for message in st.session_state.messages:
122
+ if message["role"] == "user":
123
  log_text += "<USER>\n"
124
+ log_text += message["content"] + "\n\n"
125
  else:
126
  log_text += "<ASSISTANT>\n"
127
+ log_text += message["content"] + "\n\n"
128
  log_text = log_text.rstrip("\n")
129
 
130
  # PythonでURLエンコード
 
145
  html(copy_log_button_html, height=50)
146
 
147
  st.header("Advanced Settings")
148
+ model = st.selectbox("Model", options=["claude-3-haiku", "command-r-plus"], index=0)
149
+ system_prompt = st.text_area("System prompt", height=100)
150
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
151
+ top_p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
 
152
  penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
153
  penalty_value = st.slider("Penalty Value", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
154
 
 
162
  for message in messages:
163
  if message.strip() in ["<USER>", "<ASSISTANT>"]:
164
  if role and text:
165
+ st.session_state.messages.append({"role": role, "content": text.strip()})
166
  text = ""
167
+ role = "user" if message.strip() == "<USER>" else "assistant"
168
  else:
169
  text += message
170
  if role and text:
 
181
  if st.button("Update API Key"):
182
  if new_api_key:
183
  st.session_state.api_key = new_api_key
184
+ client = Prem(api_key=new_api_key)
185
  st.success("API Key updated successfully!")
186
  else:
187
  st.warning("Please enter a valid API Key.")
 
196
 
197
  response = get_ai_response(prompt, chat_history)
198
 
199
+ st.session_state.messages.append({"role": "user", "content": prompt})
200
+ st.session_state.messages.append({"role": "assistant", "content": response})
201
  st.rerun()