ohmygaugh commited on
Commit
9d411a7
·
1 Parent(s): 86cbe3c

This fixed the docker container health errors. just there is no mcp connection still.

Browse files
README.md CHANGED
@@ -45,7 +45,7 @@ This project implements an intelligent, multi-step GraphRAG-powered agent that u
45
 
46
  ### Prerequisites
47
  - Docker & Docker Compose
48
- - OpenAI API key
49
 
50
  ### Setup
51
  1. **Clone and configure**:
@@ -55,9 +55,9 @@ This project implements an intelligent, multi-step GraphRAG-powered agent that u
55
  touch .env
56
  ```
57
 
58
- 2. **Add your OpenAI API key** to the `.env` file. This is the only secret you need to provide.
59
  ```
60
- OPENAI_API_KEY="sk-your-openai-key-here"
61
  ```
62
 
63
  3. **Start the system**:
@@ -104,7 +104,7 @@ To test the agent's logic directly without the full Docker stack, you can run it
104
 
105
  2. **Set your API key**:
106
  ```bash
107
- export OPENAI_API_KEY="sk-your-openai-key-here"
108
  ```
109
 
110
  3. **Run the agent**:
 
45
 
46
  ### Prerequisites
47
  - Docker & Docker Compose
48
+ - LLM API key (e.g., for OpenAI)
49
 
50
  ### Setup
51
  1. **Clone and configure**:
 
55
  touch .env
56
  ```
57
 
58
+ 2. **Add your LLM API key** to the `.env` file.
59
  ```
60
+ LLM_API_KEY="sk-your-llm-api-key-here"
61
  ```
62
 
63
  3. **Start the system**:
 
104
 
105
  2. **Set your API key**:
106
  ```bash
107
+ export LLM_API_KEY="sk-your-llm-api-key-here"
108
  ```
109
 
110
  3. **Run the agent**:
agent/main.py CHANGED
@@ -3,17 +3,18 @@ import sys
3
  import logging
4
  import json
5
  from typing import Annotated, List, TypedDict
 
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
  import uvicorn
9
  from fastapi.responses import StreamingResponse
10
 
11
  from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
12
- from langchain_openai import OpenAI
13
  from langgraph.graph import StateGraph, START, END
14
  from langgraph.prebuilt import ToolNode
15
 
16
- from agent.tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool
17
 
18
  # --- Configuration & Logging ---
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
21
 
22
  MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
23
  API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
24
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
25
 
26
  # --- Agent State Definition ---
27
  class AgentState(TypedDict):
@@ -32,10 +33,10 @@ class GraphRAGAgent:
32
  """The core agent for handling GraphRAG queries using LangGraph."""
33
 
34
  def __init__(self):
35
- if not OPENAI_API_KEY:
36
- raise ValueError("OPENAI_API_KEY environment variable not set.")
37
 
38
- llm = OpenAI(api_key=OPENAI_API_KEY, temperature=0, max_retries=1)
39
 
40
  mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY)
41
  tools = [
@@ -93,21 +94,25 @@ class GraphRAGAgent:
93
  yield json.dumps({"type": "final_answer", "content": last_message.content}) + "\\n\\n"
94
 
95
  # --- FastAPI Application ---
96
- app = FastAPI(title="GraphRAG Agent Server")
97
  agent = None
98
 
99
- class QueryRequest(BaseModel):
100
- question: str
101
-
102
- @app.on_event("startup")
103
- def startup_event():
104
- """Initialize the agent on server startup."""
105
  global agent
 
106
  try:
107
  agent = GraphRAGAgent()
108
  logger.info("GraphRAGAgent initialized successfully.")
109
  except ValueError as e:
110
  logger.error(f"Agent initialization failed: {e}")
 
 
 
 
 
 
 
111
 
112
  @app.post("/query")
113
  async def execute_query(request: QueryRequest) -> StreamingResponse:
 
3
  import logging
4
  import json
5
  from typing import Annotated, List, TypedDict
6
+ from contextlib import asynccontextmanager
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
  import uvicorn
10
  from fastapi.responses import StreamingResponse
11
 
12
  from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
13
+ from langchain_openai import ChatOpenAI
14
  from langgraph.graph import StateGraph, START, END
15
  from langgraph.prebuilt import ToolNode
16
 
17
+ from tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool
18
 
19
  # --- Configuration & Logging ---
20
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
22
 
23
  MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
24
  API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
25
+ LLM_API_KEY = os.getenv("LLM_API_KEY")
26
 
27
  # --- Agent State Definition ---
28
  class AgentState(TypedDict):
 
33
  """The core agent for handling GraphRAG queries using LangGraph."""
34
 
35
  def __init__(self):
36
+ if not LLM_API_KEY:
37
+ raise ValueError("LLM_API_KEY environment variable not set.")
38
 
39
+ llm = ChatOpenAI(api_key=LLM_API_KEY, model="gpt-4o-mini", temperature=0, max_retries=1)
40
 
41
  mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY)
