Nyanfa commited on
Commit
4bc123e
·
verified ·
1 Parent(s): 6158f66

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cohere
2
+ import streamlit as st
3
+ import re
4
+
5
+ st.title("Cohere Chat UI")
6
+
7
+ api_key = st.sidebar.text_input("API Key", type="password")
8
+ if api_key:
9
+ client = cohere.Client(api_key=api_key)
10
+ else:
11
+ st.warning("Please enter your API key to use the app.")
12
+ st.stop()
13
+
14
+ if "messages" not in st.session_state:
15
+ st.session_state.messages = []
16
+
17
+ def display_messages():
18
+ for i, message in enumerate(st.session_state.messages):
19
+ name = "user" if message["role"] == "USER" else "ai"
20
+ with st.chat_message(name):
21
+ st.markdown(message["text"])
22
+ col1, col2, col3 = st.columns([1, 1, 1])
23
+ with col1:
24
+ if st.button("Edit", key=f"edit_{i}_{len(st.session_state.messages)}"):
25
+ st.session_state.edit_index = i
26
+ st.rerun()
27
+ with col2:
28
+ if st.button("Delete", key=f"delete_{i}_{len(st.session_state.messages)}"):
29
+ del st.session_state.messages[i]
30
+ st.rerun()
31
+ with col3:
32
+ if f"copy_state_{i}" not in st.session_state:
33
+ st.session_state[f"copy_state_{i}"] = False
34
+ if st.button("Copy", key=f"copy_{i}_{len(st.session_state.messages)}"):
35
+ st.session_state[f"copy_state_{i}"] = not st.session_state[f"copy_state_{i}"]
36
+ if st.session_state[f"copy_state_{i}"]:
37
+ copy_text = message["text"]
38
+ st.code(copy_text, language='plain')
39
+
40
+ if "edit_index" in st.session_state and st.session_state.edit_index == i:
41
+ with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
42
+ new_content = st.text_area("Edit message", value=st.session_state.messages[i]["text"])
43
+ if st.form_submit_button("Save"):
44
+ st.session_state.messages[i]["text"] = new_content
45
+ del st.session_state.edit_index
46
+ st.rerun()
47
+ display_messages()
48
+
49
+ # Add sidebar for advanced settings
50
+ with st.sidebar:
51
+ st.header("Advanced Settings")
52
+ model = st.selectbox("Model", options=["command-r-plus", "command-r"], index=0)
53
+ preamble = st.text_area("Preamble", height=100)
54
+ temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
55
+ penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
56
+ penalty_value = st.slider("Penalty Value", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
57
+
58
+ if "show_log" not in st.session_state:
59
+ st.session_state.show_log = False
60
+
61
+ if st.button("Copy Log"):
62
+ st.session_state.show_log = not st.session_state.show_log
63
+
64
+ if st.session_state.show_log:
65
+ log_text = ""
66
+ for message in st.session_state.messages:
67
+ if message["role"] == "USER":
68
+ log_text += "Human\n"
69
+ log_text += message["text"] + "\n\n"
70
+ else:
71
+ log_text += "Assistant\n"
72
+ log_text += message["text"] + "\n\n"
73
+ log_text = log_text.rstrip("\n")
74
+ st.code(log_text, language='plain')
75
+
76
+ st.header("Restore History")
77
+ history_input = st.text_area("Paste conversation history:", height=200)
78
+ if st.button("Restore History"):
79
+ st.session_state.messages = []
80
+ messages = re.split(r"^(Human|Assistant)\n", history_input, flags=re.MULTILINE)
81
+ role = None
82
+ text = ""
83
+ for message in messages:
84
+ if message.strip() in ["Human", "Assistant"]:
85
+ if role and text:
86
+ st.session_state.messages.append({"role": role, "text": text.strip()})
87
+ text = ""
88
+ role = "USER" if message.strip() == "Human" else "CHATBOT"
89
+ else:
90
+ text += message
91
+ if role and text:
92
+ st.session_state.messages.append({"role": role, "text": text.strip()})
93
+ st.rerun()
94
+
95
+ if prompt := st.chat_input("What is up?"):
96
+ chat_history = st.session_state.messages.copy()
97
+
98
+ with st.chat_message("user"):
99
+ st.write(prompt)
100
+
101
+ with st.chat_message("ai"):
102
+ penalty_kwargs = {
103
+ "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
104
+ }
105
+ stream = client.chat_stream(
106
+ message=prompt,
107
+ model=model,
108
+ preamble=preamble,
109
+ chat_history=chat_history,
110
+ temperature=temperature,
111
+ **penalty_kwargs
112
+ )
113
+
114
+ response = ""
115
+ placeholder = st.empty()
116
+ for event in stream:
117
+ if event.event_type == "text-generation":
118
+ content = event.text
119
+ response += content
120
+ placeholder.markdown(response)
121
+
122
+ st.session_state.messages.append({"role": "USER", "text": prompt})
123
+ st.session_state.messages.append({"role": "CHATBOT", "text": response})
124
+ st.rerun()