Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| from os import environ | |
| from time import sleep | |
| import datetime | |
| import streamlit as st | |
| from lib.sessions import SessionManager | |
| from langchain.schema import HumanMessage, FunctionMessage | |
| from helper import ( | |
| build_agents, | |
| MYSCALE_HOST, | |
| MYSCALE_PASSWORD, | |
| MYSCALE_PORT, | |
| MYSCALE_USER, | |
| DEFAULT_SYSTEM_PROMPT, | |
| ) | |
| from login import back_to_main | |
| environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"] | |
| TOOL_NAMES = { | |
| "langchain_retriever_tool": "Self-querying retriever", | |
| "vecsql_retriever_tool": "Vector SQL", | |
| } | |
| def on_chat_submit(): | |
| ret = st.session_state.agent({"input": st.session_state.chat_input}) | |
| print(ret) | |
| def clear_history(): | |
| if "agent" in st.session_state: | |
| st.session_state.agent.memory.clear() | |
| def back_to_main(): | |
| if "user_info" in st.session_state: | |
| del st.session_state.user_info | |
| if "user_name" in st.session_state: | |
| del st.session_state.user_name | |
| if "jump_query_ask" in st.session_state: | |
| del st.session_state.jump_query_ask | |
| def on_session_change_submit(): | |
| if "session_manager" in st.session_state and "session_editor" in st.session_state: | |
| print(st.session_state.session_editor) | |
| try: | |
| for elem in st.session_state.session_editor["added_rows"]: | |
| if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem: | |
| if elem["session_id"] != "" and "?" not in elem["session_id"]: | |
| st.session_state.session_manager.add_session( | |
| user_id=st.session_state.user_name, | |
| session_id=f"{st.session_state.user_name}?{elem['session_id']}", | |
| system_prompt=elem["system_prompt"], | |
| ) | |
| else: | |
| raise KeyError( | |
| "`session_id` should NOT be neither empty nor contain question marks." | |
| ) | |
| else: | |
| raise KeyError( | |
| "You should fill both `session_id` and `system_prompt` to add a column!" | |
| ) | |
| for elem in st.session_state.session_editor["deleted_rows"]: | |
| st.session_state.session_manager.remove_session( | |
| session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}", | |
| ) | |
| refresh_sessions() | |
| except Exception as e: | |
| sleep(2) | |
| st.error(f"{type(e)}: {str(e)}") | |
| finally: | |
| st.session_state.session_editor["added_rows"] = [] | |
| st.session_state.session_editor["deleted_rows"] = [] | |
| refresh_agent() | |
| def build_session_manager(): | |
| return SessionManager( | |
| host=MYSCALE_HOST, | |
| port=MYSCALE_PORT, | |
| username=MYSCALE_USER, | |
| password=MYSCALE_PASSWORD, | |
| ) | |
| def refresh_sessions(): | |
| st.session_state[ | |
| "current_sessions" | |
| ] = st.session_state.session_manager.list_sessions(st.session_state.user_name) | |
| if type(st.session_state.current_sessions) is not dict and len(st.session_state.current_sessions) <= 0: | |
| st.session_state.session_manager.add_session( | |
| st.session_state.user_name, | |
| f"{st.session_state.user_name}?default", | |
| DEFAULT_SYSTEM_PROMPT, | |
| ) | |
| st.session_state[ | |
| "current_sessions" | |
| ] = st.session_state.session_manager.list_sessions(st.session_state.user_name) | |
| try: | |
| dfl_indx = [ | |
| x["session_id"] for x in st.session_state.current_sessions | |
| ].index("default") | |
| except ValueError: | |
| dfl_indx = 0 | |
| st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx] | |
| def refresh_agent(): | |
| with st.spinner("Initializing session..."): | |
| print( | |
| f"??? Changed to ", | |
| f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", | |
| ) | |
| st.session_state["agent"] = build_agents( | |
| f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", | |
| ["LangChain Self Query Retriever For Wikipedia"] | |
| if "selected_tools" not in st.session_state | |
| else st.session_state.selected_tools, | |
| system_prompt=DEFAULT_SYSTEM_PROMPT | |
| if "sel_sess" not in st.session_state | |
| else st.session_state.sel_sess["system_prompt"], | |
| ) | |
| st.session_state["session_manager"] = build_session_manager() | |
| def chat_page(): | |
| if "sel_sess" not in st.session_state: | |
| st.session_state["sel_sess"] = { | |
| "session_id": "default", | |
| "system_prompt": DEFAULT_SYSTEM_PROMPT, | |
| } | |
| st.session_state["session_manager"] = build_session_manager() | |
| with st.sidebar: | |
| with st.expander("Session Management"): | |
| refresh_sessions() | |
| st.data_editor( | |
| st.session_state.current_sessions, | |
| num_rows="dynamic", | |
| key="session_editor", | |
| use_container_width=True, | |
| ) | |
| st.button("Submit Change!", on_click=on_session_change_submit) | |
| with st.expander("Session Selection", expanded=True): | |
| try: | |
| dfl_indx = [ | |
| x["session_id"] for x in st.session_state.current_sessions | |
| ].index("default") | |
| except Exception as e: | |
| print("*** ", str(e)) | |
| dfl_indx = 0 | |
| st.selectbox( | |
| "Choose a session be chat:", | |
| options=st.session_state.current_sessions, | |
| index=dfl_indx, | |
| key="sel_sess", | |
| format_func=lambda x: x["session_id"], | |
| on_change=refresh_agent, | |
| ) | |
| print(st.session_state.sel_sess) | |
| with st.expander("Tool Settings", expanded=True): | |
| st.multiselect( | |
| "Knowledge Base", | |
| st.session_state.tools.keys(), | |
| default=["LangChain Self Query Retriever For Wikipedia"], | |
| key="selected_tools", | |
| on_change=refresh_agent, | |
| ) | |
| st.button("Clear Chat History", on_click=clear_history) | |
| st.button("Logout", on_click=back_to_main) | |
| if 'agent' not in st.session_state: | |
| refresh_agent() | |
| print("!!! ", st.session_state.agent.memory.chat_memory.session_id) | |
| for msg in st.session_state.agent.memory.chat_memory.messages: | |
| speaker = "user" if isinstance(msg, HumanMessage) else "assistant" | |
| if isinstance(msg, FunctionMessage): | |
| with st.chat_message("Knowledge Base", avatar="π"): | |
| st.write( | |
| f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" | |
| ) | |
| st.write("Retrieved from knowledge base:") | |
| try: | |
| st.dataframe( | |
| pd.DataFrame.from_records(map(dict, eval(msg.content))) | |
| ) | |
| except: | |
| st.write(msg.content) | |
| else: | |
| if len(msg.content) > 0: | |
| with st.chat_message(speaker): | |
| print(type(msg), msg.dict()) | |
| st.write( | |
| f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" | |
| ) | |
| st.write(f"{msg.content}") | |
| st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input") | |