File size: 20,154 Bytes
4bc123e
 
0d47da9
4f4cd9b
4bc123e
0d47da9
bbd6eac
 
 
4bc123e
 
 
561c31a
 
 
5b4675a
95f42b1
bbd6eac
95f42b1
 
 
 
561c31a
 
 
4bc123e
bbd6eac
4bc123e
 
 
 
bbd6eac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f4cd9b
df388cb
4f4cd9b
bbd6eac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561c31a
 
 
bbd6eac
 
 
 
 
 
 
 
 
561c31a
bbd6eac
 
 
 
a4a272c
bbd6eac
 
 
 
4f4cd9b
561c31a
df388cb
 
 
 
 
 
 
 
 
 
 
 
 
 
63a5099
 
bbd6eac
 
 
 
4f4cd9b
703c3e3
 
 
63a5099
4f4cd9b
 
 
561c31a
703c3e3
 
 
 
63a5099
0dda2e1
 
 
 
63a5099
0dda2e1
63a5099
4bc123e
 
bbd6eac
 
 
703c3e3
 
63a5099
 
703c3e3
0dda2e1
 
63a5099
 
561c31a
4bc123e
 
 
 
 
ea5dae0
4bc123e
 
 
bbd6eac
63a5099
0d47da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bc123e
bbd6eac
561c31a
 
970ed2d
 
 
 
561c31a
4bc123e
 
bbd6eac
561c31a
 
 
bbd6eac
561c31a
 
 
 
 
 
4bc123e
 
 
5e547c5
703c3e3
5e547c5
 
 
 
 
 
bbd6eac
5e547c5
bbd6eac
5e547c5
 
bbd6eac
5e547c5
 
 
 
 
 
 
0d47da9
5e547c5
bbd6eac
5e547c5
bbd6eac
 
5e547c5
 
 
bbd6eac
 
 
 
 
 
 
 
5e547c5
 
 
72e2335
 
 
 
 
 
 
 
 
 
0d47da9
5e547c5
 
 
540e53f
 
 
 
 
 
5e547c5
 
 
 
 
 
bbd6eac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e547c5
 
 
 
 
 
 
 
 
 
bbd6eac
5e547c5
bbd6eac
5e547c5
 
 
bbd6eac
5e547c5
 
 
 
 
 
 
 
 
 
 
 
bbd6eac
5e547c5
4bc123e
5e547c5
 
 
 
bbd6eac
5e547c5
 
 
 
 
 
 
de74109
 
 
5e547c5
de74109
0d47da9
5e547c5
de74109
 
 
 
5e547c5
561c31a
4f4cd9b
0dda2e1
bbd6eac
4f4cd9b
 
 
 
 
5e547c5
 
 
 
 
de74109
540e53f
de74109
 
 
103ce2f
de74109
 
540e53f
de74109
 
 
540e53f
de74109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561c31a
 
4f4cd9b
0dda2e1
4f4cd9b
bbd6eac
 
 
4f4cd9b
 
 
 
 
703c3e3
bbd6eac
4bc123e
4f4cd9b
703c3e3
 
 
 
de74109
703c3e3
 
bbd6eac
4f4cd9b
bbd6eac
4bc123e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
import cohere
import streamlit as st
from streamlit.components.v1 import html
from streamlit_extras.stylable_container import stylable_container
import re
import urllib.parse
from langchain_text_splitters import RecursiveCharacterTextSplitter
import numpy as np
import pypdfium2 as pdfium

st.title("Cohere Chat UI")

if "api_key" not in st.session_state:
    api_key = st.text_input("Enter your API Key", type="password")
    if api_key:
        if api_key.isascii():
            st.session_state.api_key = api_key
            client = cohere.ClientV2(api_key=api_key)
            st.rerun()
        else:
            st.warning("Please enter your API key correctly.")
            st.stop()
    else:
        st.warning("Please enter your API key to use the app. You can obtain your API key from here: https://dashboard.cohere.com/api-keys")
        st.stop()
else:
    client = cohere.ClientV2(api_key=st.session_state.api_key)