42
  tools = [
 
94
  yield json.dumps({"type": "final_answer", "content": last_message.content}) + "\\n\\n"
95
 
96
  # --- FastAPI Application ---
 
97
  agent = None
98
 
99
+ @asynccontextmanager
100
+ async def lifespan(app: FastAPI):
101
+ """Handles agent initialization on startup."""
 
 
 
102
  global agent
103
+ logger.info("Agent server startup...")
104
  try:
105
  agent = GraphRAGAgent()
106
  logger.info("GraphRAGAgent initialized successfully.")
107
  except ValueError as e:
108
  logger.error(f"Agent initialization failed: {e}")
109
+ yield
110
+ logger.info("Agent server shutdown.")
111
+
112
+ app = FastAPI(title="GraphRAG Agent Server", lifespan=lifespan)
113
+
114
+ class QueryRequest(BaseModel):
115
+ question: str
116
 
117
  @app.post("/query")
118
  async def execute_query(request: QueryRequest) -> StreamingResponse:
agent/requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  requests
2
- python-dotenv
3
  langchain
4
  langchain-openai
5
  pydantic
 
1
  requests
 
2
  langchain
3
  langchain-openai
4
  pydantic
docker-compose.yml CHANGED
@@ -1,5 +1,3 @@
1
- version: '3.8'
2
-
3
  services:
4
  neo4j:
5
  build: ./neo4j
@@ -50,7 +48,7 @@ services:
50
  - MCP_URL=http://mcp:8000/mcp
51
  - MCP_API_KEY=dev-key-123
52
  - AGENT_POLL_INTERVAL=${AGENT_POLL_INTERVAL}
53
- - OPENAI_API_KEY=${OPENAI_API_KEY}
54
  depends_on:
55
  mcp:
56
  condition: service_healthy
 
 
 
1
  services:
2
  neo4j:
3
  build: ./neo4j
 
48
  - MCP_URL=http://mcp:8000/mcp
49
  - MCP_API_KEY=dev-key-123
50
  - AGENT_POLL_INTERVAL=${AGENT_POLL_INTERVAL}
51
+ - LLM_API_KEY=${LLM_API_KEY}
52
  depends_on:
53
  mcp:
54
  condition: service_healthy
mcp/core/config.py CHANGED
@@ -1,22 +1,21 @@
1
  import os
2
 
3
  # --- Neo4j Configuration ---
4
- NEO4J_URI = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
5
- NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
6
- NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
7
 
8
- # --- SQLite Configuration ---
9
- SQLITE_DATA_DIR = os.getenv("SQLITE_DATA_DIR", "/app/data")
 
 
10
 
11
- def get_sqlite_connection_string(db_name: str) -> str:
12
- """
13
- Generates the SQLAlchemy connection string for a given SQLite database file.
14
- Assumes the database file is located in the SQLITE_DATA_DIR.
15
- Example: get_sqlite_connection_string("clinical_trials.db")
16
- -> "sqlite:////app/data/clinical_trials.db"
17
- """
18
- db_path = os.path.join(SQLITE_DATA_DIR, db_name)
19
- return f"sqlite:///{db_path}"
20
 
21
  # --- Application Settings ---
22
  # You can add other application-wide settings here
 
1
  import os
2
 
3
  # --- Neo4j Configuration ---
4
+ NEO4J_URI = os.getenv("NEO4J_BOLT_URL", "bolt://localhost:7687")
5
+ NEO4J_USER = "neo4j"
 
6
 
7
+ # The NEO4J_AUTH env var is in the format 'neo4j/password'
8
+ # We need to extract the password part.
9
+ neo4j_auth = os.getenv("NEO4J_AUTH", "neo4j/password")
10
+ NEO4J_PASSWORD = neo4j_auth.split('/')[1] if '/' in neo4j_auth else neo4j_auth
11
 
12
+ # --- Database Configuration ---
13
+ # A dictionary of connection strings for the SQLite databases
14
+ DB_CONNECTIONS = {
15
+ "clinical_trials": f"sqlite:////app/data/clinical_trials.db",
16
+ "drug_discovery": f"sqlite:////app/data/drug_discovery.db",
17
+ "laboratory": f"sqlite:////app/data/laboratory.db",
18
+ }
 
 
19
 
20
  # --- Application Settings ---
21
  # You can add other application-wide settings here
mcp/core/database.py CHANGED
@@ -1,26 +1,39 @@
1
  from sqlalchemy import create_engine
2
  from sqlalchemy.engine import Engine
3
  import logging
 
 
 
4
 
5
  logging.basicConfig(level=logging.INFO)
6
  logger = logging.getLogger(__name__)
7
 
8
- def get_db_engine(connection_string: str) -> Engine | None:
9
- """
10
- Creates a SQLAlchemy engine for a given database connection string.
11
-
12
- Args:
13
- connection_string: The database connection string.
14
 
15
- Returns:
16
- A SQLAlchemy Engine instance, or None if connection fails.
17
  """
18
- try:
19
- engine = create_engine(connection_string)
20
- # Test the connection
21
- with engine.connect() as connection:
22
- logger.info(f"Successfully connected to {engine.url.database}")
23
- return engine
24
- except Exception as e:
25
- logger.error(f"Failed to connect to database: {e}")
26
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from sqlalchemy import create_engine
2
  from sqlalchemy.engine import Engine
3
  import logging
4
+ from typing import Dict
5
+
6
+ from . import config
7
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
+ # A dictionary to hold the initialized SQLAlchemy engines
12
+ _db_engines: Dict[str, Engine] = {}
 
 
 
 
13
 
14
+ def get_db_connections() -> Dict[str, Engine]:
 
15
  """
16
+ Initializes and returns a dictionary of SQLAlchemy engines for all configured databases.
17
+ This function is idempotent.
18
+ """
19
+ global _db_engines
20
+ if not _db_engines:
21
+ logger.info("Initializing database connections...")
22
+ for db_name, conn_str in config.DB_CONNECTIONS.items():
23
+ try:
24
+ engine = create_engine(conn_str)
25
+ # Test the connection
26
+ with engine.connect():
27
+ logger.info(f"Successfully connected to {db_name}")
28
+ _db_engines[db_name] = engine
29
+ except Exception as e:
30
+ logger.error(f"Failed to connect to {db_name}: {e}")
31
+ return _db_engines
32
+
33
+ def close_db_connections():
34
+ """Closes all active database connections."""
35
+ global _db_engines
36
+ logger.info("Closing database connections...")
37
+ for engine in _db_engines.values():
38
+ engine.dispose()
39
+ _db_engines = {}
mcp/core/discovery.py CHANGED
@@ -1,98 +1,70 @@
1
- from sqlalchemy import inspect, text
2
  from sqlalchemy.engine import Engine
3
  from typing import Dict, Any, List
4
  import logging
5
- import json
6
- from concurrent.futures import TimeoutError, ThreadPoolExecutor
7
 
8
- logger = logging.getLogger(__name__)
9
-
10
- def get_table_schema(inspector, table_name: str) -> Dict[str, Any]:
11
- """Extracts schema for a single table."""
12
- columns = inspector.get_columns(table_name)
13
- primary_keys = inspector.get_pk_constraint(table_name)['constrained_columns']
14
- foreign_keys = inspector.get_foreign_keys(table_name)
15
-
16
- table_schema = {
17
- "name": table_name,
18
- "columns": [],
19
- "primary_keys": primary_keys,
20
- "foreign_keys": foreign_keys
21
- }
22
-
23
- for col in columns:
24
- table_schema["columns"].append({
25
- "name": col['name'],
26
- "type": str(col['type']),
27
- "nullable": col['nullable'],
28
- "default": col.get('default'),
29
- })
30
- return table_schema
31
 
32
- def get_sample_data(engine: Engine, table_name: str, sample_size: int = 5) -> Dict[str, Any]:
33
- """Fetches sample data and distinct values for each column."""
34
- sample_data = {}
35
- with engine.connect() as connection:
36
- # Get row count
37
- try:
38
- result = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"'))
39
- sample_data['row_count'] = result.scalar_one()
40
- except Exception as e:
41
- logger.warning(f"Could not get row count for table {table_name}: {e}")
42
- sample_data['row_count'] = -1 # Indicate error or unknown
43
-
44
- # Get sample rows
45
- try:
46
- result = connection.execute(text(f'SELECT * FROM "{table_name}" LIMIT {sample_size}'))
47
- rows = [dict(row._mapping) for row in result.fetchall()]
48
- # Attempt to JSON serialize to handle complex types gracefully
49
- sample_data['sample_rows'] = json.loads(json.dumps(rows, default=str))
50
- except Exception as e:
51
- logger.warning(f"Could not get sample rows for table {table_name}: {e}")
52
- sample_data['sample_rows'] = []
53
-
54
- return sample_data
55
-
56
-
57
- def discover_schema(engine: Engine, timeout: int = 30) -> Dict[str, Any] | None:
58
- """
59
- Discovers the full schema of a database using SQLAlchemy's inspection API.
60
- Includes table schemas and sample data.
61
- """
62
- try:
63
- with ThreadPoolExecutor() as executor:
64
- future = executor.submit(_discover_schema_task, engine)
65
- return future.result(timeout=timeout)
66
- except TimeoutError:
67
- logger.error(f"Schema discovery for {engine.url.database} timed out after {timeout} seconds.")
68
- return None
69
- except Exception as e:
70
- logger.error(f"An unexpected error occurred during schema discovery for {engine.url.database}: {e}")
71
- return None
72
 
73
- def _discover_schema_task(engine: Engine) -> Dict[str, Any]:
74
- """The actual schema discovery logic to be run with a timeout."""
75
  inspector = inspect(engine)
76
  db_schema = {
77
- "database_name": engine.url.database,
78
- "dialect": engine.dialect.name,
79
  "tables": []
80
  }
81
-
82
  table_names = inspector.get_table_names()
83
-
84
  for table_name in table_names:
85
- try:
86
- logger.info(f"Discovering schema for table: {table_name}")
87
- table_schema = get_table_schema(inspector, table_name)
88
-
89
- logger.info(f"Collecting sample data for table: {table_name}")
90
- sample_info = get_sample_data(engine, table_name)
91
- table_schema.update(sample_info)
92
 
93
- db_schema["tables"].append(table_schema)
94
- except Exception as e:
95
- logger.error(f"Could not inspect table '{table_name}': {e}")
96
- continue
 
 
 
97
 
98
- return db_schema
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import inspect
2
  from sqlalchemy.engine import Engine
3
  from typing import Dict, Any, List
4
  import logging
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
 
6
 
7
+ from .database import get_db_connections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ def _discover_single_db_schema(db_name: str, engine: Engine) -> Dict[str, Any]:
12
+ """Discovers the schema for a single database engine."""
13
  inspector = inspect(engine)
14
  db_schema = {
15
+ "database_name": db_name,
 
16
  "tables": []
17
  }
 
18
  table_names = inspector.get_table_names()
 
19
  for table_name in table_names:
20
+ columns = inspector.get_columns(table_name)
21
+ db_schema["tables"].append({
22
+ "name": table_name,
23
+ "columns": [{"name": c['name'], "type": str(c['type'])} for c in columns]
24
+ })
25
+ return db_schema
 
26
 
27
+ async def get_relevant_schemas(query: str) -> List[Dict[str, Any]]:
28
+ """
29
+ Discovers schemas from all connected databases and performs a simple keyword search.
30
+ A more advanced implementation would use embeddings for semantic search.
31
+ """
32
+ db_engines = get_db_connections()
33
+ all_schemas = []
34
 
35
+ with ThreadPoolExecutor() as executor:
36
+ # Discover all schemas in parallel
37
+ future_to_db = {executor.submit(_discover_single_db_schema, name, eng): name for name, eng in db_engines.items()}
38
+ for future in as_completed(future_to_db):
39
+ try:
40
+ all_schemas.append(future.result())
41
+ except Exception as e:
42
+ db_name = future_to_db[future]
43
+ logger.error(f"Failed to discover schema for {db_name}: {e}")
44
+
45
+ if not query:
46
+ return all_schemas
47
+
48
+ # Simple keyword filtering
49
+ keywords = query.lower().split()
50
+ relevant_schemas = []
51
+ for db_schema in all_schemas:
52
+ for table in db_schema.get("tables", []):
53
+ if any(keyword in table['name'].lower() for keyword in keywords):
54
+ relevant_schemas.append({
55
+ "database": db_schema["database_name"],
56
+ "table": table['name'],
57
+ "columns": table['columns']
58
+ })
59
+ else:
60
+ for col in table.get("columns", []):
61
+ if any(keyword in col['name'].lower() for keyword in keywords):
62
+ relevant_schemas.append({
63
+ "database": db_schema["database_name"],
64
+ "table": table['name'],
65
+ "columns": table['columns'] # Return full table if a column matches
66
+ })
67
+ break # Move to next table
68
+
69
+ # Deduplicate results (in case multiple keywords match the same table)
70
+ return [dict(t) for t in {tuple(d.items()) for d in relevant_schemas}]
mcp/core/graph.py CHANGED
@@ -1,151 +1,79 @@
1
- from neo4j import GraphDatabase
2
  import logging
3
- import json
4
  from typing import List, Dict, Any
 
5
  from . import config
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
- class GraphStore:
10
- def __init__(self):
11
- self._driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
12
- self.ensure_constraints()
13
 
14
- def close(self):
15
- self._driver.close()
 
 
 
 
 
 
16
 
17
- def ensure_constraints(self):
18
- """Ensure uniqueness constraints are set up in Neo4j."""
19
- with self._driver.session() as session:
20
- session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE")
21
- session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE")
22
- session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE")
23
- logger.info("Neo4j constraints ensured.")
24
-
25
- def import_schema(self, schema_data: dict):
26
- """
27
- Imports a discovered database schema into the Neo4j graph.
28
- """
29
- db_name = schema_data['database_name']
30
-
31
- with self._driver.session() as session:
32
- # Create Database node
33
- session.run("MERGE (d:Database {name: $db_name})", db_name=db_name)
34
-
35
- for table in schema_data['tables']:
36
- table_unique_name = f"{db_name}.{table['name']}"
37
- table_properties = {
38
- "name": table['name'],
39
- "unique_name": table_unique_name,
40
- "row_count": table.get('row_count', -1),
41
- "sample_rows": json.dumps(table.get('sample_rows', []))
42
- }
43
-
44
- # Create Table node and HAS_TABLE relationship
45
- session.run(
46
- """
47
- MATCH (d:Database {name: $db_name})
48
- MERGE (t:Table {unique_name: $unique_name})
49
- ON CREATE SET t += $props
50
- ON MATCH SET t += $props
51
- MERGE (d)-[:HAS_TABLE]->(t)
52
- """,
53
- db_name=db_name,
54
- unique_name=table_unique_name,
55
- props=table_properties
56
- )
57
 
58
- for column in table['columns']:
59
- column_unique_name = f"{table_unique_name}.{column['name']}"
60
- column_properties = {
61
- "name": column['name'],
62
- "unique_name": column_unique_name,
63
- "type": column['type'],
64
- "nullable": column['nullable'],
65
- "default": str(column.get('default')) # Ensure default is string
66
- }
67
 
68
- # Create Column node and HAS_COLUMN relationship
69
- session.run(
70
- """
71
- MATCH (t:Table {unique_name: $table_unique_name})
72
- MERGE (c:Column {unique_name: $column_unique_name})
73
- ON CREATE SET c += $props
74
- ON MATCH SET c += $props
75
- MERGE (t)-[:HAS_COLUMN]->(c)
76
- """,
77
- table_unique_name=table_unique_name,
78
- column_unique_name=column_unique_name,
79
- props=column_properties
80
- )
81
 
82
- # After all tables and columns are created, create foreign key relationships
83
- for table in schema_data['tables']:
84
- table_unique_name = f"{db_name}.{table['name']}"
85
- if table.get('foreign_keys'):
86
- for fk in table['foreign_keys']:
87
- constrained_columns = fk['constrained_columns']
88
- referred_table = fk['referred_table']
89
- referred_columns = fk['referred_columns']
90
-
91
- referred_table_unique_name = f"{db_name}.{referred_table}"
92
 
93
- for i, col_name in enumerate(constrained_columns):
94
- from_col_unique_name = f"{table_unique_name}.{col_name}"
95
- to_col_unique_name = f"{referred_table_unique_name}.{referred_columns[i]}"
96
-
97
- session.run(
98
- """
99
- MATCH (from_col:Column {unique_name: $from_col})
100
- MATCH (to_col:Column {unique_name: $to_col})
101
- MERGE (from_col)-[:REFERENCES]->(to_col)
102
- """,
103
- from_col=from_col_unique_name,
104
- to_col=to_col_unique_name
105
- )
106
- logger.info(f"Successfully imported schema for database: {db_name}")
107
 
108
- def find_shortest_path(self, start_node_name: str, end_node_name: str) -> List[Dict[str, Any]]:
109
- """
110
- Finds the shortest path between two nodes (Tables or Columns) in the graph.
111
- This is a generic pathfinder.
112
- """
113
- query = """
114
- MATCH (start {unique_name: $start_name}), (end {unique_name: $end_name})
115
- CALL apoc.path.shortestPath(start, end, 'REFERENCES|HAS_COLUMN|HAS_TABLE', {maxLevel: 10}) YIELD path
116
- RETURN path
117
- """
118
- with self._driver.session() as session:
119
- result = session.run(query, start_name=start_node_name, end_name=end_node_name)
120
- # The result is complex, we need to parse it into a user-friendly format.
121
- # For now, returning the raw path objects.
122
- return [record["path"] for record in result]
123
 
124
- def keyword_search(self, keyword: str) -> List[Dict[str, Any]]:
125
- """
126
- Searches for tables and columns matching a keyword.
127
- Returns a list of matching nodes with their database and table context.
128
- """
129
- query = """
130
- MATCH (n)
131
- WHERE (n:Table OR n:Column) AND n.name CONTAINS $keyword
132
- OPTIONAL MATCH (d:Database)-[:HAS_TABLE]->(t:Table)-[:HAS_COLUMN]->(n) WHERE n:Column
133
- OPTIONAL MATCH (d2:Database)-[:HAS_TABLE]->(n) WHERE n:Table
134
- WITH COALESCE(d, d2) AS db, COALESCE(t, n) AS tbl, n AS item
135
- RETURN db.name AS database, tbl.name AS table, item.name AS name, labels(item) AS type
136
- LIMIT 25
137
- """
138
- with self._driver.session() as session:
139
- result = session.run(query, keyword=keyword)
140
- return [record.data() for record in result]
141
 
142
- def get_table_row_count(self, table_unique_name: str) -> int:
143
- """Retrieves the stored row count for a given table."""
144
- query = """
145
- MATCH (t:Table {unique_name: $unique_name})
146
- RETURN t.row_count AS row_count
147
- """
148
- with self._driver.session() as session:
149
- result = session.run(query, unique_name=table_unique_name)
150
- record = result.single()
151
- return record['row_count'] if record else -1
 
1
+ from neo4j import GraphDatabase, Driver
2
  import logging
 
3
  from typing import List, Dict, Any
4
+
5
  from . import config
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
+ _driver: Driver = None
 
 
 
10
 
11
+ def get_graph_driver() -> Driver:
12
+ """Initializes and returns the singleton Neo4j driver instance."""
13
+ global _driver
14
+ if _driver is None:
15
+ logger.info("Initializing Neo4j driver...")
16
+ _driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
17
+ _ensure_constraints(_driver)
18
+ return _driver
19
 
20
+ def close_graph_driver():
21
+ """Closes the Neo4j driver connection."""
22
+ global _driver
23
+ if _driver:
24
+ logger.info("Closing Neo4j driver.")
25
+ _driver.close()
26
+ _driver = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def _ensure_constraints(driver: Driver):
29
+ """Ensure uniqueness constraints are set up in Neo4j."""
30
+ with driver.session() as session:
31
+ session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE")
32
+ session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE")
33
+ session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE")
34
+ logger.info("Neo4j constraints ensured.")
 
 
35
 
36
+ def _keyword_search(keyword: str) -> List[Dict[str, Any]]:
37
+ """Internal helper to search for table nodes by keyword."""
38
+ driver = get_graph_driver()
39
+ query = """
40
+ MATCH (d:Database)-[:HAS_TABLE]->(t:Table)
41
+ WHERE t.name CONTAINS $keyword
42
+ RETURN d.name as database, t.name as table
43
+ LIMIT 5
44
+ """
45
+ with driver.session() as session:
46
+ result = session.run(query, keyword=keyword)
47
+ return [record.data() for record in result]
 
48
 
49
+ def find_join_path(table1_name: str, table2_name: str) -> str:
50
+ """
51
+ Finds a human-readable join path between two tables using the graph's schema.
52
+ """
53
+ driver = get_graph_driver()
 
 
 
 
 
54
 
55
+ t1_nodes = _keyword_search(table1_name)
56
+ t2_nodes = _keyword_search(table2_name)
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ if not t1_nodes: return f"Could not find a table matching '{table1_name}'."
59
+ if not t2_nodes: return f"Could not find a table matching '{table2_name}'."
60
+
61
+ t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}"
62
+ t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}"
 
 
 
 
 
 
 
 
 
 
63
 
