File size: 7,737 Bytes
27b66c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# main.py
import logging
import uuid
from typing import List

import pandas as pd
from langchain_community.chat_message_histories import ChatMessageHistory
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.runnables.history import RunnableWithMessageHistory

from src.agents import (
    create_supervisor_agent,
    create_search_agent,
    create_visualization_agent,
    create_pandas_agent,
    create_hard_coded_visualization_agent,
    create_oceanographer_agent,
)
from src.search.dataset_utils import fetch_dataset, convert_df_to_csv
from src.memory import CustomMemorySaver


def initialize_session_state(session_state: dict):
    session_state_defaults = {
        "messages_search": [],
        "messages_data_agent": [],
        "datasets_cache": {},
        "datasets_info": None,
        "active_datasets": [],
        "selected_datasets": set(),
        "show_dataset": True,
        "current_page": "search",
        "dataset_dfs": {},
        "dataset_names": {},
        "saved_plot_paths": {},
        "memory": MemorySaver(),
        "visualization_agent_used": False,
        "chat_history": ChatMessageHistory(session_id="search-agent-session"),
        "search_method": "PANGAEA Search (default)",
        "selected_text": "",
        "new_plot_generated": False,
        "execution_history": []
    }

    for key, value in session_state_defaults.items():
        if key not in session_state:
            session_state[key] = value


def get_search_agent(datasets_info, model_name, api_key):
    return create_search_agent(datasets_info=datasets_info)


def process_search_query(user_input: str, search_agent, session_data: dict):
    session_data["chat_history"] = ChatMessageHistory(session_id="search-agent-session")
    for message in session_data["messages_search"]:
        if message["role"] == "user":
            session_data["chat_history"].add_user_message(message["content"])
        elif message["role"] == "assistant":
            session_data["chat_history"].add_ai_message(message["content"])

    def get_truncated_chat_history(session_id):
        truncated_messages = session_data["chat_history"].messages[-10:]
        truncated_history = ChatMessageHistory(session_id=session_id)
        for msg in truncated_messages:
            if isinstance(msg, HumanMessage):
                truncated_history.add_user_message(msg.content)
            elif isinstance(msg, AIMessage):
                truncated_history.add_ai_message(msg.content)
            else:
                truncated_history.add_message(msg)
        return truncated_history

    search_agent_with_memory = RunnableWithMessageHistory(
        search_agent,
        get_truncated_chat_history,
        input_messages_key="input",
        history_messages_key="chat_history",
    )

    response = search_agent_with_memory.invoke(
        {"input": user_input},
        {"configurable": {"session_id": "search-agent-session"}},
    )

    ai_message = response["output"]
    return ai_message


def add_user_message_to_search(user_input: str, session_data: dict):
    session_data["messages_search"].append({"role": "user", "content": user_input})


def add_assistant_message_to_search(content: str, session_data: dict):
    session_data["messages_search"].append({"role": "assistant", "content": content})


def load_selected_datasets_into_cache(selected_datasets, session_data: dict):
    for doi in selected_datasets:
        if doi not in session_data["datasets_cache"]:
            dataset, name = fetch_dataset(doi)
            if dataset is not None:
                session_data["datasets_cache"][doi] = (dataset, name)
                session_data["dataset_dfs"][doi] = dataset
                session_data["dataset_names"][doi] = name


def set_active_datasets_from_selection(session_data: dict):
    session_data["active_datasets"] = list(session_data["selected_datasets"])


def get_datasets_info_for_active_datasets(session_data: dict):
    if session_data["datasets_info"] is None:
        return []

    datasets_info = []
    for doi in session_data["active_datasets"]:
        dataset, name = session_data["datasets_cache"].get(doi, (None, None))
        if dataset is not None:
            description_row = session_data["datasets_info"].loc[
                session_data["datasets_info"]["DOI"] == doi, "Short Description"
            ]
            description = description_row.values[0] if len(description_row) > 0 else "No description"
            df_head = dataset.head().to_string()
            datasets_info.append({
                'doi': doi,
                'name': name,
                'description': description,
                'df_head': df_head,
                'dataset': dataset
            })
    return datasets_info


def create_and_invoke_supervisor_agent(user_query: str, datasets_info: list, memory, session_data: dict):
    graph = create_supervisor_agent(user_query, datasets_info, memory)
    if graph is None:
        return None

    messages = []
    for message in session_data["messages_data_agent"]:
        if message["role"] == "user":
            messages.append(HumanMessage(content=message["content"], name="User"))
        elif message["role"] == "assistant":
            messages.append(AIMessage(content=message["content"], name="Assistant"))
        else:
            messages.append(AIMessage(content=message["content"], name=message["role"]))

    limited_messages = messages[-7:]
    initial_state = {
        "messages": limited_messages,
        "next": "supervisor",
        "agent_scratchpad": [],
        "input": user_query,
        "plot_images": [],
        "last_agent_message": ""
    }

    config = {"configurable": {"thread_id": session_data.get('thread_id', str(uuid.uuid4())), "recursion_limit": 5}}

    response = graph.invoke(initial_state, config=config)
    return response


def add_user_message_to_data_agent(user_input: str, session_data: dict):
    session_data["messages_data_agent"].append({"role": "user", "content": f"{user_input}"})


def add_assistant_message_to_data_agent(content: str, plot_images, visualization_agent_used, session_data: dict):
    new_message = {
        "role": "assistant",
        "content": content,
        "plot_images": plot_images if plot_images else [],
        "visualization_agent_used": visualization_agent_used
    }
    session_data["messages_data_agent"].append(new_message)


def convert_dataset_to_csv(dataset: pd.DataFrame) -> bytes:
    return convert_df_to_csv(dataset)


def has_new_plot(session_data: dict) -> bool:
    return session_data.get("new_plot_generated", False)


def reset_new_plot_flag(session_data: dict):
    session_data["new_plot_generated"] = False


def get_dataset_csv_name(doi: str) -> str:
    return f"dataset_{doi.split('/')[-1]}.csv"


def set_current_page(session_data: dict, page_name: str):
    session_data["current_page"] = page_name


def set_selected_text(session_data: dict, text: str):
    session_data["selected_text"] = text


def set_show_dataset(session_data: dict, show: bool):
    session_data["show_dataset"] = show


def set_dataset_for_data_agent(session_data: dict, doi: str, csv_data: bytes, dataset: pd.DataFrame, name: str):
    session_data["dataset_csv"] = csv_data
    session_data["dataset_df"] = dataset
    session_data["dataset_name"] = name
    session_data["current_page"] = "data_agent"


def ensure_memory(session_data: dict):
    if "memory" not in session_data:
        session_data["memory"] = CustomMemorySaver()


def ensure_thread_id(session_data: dict):
    if "thread_id" not in session_data:
        session_data["thread_id"] = str(uuid.uuid4())