amirkiarafiei commited on
Commit
1acd5e4
·
1 Parent(s): 434a9c4

feat: enhance chat history management and update .gitignore

Browse files
Files changed (4) hide show
  1. .gitignore +124 -1
  2. chat_history.json +1 -0
  3. langchain_mcp_client.py +104 -62
  4. run.sh +1 -1
.gitignore CHANGED
@@ -1,2 +1,125 @@
1
  .idea
2
- .env
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  .idea
2
+ .env
3
+ .vscode
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ *.cover
46
+ *.py,cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.coveragerc
52
+ *.pylintrc
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # celery
88
+ celerybeat-schedule
89
+ celerybeat.pid
90
+
91
+ # SageMath
92
+ *.sage.py
93
+
94
+ # Environments
95
+ .env/
96
+ .venv/
97
+ env/
98
+ venv/
99
+ ENV/
100
+ env.bak/
101
+ venv.bak/
102
+
103
+ # Spyder project settings
104
+ .spyderproject
105
+ .spyproject
106
+
107
+ # Rope project settings
108
+ .ropeproject
109
+
110
+ # mkdocs documentation
111
+ /site
112
+
113
+ # mypy
114
+ .mypy_cache/
115
+ .dmypy.json
116
+ dmypy.json
117
+
118
+ # Pyre type checker
119
+ .pyre/
120
+
121
+ # pytype static type analyzer
122
+ .pytype/
123
+
124
+ # Cython debug symbols
125
+ cython_debug/
chat_history.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "# Result\nThe tables in the database are:\n* dim\\_agreement\n* dim\\_customer\n* dim\\_product\n* dim\\_product\\_order\\_item\n\n# Explanation\nThe `list_tables` tool was called to retrieve a list of all available tables in the database schema. The result shows the names of these tables.\n\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "what was my previous request ?", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "# Result\nlist all tables\n# Explanation\nThe user's previous request was to list all the tables in the database. This was determined by reviewing the previous turn in the conversation history.\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "count tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\n4\n# Explanation\nThe user requested the number of tables. The previous response listed the tables, and counting them gives the answer.\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "what was my last message ?", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\ncount tables\n# Explanation\nThe user's last message was \"count tables\". This was determined by reviewing the previous turn in the conversation history.\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "what was my very first message in this conversation?", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\nlist all tables\n# Explanation\nThe user's very first message in this conversation was \"list all tables\". This was determined by reviewing the conversation history.\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "total number of sales in 2024", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\nThe tables in the database are:\n* dim\\_agreement\n* dim\\_customer\n* dim\\_product\n* dim\\_product\\_order\\_item\n\n# Explanation\nThe `list_tables` tool was called to retrieve a list of all available tables in the database schema. The result shows the names of these tables.\n\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}, {"type": "human", "data": {"content": "list all tables", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "name": null, "id": null, "example": false}}, {"type": "ai", "data": {"content": "```markdown\n# Result\nThe tables in the database are:\n* dim_agreement\n* dim_customer\n* dim_product\n* dim_product_order_item\n\n# Explanation\nThe `list_tables` tool was called to retrieve a list of all available tables in the database schema. The result shows the names of these tables.\n\n# Query\n```sql\nN/A\n```", "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": null, "id": null, "example": false, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": null}}]
langchain_mcp_client.py CHANGED
@@ -1,4 +1,5 @@
1
  import os.path
 
2
  from typing import Tuple, Any
3
 
4
  from mcp import ClientSession, StdioServerParameters
@@ -7,11 +8,12 @@ from langchain_mcp_adapters.tools import load_mcp_tools
7
  from langgraph.prebuilt import create_react_agent
8
  from langchain_core.prompts import PromptTemplate
9
  from langchain_core.messages import AIMessage, HumanMessage
 
 
10
  from langchain.chat_models import init_chat_model
11
  from utils import parse_mcp_output, classify_intent
12
  import logging
13
  from dotenv import load_dotenv
14
- from langgraph.checkpoint.memory import MemorySaver
15
 
16
 
17
 
@@ -19,71 +21,96 @@ logger = logging.getLogger(__name__)
19
  load_dotenv()
20
  async def lc_mcp_exec(request: str, history) -> tuple[Any, Any]:
21
  """
22
- Execute the full PostgreSQL MCP pipeline: load summary, connect session,
23
- load memory and tools, build prompt, run agent, update memory.
24
-
25
- Args:
26
- request (str): User's request input.
27
- llm (Any): Language model for reasoning agent.
28
-
29
- Returns:
30
- str: Agent response message.
31
  """