if "messages" not in st.session_state:
    st.session_state.messages = []

if "rag_file_key" not in st.session_state:
    st.session_state.rag_file_key = None

if "rag_embedded" not in st.session_state:
    st.session_state.rag_embedded = False

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,
    chunk_overlap=50,
    length_function=len,
    is_separator_regex=False,
)

def batch_embed(texts, batch_size=96):
    all_embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        response = client.embed(
            texts=batch,
            model=embed_model,
            input_type="search_document",
            embedding_types=['float']
        )
        all_embeddings.extend(response.embeddings.float)
    return all_embeddings

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def get_ai_response(chat_history):
    st.session_state.is_streaming = True
    st.session_state.response = ""
    
    with st.chat_message("assistant", avatar=st.session_state.assistant_avatar):
        # RAG
        if st.session_state.get("rag_chunks") and st.session_state.get("rag_embeddings"):
            chunks = st.session_state.rag_chunks
            embeddings = st.session_state.rag_embeddings

            vector_database = {i: np.array(embedding) for i, embedding in enumerate(embeddings)}

            query = chat_history[-1]["content"]
            query_embedding = client.embed(texts=[query], model=embed_model, input_type="search_query", embedding_types=['float']).embeddings.float[0]

            similarities = [cosine_similarity(query_embedding, chunk) for chunk in embeddings]
            top_indices = np.argsort(similarities)[::-1][:10]
            top_chunks_after_retrieval = [chunks[i] for i in top_indices]

            rerank_response = client.rerank(query=query, documents=top_chunks_after_retrieval, top_n=3, model=rerank_model)
            top_chunks_after_rerank = [top_chunks_after_retrieval[result.index] for result in rerank_response.results]
            documents = [{"data": {"title": f"chunk {i}", "snippet": chunk}} for i, chunk in enumerate(top_chunks_after_rerank)]
        
        penalty_kwargs = {
            "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
        }

        chat_history.insert(0, {"role": "system", "content": preamble})

        stream_kwargs = {
            "messages": chat_history,
            "model": model,
            "temperature": temperature,
            "k": k,
            "p": p,
            **penalty_kwargs
        }
        
        if st.session_state.get("rag_text"):
            stream_kwargs["documents"] = documents
            stream_kwargs["citation_options"] = {"mode": "OFF"}
        elif model in ["command-r-08-2024", "command-r-plus-08-2024"]:
            stream_kwargs["safety_mode"] = "OFF"
        
        stream = client.chat_stream(**stream_kwargs)

        placeholder = st.empty()

        with stylable_container(
            key="stop_generating",
            css_styles="""
                button {
                    position: fixed;
                    bottom: 100px;
                    left: 50%;
                    transform: translateX(-50%);
                    z-index: 1;
                }
                """,
        ):
            st.button("Stop generating")

        shown_message = ""

        for chunk in stream:
            if chunk.type == "content-delta":
                content = chunk.delta.message.content.text
                st.session_state.response += content
                shown_message += content.replace("\n", "  \n")\
                                        .replace("<", "\\<")\
                                        .replace(">", "\\>")
                placeholder.markdown(shown_message)

    st.session_state.is_streaming = False
    return st.session_state.response

def normalize_code_block(match):
    return match.group(0).replace("  \n", "\n")\
                         .replace("\\<", "<")\
                         .replace("\\>", ">")

def normalize_inline(match):
    return match.group(0).replace("\\<", "<")\
                         .replace("\\>", ">")

code_block_pattern = r"(```.*?```)"
inline_pattern = r"`([^`\n]+?)`"

