pangaeagpt / main.py
dmpantiu's picture
Upload 8 files
27b66c3 unverified
# main.py
import logging
import uuid
from typing import List
import pandas as pd
from langchain_community.chat_message_histories import ChatMessageHistory
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from src.agents import (
create_supervisor_agent,
create_search_agent,
create_visualization_agent,
create_pandas_agent,
create_hard_coded_visualization_agent,
create_oceanographer_agent,
)
from src.search.dataset_utils import fetch_dataset, convert_df_to_csv
from src.memory import CustomMemorySaver
def initialize_session_state(session_state: dict):
session_state_defaults = {
"messages_search": [],
"messages_data_agent": [],
"datasets_cache": {},
"datasets_info": None,
"active_datasets": [],
"selected_datasets": set(),
"show_dataset": True,
"current_page": "search",
"dataset_dfs": {},
"dataset_names": {},
"saved_plot_paths": {},
"memory": MemorySaver(),
"visualization_agent_used": False,
"chat_history": ChatMessageHistory(session_id="search-agent-session"),
"search_method": "PANGAEA Search (default)",
"selected_text": "",
"new_plot_generated": False,
"execution_history": []
}
for key, value in session_state_defaults.items():
if key not in session_state:
session_state[key] = value
def get_search_agent(datasets_info, model_name, api_key):
return create_search_agent(datasets_info=datasets_info)
def process_search_query(user_input: str, search_agent, session_data: dict):
session_data["chat_history"] = ChatMessageHistory(session_id="search-agent-session")
for message in session_data["messages_search"]:
if message["role"] == "user":
session_data["chat_history"].add_user_message(message["content"])
elif message["role"] == "assistant":
session_data["chat_history"].add_ai_message(message["content"])
def get_truncated_chat_history(session_id):
truncated_messages = session_data["chat_history"].messages[-10:]
truncated_history = ChatMessageHistory(session_id=session_id)
for msg in truncated_messages:
if isinstance(msg, HumanMessage):
truncated_history.add_user_message(msg.content)
elif isinstance(msg, AIMessage):
truncated_history.add_ai_message(msg.content)
else:
truncated_history.add_message(msg)
return truncated_history
search_agent_with_memory = RunnableWithMessageHistory(
search_agent,
get_truncated_chat_history,
input_messages_key="input",
history_messages_key="chat_history",
)
response = search_agent_with_memory.invoke(
{"input": user_input},
{"configurable": {"session_id": "search-agent-session"}},
)
ai_message = response["output"]
return ai_message
def add_user_message_to_search(user_input: str, session_data: dict):
session_data["messages_search"].append({"role": "user", "content": user_input})
def add_assistant_message_to_search(content: str, session_data: dict):
session_data["messages_search"].append({"role": "assistant", "content": content})
def load_selected_datasets_into_cache(selected_datasets, session_data: dict):
for doi in selected_datasets:
if doi not in session_data["datasets_cache"]:
dataset, name = fetch_dataset(doi)
if dataset is not None:
session_data["datasets_cache"][doi] = (dataset, name)
session_data["dataset_dfs"][doi] = dataset
session_data["dataset_names"][doi] = name
def set_active_datasets_from_selection(session_data: dict):
session_data["active_datasets"] = list(session_data["selected_datasets"])
def get_datasets_info_for_active_datasets(session_data: dict):
if session_data["datasets_info"] is None:
return []
datasets_info = []
for doi in session_data["active_datasets"]:
dataset, name = session_data["datasets_cache"].get(doi, (None, None))
if dataset is not None:
description_row = session_data["datasets_info"].loc[
session_data["datasets_info"]["DOI"] == doi, "Short Description"
]
description = description_row.values[0] if len(description_row) > 0 else "No description"
df_head = dataset.head().to_string()
datasets_info.append({
'doi': doi,
'name': name,
'description': description,
'df_head': df_head,
'dataset': dataset
})
return datasets_info
def create_and_invoke_supervisor_agent(user_query: str, datasets_info: list, memory, session_data: dict):
graph = create_supervisor_agent(user_query, datasets_info, memory)
if graph is None:
return None
messages = []
for message in session_data["messages_data_agent"]:
if message["role"] == "user":
messages.append(HumanMessage(content=message["content"], name="User"))
elif message["role"] == "assistant":
messages.append(AIMessage(content=message["content"], name="Assistant"))
else:
messages.append(AIMessage(content=message["content"], name=message["role"]))
limited_messages = messages[-7:]
initial_state = {
"messages": limited_messages,
"next": "supervisor",
"agent_scratchpad": [],
"input": user_query,
"plot_images": [],
"last_agent_message": ""
}
config = {"configurable": {"thread_id": session_data.get('thread_id', str(uuid.uuid4())), "recursion_limit": 5}}
response = graph.invoke(initial_state, config=config)
return response
def add_user_message_to_data_agent(user_input: str, session_data: dict):
session_data["messages_data_agent"].append({"role": "user", "content": f"{user_input}"})
def add_assistant_message_to_data_agent(content: str, plot_images, visualization_agent_used, session_data: dict):
new_message = {
"role": "assistant",
"content": content,
"plot_images": plot_images if plot_images else [],
"visualization_agent_used": visualization_agent_used
}
session_data["messages_data_agent"].append(new_message)
def convert_dataset_to_csv(dataset: pd.DataFrame) -> bytes:
return convert_df_to_csv(dataset)
def has_new_plot(session_data: dict) -> bool:
return session_data.get("new_plot_generated", False)
def reset_new_plot_flag(session_data: dict):
session_data["new_plot_generated"] = False
def get_dataset_csv_name(doi: str) -> str:
return f"dataset_{doi.split('/')[-1]}.csv"
def set_current_page(session_data: dict, page_name: str):
session_data["current_page"] = page_name
def set_selected_text(session_data: dict, text: str):
session_data["selected_text"] = text
def set_show_dataset(session_data: dict, show: bool):
session_data["show_dataset"] = show
def set_dataset_for_data_agent(session_data: dict, doi: str, csv_data: bytes, dataset: pd.DataFrame, name: str):
session_data["dataset_csv"] = csv_data
session_data["dataset_df"] = dataset
session_data["dataset_name"] = name
session_data["current_page"] = "data_agent"
def ensure_memory(session_data: dict):
if "memory" not in session_data:
session_data["memory"] = CustomMemorySaver()
def ensure_thread_id(session_data: dict):
if "thread_id" not in session_data:
session_data["thread_id"] = str(uuid.uuid4())