64
+ path_query = """
65
+ MATCH (start:Table {unique_name: $start_name}), (end:Table {unique_name: $end_name})
66
+ CALL apoc.path.shortestPath(start, end, 'HAS_COLUMN|REFERENCES|<HAS_COLUMN', {maxLevel: 5}) YIELD path
67
+ WITH [n in nodes(path) | COALESCE(n.name, '')] as path_nodes
68
+ RETURN FILTER(name in path_nodes WHERE name <> '') as path
69
+ LIMIT 1
70
+ """
71
+ with driver.session() as session:
72
+ result = session.run(path_query, start_name=t1_unique_name, end_name=t2_unique_name)
73
+ record = result.single()
74
+
75
+ if not record or not record["path"]:
76
+ return f"No join path found between {table1_name} and {table2_name}."
 
 
 
 
77
 
78
+ path_str = " -> ".join(record["path"])
79
+ return f"Found path: {path_str}"
 
 
 
 
 
 
 
 
mcp/core/intelligence.py CHANGED
@@ -1,161 +1,73 @@
1
  import sqlparse
2
  import logging
3
  from typing import List, Dict, Any
4
-
5
- from .graph import GraphStore
6
- from .database import get_db_engine
7
- from . import config
8
  from sqlalchemy import text
9
 