32
- # TODO: give summary file path from config
33
- table_summary = load_table_summary("table_summary.txt")
34
- server_params = get_server_params()
35
-
36
- api_key = os.getenv("API_KEY")
37
- llm = init_chat_model(model="gemini-2.0-flash", model_provider="google_genai",
38
- api_key=api_key)
39
-
40
- async with stdio_client(server_params) as (read, write):
41
- async with ClientSession(read, write) as session:
42
- await session.initialize()
43
-
44
- tools = await load_and_enrich_tools(session)
45
-
46
- intent = classify_intent(request)
47
-
48
- messages = []
49
- memory = MemorySaver()
50
- agent = create_react_agent(llm, tools, checkpointer=memory)
51
-
52
- messages.append(HumanMessage(content=request))
53
- if history:
54
- # Clear existing messages and rebuild from history
55
- messages = []
56
-
57
- # Process Gradio chat history format
58
- for msg in history:
59
- # Gradio format: {'role': 'user/assistant', 'metadata': None, 'content': 'message', 'options': None}
60
- role = msg.get('role', '')
61
- content = msg.get('content', '')
62
-
63
- if role == 'user' and content:
64
- messages.append(HumanMessage(content=content))
65
- elif role == 'assistant' and content:
66
- messages.append(AIMessage(content=content))
67
-
68
- # Add the current query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  messages.append(HumanMessage(content=request))
70
 
71
- prompt = await build_prompt(session, intent, request, tools, table_summary, messages)
72
- config = {"configurable": {"thread_id": "conversation_123"}}
73
- agent_response = await agent.ainvoke(
74
- {"messages": prompt},
75
- config
76
- )
77
 
78
- if "messages" in agent_response:
79
- response = agent_response["messages"][-1].content
80
- else:
81
- response = "No response generated"
 
 
 
 
82
 
83
- messages.append(AIMessage(content=response))
 
 
 
 
 
 
84
 
 
 
85
 
86
- return response, messages
 
 
 
87
 
88
 
89
  # ---------------- Helper Functions ---------------- #
@@ -104,10 +131,10 @@ async def load_and_enrich_tools(session: ClientSession):
104
  tools = await load_mcp_tools(session)
105
  return tools
106
 
107
- async def build_prompt(session, intent, request, tools, summary, messages):
108
  superset_prompt = await session.read_resource("resource://last_prompt")
109
  conversation_prompt = await session.read_resource("resource://base_prompt")
110
- # TODO: add uri's from config
111
  if intent == "superset_request":
112
  template = superset_prompt.contents[0].text