def display_messages():
    for i, message in enumerate(st.session_state.messages):
        avatar = st.session_state.user_avatar if message["role"] == "user" else st.session_state.assistant_avatar
        with st.chat_message(message["role"], avatar=avatar):
            shown_message = message["content"].replace("\n", "  \n")\
                                           .replace("<", "\\<")\
                                           .replace(">", "\\>")
            if "```" in shown_message:
                # Replace "  \n" with "\n" within code blocks
                shown_message = re.sub(code_block_pattern, normalize_code_block, shown_message, flags=re.DOTALL)
            if "`" in shown_message:
                shown_message = re.sub(inline_pattern, normalize_inline, shown_message)
            st.markdown(shown_message)
            
            col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
            with col1:
                if st.button("Edit", key=f"edit_{i}_{len(st.session_state.messages)}"):
                    st.session_state.edit_index = i
                    st.rerun()
            with col2:
                if st.session_state.is_delete_mode and st.button("Delete", key=f"delete_{i}_{len(st.session_state.messages)}"):
                    del st.session_state.messages[i]
                    st.rerun()
            with col3:
                text_to_copy = message["content"]
                # Encode the string to escape
                text_to_copy_escaped = urllib.parse.quote(text_to_copy)

                copy_button_html = f"""
                <button id="copy-msg-btn-{i}" style='font-size: 1em; padding: 0.5em;' onclick='copyMessage("{i}")'>Copy</button>
                
                <script>
                function copyMessage(index) {{
                    navigator.clipboard.writeText(decodeURIComponent("{text_to_copy_escaped}"));
                    let copyBtn = document.getElementById("copy-msg-btn-" + index);
                    copyBtn.innerHTML = "Copied!";
                    setTimeout(function(){{ copyBtn.innerHTML = "Copy"; }}, 2000);
                }}
                </script>
                """
                html(copy_button_html, height=50)

            if i == len(st.session_state.messages) - 1 and message["role"] == "assistant":
                with col4:
                    if st.button("Retry", key=f"retry_{i}_{len(st.session_state.messages)}"):
                        if len(st.session_state.messages) >= 2:
                            del st.session_state.messages[-1]
                            st.session_state.retry_flag = True
                            st.rerun()
        
        if "edit_index" in st.session_state and st.session_state.edit_index == i:
            with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
                new_content = st.text_area("Edit message", height=200, value=st.session_state.messages[i]["content"])
                col1, col2 = st.columns([1, 1])
                with col1:
                    if st.form_submit_button("Save"):
                        st.session_state.messages[i]["content"] = new_content
                        del st.session_state.edit_index
                        st.rerun()
                with col2:
                    if st.form_submit_button("Cancel"):
                        del st.session_state.edit_index
                        st.rerun()

