# main.py import logging import uuid from typing import List import pandas as pd from langchain_community.chat_message_histories import ChatMessageHistory from langgraph.checkpoint.memory import MemorySaver from langchain_core.messages import HumanMessage, AIMessage from langchain_core.runnables.history import RunnableWithMessageHistory from src.agents import ( create_supervisor_agent, create_search_agent, create_visualization_agent, create_pandas_agent, create_hard_coded_visualization_agent, create_oceanographer_agent, ) from src.search.dataset_utils import fetch_dataset, convert_df_to_csv from src.memory import CustomMemorySaver def initialize_session_state(session_state: dict): session_state_defaults = { "messages_search": [], "messages_data_agent": [], "datasets_cache": {}, "datasets_info": None, "active_datasets": [], "selected_datasets": set(), "show_dataset": True, "current_page": "search", "dataset_dfs": {}, "dataset_names": {}, "saved_plot_paths": {}, "memory": MemorySaver(), "visualization_agent_used": False, "chat_history": ChatMessageHistory(session_id="search-agent-session"), "search_method": "PANGAEA Search (default)", "selected_text": "", "new_plot_generated": False, "execution_history": [] } for key, value in session_state_defaults.items(): if key not in session_state: session_state[key] = value def get_search_agent(datasets_info, model_name, api_key): return create_search_agent(datasets_info=datasets_info) def process_search_query(user_input: str, search_agent, session_data: dict): session_data["chat_history"] = ChatMessageHistory(session_id="search-agent-session") for message in session_data["messages_search"]: if message["role"] == "user": session_data["chat_history"].add_user_message(message["content"]) elif message["role"] == "assistant": session_data["chat_history"].add_ai_message(message["content"]) def get_truncated_chat_history(session_id): truncated_messages = session_data["chat_history"].messages[-10:] truncated_history = ChatMessageHistory(session_id=session_id) for msg in truncated_messages: if isinstance(msg, HumanMessage): truncated_history.add_user_message(msg.content) elif isinstance(msg, AIMessage): truncated_history.add_ai_message(msg.content) else: truncated_history.add_message(msg) return truncated_history search_agent_with_memory = RunnableWithMessageHistory( search_agent, get_truncated_chat_history, input_messages_key="input", history_messages_key="chat_history", ) response = search_agent_with_memory.invoke( {"input": user_input}, {"configurable": {"session_id": "search-agent-session"}}, ) ai_message = response["output"] return ai_message def add_user_message_to_search(user_input: str, session_data: dict): session_data["messages_search"].append({"role": "user", "content": user_input}) def add_assistant_message_to_search(content: str, session_data: dict): session_data["messages_search"].append({"role": "assistant", "content": content}) def load_selected_datasets_into_cache(selected_datasets, session_data: dict): for doi in selected_datasets: if doi not in session_data["datasets_cache"]: dataset, name = fetch_dataset(doi) if dataset is not None: session_data["datasets_cache"][doi] = (dataset, name) session_data["dataset_dfs"][doi] = dataset session_data["dataset_names"][doi] = name def set_active_datasets_from_selection(session_data: dict): session_data["active_datasets"] = list(session_data["selected_datasets"]) def get_datasets_info_for_active_datasets(session_data: dict): if session_data["datasets_info"] is None: return [] datasets_info = [] for doi in session_data["active_datasets"]: dataset, name = session_data["datasets_cache"].get(doi, (None, None)) if dataset is not None: description_row = session_data["datasets_info"].loc[ session_data["datasets_info"]["DOI"] == doi, "Short Description" ] description = description_row.values[0] if len(description_row) > 0 else "No description" df_head = dataset.head().to_string() datasets_info.append({ 'doi': doi, 'name': name, 'description': description, 'df_head': df_head, 'dataset': dataset }) return datasets_info def create_and_invoke_supervisor_agent(user_query: str, datasets_info: list, memory, session_data: dict): graph = create_supervisor_agent(user_query, datasets_info, memory) if graph is None: return None messages = [] for message in session_data["messages_data_agent"]: if message["role"] == "user": messages.append(HumanMessage(content=message["content"], name="User")) elif message["role"] == "assistant": messages.append(AIMessage(content=message["content"], name="Assistant")) else: messages.append(AIMessage(content=message["content"], name=message["role"])) limited_messages = messages[-7:] initial_state = { "messages": limited_messages, "next": "supervisor", "agent_scratchpad": [], "input": user_query, "plot_images": [], "last_agent_message": "" } config = {"configurable": {"thread_id": session_data.get('thread_id', str(uuid.uuid4())), "recursion_limit": 5}} response = graph.invoke(initial_state, config=config) return response def add_user_message_to_data_agent(user_input: str, session_data: dict): session_data["messages_data_agent"].append({"role": "user", "content": f"{user_input}"}) def add_assistant_message_to_data_agent(content: str, plot_images, visualization_agent_used, session_data: dict): new_message = { "role": "assistant", "content": content, "plot_images": plot_images if plot_images else [], "visualization_agent_used": visualization_agent_used } session_data["messages_data_agent"].append(new_message) def convert_dataset_to_csv(dataset: pd.DataFrame) -> bytes: return convert_df_to_csv(dataset) def has_new_plot(session_data: dict) -> bool: return session_data.get("new_plot_generated", False) def reset_new_plot_flag(session_data: dict): session_data["new_plot_generated"] = False def get_dataset_csv_name(doi: str) -> str: return f"dataset_{doi.split('/')[-1]}.csv" def set_current_page(session_data: dict, page_name: str): session_data["current_page"] = page_name def set_selected_text(session_data: dict, text: str): session_data["selected_text"] = text def set_show_dataset(session_data: dict, show: bool): session_data["show_dataset"] = show def set_dataset_for_data_agent(session_data: dict, doi: str, csv_data: bytes, dataset: pd.DataFrame, name: str): session_data["dataset_csv"] = csv_data session_data["dataset_df"] = dataset session_data["dataset_name"] = name session_data["current_page"] = "data_agent" def ensure_memory(session_data: dict): if "memory" not in session_data: session_data["memory"] = CustomMemorySaver() def ensure_thread_id(session_data: dict): if "thread_id" not in session_data: session_data["thread_id"] = str(uuid.uuid4())