113
  return template.format(
@@ -116,10 +143,25 @@ async def build_prompt(session, intent, request, tools, summary, messages):
116
  else:
117
  template = conversation_prompt.contents[0].text
118
  tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
 
 
 
 
 
 
 
 
 
 
 
 
119
  return template.format(
120
  new_request=request,
121
  tools=tools_str,
122
  descriptions=summary,
123
- chat_history = messages
 
 
 
124
  )
125
 
 
1
  import os.path
2
+ import json
3
  from typing import Tuple, Any
4
 
5
  from mcp import ClientSession, StdioServerParameters
 
8
  from langgraph.prebuilt import create_react_agent
9
  from langchain_core.prompts import PromptTemplate
10
  from langchain_core.messages import AIMessage, HumanMessage
11
+ from langchain.memory import ChatMessageHistory
12
+ from langchain_community.chat_message_histories import FileChatMessageHistory
13
  from langchain.chat_models import init_chat_model
14
  from utils import parse_mcp_output, classify_intent
15
  import logging
16
  from dotenv import load_dotenv
 
17
 
18
 
19
 
 
21
  load_dotenv()
22
  async def lc_mcp_exec(request: str, history) -> tuple[Any, Any]:
23
  """
24
+ Execute the full PostgreSQL MCP pipeline with persistent memory.
 
 
 
 
 
 
 
 
25
  """
26
+ try:
27
+ history_file = os.path.join(os.path.dirname(__file__), "chat_history.json")
28
+
29
+ # Initialize chat history file if it doesn't exist or is empty
30
+ if not os.path.exists(history_file) or os.path.getsize(history_file) == 0:
31
+ with open(history_file, 'w') as f:
32
+ json.dump({"messages": []}, f)
33
+
34
+ message_history = FileChatMessageHistory(file_path=history_file)
35
+
36
+ try:
37
+ # Load existing messages or handle bootstrap scenario
38
+ existing_messages = message_history.messages
39
+ except json.JSONDecodeError:
40
+ # If JSON is corrupted, reinitialize the file
41
+ logger.warning("Chat history file corrupted, reinitializing...")
42
+ with open(history_file, 'w') as f:
43
+ json.dump({"messages": []}, f)
44
+ existing_messages = []
45
+
46
+ # Format existing messages properly
47
+ formatted_history = []
48
+ for msg in existing_messages:
49
+ if isinstance(msg, HumanMessage):
50
+ formatted_history.append(HumanMessage(content=msg.content))
51
+ elif isinstance(msg, AIMessage):
52
+ formatted_history.append(AIMessage(content=msg.content))
53
+
54
+ # TODO: give summary file path from config
55
+ table_summary = load_table_summary("table_summary.txt")
56
+ server_params = get_server_params()
57
+
58
+ api_key = os.getenv("API_KEY")
59
+ llm = init_chat_model(model="gemini-2.0-flash", model_provider="google_genai",
60
+ api_key=api_key)
61
+
62
+ async with stdio_client(server_params) as (read, write):
63
+ async with ClientSession(read, write) as session:
64
+ await session.initialize()
65
+
66
+ tools = await load_and_enrich_tools(session)
67
+ intent = classify_intent(request)
68
+
69
+ # Add new user message before processing
70
+ message_history.add_user_message(request)
71
+
72
+ # Create agent and prepare system message
73
+ agent = create_react_agent(llm, tools)
74
+
75
+ # Create base messages list with system message
76
+ base_message = HumanMessage(content="""You are a PostgreSQL database expert assistant.
77
+ Use the conversation history for context when available.""")
78
+ messages = [base_message]
79
+
80
+ # Add history if exists
81
+ if formatted_history:
82
+ messages.extend(formatted_history)
83
+
84
+ # Add current request
85
  messages.append(HumanMessage(content=request))
86
 
87
+ # Build prompt with conversation context
88
+ prompt = await build_prompt(session, intent, request, tools, table_summary, formatted_history)
 
 
 
 
89
 
90
+ # Invoke agent with proper message structure
91
+ agent_response = await agent.ainvoke(
92
+ {
93
+ "messages": prompt,
94
+ "chat_history": [msg.content for msg in formatted_history]
95
+ },
96
+ config={"configurable": {"thread_id": "conversation_123"}}
97
+ )
98
 
99
+ if "messages" in agent_response:
100
+ response = agent_response["messages"][-1].content
101
+ # Save assistant response
102
+ message_history.add_ai_message(response)
103
+ else:
104
+ response = "No response generated"
105
+ message_history.add_ai_message(response)
106
 
107
+ # Return current response and up-to-date messages
108
+ return response, message_history.messages
109
 
110
+ except Exception as e:
111
+ logger.error(f"Error in chat history handling: {str(e)}", exc_info=True)
112
+ # Fallback to stateless response if history fails
113
+ return f"Error in conversation: {str(e)}", []
114
 
115
 
116
  # ---------------- Helper Functions ---------------- #
 
131
  tools = await load_mcp_tools(session)
132
  return tools
133
 
134
+ async def build_prompt(session, intent, request, tools, summary, chat_history):
135
  superset_prompt = await session.read_resource("resource://last_prompt")
136
  conversation_prompt = await session.read_resource("resource://base_prompt")
137
+
138
  if intent == "superset_request":
139
  template = superset_prompt.contents[0].text
140
  return template.format(
 
143
  else:
144
  template = conversation_prompt.contents[0].text
145
  tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
146
+
147
+ # Handle history formatting with proper message access
148
+ history_str = ""
149
+ if chat_history:
150
+ history_sections = []
151
+ for msg in chat_history:
152
+ if isinstance(msg, HumanMessage):
153
+ history_sections.append(f"Previous Human Question:\n{msg.content}\n")
154
+ elif isinstance(msg, AIMessage):
155
+ history_sections.append(f"Previous Assistant Response:\n{msg.content}\n")
156
+ history_str = "\n".join(history_sections)
157
+
158
  return template.format(
159
  new_request=request,
160
  tools=tools_str,
161
  descriptions=summary,
162
+ chat_history=f"\nPrevious Conversation History:\n{history_str}" if history_str else "\nThis is a new conversation.",
163
+ system_instructions="""You are a PostgreSQL database expert assistant.
164
+ Use the conversation history when available to maintain context.
165
+ For new conversations, focus on understanding the initial request."""
166
  )
167
 
run.sh CHANGED
@@ -1,6 +1,6 @@
1
  #!/bin/bash
2
 
3
  # Replace 'myenv' with the name of your conda environment
4
- conda activate myenv
5
 
6
  python gradio_app.py
 
1
  #!/bin/bash
2
 
3
  # Replace 'myenv' with the name of your conda environment
4
+ # conda activate myenv
5
 
6
  python gradio_app.py