ohmygaugh commited on
Commit
8595be6
·
1 Parent(s): e02fab5

fixed one more bug. during stress testing

Browse files
mcp/core/intelligence.py CHANGED
@@ -27,34 +27,43 @@ def _get_database_for_table(table_name: str) -> str | None:
27
  async def execute_federated_query(sql: str) -> List[Dict[str, Any]]:
28
  """
29
  Executes a SQL query against the correct SQLite database.
30
- This is a simplified version of a federated query engine. It identifies the
31
- target database from the first table name in the SQL query.
32
  """
33
  parsed = sqlparse.parse(sql)[0]
34
  target_table = None
35
 
36
- # Find the first table name in the parsed SQL
 
37
  for token in parsed.tokens:
38
- if isinstance(token, sqlparse.sql.Identifier):
 
 
 
39
  target_table = token.get_real_name()
40
  break
41
- elif token.is_group:
42
  for sub_token in token.tokens:
43
  if isinstance(sub_token, sqlparse.sql.Identifier):
44
  target_table = sub_token.get_real_name()
45
  break
46
- if target_table:
47
- break
48
 
49
  if not target_table:
50
  raise ValueError("Could not identify a target table in the SQL query.")
51
 
52
  logger.info(f"Identified target table: {target_table}")
53
 
54
- # Determine which database engine to use
55
  db_name = _get_database_for_table(target_table)
56
  if not db_name:
57
  raise ValueError(f"Table '{target_table}' not found in any known database.")
 
 
 
 
 
 
58
 
59
  db_engines = get_db_connections()
60
  engine = db_engines.get(db_name)
 
27
  async def execute_federated_query(sql: str) -> List[Dict[str, Any]]:
28
  """
29
  Executes a SQL query against the correct SQLite database.
30
+ Strips database prefixes from table names (e.g., clinical_trials.patients → patients).
 
31
  """
32
  parsed = sqlparse.parse(sql)[0]
33
  target_table = None
34
 
35
+ # Find table name from FROM clause
36
+ from_found = False
37
  for token in parsed.tokens:
38
+ if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM':
39
+ from_found = True
40
+ continue
41
+ elif from_found and isinstance(token, sqlparse.sql.Identifier):
42
  target_table = token.get_real_name()
43
  break
44
+ elif from_found and token.is_group:
45
  for sub_token in token.tokens:
46
  if isinstance(sub_token, sqlparse.sql.Identifier):
47
  target_table = sub_token.get_real_name()
48
  break
49
+ if target_table:
50
+ break
51
 
52
  if not target_table:
53
  raise ValueError("Could not identify a target table in the SQL query.")
54
 
55
  logger.info(f"Identified target table: {target_table}")
56
 
57
+ # Determine which database this table belongs to
58
  db_name = _get_database_for_table(target_table)
59
  if not db_name:
60
  raise ValueError(f"Table '{target_table}' not found in any known database.")
61
+
62
+ # Strip all database prefixes from SQL (e.g., "clinical_trials.patients" → "patients")
63
+ for known_db in ["clinical_trials", "laboratory", "drug_discovery"]:
64
+ sql = sql.replace(f"{known_db}.", "")
65
+
66
+ logger.info(f"Cleaned SQL for database '{db_name}': {sql}")
67
 
68
  db_engines = get_db_connections()
69
  engine = db_engines.get(db_name)
mcp/requirements.txt CHANGED
@@ -4,5 +4,4 @@ neo4j==5.14.0
4
  pydantic==2.4.0
5
  requests==2.31.0
6
  SQLAlchemy==2.0.29
7
- sqlparse==0.5.0
8
- mcp==1.1.1
 
4
  pydantic==2.4.0
5
  requests==2.31.0
6
  SQLAlchemy==2.0.29
7
+ sqlparse==0.5.0
 
streamlit/app.py CHANGED
@@ -29,6 +29,8 @@ if 'messages' not in st.session_state:
29
  st.session_state.messages = []
30
  if 'schema_info' not in st.session_state:
31
  st.session_state.schema_info = ""
 
 
32
 
33
  # --- Helper Functions ---
34
  def stream_agent_response(question: str):
@@ -125,6 +127,42 @@ def display_sidebar():
125
  st.session_state.messages = []
126
  st.rerun()
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def main():
129
  display_sidebar()
130
  st.title("💬 GraphRAG Conversational Agent")
@@ -134,6 +172,15 @@ def main():
134
  for message in st.session_state.messages:
135
  with st.chat_message(message["role"]):
136
  st.markdown(message["content"])
 
 
 
 
 
 
 
 
 
137
 
138
  if prompt := st.chat_input("Ask your question here..."):
139
  st.session_state.messages.append({"role": "user", "content": prompt})
