| from utils_app import _update_session, supabase_client, _get_session_messages, _add_footnote_description | |
| from supabase_memory import SupabaseChatMessageHistory | |
| from graph import _get_graph | |
| from langchain_core.messages import AIMessage, HumanMessage | |
| import os | |
| import pandas as pd | |
| from prompts import _AGENT_SYSTEM_TEMPLATE, _ANSWERER_SYSTEM_TEMPLATE | |
| async def _run_graph( | |
| session_id:str, | |
| input:str, | |
| agent_model_name:str = "gpt-4o", | |
| agent_temperature:float = 0.0, | |
| answerer_model_name:str = "claude-3-5-sonnet-20240620", | |
| answerer_temperature:float = 0.0, | |
| collection_index:int = 0, | |
| use_doctrines:bool = True, | |
| search_type:str = "similarity", | |
| k:int = 10, | |
| similarity_threshold:float = 0.65, | |
| agent_system_prompt_template:str = _AGENT_SYSTEM_TEMPLATE, | |
| answerer_system_prompt_template:str = _ANSWERER_SYSTEM_TEMPLATE, | |
| ) : | |
| memory = SupabaseChatMessageHistory( | |
| session_id = session_id, | |
| table_name = os.environ["MESSAGES_TABLE_NAME"], | |
| session_name = "chat", | |
| client = supabase_client, | |
| ) | |
| _update_session( | |
| session_id, | |
| metadata = { | |
| "agent_model_name": agent_model_name, | |
| "agent_temperature": agent_temperature, | |
| "answerer_model_name": answerer_model_name, | |
| "answerer_temperature": answerer_temperature, | |
| "collection_index": collection_index, | |
| "use_doctrines": use_doctrines, | |
| "search_type": search_type, | |
| "k": k, | |
| "similarity_threshold": similarity_threshold, | |
| "agent_system_prompt_template": agent_system_prompt_template, | |
| "answerer_system_prompt_template": answerer_system_prompt_template, | |
| } | |
| ) | |
| graph = _get_graph( | |
| agent_model_name = agent_model_name, | |
| agent_system_template = agent_system_prompt_template, | |
| agent_temperature = agent_temperature, | |
| answerer_model_name = answerer_model_name, | |
| answerer_system_template = answerer_system_prompt_template, | |
| answerer_temperature = answerer_temperature, | |
| collection_index = collection_index, | |
| use_doctrines = use_doctrines, | |
| search_type = search_type, | |
| similarity_threshold = similarity_threshold, | |
| k = k, | |
| ) | |
| chat_history = memory.messages | |
| input_message_id = memory.add_message( | |
| message = HumanMessage(input) | |
| ) | |
| output_message_id = memory.add_message( | |
| message = AIMessage(""), | |
| query_id = input_message_id | |
| ) | |
| try: | |
| final_state = await graph.ainvoke( | |
| input = { | |
| "query": input, | |
| "chat_history": chat_history, | |
| } | |
| ) | |
| response_message = final_state["response"]["answer"] | |
| response_message.response_metadata["docs"] = [doc[0].metadata for doc in final_state["response"]["docs"]] | |
| response_message.response_metadata["standalone_question"] = final_state["response"]["standalone_question"] | |
| response_message.content = _add_footnote_description(response_message.content, response_message.response_metadata["docs"]) | |
| memory.update_message( | |
| message = response_message, | |
| message_id = output_message_id | |
| ) | |
| return _get_session_messages(session_id) | |
| except Exception as e: | |
| memory.update_message( | |
| message_id = output_message_id, | |
| error_log = str(e) | |
| ) | |
| return _get_session_messages(session_id) + [(input, f"Oops! An error occurred: {str(e)}")] |