10
- logger = logging.getLogger(__name__)
11
 
12
- # Constants for query cost estimation
13
- ROW_EXECUTION_THRESHOLD = 100 # Execute queries expected to return fewer rows
14
- JOIN_CARDINALITY_ESTIMATE = 1000 # A simplistic estimate for joins
15
 
16
- class QueryIntelligence:
17
  """
18
- Provides intelligence for handling SQL queries. It estimates query cost
19
- and decides on an execution strategy.
20
  """
21
- def __init__(self, graph_store: GraphStore):
22
- self.graph_store = graph_store
23
- self.db_engines = {}
24
-
25
- def _get_engine_for_db(self, db_name: str):
26
- """Helper to get or create an engine for a specific database."""
27
- if db_name not in self.db_engines:
28
- # Assuming db_name includes the .db extension
29
- connection_string = config.get_sqlite_connection_string(db_name)
30
- self.db_engines[db_name] = get_db_engine(connection_string)
31
- return self.db_engines.get(db_name)
32
-
33
- async def get_relevant_schemas(self, query: str) -> List[Dict[str, Any]]:
34
- """Finds schemas relevant to a natural language query."""
35
- # This is a simplistic keyword search. A real implementation would use
36
- # embedding-based search or an LLM to extract entities.
37
- keywords = query.split()
38
- all_results = []
39
- for keyword in keywords:
40
- if len(keyword) > 2: # Avoid very short keywords
41
- results = self.graph_store.keyword_search(keyword)
42
- all_results.extend(results)
43
- # Deduplicate results
44
- return [dict(t) for t in {tuple(d.items()) for d in all_results}]
45
-
46
- async def find_join_path(self, table1_name: str, table2_name: str) -> str:
47
- """Finds a join path between two tables using the graph."""
48
- # This is a simplification. It requires table names to be unique or requires
49
- # the user to provide fully qualified names (db.table).
50
- t1_nodes = self.graph_store.keyword_search(table1_name)
51
- t2_nodes = self.graph_store.keyword_search(table2_name)
52
-
53
- if not t1_nodes or not t2_nodes:
54
- return "Could not find one or both tables."
55
-
56
- # Assume the first result is correct for simplicity
57
- t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}"
58
- t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}"
59
-
60
- path_result = self.graph_store.find_shortest_path(t1_unique_name, t2_unique_name)
61
-
62
- if not path_result:
63
- return f"No path found between {table1_name} and {table2_name}."
64
-
65
- # Format the path for display
66
- # This is a complex task. The raw path from Neo4j needs careful parsing.
67
- # This is a placeholder for that logic.
68
- return f"Path found (details require parsing): {path_result}"
69
 
