Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from openai import OpenAI | |
| import time | |
| import html | |
| def predict(message, history, character, api_key, progress=gr.Progress()): | |
| client = OpenAI(api_key=api_key) | |
| history_openai_format = [] | |
| for human, assistant in history: | |
| history_openai_format.append({"role": "user", "content": human}) | |
| history_openai_format.append({"role": "assistant", "content": assistant}) | |
| history_openai_format.append({"role": "user", "content": message}) | |
| response = client.chat.completions.create( | |
| model='gpt-4', | |
| messages=history_openai_format, | |
| temperature=1.0, | |
| stream=True | |
| ) | |
| partial_message = "" | |
| for chunk in progress.tqdm(response, desc="Generating"): | |
| if chunk.choices[0].delta.content: | |
| partial_message += chunk.choices[0].delta.content | |
| yield partial_message | |
| time.sleep(0.01) | |
| def format_history(history): | |
| html_content = "" | |
| for human, ai in history: | |
| human_formatted = html.escape(human).replace('\n', '<br>') | |
| html_content += f'<div class="message user-message"><strong>You:</strong> {human_formatted}</div>' | |
| if ai: | |
| ai_formatted = html.escape(ai).replace('\n', '<br>') | |
| html_content += f'<div class="message ai-message"><strong>AI:</strong> {ai_formatted}</div>' | |
| return html_content | |
| css = """ | |
| #chat-display { | |
| height: 600px; | |
| overflow-y: auto; | |
| border: 1px solid #ccc; | |
| padding: 10px; | |
| margin-bottom: 10px; | |
| } | |
| #chat-display::-webkit-scrollbar { | |
| width: 10px; | |
| } | |
| #chat-display::-webkit-scrollbar-track { | |
| background: #f1f1f1; | |
| } | |
| #chat-display::-webkit-scrollbar-thumb { | |
| background: #888; | |
| } | |
| #chat-display::-webkit-scrollbar-thumb:hover { | |
| background: #555; | |
| } | |
| .message { | |
| margin-bottom: 10px; | |
| word-wrap: break-word; | |
| overflow-wrap: break-word; | |
| } | |
| .user-message, .ai-message { | |
| padding: 5px; | |
| border-radius: 5px; | |
| max-height: 300px; | |
| overflow-y: auto; | |
| } | |
| .user-message { | |
| background-color: #e6f3ff; | |
| } | |
| .ai-message { | |
| background-color: #f0f0f0; | |
| } | |
| .user-message::-webkit-scrollbar, .ai-message::-webkit-scrollbar { | |
| width: 5px; | |
| } | |
| .user-message::-webkit-scrollbar-thumb, .ai-message::-webkit-scrollbar-thumb { | |
| background: #888; | |
| } | |
| """ | |
| js = """ | |
| let lastScrollTop = 0; | |
| let isNearBottom = true; | |
| function updateScroll() { | |
| const chatDisplay = document.getElementById('chat-display'); | |
| if (!chatDisplay) return; | |
| const currentScrollTop = chatDisplay.scrollTop; | |
| const scrollHeight = chatDisplay.scrollHeight; | |
| const clientHeight = chatDisplay.clientHeight; | |
| // Check if user was near bottom before update | |
| isNearBottom = (currentScrollTop + clientHeight >= scrollHeight - 50); | |
| if (isNearBottom) { | |
| chatDisplay.scrollTop = scrollHeight; | |
| } else { | |
| chatDisplay.scrollTop = lastScrollTop; | |
| } | |
| lastScrollTop = chatDisplay.scrollTop; | |
| } | |
| // Set up a MutationObserver to watch for changes in the chat display | |
| const observer = new MutationObserver(updateScroll); | |
| const config = { childList: true, subtree: true }; | |
| // Start observing the chat display for configured mutations | |
| document.addEventListener('DOMContentLoaded', (event) => { | |
| const chatDisplay = document.getElementById('chat-display'); | |
| if (chatDisplay) { | |
| observer.observe(chatDisplay, config); | |
| // Also update scroll on manual scroll | |
| chatDisplay.addEventListener('scroll', function() { | |
| lastScrollTop = chatDisplay.scrollTop; | |
| isNearBottom = (chatDisplay.scrollTop + chatDisplay.clientHeight >= chatDisplay.scrollHeight - 50); | |
| }); | |
| } | |
| // Add event listener for Enter key | |
| const textbox = document.querySelector('#component-13 input'); // Update this selector if needed | |
| if (textbox) { | |
| textbox.addEventListener('keydown', function(e) { | |
| if (e.key === 'Enter' && !e.shiftKey) { | |
| e.preventDefault(); | |
| document.querySelector('#component-13 button').click(); | |
| } | |
| }); | |
| } | |
| }); | |
| """ | |
| def user(user_message, history, character, api_key): | |
| if user_message.strip() == "": | |
| return "", history, format_history(history) | |
| history.append([user_message, None]) | |
| formatted_history = format_history(history) | |
| # Start bot response generation | |
| bot_message_generator = predict(user_message, history[:-1], character, api_key) | |
| for chunk in bot_message_generator: | |
| history[-1][1] = chunk | |
| formatted_history = format_history(history) | |
| yield "", history, formatted_history | |
| with gr.Blocks(css=css, js=js) as demo: | |
| gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>My Chatbot</h1>") | |
| chat_history = gr.State([]) | |
| chat_display = gr.HTML(elem_id="chat-display") | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your message", | |
| lines=1, | |
| placeholder="Type your message here... (Press Enter to send)", | |
| elem_id="user-input" | |
| ) | |
| send_btn = gr.Button("Send") | |
| clear = gr.Button("Clear") | |
| dropdown = gr.Dropdown( | |
| ["Character 1", "Character 2", "Character 3", "Character 4", "Character 5", "Character 6", "Character 7", "Character 8", "Character 9", "Character 10", "Character 11", "Character 12", "Character 13"], | |
| label="Characters", | |
| info="Select the character that you'd like to speak to", | |
| value="Character 1" | |
| ) | |
| api_key = gr.Textbox(type="password", label="OpenAI API Key") | |
| send_btn.click(user, [msg, chat_history, dropdown, api_key], [msg, chat_history, chat_display]) | |
| msg.submit(user, [msg, chat_history, dropdown, api_key], [msg, chat_history, chat_display]) | |
| clear.click(lambda: ([], []), None, [chat_history, chat_display], queue=False) | |
| dropdown.change(lambda x: ([], []), dropdown, [chat_history, chat_display]) | |
| demo.queue() | |
| demo.launch(max_threads=20) |