@@ -143,6 +190,7 @@ def main():
143
  with st.chat_message("assistant"):
144
  full_response = ""
145
  response_box = st.empty()
 
146
 
147
  for chunk in stream_agent_response(prompt):
148
  if "error" in chunk:
@@ -156,12 +204,34 @@ def main():
156
  full_response += f"🤔 *{content}*\n\n"
157
  elif chunk.get("type") == "observation":
158
  full_response += f"{content}\n\n"
 
 
 
 
159
  elif chunk.get("type") == "final_answer":
160
  full_response += f"**Final Answer:**\n{content}"
161
 
162
  response_box.markdown(full_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- st.session_state.messages.append({"role": "assistant", "content": full_response})
 
 
 
 
165
 
166
  if __name__ == "__main__":
167
  main()
 
29
  st.session_state.messages = []
30
  if 'schema_info' not in st.session_state:
31
  st.session_state.schema_info = ""
32
+ if 'current_results' not in st.session_state:
33
+ st.session_state.current_results = None
34
 
35
  # --- Helper Functions ---
36
  def stream_agent_response(question: str):
 
127
  st.session_state.messages = []
128
  st.rerun()
129
 
130
+ def extract_sql_results(observation_content: str) -> pd.DataFrame | None:
131
+ """Extract SQL results from execute_query tool observation."""
132
+ try:
133
+ if "execute_query" not in observation_content or "returned:" not in observation_content:
134
+ return None
135
+
136
+ # Extract the content between triple backticks
137
+ if "```" in observation_content:
138
+ parts = observation_content.split("```")
139
+ if len(parts) >= 2:
140
+ result_text = parts[1].strip()
141
+
142
+ # Parse table format: "column1 | column2 | column3"
143
+ lines = [line.strip() for line in result_text.split('\n') if line.strip()]
144
+
145
+ if len(lines) < 3: # Need headers, separator, and at least one row
146
+ return None
147
+
148
+ # Parse headers
149
+ headers = [h.strip() for h in lines[0].split('|')]
150
+
151
+ # Parse data rows (skip separator line at index 1)
152
+ data_rows = []
153
+ for line in lines[2:]:
154
+ if "and" in line and "more rows" in line:
155
+ break
156
+ row_values = [v.strip() for v in line.split('|')]
157
+ if len(row_values) == len(headers):
158
+ data_rows.append(row_values)
159
+
160
+ if data_rows:
161
+ return pd.DataFrame(data_rows, columns=headers)
162
+ except Exception:
163
+ pass
164
+ return None
165
+
166
  def main():
167
  display_sidebar()
168
  st.title("💬 GraphRAG Conversational Agent")
 
172
  for message in st.session_state.messages:
173
  with st.chat_message(message["role"]):
174
  st.markdown(message["content"])
175
+ if message.get("dataframe") is not None:
176
+ st.dataframe(message["dataframe"], use_container_width=True)
177
+ csv = message["dataframe"].to_csv(index=False)
178
+ st.download_button(
179
+ label="📥 Download CSV",
180
+ data=csv,
181
+ file_name="query_results.csv",
182
+ mime="text/csv"
183
+ )
184
 
185
  if prompt := st.chat_input("Ask your question here..."):
186
  st.session_state.messages.append({"role": "user", "content": prompt})
 
190
  with st.chat_message("assistant"):
191
  full_response = ""
192
  response_box = st.empty()
193
+ sql_results_df = None
194
 
195
  for chunk in stream_agent_response(prompt):
196
  if "error" in chunk:
 
204
  full_response += f"🤔 *{content}*\n\n"
205
  elif chunk.get("type") == "observation":
206
  full_response += f"{content}\n\n"
207
+ # Try to extract SQL results
208
+ df = extract_sql_results(content)
209
+ if df is not None:
210
+ sql_results_df = df
211
  elif chunk.get("type") == "final_answer":
212
  full_response += f"**Final Answer:**\n{content}"
213
 
214
  response_box.markdown(full_response)
215
+
216
+ # Display DataFrame if SQL results were found
217
+ if sql_results_df is not None:
218
+ st.markdown("---")
219
+ st.markdown("### 📊 Query Results")
220
+ st.dataframe(sql_results_df, use_container_width=True)
221
+ csv = sql_results_df.to_csv(index=False)
222
+ st.download_button(
223
+ label="📥 Download CSV",
224
+ data=csv,
225
+ file_name="query_results.csv",
226
+ mime="text/csv",
227
+ key=f"download_{len(st.session_state.messages)}"
228
+ )
229
 
230
+ st.session_state.messages.append({
231
+ "role": "assistant",
232
+ "content": full_response,
233
+ "dataframe": sql_results_df
234
+ })
235
 
236
  if __name__ == "__main__":
237
  main()