Selcan Yukcu commited on
Commit
6a5afc3
·
2 Parent(s): 0d6f96c 0fe55ab

Merge remote-tracking branch 'origin/main' into selcan_test

Browse files

# Conflicts:
# .env.sample
# gradio_app.py
# langchain_mcp_client.py

.env.sample CHANGED
@@ -1,3 +1,2 @@
1
- API_KEY = ""
2
- DSN =
3
- SCHEMA =
 
1
+ API_KEY=
2
+ MCP_SERVER_PATH=/home/andre/andre/postgre_mcp_server.py
 
.gitignore CHANGED
@@ -1,2 +1,125 @@
1
  .idea
2
- .env
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  .idea
2
+ .env
3
+ .vscode
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ *.cover
46
+ *.py,cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.coveragerc
52
+ *.pylintrc
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # celery
88
+ celerybeat-schedule
89
+ celerybeat.pid
90
+
91
+ # SageMath
92
+ *.sage.py
93
+
94
+ # Environments
95
+ .env/
96
+ .venv/
97
+ env/
98
+ venv/
99
+ ENV/
100
+ env.bak/
101
+ venv.bak/
102
+
103
+ # Spyder project settings
104
+ .spyderproject
105
+ .spyproject
106
+
107
+ # Rope project settings
108
+ .ropeproject
109
+
110
+ # mkdocs documentation
111
+ /site
112
+
113
+ # mypy
114
+ .mypy_cache/
115
+ .dmypy.json
116
+ dmypy.json
117
+
118
+ # Pyre type checker
119
+ .pyre/
120
+
121
+ # pytype static type analyzer
122
+ .pytype/
123
+
124
+ # Cython debug symbols
125
+ cython_debug/
chat_history.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "# Result\nThe tables in the database are:\n* dim\\_agreement\n* dim\\_customer\n* dim\\_product\n* dim\\_product\\_order\\_item\n\n# Explanation\nThe `list_tables` tool was called to retrieve a list of all available tables in the database schema. The result shows the names of these tables.\n\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "what was my previous request ?", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "# Result\nlist all tables\n# Explanation\nThe user's previous request was to list all the tables in the database. This was determined by reviewing the previous turn in the conversation history.\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "count tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\n4\n# Explanation\nThe user requested the number of tables. The previous response listed the tables, and counting them gives the answer.\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "what was my last message ?", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\ncount tables\n# Explanation\nThe user's last message was \"count tables\". This was determined by reviewing the previous turn in the conversation history.\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "what was my very first message in this conversation?", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\nlist all tables\n# Explanation\nThe user's very first message in this conversation was \"list all tables\". This was determined by reviewing the conversation history.\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "total number of sales in 2024", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\nThe tables in the database are:\n* dim\\_agreement\n* dim\\_customer\n* dim\\_product\n* dim\\_product\\_order\\_item\n\n# Explanation\nThe `list_tables` tool was called to retrieve a list of all available tables in the database schema. The result shows the names of these tables.\n\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\nThe tables in the database are:\n* dim_agreement\n* dim_customer\n* dim_product\n* dim_product_order_item\n\n# Explanation\nThe `list_tables` tool was called to retrieve a list of all available tables in the database schema. The result shows the names of these tables.\n\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}]
gradio_app.py CHANGED
@@ -19,13 +19,13 @@ def load_db_configs():
19
  return configs["db_configs"]
20
 
21
  # Async-compatible wrapper
22
- async def run_agent(request):
23
  # configs = load_db_configs()
24
  # final_answer, last_tool_answer, = await pg_mcp_exec(request)
25
  # return final_answer, last_tool_answer
26
 
27
- result = await lc_mcp_exec(request)
28
- return result
29
 
30
  # Gradio UI