70
- async def execute_query(self, sql: str, limit: int) -> List[Dict[str, Any]]:
71
- """
72
- Executes a SQL query against the appropriate database if the estimated
73
- cost is below the threshold.
74
- """
75
- cost_estimate = self.estimate_query_cost(sql)
76
 
77
- if cost_estimate['decision'] != 'execute':
78
- raise PermissionError(f"Query execution denied. Estimated cost is too high ({cost_estimate['estimated_rows']} rows).")
79
 
80
- # This is a major simplification. Determining which database to run the query
81
- # against is a hard problem (especially for federated queries).
82
- # We assume the first table found belongs to the correct database.
83
- parsed_sql = self._parse_sql(sql)
84
- if not parsed_sql['tables']:
85
- raise ValueError("No tables found in SQL query.")
86
-
87
- first_table = parsed_sql['tables'][0]
88
- search_results = self.graph_store.keyword_search(first_table)
89
- if not search_results:
90
- raise ValueError(f"Table '{first_table}' not found in any known database.")
91
-
92
- db_name = search_results[0]['database']
93
- engine = self._get_engine_for_db(db_name)
94
-
95
- if not engine:
96
- raise ConnectionError(f"Could not connect to database: {db_name}")
97
 
 
98
  with engine.connect() as connection:
99
- # Append limit to the query
100
- safe_sql = f"{sql.strip().rstrip(';')} LIMIT {int(limit)}"
101
- result = connection.execute(text(safe_sql))
102
  return [dict(row._mapping) for row in result.fetchall()]
103
-
104
- def _parse_sql(self, sql: str) -> Dict[str, Any]:
105
- """Parses the SQL to identify tables and columns."""
106
- parsed = sqlparse.parse(sql)[0]
107
- # This is a simplistic parser. A real implementation would need
108
- # a much more robust SQL parsing library to handle complex queries, CTEs, etc.
109
- tables = set()
110
- for token in parsed.tokens:
111
- if isinstance(token, sqlparse.sql.Identifier):
112
- tables.add(token.get_real_name())
113
- elif token.is_group:
114
- # Look for identifiers within subgroups (e.g., in FROM or JOIN clauses)
115
- for sub_token in token.tokens:
116
- if isinstance(sub_token, sqlparse.sql.Identifier):
117
- tables.add(sub_token.get_real_name())
118
-
119
- return {"tables": list(tables)}
120
-
121
- def estimate_query_cost(self, sql: str) -> Dict[str, Any]:
122
- """
123
- Estimates the cost of a query based on row counts from the graph.
124
- """
125
- try:
126
- parsed_sql = self._parse_sql(sql)
127
- tables_in_query = parsed_sql['tables']
128
-
129
- if not tables_in_query:
130
- return {"estimated_rows": 0, "decision": "execute", "message": "No tables found in query."}
131
-
132
- # For simplicity, we'll take the max row count of any table in the query.
133
- # A real system would analyze JOINs and WHERE clauses.
134
- max_rows = 0
135
- for table_name in tables_in_query:
136
- # Need to find the unique name. This assumes table names are unique across DBs for now.
137
- # A real implementation needs context of which DB is being queried.
138
- search_result = self.graph_store.keyword_search(table_name)
139
- if search_result:
140
- table_unique_name = f"{search_result[0]['database']}.{search_result[0]['table']}"
141
- row_count = self.graph_store.get_table_row_count(table_unique_name)
142
- if row_count > max_rows:
143
- max_rows = row_count
144
-
145
- estimated_rows = max_rows
146
- # Crude adjustment for joins
147
- if len(tables_in_query) > 1:
148
- # A better estimate would involve graph traversal and statistical models
149
- estimated_rows *= JOIN_CARDINALITY_ESTIMATE * (len(tables_in_query) - 1)
150
-
151
- decision = "execute" if estimated_rows < ROW_EXECUTION_THRESHOLD else "return_sql"
152
-
153
- return {
154
- "estimated_rows": estimated_rows,
155
- "decision": decision,
156
- "tables_found": tables_in_query
157
- }
158
-
159
- except Exception as e:
160
- logger.error(f"Error estimating query cost: {e}")
161
- return {"estimated_rows": -1, "decision": "error", "message": str(e)}
 
1
  import sqlparse
2
  import logging
3
  from typing import List, Dict, Any
 
 
 
 
4
  from sqlalchemy import text
5
 
6
+ from .database import get_db_connections
7
 
8
+ logger = logging.getLogger(__name__)
 
 
9
 
10
+ def _get_database_for_table(table_name: str) -> str | None:
11
  """
12
+ Finds which database a table belongs to by checking the graph.
13
+ (This is a simplified helper; assumes GraphStore is accessible or passed)
14
  """