# Add sidebar for advanced settings
with st.sidebar:
    settings_tab, appearance_tab = st.tabs(["Settings", "Appearance"])

    with settings_tab:
        st.markdown("Help (Japanese): https://rentry.org/9hgneofz")
    
        # Copy Conversation History button
        log_text = ""
        for message in st.session_state.messages:
            if message["role"] == "user":
                log_text += "<USER>\n"
                log_text += message["content"] + "\n\n"
            else:
                log_text += "<ASSISTANT>\n"
                log_text += message["content"] + "\n\n"
        log_text = log_text.rstrip("\n")
            
        # Encode the string to escape
        log_text_escaped = urllib.parse.quote(log_text)
    
        copy_log_button_html = f"""
        <button id="copy-log-btn" style='font-size: 1em; padding: 0.5em;' onclick='copyLog()'>Copy Conversation History</button>
        
        <script>
        const log_text_escaped = "{log_text_escaped}";
        function copyLog() {{
            navigator.clipboard.writeText(decodeURIComponent(log_text_escaped));
            const copyBtn = document.getElementById("copy-log-btn");
            copyBtn.innerHTML = "Copied!";
            setTimeout(function(){{ copyBtn.innerHTML = "Copy Conversation History"; }}, 2000);
        }}
        window.parent.document.addEventListener('keydown', (e) => {{
            if ( e.code == "Pause" ){{
                window.parent.navigator.clipboard.writeText(decodeURIComponent(log_text_escaped));
                const copyBtn = document.getElementById("copy-log-btn");
                copyBtn.innerHTML = "Copied!";
                setTimeout(function(){{ copyBtn.innerHTML = "Copy Conversation History"; }}, 2000);
            }}
        }} , false);
        </script>
        """
        html(copy_log_button_html, height=50)

        if st.session_state.get("is_history_shown") != True:
            if st.button("Display History as Code Block"):
                st.session_state.is_history_shown = True
                st.rerun()
        else:
            if st.button("Hide History"):
                st.session_state.is_history_shown = False
                st.rerun()
            st.code(log_text)
    
        st.session_state.is_delete_mode = st.toggle("Enable Delete button")
        
        st.header("Advanced Settings")
        model = st.selectbox("Model", options=["command-a-03-2025",
                                               "command-r-plus",
                                               "command-r",
                                               "command-r-plus-08-2024",
                                               "command-r-08-2024",
                                              ], index=0)
        preamble = st.text_area("Preamble", height=200)
        temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
        k = st.slider("Top-K", min_value=0, max_value=500, value=0, step=1)
        p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
        penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
        penalty_value = st.slider("Penalty Value", min_value=0.0, max_value=1.0, value=0.0, step=0.1)

        st.header("RAG")
        st.markdown("Select the model and encoding before uploading the file.")
        rag_model = st.selectbox("RAG Model", options=["Multilingual", "English"], index=0)
        file_encoding = st.selectbox("Encoding", options=["utf_8", "shift_jis"], index=0)
        st.session_state.rag_file = st.file_uploader("Choose a txt or pdf file", type=["txt", "pdf"], key="rag_file_uploader")

        if rag_model == "Multilingual":
            embed_model = "embed-multilingual-v3.0"
            rerank_model = "rerank-multilingual-v3.0"
        else:
            embed_model = "embed-english-v3.0"
            rerank_model = "rerank-english-v3.0"
        
        if st.session_state.rag_file is not None:
            if st.session_state.rag_file_key != st.session_state.rag_file:
                st.session_state.rag_file_key = st.session_state.rag_file
                st.session_state.rag_embedded = False
                if "rag_text" in st.session_state:
                    del st.session_state.rag_text
                if "rag_chunks" in st.session_state:
                    del st.session_state.rag_chunks
                if "rag_embeddings" in st.session_state:
                    del st.session_state.rag_embeddings
        
            if not st.session_state.rag_embedded:
                if st.session_state.rag_file.type == "application/pdf":
                    pdf = pdfium.PdfDocument(st.session_state.rag_file)
                    st.session_state.rag_text = ""
                    for page in pdf:
                        textpage = page.get_textpage()
                        st.session_state.rag_text += textpage.get_text_range()
                else:
                    st.session_state.rag_text = st.session_state.rag_file.read().decode(file_encoding)
                chunks_ = text_splitter.create_documents([st.session_state.rag_text])
                chunks = [c.page_content for c in chunks_]
                embeddings = batch_embed(chunks)
                st.session_state.rag_chunks = chunks
                st.session_state.rag_embeddings = embeddings
                st.session_state.rag_embedded = True
        else:
            st.session_state.rag_file_key = None
            st.session_state.rag_embedded = False
            if "rag_text" in st.session_state:
                del st.session_state.rag_text
            if "rag_chunks" in st.session_state:
                del st.session_state.rag_chunks
            if "rag_embeddings" in st.session_state:
                del st.session_state.rag_embeddings
        
        st.header("Restore History")
        history_input = st.text_area("Paste conversation history:", height=200)
        if st.button("Restore History"):
            st.session_state.messages = []
            messages = re.split(r"^(<USER>|<ASSISTANT>)\n", history_input, flags=re.MULTILINE)
            role = None
            text = ""
            for message in messages:
                if message.strip() in ["<USER>", "<ASSISTANT>"]:
                    if role and text:
                        st.session_state.messages.append({"role": role, "content": text.strip()})
                        text = ""
                    role = "user" if message.strip() == "<USER>" else "assistant"
                else:
                    text += message
            if role and text:
                st.session_state.messages.append({"role": role, "content": text.strip()})
            st.rerun()
    
        st.header("Clear History")
        if st.button("Clear Chat History"):
            st.session_state.messages = []
            st.rerun()
        
        st.header("Change API Key")
        new_api_key = st.text_input("Enter new API Key", type="password")
        if st.button("Update API Key"):
            if new_api_key and new_api_key.isascii():
                st.session_state.api_key = new_api_key
                client = cohere.ClientV2(api_key=new_api_key)
                st.success("API Key updated successfully!")
            else:
                st.warning("Please enter a valid API Key.")
                
    with appearance_tab:
        st.header("Font Selection")
        
        font_options = {
            "Zen Maru Gothic": "Zen Maru Gothic",
            "Noto Sans JP": "Noto Sans JP",
            "Sawarabi Mincho": "Sawarabi Mincho"
        }
        selected_font = st.selectbox("Choose a font", ["Default"] + list(font_options.keys()))
        
        st.header("Change the font size")
        st.session_state.font_size  = st.slider("Font size", min_value=16.0, max_value=50.0, value=16.0, step=1.0)
        
        st.header("Change the user's icon")
        st.session_state.user_avatar = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg", "webp", "gif", "bmp", "svg",], key="user_avatar_uploader")
    
        st.header("Change the assistant's icon")
        st.session_state.assistant_avatar = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg", "webp", "gif", "bmp", "svg",], key="assistant_avatar_uploader")

        st.header("Change the icon size")
        st.session_state.avatar_size  = st.slider("Icon size", min_value=2.0, max_value=20.0, value=2.0, step=0.2)


