talk2data / langchain_mcp_client.py
amirkiarafiei's picture
feat: remove redundant code and improve visualization tool
16b2ff9
import os
import os.path
import json
from typing import Tuple, Any
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from langchain_mcp_adapters.tools import load_mcp_tools
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage, SystemMessage
from langchain_community.chat_message_histories import FileChatMessageHistory
from langchain.chat_models import init_chat_model
import logging
from langchain.globals import set_debug
from langchain_community.chat_message_histories import ChatMessageHistory
from memory_store import MemoryStore
from dotenv import load_dotenv
load_dotenv()
# set_debug(True)
# Set up logging
logger = logging.getLogger(__name__)
async def lc_mcp_exec(request: str, history=None) -> Tuple[str, list]:
"""
Execute the PostgreSQL MCP pipeline with in-memory chat history.
Returns the response and the updated message history.
"""
try:
# Get the singleton memory store instance
message_history = MemoryStore.get_memory()
# Load table summary and server parameters
table_summary = load_table_summary(os.environ["TABLE_SUMMARY_PATH"])
server_params = get_server_params()
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if OPENAI_API_KEY:
# Initialize the LLM for OpenAI
llm = init_chat_model(
model_provider=os.environ["OPENAI_MODEL_PROVIDER"],
model=os.environ["OPENAI_MODEL"],
api_key=OPENAI_API_KEY
)
else:
# Initialize the LLM for Gemini
llm = init_chat_model(
model_provider=os.environ["GEMINI_MODEL_PROVIDER"],
model=os.environ["GEMINI_MODEL"],
api_key=os.environ["GEMINI_API_KEY"]
)
# Initialize the MCP client
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
# Load tools and create the agent
tools = await load_and_enrich_tools(session)
agent = create_react_agent(llm, tools)
# clear the memory
if request == "/clear-cache":
message_history.clear()
return "Memory cleared", []
# Add new user message to memory
message_history.add_user_message(request)
# Get system prompt and create system message
system_prompt = await build_prompt(session, tools, table_summary)
system_message = SystemMessage(content=system_prompt)
# Combine system message with chat history
input_messages = [system_message] + message_history.messages
# Invoke agent
agent_response = await agent.ainvoke(
{"messages": input_messages},
config={"configurable": {"thread_id": "conversation_123"}}
)
# Process agent response
response_content = "No response generated"
if "messages" in agent_response and agent_response["messages"]:
new_messages = agent_response["messages"][len(input_messages):]
# Save new messages to memory
for msg in new_messages:
if isinstance(msg, (AIMessage, ToolMessage)):
message_history.add_message(msg)
else:
logger.debug(f"Skipping unexpected message type: {type(msg)}")
response_content = agent_response["messages"][-1].content
else:
message_history.add_ai_message(response_content)
return response_content, message_history.messages
except Exception as e:
logger.error(f"Error in execution: {str(e)}", exc_info=True)
return f"Error: {str(e)}", []
# ---------------- Helper Functions ---------------- #
def load_table_summary(path: str) -> str:
with open(path, 'r') as file:
return file.read()
def get_server_params() -> StdioServerParameters:
# Prepare the environment dictionary to pass to the subprocess
subprocess_env = {}
# List of environment variables that the postgre_mcp_server.py needs
required_vars_for_server = [
# "TABLE_SUMMARY_PATH",
"DB_URL",
"DB_SCHEMA",
"PANDAS_KEY",
"PANDAS_EXPORTS_PATH",
# "GEMINI_API_KEY",
# "GEMINI_MODEL",
# "GEMINI_MODEL_PROVIDER",
# "OPENAI_MODEL_PROVIDER",
# "OPENAI_MODEL",
"OPENAI_API_KEY",
]
for var_name in required_vars_for_server:
value = os.getenv(var_name)
if value is not None:
subprocess_env[var_name] = value
else:
logger.warning(f"Environment variable {var_name} not found for passing to MCP server subprocess.")
logger.info(f"Passing environment to MCP server subprocess: {subprocess_env.keys()}")
return StdioServerParameters(
command="python",
args=[os.environ["MCP_SERVER_PATH"]], # MCP_SERVER_PATH itself must be available to this client
env=subprocess_env
)
async def load_and_enrich_tools(session: ClientSession):
tools = await load_mcp_tools(session)
return tools
async def build_prompt(session, tools, table_summary):
conversation_prompt = await session.read_resource("resource://base_prompt")
template = conversation_prompt.contents[0].text
tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
return template.format(
tools=tools_str,
descriptions=table_summary,
)