15
+ # This is a placeholder for the logic to find a table's database.
16
+ # In a real scenario, this would query Neo4j. We'll simulate it.
17
+ # A simple mapping for our known databases:
18
+ if table_name in ["studies", "patients", "adverse_events"]:
19
+ return "clinical_trials"
20
+ if table_name in ["lab_tests", "test_results", "biomarkers"]:
21
+ return "laboratory"
22
+ if table_name in ["compounds", "assay_results", "drug_targets", "compound_targets"]:
23
+ return "drug_discovery"
24
+ return None
25
+
26
+
27
+ async def execute_federated_query(sql: str) -> List[Dict[str, Any]]:
28
+ """
29
+ Executes a SQL query against the correct SQLite database.
30
+ This is a simplified version of a federated query engine. It identifies the
31
+ target database from the first table name in the SQL query.
32
+ """
33
+ parsed = sqlparse.parse(sql)[0]
34
+ target_table = None
35
+
36
+ # Find the first table name in the parsed SQL
37
+ for token in parsed.tokens:
38
+ if isinstance(token, sqlparse.sql.Identifier):
39
+ target_table = token.get_real_name()
40
+ break
41
+ elif token.is_group:
42
+ for sub_token in token.tokens:
43
+ if isinstance(sub_token, sqlparse.sql.Identifier):
44
+ target_table = sub_token.get_real_name()
45
+ break
46
+ if target_table:
47
+ break
48
+
49
+ if not target_table:
50
+ raise ValueError("Could not identify a target table in the SQL query.")
51
+
52
+ logger.info(f"Identified target table: {target_table}")
53
+
54
+ # Determine which database engine to use
55
+ db_name = _get_database_for_table(target_table)
56
+ if not db_name:
57
+ raise ValueError(f"Table '{target_table}' not found in any known database.")
 
 
 
 
 
58
 
59
+ db_engines = get_db_connections()
60
+ engine = db_engines.get(db_name)
 
 
 
 
61
 
62
+ if not engine:
63
+ raise ConnectionError(f"No active connection for database '{db_name}'.")
64
 
