Nyanfa commited on
Commit
f908fa8
·
verified ·
1 Parent(s): 3664474

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +201 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cohere
2
+ import streamlit as st
3
+ from streamlit.components.v1 import html
4
+ import re
5
+ import urllib.parse
6
+
7
+ st.title("Prem Chat UI")
8
+
9
+ if "api_key" not in st.session_state:
10
+ api_key = st.text_input("Enter your API Key", type="password")
11
+ if api_key:
12
+ st.session_state.api_key = api_key
13
+ client = cohere.Client(api_key=api_key)
14
+ st.rerun()
15
+ else:
16
+ st.warning("Please enter your API key to use the app. You can obtain your API key from here: https://dashboard.cohere.com/api-keys")
17
+ st.stop()
18
+ else:
19
+ client = cohere.Client(api_key=st.session_state.api_key)
20
+
21
+ if "messages" not in st.session_state:
22
+ st.session_state.messages = []
23
+
24
+ def get_ai_response(prompt, chat_history):
25
+ with st.chat_message("ai"):
26
+ penalty_kwargs = {
27
+ "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
28
+ }
29
+
30
+ stream = client.chat_stream(
31
+ message=prompt,
32
+ model=model,
33
+ preamble=preamble,
34
+ chat_history=chat_history,
35
+ temperature=temperature,
36
+ k=k,
37
+ p=p,
38
+ **penalty_kwargs
39
+ )
40
+
41
+ response = ""
42
+ placeholder = st.empty()
43
+ for event in stream:
44
+ if event.event_type == "text-generation":
45
+ content = event.text
46
+ response += content
47
+ placeholder.markdown(response)
48
+
49
+ return response
50
+
51
+ def display_messages():
52
+ for i, message in enumerate(st.session_state.messages):
53
+ name = "user" if message["role"] == "USER" else "ai"
54
+ with st.chat_message(name):
55
+ st.markdown(message["text"])
56
+ col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
57
+ with col1:
58
+ if st.button("Edit", key=f"edit_{i}_{len(st.session_state.messages)}"):
59
+ st.session_state.edit_index = i
60
+ st.rerun()
61
+ with col2:
62
+ if st.button("Delete", key=f"delete_{i}_{len(st.session_state.messages)}"):
63
+ del st.session_state.messages[i]
64
+ st.rerun()
65
+ with col3:
66
+ text_to_copy = message["text"]
67
+ # PythonでURLエンコード
68
+ text_to_copy_escaped = urllib.parse.quote(text_to_copy)
69
+
70
+ copy_button_html = f"""
71
+ <button id="copy-msg-btn-{i}" style='font-size: 1em; padding: 0.5em;' onclick='copyMessage("{i}")'>Copy</button>
72
+
73
+ <script>
74
+ function copyMessage(index) {{
75
+ navigator.clipboard.writeText(decodeURIComponent("{text_to_copy_escaped}"));
76
+ let copyBtn = document.getElementById("copy-msg-btn-" + index);
77
+ copyBtn.innerHTML = "Copied!";
78
+ setTimeout(function(){{ copyBtn.innerHTML = "Copy"; }}, 2000);
79
+ }}
80
+ </script>
81
+ """
82
+ html(copy_button_html, height=50)
83
+
84
+ if i == len(st.session_state.messages) - 1 and message["role"] == "CHATBOT":
85
+ with col4:
86
+ if st.button("Retry", key=f"retry_{i}_{len(st.session_state.messages)}"):
87
+ if len(st.session_state.messages) >= 2:
88
+ del st.session_state.messages[-1]
89
+ st.session_state.retry_flag = True
90
+ st.rerun()
91
+
92
+ if "edit_index" in st.session_state and st.session_state.edit_index == i:
93
+ with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
94
+ new_content = st.text_area("Edit message", value=st.session_state.messages[i]["text"])
95
+ col1, col2 = st.columns([1, 1])
96
+ with col1:
97
+ if st.form_submit_button("Save"):
98
+ st.session_state.messages[i]["text"] = new_content
99
+ del st.session_state.edit_index
100
+ st.rerun()
101
+ with col2:
102
+ if st.form_submit_button("Cancel"):
103
+ del st.session_state.edit_index
104
+ st.rerun()
105
+
106
+ if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
107
+ if len(st.session_state.messages) > 0: # メッセージリストが空でないことを確認
108
+ prompt = st.session_state.messages[-1]["text"]
109
+ response = get_ai_response(prompt, st.session_state.messages[:-1])
110
+ st.session_state.messages.append({"role": "CHATBOT", "text": response})
111
+ st.session_state.retry_flag = False
112
+ st.rerun()
113
+ else:
114
+ st.session_state.retry_flag = False # retry_flagをFalseに設定して処理を続行
115
+
116
+ # Add sidebar for advanced settings
117
+ with st.sidebar:
118
+ # Copy Conversation History button
119
+ log_text = ""
120
+ for message in st.session_state.messages:
121
+ if message["role"] == "USER":
122
+ log_text += "<USER>\n"
123
+ log_text += message["text"] + "\n\n"
124
+ else:
125
+ log_text += "<ASSISTANT>\n"
126
+ log_text += message["text"] + "\n\n"
127
+ log_text = log_text.rstrip("\n")
128
+
129
+ # PythonでURLエンコード
130
+ log_text_escaped = urllib.parse.quote(log_text)
131
+
132
+ copy_log_button_html = f"""
133
+ <button id="copy-log-btn" style='font-size: 1em; padding: 0.5em;' onclick='copyLog()'>Copy Conversation History</button>
134
+
135
+ <script>
136
+ function copyLog() {{
137
+ navigator.clipboard.writeText(decodeURIComponent("{log_text_escaped}"));
138
+ let copyBtn = document.getElementById("copy-log-btn");
139
+ copyBtn.innerHTML = "Copied!";
140
+ setTimeout(function(){{ copyBtn.innerHTML = "Copy Conversation History"; }}, 2000);
141
+ }}
142
+ </script>
143
+ """
144
+ html(copy_log_button_html, height=50)
145
+
146
+ st.header("Advanced Settings")
147
+ model = st.selectbox("Model", options=["command-r-plus", "command-r"], index=0)
148
+ preamble = st.text_area("Preamble", height=100)
149
+ temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
150
+ k = st.slider("Top-K", min_value=0, max_value=500, value=0, step=1)
151
+ p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
152
+ penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
153
+ penalty_value = st.slider("Penalty Value", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
154
+
155
+ st.header("Restore History")
156
+ history_input = st.text_area("Paste conversation history:", height=200)
157
+ if st.button("Restore History"):
158
+ st.session_state.messages = []
159
+ messages = re.split(r"^(<USER>|<ASSISTANT>)\n", history_input, flags=re.MULTILINE)
160
+ role = None
161
+ text = ""
162
+ for message in messages:
163
+ if message.strip() in ["<USER>", "<ASSISTANT>"]:
164
+ if role and text:
165
+ st.session_state.messages.append({"role": role, "text": text.strip()})
166
+ text = ""
167
+ role = "USER" if message.strip() == "<USER>" else "CHATBOT"
168
+ else:
169
+ text += message
170
+ if role and text:
171
+ st.session_state.messages.append({"role": role, "text": text.strip()})
172
+ st.rerun()
173
+
174
+ st.header("Clear History")
175
+ if st.button("Clear Chat History"):
176
+ st.session_state.messages = []
177
+ st.rerun()
178
+
179
+ st.header("Change API Key")
180
+ new_api_key = st.text_input("Enter new API Key", type="password")
181
+ if st.button("Update API Key"):
182
+ if new_api_key:
183
+ st.session_state.api_key = new_api_key
184
+ client = cohere.Client(api_key=new_api_key)
185
+ st.success("API Key updated successfully!")
186
+ else:
187
+ st.warning("Please enter a valid API Key.")
188
+
189
+ display_messages()
190
+
191
+ if prompt := st.chat_input("What is up?"):
192
+ chat_history = st.session_state.messages.copy()
193
+
194
+ with st.chat_message("user"):
195
+ st.write(prompt)
196
+
197
+ response = get_ai_response(prompt, chat_history)
198
+
199
+ st.session_state.messages.append({"role": "USER", "text": prompt})
200
+ st.session_state.messages.append({"role": "CHATBOT", "text": response})
201
+ st.rerun()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ cohere