File size: 5,955 Bytes
1c0c480
64235d1
676f4cd
64235d1
 
 
 
 
a0d91e2
1acd5e4
64235d1
 
09e2bc4
938a3f9
8ba5d9d
f418e56
09e2bc4
f418e56
8ba5d9d
16b2ff9
09e2bc4
64235d1
6c01c87
676f4cd
6c01c87
 
a0d91e2
64235d1
8ba5d9d
a0d91e2
64235d1
1acd5e4
8ba5d9d
 
64235d1
a0d91e2
6c01c87
1acd5e4
e087162
 
 
ee5c129
 
 
 
e087162
ee5c129
 
 
 
 
 
 
 
6a39607
6c01c87
1acd5e4
 
 
1c0c480
 
1acd5e4
a0d91e2
 
e087162
 
 
 
 
8ba5d9d
1acd5e4
4f11b82
8ba5d9d
a0d91e2
8ba5d9d
64235d1
8ba5d9d
a0d91e2
 
8ba5d9d
1acd5e4
a0d91e2
1acd5e4
 
64235d1
8ba5d9d
a0d91e2
 
 
8ba5d9d
 
a0d91e2
8ba5d9d
a0d91e2
 
 
 
 
1acd5e4
a0d91e2
64235d1
a0d91e2
8ba5d9d
1acd5e4
a0d91e2
8ba5d9d
64235d1
 
 
 
 
 
 
0391cfb
1c0c480
ee5c129
1c0c480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58f343a
1c0c480
 
 
 
 
 
 
 
 
 
 
64235d1
 
1c0c480
 
64235d1
 
6c01c87
1c0c480
64235d1
 
 
 
6c01c87
a0d91e2
64235d1
86fa3b8
6c01c87
 
 
 
 
 
a0d91e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import os.path
import json
from typing import Tuple, Any
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from langchain_mcp_adapters.tools import load_mcp_tools
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage, SystemMessage
from langchain_community.chat_message_histories import FileChatMessageHistory
from langchain.chat_models import init_chat_model
import logging
from langchain.globals import set_debug
from langchain_community.chat_message_histories import ChatMessageHistory
from memory_store import MemoryStore
from dotenv import load_dotenv

load_dotenv()

# set_debug(True)


# Set up logging
logger = logging.getLogger(__name__)


async def lc_mcp_exec(request: str, history=None) -> Tuple[str, list]:
    """
    Execute the PostgreSQL MCP pipeline with in-memory chat history.
    Returns the response and the updated message history.
    """
    try:
        # Get the singleton memory store instance
        message_history = MemoryStore.get_memory()

        # Load table summary and server parameters
        table_summary = load_table_summary(os.environ["TABLE_SUMMARY_PATH"])
        server_params = get_server_params()
        
        OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
        if OPENAI_API_KEY:
            # Initialize the LLM for OpenAI
            llm = init_chat_model(
                model_provider=os.environ["OPENAI_MODEL_PROVIDER"],
                model=os.environ["OPENAI_MODEL"],
                api_key=OPENAI_API_KEY
            )
        else:
            # Initialize the LLM for Gemini
            llm = init_chat_model(
                model_provider=os.environ["GEMINI_MODEL_PROVIDER"],
                model=os.environ["GEMINI_MODEL"],
                api_key=os.environ["GEMINI_API_KEY"]
            )

        # Initialize the MCP client
        async with stdio_client(server_params) as (read, write):
            async with ClientSession(read, write) as session:
                await session.initialize()
                
                # Load tools and create the agent
                tools = await load_and_enrich_tools(session)
                agent = create_react_agent(llm, tools)

                # clear the memory
                if request == "/clear-cache":
                    message_history.clear()
                    return "Memory cleared", []
                
                # Add new user message to memory
                message_history.add_user_message(request)

                # Get system prompt and create system message
                system_prompt = await build_prompt(session, tools, table_summary)
                system_message = SystemMessage(content=system_prompt)

                # Combine system message with chat history
                input_messages = [system_message] + message_history.messages

                # Invoke agent
                agent_response = await agent.ainvoke(
                    {"messages": input_messages},
                    config={"configurable": {"thread_id": "conversation_123"}}
                )

                # Process agent response
                response_content = "No response generated"
                if "messages" in agent_response and agent_response["messages"]:
                    new_messages = agent_response["messages"][len(input_messages):]
                    
                    # Save new messages to memory
                    for msg in new_messages:
                        if isinstance(msg, (AIMessage, ToolMessage)):
                            message_history.add_message(msg)
                        else:
                            logger.debug(f"Skipping unexpected message type: {type(msg)}")

                    response_content = agent_response["messages"][-1].content
                else:
                    message_history.add_ai_message(response_content)

                return response_content, message_history.messages

    except Exception as e:
        logger.error(f"Error in execution: {str(e)}", exc_info=True)
        return f"Error: {str(e)}", []

# ---------------- Helper Functions ---------------- #

def load_table_summary(path: str) -> str:
    with open(path, 'r') as file:
        return file.read()



def get_server_params() -> StdioServerParameters: 
    
    # Prepare the environment dictionary to pass to the subprocess
    subprocess_env = {}

    # List of environment variables that the postgre_mcp_server.py needs
    required_vars_for_server = [
        # "TABLE_SUMMARY_PATH",
        "DB_URL",
        "DB_SCHEMA",
        "PANDAS_KEY",
        "PANDAS_EXPORTS_PATH",
        # "GEMINI_API_KEY",
        # "GEMINI_MODEL",
        # "GEMINI_MODEL_PROVIDER",
        # "OPENAI_MODEL_PROVIDER",
        # "OPENAI_MODEL",
        "OPENAI_API_KEY",
    ]

    for var_name in required_vars_for_server:
        value = os.getenv(var_name) 
        if value is not None:
            subprocess_env[var_name] = value
        else:
            logger.warning(f"Environment variable {var_name} not found for passing to MCP server subprocess.")

    logger.info(f"Passing environment to MCP server subprocess: {subprocess_env.keys()}")

    return StdioServerParameters(
        command="python",
        args=[os.environ["MCP_SERVER_PATH"]], # MCP_SERVER_PATH itself must be available to this client
        env=subprocess_env 
    )



async def load_and_enrich_tools(session: ClientSession):
    tools = await load_mcp_tools(session)
    return tools


async def build_prompt(session, tools, table_summary):
    conversation_prompt = await session.read_resource("resource://base_prompt")

    template = conversation_prompt.contents[0].text
    tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])

    return template.format(
        tools=tools_str,
        descriptions=table_summary,
    )