65
+ logger.info(f"Executing query against database: {db_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ try:
68
  with engine.connect() as connection:
69
+ result = connection.execute(text(sql))
 
 
70
  return [dict(row._mapping) for row in result.fetchall()]
71
+ except Exception as e:
72
+ logger.error(f"Failed to execute query on {db_name}: {e}", exc_info=True)
73
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mcp/main.py CHANGED
@@ -1,146 +1,88 @@
1
- from fastapi import FastAPI, Header, HTTPException
2
- from neo4j import GraphDatabase
3
  import os
4
- import json
5
- from datetime import datetime
6
- import psycopg2
7
- from psycopg2.extras import RealDictCursor
8
 
9
- app = FastAPI()
10
- driver = GraphDatabase.driver(
11
- os.getenv("NEO4J_BOLT_URL"),
12
- auth=("neo4j", os.getenv("NEO4J_AUTH").split("/")[1])
13
- )
 
14
 
15
- VALID_API_KEYS = os.getenv("MCP_API_KEYS", "").split(",")
16
- POSTGRES_CONN = os.getenv("POSTGRES_CONNECTION")
 
17
 
18
- @app.get("/health")
19
- def health():
20
- return {"ok": True, "timestamp": datetime.now().isoformat()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- @app.post("/mcp")
23
- async def execute_tool(request: dict, x_api_key: str = Header(None)):
24
- # Verify API key
25
  if x_api_key not in VALID_API_KEYS:
26
- raise HTTPException(status_code=401, detail="Invalid API key")
27
-
28
- print(f"Raw request: {request}")
29
- tool = request.get("tool")
30
- params = request.get("params", {})
31
- print(f"Tool: {tool}, Params: {params}")
32
-
33
- if tool == "get_schema":
34
- # Return node labels and relationships
35
- with driver.session() as session:
36
- result = session.run("CALL db.labels() YIELD label RETURN collect(label) as labels")
37
- return {"labels": result.single()["labels"]}
38
-
39
- elif tool == "query_graph":
40
- # Execute parameterized query
41
- try:
42
- query = params.get("query")
43
- query_params = params.get("parameters", {})
44
-
45
- # Fix parameter substitution issue - replace placeholders with Neo4j parameters
46
- # The $ character gets stripped by environment variable substitution
47
- # So we use $$PARAM$$ as a placeholder and replace it with $PARAM
48
- import re
49
- for param_name in query_params.keys():
50
- # Replace $$param_name$$ with $param_name
51
- query = query.replace(f'$${param_name}$$', f'${param_name}')
52
- # Also handle the case where frontend sends $param (which becomes param)
53
- query = query.replace(f' {param_name} ', f' ${param_name} ')
54
- query = query.replace(f'={param_name})', f'=${param_name})')
55
- query = query.replace(f'({param_name})', f'(${param_name})')
56
-
57
- print(f"Original query: {params.get('query')}")
58
- print(f"Fixed query: {query}")
59
- print(f"With parameters: {query_params}")
60
-
61
- with driver.session() as session:
62
- result = session.run(query, query_params)
63
- return {"data": [dict(record) for record in result]}
64
- except Exception as e:
65
- print(f"Query error: {e}")
66
- return {"error": str(e), "query": query, "parameters": query_params}
67
-
68
- elif tool == "write_graph":
69
- # Structured write operation
70
- action = params.get("action")
71
- if action == "create_node":
72
- label = params.get("label")
73
- properties = params.get("properties", {})
74
- with driver.session() as session:
75
- result = session.run(f"CREATE (n:{label} $props) RETURN n", {"props": properties})
76
- record = result.single()
77
- if record:
78
- node = record["n"]
79
- return {"created": dict(node) if hasattr(node, 'items') else {"id": str(node.id), "labels": list(node.labels), "properties": dict(node)}}
80
- return {"created": {}}
81
-
82
- elif tool == "get_next_instruction":
83
- # Get next pending instruction
84
- with driver.session() as session:
85
- result = session.run("""
86
- MATCH (i:Instruction {status: 'pending'})
87
- RETURN i ORDER BY i.sequence LIMIT 1
88
- """)
89
- record = result.single()
90
- return {"instruction": dict(record["i"]) if record else None}
91
-
92
- elif tool == "query_postgres":
93
- query = params.get("query")
94
- try:
95
- conn = psycopg2.connect(POSTGRES_CONN)
96
- with conn.cursor(cursor_factory=RealDictCursor) as cur:
97
- cur.execute(query)
98
- if cur.description: # SELECT query
99
- results = cur.fetchall()
100
- return {"data": results, "row_count": len(results)}
101
- else: # INSERT/UPDATE/DELETE
102
- conn.commit()
103
- return {"affected_rows": cur.rowcount}
104
- except Exception as e:
105
- return {"error": str(e)}
106
- finally:
107
- if 'conn' in locals():
108
- conn.close()
109
 
110
- elif tool == "discover_postgres_schema":
111
- try:
112
- conn = psycopg2.connect(POSTGRES_CONN)
113
- with conn.cursor(cursor_factory=RealDictCursor) as cur:
114
- # Get all tables
115
- cur.execute("""
116
- SELECT table_name, table_schema
117
- FROM information_schema.tables
118
- WHERE table_schema = 'public'
119
- AND table_type = 'BASE TABLE'
120
- """)
121
- tables = cur.fetchall()
122
-
123
- schema_info = {}
124
- for table in tables:
125
- table_name = table['table_name']
126
-
127
- # Get columns for each table
128
- cur.execute("""
129
- SELECT column_name, data_type, is_nullable,
130
- column_default, character_maximum_length
131
- FROM information_schema.columns
132
- WHERE table_schema = 'public'
133
- AND table_name = %s
134
- ORDER BY ordinal_position
135
- """, (table_name,))
136
-
137
- schema_info[table_name] = cur.fetchall()
138
-
139
- return {"schema": schema_info}
140
- except Exception as e:
141
- return {"error": str(e)}
142
- finally:
143
- if 'conn' in locals():
144
- conn.close()
145
-
146
- return {"error": "Unknown tool"}
 
1
+ from fastapi import FastAPI, Header, HTTPException, Depends
2
+ from typing import List, Dict, Any
3
  import os
4
+ import logging
5
+ from pydantic import BaseModel
 
 
6
 
7
+ # --- Core Logic Imports ---
8
+ # These imports assume your project structure places the core logic correctly.
9
+ from core.database import get_db_connections, close_db_connections
10
+ from core.discovery import get_relevant_schemas
11
+ from core.graph import find_join_path, get_graph_driver, close_graph_driver
12
+ from core.intelligence import execute_federated_query
13
 
14
+ # --- App Configuration ---
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
 
18
+ app = FastAPI(title="MCP Server", version="2.0")
19
+
20
+ VALID_API_KEYS = os.getenv("MCP_API_KEYS", "dev-key-123").split(",")
21
+
22
+ # --- Pydantic Models ---
23
+ class ToolRequest(BaseModel):
24
+ tool: str
25
+ params: Dict[str, Any]
26
+
27
+ class SchemaQuery(BaseModel):
28
+ query: str
29
+
30
+ class JoinPathRequest(BaseModel):
31
+ table1: str
32
+ table2: str
33
+
34
+ class SQLQuery(BaseModel):
35
+ sql: str
36
 
37
+ # --- Dependency for Auth ---
38
+ async def verify_api_key(x_api_key: str = Header(...)):
 
39
  if x_api_key not in VALID_API_KEYS:
40
+ raise HTTPException(status_code=401, detail="Invalid API Key")
41
+ return x_api_key
42
+
43
+ # --- Event Handlers ---
44
+ @app.on_event("startup")
45
+ async def startup_event():
46
+ """Initializes the database connection pool on server startup."""
47
+ get_db_connections()
48
+ get_graph_driver()
49
+ logger.info("MCP server started and database connections initialized.")
50
+
51
+ @app.on_event("shutdown")
52
+ def shutdown_event():
53
+ """Closes the database connection pool on server shutdown."""
54
+ close_db_connections()
55
+ close_graph_driver()
56
+ logger.info("MCP server shutting down and database connections closed.")
57
+
58
+ # --- API Endpoints ---
59
+ @app.get("/health")
60
+ def health_check():
61
+ return {"status": "ok"}
62
+
63
+ @app.post("/mcp/discovery/get_relevant_schemas", dependencies=[Depends(verify_api_key)])
64
+ async def discover_schemas(request: SchemaQuery):
65
+ try:
66
+ schemas = await get_relevant_schemas(request.query)
67
+ return {"status": "success", "schemas": schemas}
68
+ except Exception as e:
69
+ logger.error(f"Schema discovery failed: {e}", exc_info=True)
70
+ raise HTTPException(status_code=500, detail=str(e))
71
+
72
+ @app.post("/mcp/graph/find_join_path", dependencies=[Depends(verify_api_key)])
73
+ async def get_join_path(request: JoinPathRequest):
74
+ try:
75
+ path = find_join_path(request.table1, request.table2)
76
+ return {"status": "success", "path": path}
77
+ except Exception as e:
78
+ logger.error(f"Join path finding failed: {e}", exc_info=True)
79
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ @app.post("/mcp/intelligence/execute_query", dependencies=[Depends(verify_api_key)])
82
+ async def execute_query(request: SQLQuery):
83
+ try:
84
+ results = await execute_federated_query(request.sql)
85
+ return {"status": "success", "results": results}
86
+ except Exception as e:
87
+ logger.error(f"Query execution failed: {e}", exc_info=True)
88
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
streamlit/requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
  streamlit==1.28.0
2
  requests==2.31.0
3
  pandas==2.1.0
4
- python-dotenv==1.0.0
 
1
  streamlit==1.28.0
2
  requests==2.31.0
3
  pandas==2.1.0