# After Stop generating
if st.session_state.get("is_streaming"):
    st.session_state.messages.append({"role": "assistant", "content": st.session_state.response})
    st.session_state.is_streaming = False
    if "retry_flag" in st.session_state and st.session_state.retry_flag:
        st.session_state.retry_flag = False
    st.rerun()

if selected_font != "Default":
    with open("style.css") as css:
        st.markdown(f'<style>{css.read()}</style>', unsafe_allow_html=True)
    st.markdown(f'<style>body * {{ font-family: "{font_options[selected_font]}", serif !important; }}</style>', unsafe_allow_html=True)

# Change font size
st.markdown(f'<style>[data-testid="stChatMessageContent"] .st-emotion-cache-kj6hex p{{font-size: {st.session_state.font_size}px;}}</style>', unsafe_allow_html=True)

# Change icon size
# (CSS element names may be subject to change.)
# (Contributor: ★31 >>538)
AVATAR_SIZE_STYLE = f"""
                <style>
                    [data-testid="stChatMessageAvatarUser"] {{
                        width: {st.session_state.avatar_size}rem;
                        height: {st.session_state.avatar_size}rem;
                    }}
                    [data-testid="stChatMessageAvatarAssistant"] {{
                        width: {st.session_state.avatar_size}rem;
                        height: {st.session_state.avatar_size}rem;
                    }}
                    [data-testid="stChatMessage"] .st-emotion-cache-1pbsqtx {{
                        width: {st.session_state.avatar_size / 1.6}rem;
                        height: {st.session_state.avatar_size / 1.6}rem;
                    }}
                    [data-testid="stChatMessage"] .st-emotion-cache-p4micv {{
                        width: {st.session_state.avatar_size}rem;
                        height: {st.session_state.avatar_size}rem;
                    }}
                </style>
"""
st.markdown(AVATAR_SIZE_STYLE, unsafe_allow_html=True)

display_messages()

# After Retry
if st.session_state.get("retry_flag"):
    if len(st.session_state.messages) > 0:
        messages = st.session_state.messages.copy()
        response = get_ai_response(messages)
        st.session_state.messages.append({"role": "assistant", "content": response})
        st.session_state.retry_flag = False
        st.rerun()
    else:
        st.session_state.retry_flag = False

if prompt := st.chat_input("Enter your message here..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    chat_history = st.session_state.messages.copy()
    
    shown_message = prompt.replace("\n", "  \n")\
                          .replace("<", "\\<")\
                          .replace(">", "\\>")

    with st.chat_message("user", avatar=st.session_state.user_avatar):
        st.write(shown_message)
    
    response = get_ai_response(chat_history)
    
    st.session_state.messages.append({"role": "assistant", "content": response})
    st.rerun()