Spaces:
Build error
Build error
| import os | |
| import time | |
| from pathlib import Path | |
| import openai | |
| import pandas as pd | |
| import streamlit as st | |
| from streamlit.elements.utils import _shown_default_value_warning | |
| _shown_default_value_warning = True # https://discuss.streamlit.io/t/why-do-default-values-cause-a-session-state-warning/15485/21 | |
| st.set_page_config(page_title="ChatGPT", page_icon="π") | |
| def init_openai_settings(): | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| if os.getenv("OPENAI_PROXY"): | |
| openai.proxy = os.getenv("OPENAI_PROXY") | |
| def init_session(): | |
| if not st.session_state.get("params"): | |
| st.session_state["params"] = dict() | |
| if not st.session_state.get("chats"): | |
| st.session_state["chats"] = {} | |
| if "input" not in st.session_state: | |
| st.session_state["input"] = "Hello, how are you?" | |
| def new_chat(chat_name): | |
| if not st.session_state["chats"].get(chat_name): | |
| st.session_state["chats"][chat_name] = { | |
| "answer": [], | |
| "question": [], | |
| "messages": [ | |
| {"role": "system", "content": st.session_state["params"]["prompt"]} | |
| ], | |
| "is_delete": False, | |
| "display_name": chat_name, | |
| } | |
| return chat_name | |
| def switch_chat(chat_name): | |
| if st.session_state.get("current_chat") != chat_name: | |
| st.session_state["current_chat"] = chat_name | |
| render_chat(chat_name) | |
| st.stop() | |
| def switch_chat_name(chat_name): | |
| if st.session_state.get("current_chat") != chat_name: | |
| st.session_state["current_chat"] = chat_name | |
| render_sidebar() | |
| render_chat(chat_name) | |
| st.stop() | |
| def delete_chat(chat_name): | |
| if chat_name in st.session_state['chats']: | |
| st.session_state['chats'][chat_name]['is_delete'] = True | |
| current_chats = [chat for chat, value in st.session_state['chats'].items() if not value['is_delete']] | |
| if len(current_chats) == 0: | |
| switch_chat(new_chat(f"Chat{len(st.session_state['chats'])}")) | |
| st.stop() | |
| if st.session_state["current_chat"] == chat_name: | |
| del st.session_state["current_chat"] | |
| switch_chat_name(current_chats[0]) | |
| def edit_chat(chat_name, zone): | |
| def edit(): | |
| if not st.session_state['edited_name']: | |
| print('name is empty!') | |
| return None | |
| if (st.session_state['edited_name'] != chat_name | |
| and st.session_state['edited_name'] in st.session_state['chats']): | |
| print('name is duplicated!') | |
| return None | |
| if st.session_state['edited_name'] == chat_name: | |
| print('name is not modified!') | |
| return None | |
| st.session_state['chats'][chat_name]['display_name'] = st.session_state['edited_name'] | |
| edit_zone = zone.empty() | |
| time.sleep(0.1) | |
| with edit_zone.container(): | |
| st.text_input('New Name', st.session_state['chats'][chat_name]['display_name'], key='edited_name') | |
| column1, _, column2 = st.columns([1, 5, 1]) | |
| column1.button('β ', on_click=edit) | |
| column2.button('β') | |
| def render_sidebar_chat_management(zone): | |
| new_chat_button = zone.button(label="β New Chat", use_container_width=True) | |
| if new_chat_button: | |
| new_chat_name = f"Chat{len(st.session_state['chats'])}" | |
| st.session_state["current_chat"] = new_chat_name | |
| new_chat(new_chat_name) | |
| with st.sidebar.container(): | |
| for chat_name in st.session_state["chats"].keys(): | |
| if st.session_state['chats'][chat_name]['is_delete']: | |
| continue | |
| if chat_name == st.session_state.get('current_chat'): | |
| column1, column2, column3 = zone.columns([7, 1, 1]) | |
| column1.button( | |
| label='π¬ ' + st.session_state['chats'][chat_name]['display_name'], | |
| on_click=switch_chat_name, | |
| key=chat_name, | |
| args=(chat_name,), | |
| type='primary', | |
| use_container_width=True, | |
| ) | |
| column2.button(label='π', key='edit', on_click=edit_chat, args=(chat_name, zone)) | |
| column3.button(label='ποΈ', key='remove', on_click=delete_chat, args=(chat_name,)) | |
| else: | |
| zone.button( | |
| label='π¬ ' + st.session_state['chats'][chat_name]['display_name'], | |
| on_click=switch_chat_name, | |
| key=chat_name, | |
| args=(chat_name,), | |
| use_container_width=True, | |
| ) | |
| if new_chat_button: | |
| switch_chat(new_chat_name) | |
| def render_sidebar_gpt_config_tab(zone): | |
| st.session_state["params"] = dict() | |
| st.session_state["params"]["model"] = zone.selectbox( | |
| "Please select a model", | |
| ["gpt-3.5-turbo"], # , "gpt-4" | |
| help="ID of the model to use", | |
| ) | |
| st.session_state["params"]["temperature"] = zone.slider( | |
| "Temperature", | |
| min_value=0.0, | |
| max_value=2.0, | |
| value=1.2, | |
| step=0.1, | |
| format="%0.2f", | |
| help="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.", | |
| ) | |
| st.session_state["params"]["max_tokens"] = zone.slider( | |
| "Max Tokens", | |
| value=2000, | |
| step=1, | |
| min_value=100, | |
| max_value=2000, | |
| help="The maximum number of tokens to generate in the completion", | |
| ) | |
| st.session_state["params"]["stream"] = zone.checkbox( | |
| "Streaming output", | |
| value=True, | |
| help="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message", | |
| ) | |
| zone.caption('Looking for help at https://platform.openai.com/docs/api-reference/chat') | |
| def render_sidebar_prompt_config_tab(zone): | |
| prompt_text = zone.empty() | |
| st.session_state["params"]["prompt"] = prompt_text.text_area( | |
| "System Prompt", | |
| "You are a helpful assistant that translates answer from English to Chinese.", | |
| help="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.", | |
| ) | |
| template = zone.selectbox('Loading From Prompt Template', load_prompt_templates()) | |
| if template: | |
| prompts_df = load_prompts(template) | |
| actor = zone.selectbox('Loading Prompts', prompts_df.index.tolist()) | |
| if actor: | |
| st.session_state["params"]["prompt"] = prompt_text.text_area( | |
| "System Prompt", | |
| prompts_df.loc[actor].prompt, | |
| help="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.", | |
| ) | |
| def render_download_zone(zone): | |
| from io import BytesIO, StringIO | |
| if not st.session_state.get('current_chat'): | |
| return | |
| chat = st.session_state['chats'][st.session_state['current_chat']] | |
| col1, col2 = zone.columns([1, 1]) | |
| chat_messages = ['# ' + chat['display_name']] | |
| if chat["question"]: | |
| for i in range(len(chat["question"])): | |
| chat_messages.append(f"""π **YOU:** {chat["question"][i]}""") | |
| if i < len(chat["answer"]): | |
| chat_messages.append(f"""π€ **AI:** {chat["answer"][i]}""") | |
| col1.download_button('π€ Markdown', '\n'.join(chat_messages).encode('utf-8'), file_name=f"{chat['display_name']}.md", help="Download messages to a markdown file", use_container_width=True) | |
| tables = [] | |
| for answer in chat["answer"]: | |
| filter_table_str = '\n'.join([m.strip() for m in answer.split('\n') if m.strip().startswith('| ') or m == '']) | |
| try: | |
| tables.extend([pd.read_table(StringIO(filter_table_str.replace(' ', '')), sep='|').dropna(axis=1, how='all').iloc[1:] for m in filter_table_str.split('\n\n')]) | |
| except Exception as e: | |
| print(e) | |
| if tables: | |
| buffer = BytesIO() | |
| with pd.ExcelWriter(buffer) as writer: | |
| for index, table in enumerate(tables): | |
| table.to_excel(writer, sheet_name=str(index + 1), index=False) | |
| col2.download_button('π Excel', buffer.getvalue(), file_name=f"{chat['display_name']}.xlsx", help="Download tables to a excel file", use_container_width=True) | |
| def render_sidebar(): | |
| st.sidebar.title("ChatGPT") | |
| chat_name_container = st.sidebar.container() | |
| chat_config_expander = st.sidebar.expander('βοΈ Chat configuration', True) | |
| tab_gpt, tab_prompt = chat_config_expander.tabs(['π ChatGPT', 'π₯ Prompt']) | |
| download_zone = st.sidebar.empty() | |
| github_zone = st.sidebar.empty() | |
| render_sidebar_gpt_config_tab(tab_gpt) | |
| render_sidebar_prompt_config_tab(tab_prompt) | |
| render_sidebar_chat_management(chat_name_container) | |
| render_download_zone(download_zone) | |
| render_github_info(github_zone) | |
| def render_user_message(message, zone): | |
| col1, col2 = zone.columns([1,8]) | |
| col1.markdown("π **YOU:**") | |
| col2.markdown(message) | |
| def render_ai_message(message, zone): | |
| col1, col2 = zone.columns([1,8]) | |
| col1.markdown("π€ **AI:**") | |
| col2.markdown(message) | |
| def render_history_answer(chat, zone): | |
| zone.empty() | |
| time.sleep(0.1) # https://github.com/streamlit/streamlit/issues/5044 | |
| with zone.container(): | |
| if chat['messages']: | |
| st.caption(f"""βΉοΈ Prompt: {chat["messages"][0]['content']}""") | |
| if chat["question"]: | |
| for i in range(len(chat["question"])): | |
| render_user_message(chat["question"][i], st) | |
| if i < len(chat["answer"]): | |
| render_ai_message(chat["answer"][i], st) | |
| def render_last_answer(question, chat, zone): | |
| answer_zone = zone.empty() | |
| chat["messages"].append({"role": "user", "content": question}) | |
| chat["question"].append(question) | |
| if st.session_state["params"]["stream"]: | |
| answer = "" | |
| chat["answer"].append(answer) | |
| chat["messages"].append({"role": "assistant", "content": answer}) | |
| for response in get_openai_response(chat["messages"]): | |
| answer += response["choices"][0]['delta'].get("content", '') | |
| chat["answer"][-1] = answer | |
| chat["messages"][-1] = {"role": "assistant", "content": answer} | |
| render_ai_message(answer, answer_zone) | |
| else: | |
| with st.spinner("Wait for responding..."): | |
| answer = get_openai_response(chat["messages"])["choices"][0]["message"]["content"] | |
| chat["answer"].append(answer) | |
| chat["messages"].append({"role": "assistant", "content": answer}) | |
| render_ai_message(answer, zone) | |
| def render_stop_generate_button(zone): | |
| def stop(): | |
| st.session_state['regenerate'] = False | |
| zone.columns((2, 1, 2))[1].button('β‘ Stop', on_click=stop) | |
| def render_regenerate_button(chat, zone): | |
| def regenerate(): | |
| chat["messages"].pop(-1) | |
| chat["messages"].pop(-1) | |
| chat["answer"].pop(-1) | |
| st.session_state['regenerate'] = True | |
| st.session_state['last_question'] = chat["question"].pop(-1) | |
| zone.columns((2, 1, 2))[1].button('πRegenerate', type='primary', on_click=regenerate) | |
| def render_chat(chat_name): | |
| def handle_ask(): | |
| if st.session_state['input']: | |
| re_generate_zone.empty() | |
| render_user_message(st.session_state['input'], last_question_zone) | |
| render_stop_generate_button(stop_generate_zone) | |
| render_last_answer(st.session_state['input'], chat, last_answer_zone) | |
| st.session_state['input'] = '' | |
| if chat_name not in st.session_state["chats"]: | |
| st.error(f'{chat_name} is not exist') | |
| return | |
| chat = st.session_state["chats"][chat_name] | |
| if chat['is_delete']: | |
| st.error(f"{chat_name} is deleted") | |
| st.stop() | |
| if len(chat['messages']) == 1 and st.session_state["params"]["prompt"]: | |
| chat["messages"][0]['content'] = st.session_state["params"]["prompt"] | |
| conversation_zone = st.container() | |
| history_zone = conversation_zone.empty() | |
| last_question_zone = conversation_zone.empty() | |
| last_answer_zone = conversation_zone.empty() | |
| ask_form_zone = st.empty() | |
| render_history_answer(chat, history_zone) | |
| ask_form = ask_form_zone.form(chat_name) | |
| col1, col2 = ask_form.columns([10, 1]) | |
| col1.text_area("π You: ", | |
| key="input", | |
| max_chars=2000, | |
| label_visibility='collapsed') | |
| with col2.container(): | |
| for _ in range(2): | |
| st.write('\n') | |
| st.form_submit_button("π", on_click=handle_ask) | |
| stop_generate_zone = conversation_zone.empty() | |
| re_generate_zone = conversation_zone.empty() | |
| if st.session_state.get('regenerate'): | |
| render_user_message(st.session_state['last_question'], last_question_zone) | |
| render_stop_generate_button(stop_generate_zone) | |
| render_last_answer(st.session_state['last_question'], chat, last_answer_zone) | |
| st.session_state['regenerate'] = False | |
| if chat["answer"]: | |
| stop_generate_zone.empty() | |
| render_regenerate_button(chat, re_generate_zone) | |
| # render_footer() | |
| def render_footer(): | |
| st.markdown( | |
| "<br><hr><center>Made with β€οΈ by ChatGPT and StreamLit.</center><hr>", | |
| unsafe_allow_html=True) | |
| st.markdown("<style> footer {visibility: hidden;} </style>", unsafe_allow_html=True) | |
| def render_github_info(zone): | |
| with zone.container(): | |
| for i in range(1): | |
| st.write("\n") | |
| st.markdown('<a href="https://github.com/haiichuan/chatgpt-streamlit" target="_blank" rel="chatgpt-streamlit">' | |
| '<img src="https://badgen.net/badge/icon/GitHub?icon=github&label=chatgpt-streamlit" alt="GitHub">' | |
| '</a>', unsafe_allow_html=True) | |
| def get_openai_response(messages): | |
| if st.session_state["params"]["model"] in {'gpt-3.5-turbo', 'gpt4'}: | |
| response = openai.ChatCompletion.create( | |
| model=st.session_state["params"]["model"], | |
| temperature=st.session_state["params"]["temperature"], | |
| messages=messages, | |
| stream=st.session_state["params"]["stream"], | |
| max_tokens=st.session_state["params"]["max_tokens"], | |
| ) | |
| else: | |
| raise NotImplementedError('Not implemented yet!') | |
| return response | |
| def load_prompt_templates(): | |
| path = Path(__file__).parent / "templates" | |
| return [f.name for f in path.glob("*.json")] | |
| def load_prompts(template_name): | |
| if template_name: | |
| path = Path(__file__).parent / "templates" / template_name | |
| return pd.read_json(path).drop_duplicates(subset='act').set_index('act') # act, prompt | |
| if __name__ == "__main__": | |
| print("---- page reloading ----") | |
| init_openai_settings() | |
| init_session() | |
| render_sidebar() | |
| if st.session_state.get("current_chat"): | |
| render_chat(st.session_state["current_chat"]) | |
| if len(st.session_state["chats"]) == 0: | |
| switch_chat(new_chat(f"Chat{len(st.session_state['chats'])}")) | |