31
  demo = gr.Interface(
 
19
  return configs["db_configs"]
20
 
21
  # Async-compatible wrapper
22
+ async def run_agent(request, history):
23
  # configs = load_db_configs()
24
  # final_answer, last_tool_answer, = await pg_mcp_exec(request)
25
  # return final_answer, last_tool_answer
26
 
27
+ response, message = await lc_mcp_exec(request, history)
28
+ return response
29
 
30
  # Gradio UI
31
  demo = gr.Interface(
langchain_mcp_client.py CHANGED
@@ -5,66 +5,112 @@ from mcp import ClientSession, StdioServerParameters
5
  from mcp.client.stdio import stdio_client
6
  from langchain_mcp_adapters.tools import load_mcp_tools
7
  from langgraph.prebuilt import create_react_agent
 
 
 
 
8
  from langchain.chat_models import init_chat_model
9
  from utils import parse_mcp_output, classify_intent
10
- from langchain.memory import ConversationBufferMemory
11
- from langchain_core.messages import AIMessage, HumanMessage
12
- import asyncio
13
  import logging
14
- import json
15
  from dotenv import load_dotenv
16
 
 
17
  load_dotenv()
18
  logger = logging.getLogger(__name__)
19
 
20
 
21
  async def lc_mcp_exec(request: str) -> tuple[Any, Any]:
22
  """
23
- Execute the full PostgreSQL MCP pipeline: load summary, connect session,
24
- load memory and tools, build prompt, run agent, update memory.
 
 
25
 
26
- Args:
27
- request (str): User's request input.
28
- llm (Any): Language model for reasoning agent.
 
29
 
30
- Returns:
31
- str: Agent response message.
32
- """
33
- # TODO: give summary file path from config
34
- table_summary = load_table_summary("table_summary.txt")
35
- server_params = get_server_params()
36
 
37
- api_key = os.getenv("API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- llm = init_chat_model(
40
- model="gemini-2.0-flash",
41
- model_provider="google_genai",
42
- api_key=api_key,
43
- temperature=0.5,
44
- )
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- async with stdio_client(server_params) as (read, write):
47
- async with ClientSession(read, write) as session:
48
- await session.initialize()
49
- memory = load_or_create_memory()
50
 
51
- tools = await load_and_enrich_tools(session, table_summary)
52
- past_data = get_memory_snapshot(memory)
 
 
53
 
54
- intent = classify_intent(request)
55
- prompt = await build_prompt(session, intent, request, tools, table_summary, past_data)
 
56
 
57
- agent = create_react_agent(llm, tools)
58
- agent_response = await agent.ainvoke({"messages": prompt})
59
 
60
- parsed_steps, final_answer, last_tool_answer, _ = parse_mcp_output(agent_response)
61
- # Add memory update before return
62
- memory.chat_memory.add_message(HumanMessage(content=request))
63
- memory.chat_memory.add_message(AIMessage(content=final_answer))
64
 
65
- await handle_memory_save_or_reset(memory, request)
 
 
 
 
 
 
 
66
 
67
- return final_answer, last_tool_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  # ---------------- Helper Functions ---------------- #
@@ -77,40 +123,18 @@ def get_server_params() -> StdioServerParameters:
77
  # TODO: give server params from config
78
  return StdioServerParameters(
79
  command="python",
80
- args=[r"C:\Users\yukcus\Desktop\query_mcp_server\postgre_mcp_server.py"],
81
  )
82
 
83
- def load_or_create_memory() -> ConversationBufferMemory:
84
- memory = ConversationBufferMemory(return_messages=True)
85
- # You can optionally load from a file or a store if needed
86
- if os.path.exists("memory.json"):
87
- try:
88
- with open("memory.json", "r") as f:
89
- history = json.load(f)
90
- for msg in history:
91
- if msg["type"] == "human":
92
- memory.chat_memory.add_message(HumanMessage(content=msg["content"]))
93
- elif msg["type"] == "ai":
94
- memory.chat_memory.add_message(AIMessage(content=msg["content"]))
95
- except Exception as e:
96
- logger.warning(f"Failed to load memory: {e}")
97
- return memory
98
-
99
-
100
- async def load_and_enrich_tools(session: ClientSession, summary: str):
101
  tools = await load_mcp_tools(session)
102
  return tools
103
 
104
- def get_memory_snapshot(memory: ConversationBufferMemory) -> dict:
105
- return {
106
- "chat_history": "\n".join([f"{m.type}: {m.content}" for m in memory.chat_memory.messages])
107
- }
108
-
109
-
110
- async def build_prompt(session, intent, request, tools, summary, history):
111
  superset_prompt = await session.read_resource("resource://last_prompt")
112
  conversation_prompt = await session.read_resource("resource://base_prompt")
113
 
 
114
  if intent == "superset_request":
115
  template = superset_prompt.contents[0].text
116
  return template.format(
@@ -119,27 +143,25 @@ async def build_prompt(session, intent, request, tools, summary, history):
119
  else:
120
  template = conversation_prompt.contents[0].text
121
  tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
 
 
 
 
 
 
 
 
 
 
 
 
122
  return template.format(
123
  new_request=request,
124
  tools=tools_str,
125
  descriptions=summary,
126
- chat_history = history
 
 
 
127
  )
128
 
129
-
130
- async def handle_memory_save_or_reset(memory: ConversationBufferMemory, request: str):
131
- if request.strip().lower() == "stop":
132
- memory.clear()
133
- if os.path.exists("memory.json"):
134
- os.remove("memory.json")
135
- logger.info("Conversation memory reset.")
136
- else:
137
- history = []
138
- for msg in memory.chat_memory.messages:
139
- if isinstance(msg, HumanMessage):
140
- history.append({"type": "human", "content": msg.content})
141
- elif isinstance(msg, AIMessage):
142
- history.append({"type": "ai", "content": msg.content})
143
- with open("memory.json", "w") as f:
144
- json.dump(history, f, indent=2)
145
- logger.info("Conversation memory saved.")
 
5
  from mcp.client.stdio import stdio_client
6
  from langchain_mcp_adapters.tools import load_mcp_tools
7
  from langgraph.prebuilt import create_react_agent
8
+ from langchain_core.prompts import PromptTemplate
9
+ from langchain_core.messages import AIMessage, HumanMessage
10
+ from langchain.memory import ChatMessageHistory
11
+ from langchain_community.chat_message_histories import FileChatMessageHistory
12
  from langchain.chat_models import init_chat_model
13
  from utils import parse_mcp_output, classify_intent
 
 
 
14
  import logging
 
15
  from dotenv import load_dotenv
16
 
17
+
18
  load_dotenv()
19
  logger = logging.getLogger(__name__)
20
 
21
 
22
  async def lc_mcp_exec(request: str) -> tuple[Any, Any]:
23
  """
24
+ Execute the full PostgreSQL MCP pipeline with persistent memory.
25
+ """
26
+ try:
27
+ history_file = os.path.join(os.path.dirname(__file__), "chat_history.json")
28
 
29
+ # Initialize chat history file if it doesn't exist or is empty
30
+ if not os.path.exists(history_file) or os.path.getsize(history_file) == 0:
31
+ with open(history_file, 'w') as f:
32
+ json.dump({"messages": []}, f)
33
 
34
+ message_history = FileChatMessageHistory(file_path=history_file)
 
 
 
 
 
35
 
36
+ try:
37
+ # Load existing messages or handle bootstrap scenario
38
+ existing_messages = message_history.messages
39
+ except json.JSONDecodeError:
40
+ # If JSON is corrupted, reinitialize the file
41
+ logger.warning("Chat history file corrupted, reinitializing...")
42
+ with open(history_file, 'w') as f:
43
+ json.dump({"messages": []}, f)
44
+ existing_messages = []
45
+
46
+ # Format existing messages properly
47
+ formatted_history = []
48
+ for msg in existing_messages:
49
+ if isinstance(msg, HumanMessage):
50
+ formatted_history.append(HumanMessage(content=msg.content))
51
+ elif isinstance(msg, AIMessage):
52
+ formatted_history.append(AIMessage(content=msg.content))
53
 
54
+ # TODO: give summary file path from config
55
+ table_summary = load_table_summary("table_summary.txt")
56
+ server_params = get_server_params()
57
+
58
+ api_key = os.getenv("API_KEY")
59
+ llm = init_chat_model(model="gemini-2.0-flash", model_provider="google_genai",
60
+ api_key=api_key)
61
+
62
+ async with stdio_client(server_params) as (read, write):
63
+ async with ClientSession(read, write) as session:
64
+ await session.initialize()
65
+
66
+ tools = await load_and_enrich_tools(session)
67
+ intent = classify_intent(request)
68
+
69
+ # Add new user message before processing
70
+ message_history.add_user_message(request)
71
 
72
+ # Create agent and prepare system message
73
+ agent = create_react_agent(llm, tools)
 
 
74
 
75
+ # Create base messages list with system message
76
+ base_message = HumanMessage(content="""You are a PostgreSQL database expert assistant.
77
+ Use the conversation history for context when available.""")
78
+ messages = [base_message]
79
 
80
+ # Add history if exists
81
+ if formatted_history:
82
+ messages.extend(formatted_history)
83
 
84
+ # Add current request
85
+ messages.append(HumanMessage(content=request))
86
 
87
+ # Build prompt with conversation context
88
+ prompt = await build_prompt(session, intent, request, tools, table_summary, formatted_history)
 
 
89
 
90
+ # Invoke agent with proper message structure
91
+ agent_response = await agent.ainvoke(
92
+ {
93
+ "messages": prompt,
94
+ "chat_history": [msg.content for msg in formatted_history]
95
+ },
96
+ config={"configurable": {"thread_id": "conversation_123"}}
97
+ )
98
 
99
+ if "messages" in agent_response:
100
+ response = agent_response["messages"][-1].content
101
+ # Save assistant response
102
+ message_history.add_ai_message(response)
103
+ else:
104
+ response = "No response generated"
105
+ message_history.add_ai_message(response)
106
+
107
+ # Return current response and up-to-date messages
108
+ return response, message_history.messages
109
+
110
+ except Exception as e:
111
+ logger.error(f"Error in chat history handling: {str(e)}", exc_info=True)
112
+ # Fallback to stateless response if history fails
113
+ return f"Error in conversation: {str(e)}", []
114
 
115
 
116
  # ---------------- Helper Functions ---------------- #
 
123
  # TODO: give server params from config
124
  return StdioServerParameters(
125
  command="python",
126
+ args=[os.environ["MCP_SERVER_PATH"]],
127
  )
128
 
129
+ async def load_and_enrich_tools(session: ClientSession):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  tools = await load_mcp_tools(session)
131
  return tools
132
 
133
+ async def build_prompt(session, intent, request, tools, summary, chat_history):
 
 
 
 
 
 
134
  superset_prompt = await session.read_resource("resource://last_prompt")
135
  conversation_prompt = await session.read_resource("resource://base_prompt")
136
 
137
+
138
  if intent == "superset_request":
139
  template = superset_prompt.contents[0].text
140
  return template.format(
 
143
  else:
144
  template = conversation_prompt.contents[0].text
145
  tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
146
+
147
+ # Handle history formatting with proper message access
148
+ history_str = ""
149
+ if chat_history:
150
+ history_sections = []
151
+ for msg in chat_history:
152
+ if isinstance(msg, HumanMessage):
153
+ history_sections.append(f"Previous Human Question:\n{msg.content}\n")
154
+ elif isinstance(msg, AIMessage):
155
+ history_sections.append(f"Previous Assistant Response:\n{msg.content}\n")
156
+ history_str = "\n".join(history_sections)
157
+
158
  return template.format(
159
  new_request=request,
160
  tools=tools_str,
161
  descriptions=summary,
162
+ chat_history=f"\nPrevious Conversation History:\n{history_str}" if history_str else "\nThis is a new conversation.",
163
+ system_instructions="""You are a PostgreSQL database expert assistant.
164
+ Use the conversation history when available to maintain context.
165
+ For new conversations, focus on understanding the initial request."""
166
  )
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
postgre_smolagent_client.py CHANGED
@@ -8,6 +8,8 @@ from conversation_memory import ConversationMemory
8
  from utils import parse_mcp_output, classify_intent
9
  import logging
10
  from smolagents import LiteLLMModel, ToolCollection, CodeAgent
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -63,9 +65,10 @@ def load_table_summary(path: str) -> str:
63
 
64
  def get_server_params() -> StdioServerParameters:
65
  # TODO: give server params from config
 
66
  return StdioServerParameters(
67
  command="python",
68
- args=[r"/home/amirkia/Desktop/query_mcp_server/postgre_mcp_server.py"],
69
  )
70
 
71
  async def load_or_create_memory() -> ConversationMemory:
 
8
  from utils import parse_mcp_output, classify_intent
9
  import logging
10
  from smolagents import LiteLLMModel, ToolCollection, CodeAgent
11
+ from dotenv import load_dotenv
12
+
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
65
 
66
  def get_server_params() -> StdioServerParameters:
67
  # TODO: give server params from config
68
+ load_dotenv()
69
  return StdioServerParameters(
70
  command="python",
71
+ args=[os.environ["MCP_SERVER_PATH"]],
72
  )
73
 
74
  async def load_or_create_memory() -> ConversationMemory:
run.sh CHANGED
@@ -1,6 +1,6 @@
1
  #!/bin/bash
2
 
3
  # Replace 'myenv' with the name of your conda environment
4
- conda activate myenv
5
 
6
  python gradio_app.py
 
1
  #!/bin/bash
2
 
3
  # Replace 'myenv' with the name of your conda environment
4
+ # conda activate myenv
5
 
6
  python gradio_app.py