Spaces:
No application file
No application file
This fixed the docker container health errors. just there is no mcp connection still.
Browse files- README.md +4 -4
- agent/main.py +18 -13
- agent/requirements.txt +0 -1
- docker-compose.yml +1 -3
- mcp/core/config.py +13 -14
- mcp/core/database.py +30 -17
- mcp/core/discovery.py +56 -84
- mcp/core/graph.py +64 -136
- mcp/core/intelligence.py +58 -146
- mcp/main.py +81 -139
- streamlit/requirements.txt +0 -1
README.md
CHANGED
|
@@ -45,7 +45,7 @@ This project implements an intelligent, multi-step GraphRAG-powered agent that u
|
|
| 45 |
|
| 46 |
### Prerequisites
|
| 47 |
- Docker & Docker Compose
|
| 48 |
-
-
|
| 49 |
|
| 50 |
### Setup
|
| 51 |
1. **Clone and configure**:
|
|
@@ -55,9 +55,9 @@ This project implements an intelligent, multi-step GraphRAG-powered agent that u
|
|
| 55 |
touch .env
|
| 56 |
```
|
| 57 |
|
| 58 |
-
2. **Add your
|
| 59 |
```
|
| 60 |
-
|
| 61 |
```
|
| 62 |
|
| 63 |
3. **Start the system**:
|
|
@@ -104,7 +104,7 @@ To test the agent's logic directly without the full Docker stack, you can run it
|
|
| 104 |
|
| 105 |
2. **Set your API key**:
|
| 106 |
```bash
|
| 107 |
-
export
|
| 108 |
```
|
| 109 |
|
| 110 |
3. **Run the agent**:
|
|
|
|
| 45 |
|
| 46 |
### Prerequisites
|
| 47 |
- Docker & Docker Compose
|
| 48 |
+
- LLM API key (e.g., for OpenAI)
|
| 49 |
|
| 50 |
### Setup
|
| 51 |
1. **Clone and configure**:
|
|
|
|
| 55 |
touch .env
|
| 56 |
```
|
| 57 |
|
| 58 |
+
2. **Add your LLM API key** to the `.env` file.
|
| 59 |
```
|
| 60 |
+
LLM_API_KEY="sk-your-llm-api-key-here"
|
| 61 |
```
|
| 62 |
|
| 63 |
3. **Start the system**:
|
|
|
|
| 104 |
|
| 105 |
2. **Set your API key**:
|
| 106 |
```bash
|
| 107 |
+
export LLM_API_KEY="sk-your-llm-api-key-here"
|
| 108 |
```
|
| 109 |
|
| 110 |
3. **Run the agent**:
|
agent/main.py
CHANGED
|
@@ -3,17 +3,18 @@ import sys
|
|
| 3 |
import logging
|
| 4 |
import json
|
| 5 |
from typing import Annotated, List, TypedDict
|
|
|
|
| 6 |
from fastapi import FastAPI
|
| 7 |
from pydantic import BaseModel
|
| 8 |
import uvicorn
|
| 9 |
from fastapi.responses import StreamingResponse
|
| 10 |
|
| 11 |
from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
|
| 12 |
-
from langchain_openai import
|
| 13 |
from langgraph.graph import StateGraph, START, END
|
| 14 |
from langgraph.prebuilt import ToolNode
|
| 15 |
|
| 16 |
-
from
|
| 17 |
|
| 18 |
# --- Configuration & Logging ---
|
| 19 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|
| 21 |
|
| 22 |
MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
|
| 23 |
API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
|
| 24 |
-
|
| 25 |
|
| 26 |
# --- Agent State Definition ---
|
| 27 |
class AgentState(TypedDict):
|
|
@@ -32,10 +33,10 @@ class GraphRAGAgent:
|
|
| 32 |
"""The core agent for handling GraphRAG queries using LangGraph."""
|
| 33 |
|
| 34 |
def __init__(self):
|
| 35 |
-
if not
|
| 36 |
-
raise ValueError("
|
| 37 |
|
| 38 |
-
llm =
|
| 39 |
|
| 40 |
mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY)
|
| 41 |
tools = [
|
|
@@ -93,21 +94,25 @@ class GraphRAGAgent:
|
|
| 93 |
yield json.dumps({"type": "final_answer", "content": last_message.content}) + "\\n\\n"
|
| 94 |
|
| 95 |
# --- FastAPI Application ---
|
| 96 |
-
app = FastAPI(title="GraphRAG Agent Server")
|
| 97 |
agent = None
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
@app.on_event("startup")
|
| 103 |
-
def startup_event():
|
| 104 |
-
"""Initialize the agent on server startup."""
|
| 105 |
global agent
|
|
|
|
| 106 |
try:
|
| 107 |
agent = GraphRAGAgent()
|
| 108 |
logger.info("GraphRAGAgent initialized successfully.")
|
| 109 |
except ValueError as e:
|
| 110 |
logger.error(f"Agent initialization failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
@app.post("/query")
|
| 113 |
async def execute_query(request: QueryRequest) -> StreamingResponse:
|
|
|
|
| 3 |
import logging
|
| 4 |
import json
|
| 5 |
from typing import Annotated, List, TypedDict
|
| 6 |
+
from contextlib import asynccontextmanager
|
| 7 |
from fastapi import FastAPI
|
| 8 |
from pydantic import BaseModel
|
| 9 |
import uvicorn
|
| 10 |
from fastapi.responses import StreamingResponse
|
| 11 |
|
| 12 |
from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
|
| 13 |
+
from langchain_openai import ChatOpenAI
|
| 14 |
from langgraph.graph import StateGraph, START, END
|
| 15 |
from langgraph.prebuilt import ToolNode
|
| 16 |
|
| 17 |
+
from tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool
|
| 18 |
|
| 19 |
# --- Configuration & Logging ---
|
| 20 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
| 22 |
|
| 23 |
MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
|
| 24 |
API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
|
| 25 |
+
LLM_API_KEY = os.getenv("LLM_API_KEY")
|
| 26 |
|
| 27 |
# --- Agent State Definition ---
|
| 28 |
class AgentState(TypedDict):
|
|
|
|
| 33 |
"""The core agent for handling GraphRAG queries using LangGraph."""
|
| 34 |
|
| 35 |
def __init__(self):
|
| 36 |
+
if not LLM_API_KEY:
|
| 37 |
+
raise ValueError("LLM_API_KEY environment variable not set.")
|
| 38 |
|
| 39 |
+
llm = ChatOpenAI(api_key=LLM_API_KEY, model="gpt-4o-mini", temperature=0, max_retries=1)
|
| 40 |
|
| 41 |
mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY)
|
| 42 |
tools = [
|
|
|
|
| 94 |
yield json.dumps({"type": "final_answer", "content": last_message.content}) + "\\n\\n"
|
| 95 |
|
| 96 |
# --- FastAPI Application ---
|
|
|
|
| 97 |
agent = None
|
| 98 |
|
| 99 |
+
@asynccontextmanager
|
| 100 |
+
async def lifespan(app: FastAPI):
|
| 101 |
+
"""Handles agent initialization on startup."""
|
|
|
|
|
|
|
|
|
|
| 102 |
global agent
|
| 103 |
+
logger.info("Agent server startup...")
|
| 104 |
try:
|
| 105 |
agent = GraphRAGAgent()
|
| 106 |
logger.info("GraphRAGAgent initialized successfully.")
|
| 107 |
except ValueError as e:
|
| 108 |
logger.error(f"Agent initialization failed: {e}")
|
| 109 |
+
yield
|
| 110 |
+
logger.info("Agent server shutdown.")
|
| 111 |
+
|
| 112 |
+
app = FastAPI(title="GraphRAG Agent Server", lifespan=lifespan)
|
| 113 |
+
|
| 114 |
+
class QueryRequest(BaseModel):
|
| 115 |
+
question: str
|
| 116 |
|
| 117 |
@app.post("/query")
|
| 118 |
async def execute_query(request: QueryRequest) -> StreamingResponse:
|
agent/requirements.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
requests
|
| 2 |
-
python-dotenv
|
| 3 |
langchain
|
| 4 |
langchain-openai
|
| 5 |
pydantic
|
|
|
|
| 1 |
requests
|
|
|
|
| 2 |
langchain
|
| 3 |
langchain-openai
|
| 4 |
pydantic
|
docker-compose.yml
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
version: '3.8'
|
| 2 |
-
|
| 3 |
services:
|
| 4 |
neo4j:
|
| 5 |
build: ./neo4j
|
|
@@ -50,7 +48,7 @@ services:
|
|
| 50 |
- MCP_URL=http://mcp:8000/mcp
|
| 51 |
- MCP_API_KEY=dev-key-123
|
| 52 |
- AGENT_POLL_INTERVAL=${AGENT_POLL_INTERVAL}
|
| 53 |
-
-
|
| 54 |
depends_on:
|
| 55 |
mcp:
|
| 56 |
condition: service_healthy
|
|
|
|
|
|
|
|
|
|
| 1 |
services:
|
| 2 |
neo4j:
|
| 3 |
build: ./neo4j
|
|
|
|
| 48 |
- MCP_URL=http://mcp:8000/mcp
|
| 49 |
- MCP_API_KEY=dev-key-123
|
| 50 |
- AGENT_POLL_INTERVAL=${AGENT_POLL_INTERVAL}
|
| 51 |
+
- LLM_API_KEY=${LLM_API_KEY}
|
| 52 |
depends_on:
|
| 53 |
mcp:
|
| 54 |
condition: service_healthy
|
mcp/core/config.py
CHANGED
|
@@ -1,22 +1,21 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
# --- Neo4j Configuration ---
|
| 4 |
-
NEO4J_URI = os.getenv("NEO4J_BOLT_URL", "bolt://
|
| 5 |
-
NEO4J_USER =
|
| 6 |
-
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
db_path = os.path.join(SQLITE_DATA_DIR, db_name)
|
| 19 |
-
return f"sqlite:///{db_path}"
|
| 20 |
|
| 21 |
# --- Application Settings ---
|
| 22 |
# You can add other application-wide settings here
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
# --- Neo4j Configuration ---
|
| 4 |
+
NEO4J_URI = os.getenv("NEO4J_BOLT_URL", "bolt://localhost:7687")
|
| 5 |
+
NEO4J_USER = "neo4j"
|
|
|
|
| 6 |
|
| 7 |
+
# The NEO4J_AUTH env var is in the format 'neo4j/password'
|
| 8 |
+
# We need to extract the password part.
|
| 9 |
+
neo4j_auth = os.getenv("NEO4J_AUTH", "neo4j/password")
|
| 10 |
+
NEO4J_PASSWORD = neo4j_auth.split('/')[1] if '/' in neo4j_auth else neo4j_auth
|
| 11 |
|
| 12 |
+
# --- Database Configuration ---
|
| 13 |
+
# A dictionary of connection strings for the SQLite databases
|
| 14 |
+
DB_CONNECTIONS = {
|
| 15 |
+
"clinical_trials": f"sqlite:////app/data/clinical_trials.db",
|
| 16 |
+
"drug_discovery": f"sqlite:////app/data/drug_discovery.db",
|
| 17 |
+
"laboratory": f"sqlite:////app/data/laboratory.db",
|
| 18 |
+
}
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# --- Application Settings ---
|
| 21 |
# You can add other application-wide settings here
|
mcp/core/database.py
CHANGED
|
@@ -1,26 +1,39 @@
|
|
| 1 |
from sqlalchemy import create_engine
|
| 2 |
from sqlalchemy.engine import Engine
|
| 3 |
import logging
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
logging.basicConfig(level=logging.INFO)
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
Creates a SQLAlchemy engine for a given database connection string.
|
| 11 |
-
|
| 12 |
-
Args:
|
| 13 |
-
connection_string: The database connection string.
|
| 14 |
|
| 15 |
-
|
| 16 |
-
A SQLAlchemy Engine instance, or None if connection fails.
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from sqlalchemy import create_engine
|
| 2 |
from sqlalchemy.engine import Engine
|
| 3 |
import logging
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
from . import config
|
| 7 |
|
| 8 |
logging.basicConfig(level=logging.INFO)
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
+
# A dictionary to hold the initialized SQLAlchemy engines
|
| 12 |
+
_db_engines: Dict[str, Engine] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
def get_db_connections() -> Dict[str, Engine]:
|
|
|
|
| 15 |
"""
|
| 16 |
+
Initializes and returns a dictionary of SQLAlchemy engines for all configured databases.
|
| 17 |
+
This function is idempotent.
|
| 18 |
+
"""
|
| 19 |
+
global _db_engines
|
| 20 |
+
if not _db_engines:
|
| 21 |
+
logger.info("Initializing database connections...")
|
| 22 |
+
for db_name, conn_str in config.DB_CONNECTIONS.items():
|
| 23 |
+
try:
|
| 24 |
+
engine = create_engine(conn_str)
|
| 25 |
+
# Test the connection
|
| 26 |
+
with engine.connect():
|
| 27 |
+
logger.info(f"Successfully connected to {db_name}")
|
| 28 |
+
_db_engines[db_name] = engine
|
| 29 |
+
except Exception as e:
|
| 30 |
+
logger.error(f"Failed to connect to {db_name}: {e}")
|
| 31 |
+
return _db_engines
|
| 32 |
+
|
| 33 |
+
def close_db_connections():
|
| 34 |
+
"""Closes all active database connections."""
|
| 35 |
+
global _db_engines
|
| 36 |
+
logger.info("Closing database connections...")
|
| 37 |
+
for engine in _db_engines.values():
|
| 38 |
+
engine.dispose()
|
| 39 |
+
_db_engines = {}
|
mcp/core/discovery.py
CHANGED
|
@@ -1,98 +1,70 @@
|
|
| 1 |
-
from sqlalchemy import inspect
|
| 2 |
from sqlalchemy.engine import Engine
|
| 3 |
from typing import Dict, Any, List
|
| 4 |
import logging
|
| 5 |
-
import
|
| 6 |
-
from concurrent.futures import TimeoutError, ThreadPoolExecutor
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def get_table_schema(inspector, table_name: str) -> Dict[str, Any]:
|
| 11 |
-
"""Extracts schema for a single table."""
|
| 12 |
-
columns = inspector.get_columns(table_name)
|
| 13 |
-
primary_keys = inspector.get_pk_constraint(table_name)['constrained_columns']
|
| 14 |
-
foreign_keys = inspector.get_foreign_keys(table_name)
|
| 15 |
-
|
| 16 |
-
table_schema = {
|
| 17 |
-
"name": table_name,
|
| 18 |
-
"columns": [],
|
| 19 |
-
"primary_keys": primary_keys,
|
| 20 |
-
"foreign_keys": foreign_keys
|
| 21 |
-
}
|
| 22 |
-
|
| 23 |
-
for col in columns:
|
| 24 |
-
table_schema["columns"].append({
|
| 25 |
-
"name": col['name'],
|
| 26 |
-
"type": str(col['type']),
|
| 27 |
-
"nullable": col['nullable'],
|
| 28 |
-
"default": col.get('default'),
|
| 29 |
-
})
|
| 30 |
-
return table_schema
|
| 31 |
|
| 32 |
-
|
| 33 |
-
"""Fetches sample data and distinct values for each column."""
|
| 34 |
-
sample_data = {}
|
| 35 |
-
with engine.connect() as connection:
|
| 36 |
-
# Get row count
|
| 37 |
-
try:
|
| 38 |
-
result = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"'))
|
| 39 |
-
sample_data['row_count'] = result.scalar_one()
|
| 40 |
-
except Exception as e:
|
| 41 |
-
logger.warning(f"Could not get row count for table {table_name}: {e}")
|
| 42 |
-
sample_data['row_count'] = -1 # Indicate error or unknown
|
| 43 |
-
|
| 44 |
-
# Get sample rows
|
| 45 |
-
try:
|
| 46 |
-
result = connection.execute(text(f'SELECT * FROM "{table_name}" LIMIT {sample_size}'))
|
| 47 |
-
rows = [dict(row._mapping) for row in result.fetchall()]
|
| 48 |
-
# Attempt to JSON serialize to handle complex types gracefully
|
| 49 |
-
sample_data['sample_rows'] = json.loads(json.dumps(rows, default=str))
|
| 50 |
-
except Exception as e:
|
| 51 |
-
logger.warning(f"Could not get sample rows for table {table_name}: {e}")
|
| 52 |
-
sample_data['sample_rows'] = []
|
| 53 |
-
|
| 54 |
-
return sample_data
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def discover_schema(engine: Engine, timeout: int = 30) -> Dict[str, Any] | None:
|
| 58 |
-
"""
|
| 59 |
-
Discovers the full schema of a database using SQLAlchemy's inspection API.
|
| 60 |
-
Includes table schemas and sample data.
|
| 61 |
-
"""
|
| 62 |
-
try:
|
| 63 |
-
with ThreadPoolExecutor() as executor:
|
| 64 |
-
future = executor.submit(_discover_schema_task, engine)
|
| 65 |
-
return future.result(timeout=timeout)
|
| 66 |
-
except TimeoutError:
|
| 67 |
-
logger.error(f"Schema discovery for {engine.url.database} timed out after {timeout} seconds.")
|
| 68 |
-
return None
|
| 69 |
-
except Exception as e:
|
| 70 |
-
logger.error(f"An unexpected error occurred during schema discovery for {engine.url.database}: {e}")
|
| 71 |
-
return None
|
| 72 |
|
| 73 |
-
def
|
| 74 |
-
"""
|
| 75 |
inspector = inspect(engine)
|
| 76 |
db_schema = {
|
| 77 |
-
"database_name":
|
| 78 |
-
"dialect": engine.dialect.name,
|
| 79 |
"tables": []
|
| 80 |
}
|
| 81 |
-
|
| 82 |
table_names = inspector.get_table_names()
|
| 83 |
-
|
| 84 |
for table_name in table_names:
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
table_schema.update(sample_info)
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import inspect
|
| 2 |
from sqlalchemy.engine import Engine
|
| 3 |
from typing import Dict, Any, List
|
| 4 |
import logging
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
| 6 |
|
| 7 |
+
from .database import get_db_connections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
def _discover_single_db_schema(db_name: str, engine: Engine) -> Dict[str, Any]:
|
| 12 |
+
"""Discovers the schema for a single database engine."""
|
| 13 |
inspector = inspect(engine)
|
| 14 |
db_schema = {
|
| 15 |
+
"database_name": db_name,
|
|
|
|
| 16 |
"tables": []
|
| 17 |
}
|
|
|
|
| 18 |
table_names = inspector.get_table_names()
|
|
|
|
| 19 |
for table_name in table_names:
|
| 20 |
+
columns = inspector.get_columns(table_name)
|
| 21 |
+
db_schema["tables"].append({
|
| 22 |
+
"name": table_name,
|
| 23 |
+
"columns": [{"name": c['name'], "type": str(c['type'])} for c in columns]
|
| 24 |
+
})
|
| 25 |
+
return db_schema
|
|
|
|
| 26 |
|
| 27 |
+
async def get_relevant_schemas(query: str) -> List[Dict[str, Any]]:
|
| 28 |
+
"""
|
| 29 |
+
Discovers schemas from all connected databases and performs a simple keyword search.
|
| 30 |
+
A more advanced implementation would use embeddings for semantic search.
|
| 31 |
+
"""
|
| 32 |
+
db_engines = get_db_connections()
|
| 33 |
+
all_schemas = []
|
| 34 |
|
| 35 |
+
with ThreadPoolExecutor() as executor:
|
| 36 |
+
# Discover all schemas in parallel
|
| 37 |
+
future_to_db = {executor.submit(_discover_single_db_schema, name, eng): name for name, eng in db_engines.items()}
|
| 38 |
+
for future in as_completed(future_to_db):
|
| 39 |
+
try:
|
| 40 |
+
all_schemas.append(future.result())
|
| 41 |
+
except Exception as e:
|
| 42 |
+
db_name = future_to_db[future]
|
| 43 |
+
logger.error(f"Failed to discover schema for {db_name}: {e}")
|
| 44 |
+
|
| 45 |
+
if not query:
|
| 46 |
+
return all_schemas
|
| 47 |
+
|
| 48 |
+
# Simple keyword filtering
|
| 49 |
+
keywords = query.lower().split()
|
| 50 |
+
relevant_schemas = []
|
| 51 |
+
for db_schema in all_schemas:
|
| 52 |
+
for table in db_schema.get("tables", []):
|
| 53 |
+
if any(keyword in table['name'].lower() for keyword in keywords):
|
| 54 |
+
relevant_schemas.append({
|
| 55 |
+
"database": db_schema["database_name"],
|
| 56 |
+
"table": table['name'],
|
| 57 |
+
"columns": table['columns']
|
| 58 |
+
})
|
| 59 |
+
else:
|
| 60 |
+
for col in table.get("columns", []):
|
| 61 |
+
if any(keyword in col['name'].lower() for keyword in keywords):
|
| 62 |
+
relevant_schemas.append({
|
| 63 |
+
"database": db_schema["database_name"],
|
| 64 |
+
"table": table['name'],
|
| 65 |
+
"columns": table['columns'] # Return full table if a column matches
|
| 66 |
+
})
|
| 67 |
+
break # Move to next table
|
| 68 |
+
|
| 69 |
+
# Deduplicate results (in case multiple keywords match the same table)
|
| 70 |
+
return [dict(t) for t in {tuple(d.items()) for d in relevant_schemas}]
|
mcp/core/graph.py
CHANGED
|
@@ -1,151 +1,79 @@
|
|
| 1 |
-
from neo4j import GraphDatabase
|
| 2 |
import logging
|
| 3 |
-
import json
|
| 4 |
from typing import List, Dict, Any
|
|
|
|
| 5 |
from . import config
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
-
|
| 10 |
-
def __init__(self):
|
| 11 |
-
self._driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
|
| 12 |
-
self.ensure_constraints()
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def import_schema(self, schema_data: dict):
|
| 26 |
-
"""
|
| 27 |
-
Imports a discovered database schema into the Neo4j graph.
|
| 28 |
-
"""
|
| 29 |
-
db_name = schema_data['database_name']
|
| 30 |
-
|
| 31 |
-
with self._driver.session() as session:
|
| 32 |
-
# Create Database node
|
| 33 |
-
session.run("MERGE (d:Database {name: $db_name})", db_name=db_name)
|
| 34 |
-
|
| 35 |
-
for table in schema_data['tables']:
|
| 36 |
-
table_unique_name = f"{db_name}.{table['name']}"
|
| 37 |
-
table_properties = {
|
| 38 |
-
"name": table['name'],
|
| 39 |
-
"unique_name": table_unique_name,
|
| 40 |
-
"row_count": table.get('row_count', -1),
|
| 41 |
-
"sample_rows": json.dumps(table.get('sample_rows', []))
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
# Create Table node and HAS_TABLE relationship
|
| 45 |
-
session.run(
|
| 46 |
-
"""
|
| 47 |
-
MATCH (d:Database {name: $db_name})
|
| 48 |
-
MERGE (t:Table {unique_name: $unique_name})
|
| 49 |
-
ON CREATE SET t += $props
|
| 50 |
-
ON MATCH SET t += $props
|
| 51 |
-
MERGE (d)-[:HAS_TABLE]->(t)
|
| 52 |
-
""",
|
| 53 |
-
db_name=db_name,
|
| 54 |
-
unique_name=table_unique_name,
|
| 55 |
-
props=table_properties
|
| 56 |
-
)
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
"default": str(column.get('default')) # Ensure default is string
|
| 66 |
-
}
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
constrained_columns = fk['constrained_columns']
|
| 88 |
-
referred_table = fk['referred_table']
|
| 89 |
-
referred_columns = fk['referred_columns']
|
| 90 |
-
|
| 91 |
-
referred_table_unique_name = f"{db_name}.{referred_table}"
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
to_col_unique_name = f"{referred_table_unique_name}.{referred_columns[i]}"
|
| 96 |
-
|
| 97 |
-
session.run(
|
| 98 |
-
"""
|
| 99 |
-
MATCH (from_col:Column {unique_name: $from_col})
|
| 100 |
-
MATCH (to_col:Column {unique_name: $to_col})
|
| 101 |
-
MERGE (from_col)-[:REFERENCES]->(to_col)
|
| 102 |
-
""",
|
| 103 |
-
from_col=from_col_unique_name,
|
| 104 |
-
to_col=to_col_unique_name
|
| 105 |
-
)
|
| 106 |
-
logger.info(f"Successfully imported schema for database: {db_name}")
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
query = """
|
| 114 |
-
MATCH (start {unique_name: $start_name}), (end {unique_name: $end_name})
|
| 115 |
-
CALL apoc.path.shortestPath(start, end, 'REFERENCES|HAS_COLUMN|HAS_TABLE', {maxLevel: 10}) YIELD path
|
| 116 |
-
RETURN path
|
| 117 |
-
"""
|
| 118 |
-
with self._driver.session() as session:
|
| 119 |
-
result = session.run(query, start_name=start_node_name, end_name=end_node_name)
|
| 120 |
-
# The result is complex, we need to parse it into a user-friendly format.
|
| 121 |
-
# For now, returning the raw path objects.
|
| 122 |
-
return [record["path"] for record in result]
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
"""
|
| 138 |
-
with self._driver.session() as session:
|
| 139 |
-
result = session.run(query, keyword=keyword)
|
| 140 |
-
return [record.data() for record in result]
|
| 141 |
|
| 142 |
-
|
| 143 |
-
"
|
| 144 |
-
query = """
|
| 145 |
-
MATCH (t:Table {unique_name: $unique_name})
|
| 146 |
-
RETURN t.row_count AS row_count
|
| 147 |
-
"""
|
| 148 |
-
with self._driver.session() as session:
|
| 149 |
-
result = session.run(query, unique_name=table_unique_name)
|
| 150 |
-
record = result.single()
|
| 151 |
-
return record['row_count'] if record else -1
|
|
|
|
| 1 |
+
from neo4j import GraphDatabase, Driver
|
| 2 |
import logging
|
|
|
|
| 3 |
from typing import List, Dict, Any
|
| 4 |
+
|
| 5 |
from . import config
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
+
_driver: Driver = None
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
def get_graph_driver() -> Driver:
|
| 12 |
+
"""Initializes and returns the singleton Neo4j driver instance."""
|
| 13 |
+
global _driver
|
| 14 |
+
if _driver is None:
|
| 15 |
+
logger.info("Initializing Neo4j driver...")
|
| 16 |
+
_driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
|
| 17 |
+
_ensure_constraints(_driver)
|
| 18 |
+
return _driver
|
| 19 |
|
| 20 |
+
def close_graph_driver():
|
| 21 |
+
"""Closes the Neo4j driver connection."""
|
| 22 |
+
global _driver
|
| 23 |
+
if _driver:
|
| 24 |
+
logger.info("Closing Neo4j driver.")
|
| 25 |
+
_driver.close()
|
| 26 |
+
_driver = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
def _ensure_constraints(driver: Driver):
|
| 29 |
+
"""Ensure uniqueness constraints are set up in Neo4j."""
|
| 30 |
+
with driver.session() as session:
|
| 31 |
+
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE")
|
| 32 |
+
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE")
|
| 33 |
+
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE")
|
| 34 |
+
logger.info("Neo4j constraints ensured.")
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
def _keyword_search(keyword: str) -> List[Dict[str, Any]]:
|
| 37 |
+
"""Internal helper to search for table nodes by keyword."""
|
| 38 |
+
driver = get_graph_driver()
|
| 39 |
+
query = """
|
| 40 |
+
MATCH (d:Database)-[:HAS_TABLE]->(t:Table)
|
| 41 |
+
WHERE t.name CONTAINS $keyword
|
| 42 |
+
RETURN d.name as database, t.name as table
|
| 43 |
+
LIMIT 5
|
| 44 |
+
"""
|
| 45 |
+
with driver.session() as session:
|
| 46 |
+
result = session.run(query, keyword=keyword)
|
| 47 |
+
return [record.data() for record in result]
|
|
|
|
| 48 |
|
| 49 |
+
def find_join_path(table1_name: str, table2_name: str) -> str:
|
| 50 |
+
"""
|
| 51 |
+
Finds a human-readable join path between two tables using the graph's schema.
|
| 52 |
+
"""
|
| 53 |
+
driver = get_graph_driver()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
t1_nodes = _keyword_search(table1_name)
|
| 56 |
+
t2_nodes = _keyword_search(table2_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
if not t1_nodes: return f"Could not find a table matching '{table1_name}'."
|
| 59 |
+
if not t2_nodes: return f"Could not find a table matching '{table2_name}'."
|
| 60 |
+
|
| 61 |
+
t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}"
|
| 62 |
+
t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
path_query = """
|
| 65 |
+
MATCH (start:Table {unique_name: $start_name}), (end:Table {unique_name: $end_name})
|
| 66 |
+
CALL apoc.path.shortestPath(start, end, 'HAS_COLUMN|REFERENCES|<HAS_COLUMN', {maxLevel: 5}) YIELD path
|
| 67 |
+
WITH [n in nodes(path) | COALESCE(n.name, '')] as path_nodes
|
| 68 |
+
RETURN FILTER(name in path_nodes WHERE name <> '') as path
|
| 69 |
+
LIMIT 1
|
| 70 |
+
"""
|
| 71 |
+
with driver.session() as session:
|
| 72 |
+
result = session.run(path_query, start_name=t1_unique_name, end_name=t2_unique_name)
|
| 73 |
+
record = result.single()
|
| 74 |
+
|
| 75 |
+
if not record or not record["path"]:
|
| 76 |
+
return f"No join path found between {table1_name} and {table2_name}."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
path_str = " -> ".join(record["path"])
|
| 79 |
+
return f"Found path: {path_str}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mcp/core/intelligence.py
CHANGED
|
@@ -1,161 +1,73 @@
|
|
| 1 |
import sqlparse
|
| 2 |
import logging
|
| 3 |
from typing import List, Dict, Any
|
| 4 |
-
|
| 5 |
-
from .graph import GraphStore
|
| 6 |
-
from .database import get_db_engine
|
| 7 |
-
from . import config
|
| 8 |
from sqlalchemy import text
|
| 9 |
|
| 10 |
-
|
| 11 |
|
| 12 |
-
|
| 13 |
-
ROW_EXECUTION_THRESHOLD = 100 # Execute queries expected to return fewer rows
|
| 14 |
-
JOIN_CARDINALITY_ESTIMATE = 1000 # A simplistic estimate for joins
|
| 15 |
|
| 16 |
-
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
-
|
| 20 |
"""
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# Format the path for display
|
| 66 |
-
# This is a complex task. The raw path from Neo4j needs careful parsing.
|
| 67 |
-
# This is a placeholder for that logic.
|
| 68 |
-
return f"Path found (details require parsing): {path_result}"
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
Executes a SQL query against the appropriate database if the estimated
|
| 73 |
-
cost is below the threshold.
|
| 74 |
-
"""
|
| 75 |
-
cost_estimate = self.estimate_query_cost(sql)
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
-
|
| 81 |
-
# against is a hard problem (especially for federated queries).
|
| 82 |
-
# We assume the first table found belongs to the correct database.
|
| 83 |
-
parsed_sql = self._parse_sql(sql)
|
| 84 |
-
if not parsed_sql['tables']:
|
| 85 |
-
raise ValueError("No tables found in SQL query.")
|
| 86 |
-
|
| 87 |
-
first_table = parsed_sql['tables'][0]
|
| 88 |
-
search_results = self.graph_store.keyword_search(first_table)
|
| 89 |
-
if not search_results:
|
| 90 |
-
raise ValueError(f"Table '{first_table}' not found in any known database.")
|
| 91 |
-
|
| 92 |
-
db_name = search_results[0]['database']
|
| 93 |
-
engine = self._get_engine_for_db(db_name)
|
| 94 |
-
|
| 95 |
-
if not engine:
|
| 96 |
-
raise ConnectionError(f"Could not connect to database: {db_name}")
|
| 97 |
|
|
|
|
| 98 |
with engine.connect() as connection:
|
| 99 |
-
|
| 100 |
-
safe_sql = f"{sql.strip().rstrip(';')} LIMIT {int(limit)}"
|
| 101 |
-
result = connection.execute(text(safe_sql))
|
| 102 |
return [dict(row._mapping) for row in result.fetchall()]
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
parsed = sqlparse.parse(sql)[0]
|
| 107 |
-
# This is a simplistic parser. A real implementation would need
|
| 108 |
-
# a much more robust SQL parsing library to handle complex queries, CTEs, etc.
|
| 109 |
-
tables = set()
|
| 110 |
-
for token in parsed.tokens:
|
| 111 |
-
if isinstance(token, sqlparse.sql.Identifier):
|
| 112 |
-
tables.add(token.get_real_name())
|
| 113 |
-
elif token.is_group:
|
| 114 |
-
# Look for identifiers within subgroups (e.g., in FROM or JOIN clauses)
|
| 115 |
-
for sub_token in token.tokens:
|
| 116 |
-
if isinstance(sub_token, sqlparse.sql.Identifier):
|
| 117 |
-
tables.add(sub_token.get_real_name())
|
| 118 |
-
|
| 119 |
-
return {"tables": list(tables)}
|
| 120 |
-
|
| 121 |
-
def estimate_query_cost(self, sql: str) -> Dict[str, Any]:
|
| 122 |
-
"""
|
| 123 |
-
Estimates the cost of a query based on row counts from the graph.
|
| 124 |
-
"""
|
| 125 |
-
try:
|
| 126 |
-
parsed_sql = self._parse_sql(sql)
|
| 127 |
-
tables_in_query = parsed_sql['tables']
|
| 128 |
-
|
| 129 |
-
if not tables_in_query:
|
| 130 |
-
return {"estimated_rows": 0, "decision": "execute", "message": "No tables found in query."}
|
| 131 |
-
|
| 132 |
-
# For simplicity, we'll take the max row count of any table in the query.
|
| 133 |
-
# A real system would analyze JOINs and WHERE clauses.
|
| 134 |
-
max_rows = 0
|
| 135 |
-
for table_name in tables_in_query:
|
| 136 |
-
# Need to find the unique name. This assumes table names are unique across DBs for now.
|
| 137 |
-
# A real implementation needs context of which DB is being queried.
|
| 138 |
-
search_result = self.graph_store.keyword_search(table_name)
|
| 139 |
-
if search_result:
|
| 140 |
-
table_unique_name = f"{search_result[0]['database']}.{search_result[0]['table']}"
|
| 141 |
-
row_count = self.graph_store.get_table_row_count(table_unique_name)
|
| 142 |
-
if row_count > max_rows:
|
| 143 |
-
max_rows = row_count
|
| 144 |
-
|
| 145 |
-
estimated_rows = max_rows
|
| 146 |
-
# Crude adjustment for joins
|
| 147 |
-
if len(tables_in_query) > 1:
|
| 148 |
-
# A better estimate would involve graph traversal and statistical models
|
| 149 |
-
estimated_rows *= JOIN_CARDINALITY_ESTIMATE * (len(tables_in_query) - 1)
|
| 150 |
-
|
| 151 |
-
decision = "execute" if estimated_rows < ROW_EXECUTION_THRESHOLD else "return_sql"
|
| 152 |
-
|
| 153 |
-
return {
|
| 154 |
-
"estimated_rows": estimated_rows,
|
| 155 |
-
"decision": decision,
|
| 156 |
-
"tables_found": tables_in_query
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
except Exception as e:
|
| 160 |
-
logger.error(f"Error estimating query cost: {e}")
|
| 161 |
-
return {"estimated_rows": -1, "decision": "error", "message": str(e)}
|
|
|
|
| 1 |
import sqlparse
|
| 2 |
import logging
|
| 3 |
from typing import List, Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from sqlalchemy import text
|
| 5 |
|
| 6 |
+
from .database import get_db_connections
|
| 7 |
|
| 8 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
def _get_database_for_table(table_name: str) -> str | None:
|
| 11 |
"""
|
| 12 |
+
Finds which database a table belongs to by checking the graph.
|
| 13 |
+
(This is a simplified helper; assumes GraphStore is accessible or passed)
|
| 14 |
"""
|
| 15 |
+
# This is a placeholder for the logic to find a table's database.
|
| 16 |
+
# In a real scenario, this would query Neo4j. We'll simulate it.
|
| 17 |
+
# A simple mapping for our known databases:
|
| 18 |
+
if table_name in ["studies", "patients", "adverse_events"]:
|
| 19 |
+
return "clinical_trials"
|
| 20 |
+
if table_name in ["lab_tests", "test_results", "biomarkers"]:
|
| 21 |
+
return "laboratory"
|
| 22 |
+
if table_name in ["compounds", "assay_results", "drug_targets", "compound_targets"]:
|
| 23 |
+
return "drug_discovery"
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
async def execute_federated_query(sql: str) -> List[Dict[str, Any]]:
|
| 28 |
+
"""
|
| 29 |
+
Executes a SQL query against the correct SQLite database.
|
| 30 |
+
This is a simplified version of a federated query engine. It identifies the
|
| 31 |
+
target database from the first table name in the SQL query.
|
| 32 |
+
"""
|
| 33 |
+
parsed = sqlparse.parse(sql)[0]
|
| 34 |
+
target_table = None
|
| 35 |
+
|
| 36 |
+
# Find the first table name in the parsed SQL
|
| 37 |
+
for token in parsed.tokens:
|
| 38 |
+
if isinstance(token, sqlparse.sql.Identifier):
|
| 39 |
+
target_table = token.get_real_name()
|
| 40 |
+
break
|
| 41 |
+
elif token.is_group:
|
| 42 |
+
for sub_token in token.tokens:
|
| 43 |
+
if isinstance(sub_token, sqlparse.sql.Identifier):
|
| 44 |
+
target_table = sub_token.get_real_name()
|
| 45 |
+
break
|
| 46 |
+
if target_table:
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
if not target_table:
|
| 50 |
+
raise ValueError("Could not identify a target table in the SQL query.")
|
| 51 |
+
|
| 52 |
+
logger.info(f"Identified target table: {target_table}")
|
| 53 |
+
|
| 54 |
+
# Determine which database engine to use
|
| 55 |
+
db_name = _get_database_for_table(target_table)
|
| 56 |
+
if not db_name:
|
| 57 |
+
raise ValueError(f"Table '{target_table}' not found in any known database.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
db_engines = get_db_connections()
|
| 60 |
+
engine = db_engines.get(db_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
if not engine:
|
| 63 |
+
raise ConnectionError(f"No active connection for database '{db_name}'.")
|
| 64 |
|
| 65 |
+
logger.info(f"Executing query against database: {db_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
try:
|
| 68 |
with engine.connect() as connection:
|
| 69 |
+
result = connection.execute(text(sql))
|
|
|
|
|
|
|
| 70 |
return [dict(row._mapping) for row in result.fetchall()]
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error(f"Failed to execute query on {db_name}: {e}", exc_info=True)
|
| 73 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mcp/main.py
CHANGED
|
@@ -1,146 +1,88 @@
|
|
| 1 |
-
from fastapi import FastAPI, Header, HTTPException
|
| 2 |
-
from
|
| 3 |
import os
|
| 4 |
-
import
|
| 5 |
-
from
|
| 6 |
-
import psycopg2
|
| 7 |
-
from psycopg2.extras import RealDictCursor
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
async def
|
| 24 |
-
# Verify API key
|
| 25 |
if x_api_key not in VALID_API_KEYS:
|
| 26 |
-
raise HTTPException(status_code=401, detail="Invalid API
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
return {"error": str(e), "query": query, "parameters": query_params}
|
| 67 |
-
|
| 68 |
-
elif tool == "write_graph":
|
| 69 |
-
# Structured write operation
|
| 70 |
-
action = params.get("action")
|
| 71 |
-
if action == "create_node":
|
| 72 |
-
label = params.get("label")
|
| 73 |
-
properties = params.get("properties", {})
|
| 74 |
-
with driver.session() as session:
|
| 75 |
-
result = session.run(f"CREATE (n:{label} $props) RETURN n", {"props": properties})
|
| 76 |
-
record = result.single()
|
| 77 |
-
if record:
|
| 78 |
-
node = record["n"]
|
| 79 |
-
return {"created": dict(node) if hasattr(node, 'items') else {"id": str(node.id), "labels": list(node.labels), "properties": dict(node)}}
|
| 80 |
-
return {"created": {}}
|
| 81 |
-
|
| 82 |
-
elif tool == "get_next_instruction":
|
| 83 |
-
# Get next pending instruction
|
| 84 |
-
with driver.session() as session:
|
| 85 |
-
result = session.run("""
|
| 86 |
-
MATCH (i:Instruction {status: 'pending'})
|
| 87 |
-
RETURN i ORDER BY i.sequence LIMIT 1
|
| 88 |
-
""")
|
| 89 |
-
record = result.single()
|
| 90 |
-
return {"instruction": dict(record["i"]) if record else None}
|
| 91 |
-
|
| 92 |
-
elif tool == "query_postgres":
|
| 93 |
-
query = params.get("query")
|
| 94 |
-
try:
|
| 95 |
-
conn = psycopg2.connect(POSTGRES_CONN)
|
| 96 |
-
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
| 97 |
-
cur.execute(query)
|
| 98 |
-
if cur.description: # SELECT query
|
| 99 |
-
results = cur.fetchall()
|
| 100 |
-
return {"data": results, "row_count": len(results)}
|
| 101 |
-
else: # INSERT/UPDATE/DELETE
|
| 102 |
-
conn.commit()
|
| 103 |
-
return {"affected_rows": cur.rowcount}
|
| 104 |
-
except Exception as e:
|
| 105 |
-
return {"error": str(e)}
|
| 106 |
-
finally:
|
| 107 |
-
if 'conn' in locals():
|
| 108 |
-
conn.close()
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
WHERE table_schema = 'public'
|
| 119 |
-
AND table_type = 'BASE TABLE'
|
| 120 |
-
""")
|
| 121 |
-
tables = cur.fetchall()
|
| 122 |
-
|
| 123 |
-
schema_info = {}
|
| 124 |
-
for table in tables:
|
| 125 |
-
table_name = table['table_name']
|
| 126 |
-
|
| 127 |
-
# Get columns for each table
|
| 128 |
-
cur.execute("""
|
| 129 |
-
SELECT column_name, data_type, is_nullable,
|
| 130 |
-
column_default, character_maximum_length
|
| 131 |
-
FROM information_schema.columns
|
| 132 |
-
WHERE table_schema = 'public'
|
| 133 |
-
AND table_name = %s
|
| 134 |
-
ORDER BY ordinal_position
|
| 135 |
-
""", (table_name,))
|
| 136 |
-
|
| 137 |
-
schema_info[table_name] = cur.fetchall()
|
| 138 |
-
|
| 139 |
-
return {"schema": schema_info}
|
| 140 |
-
except Exception as e:
|
| 141 |
-
return {"error": str(e)}
|
| 142 |
-
finally:
|
| 143 |
-
if 'conn' in locals():
|
| 144 |
-
conn.close()
|
| 145 |
-
|
| 146 |
-
return {"error": "Unknown tool"}
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Header, HTTPException, Depends
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
import os
|
| 4 |
+
import logging
|
| 5 |
+
from pydantic import BaseModel
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
# --- Core Logic Imports ---
|
| 8 |
+
# These imports assume your project structure places the core logic correctly.
|
| 9 |
+
from core.database import get_db_connections, close_db_connections
|
| 10 |
+
from core.discovery import get_relevant_schemas
|
| 11 |
+
from core.graph import find_join_path, get_graph_driver, close_graph_driver
|
| 12 |
+
from core.intelligence import execute_federated_query
|
| 13 |
|
| 14 |
+
# --- App Configuration ---
|
| 15 |
+
logging.basicConfig(level=logging.INFO)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
+
app = FastAPI(title="MCP Server", version="2.0")
|
| 19 |
+
|
| 20 |
+
VALID_API_KEYS = os.getenv("MCP_API_KEYS", "dev-key-123").split(",")
|
| 21 |
+
|
| 22 |
+
# --- Pydantic Models ---
|
| 23 |
+
class ToolRequest(BaseModel):
|
| 24 |
+
tool: str
|
| 25 |
+
params: Dict[str, Any]
|
| 26 |
+
|
| 27 |
+
class SchemaQuery(BaseModel):
|
| 28 |
+
query: str
|
| 29 |
+
|
| 30 |
+
class JoinPathRequest(BaseModel):
|
| 31 |
+
table1: str
|
| 32 |
+
table2: str
|
| 33 |
+
|
| 34 |
+
class SQLQuery(BaseModel):
|
| 35 |
+
sql: str
|
| 36 |
|
| 37 |
+
# --- Dependency for Auth ---
|
| 38 |
+
async def verify_api_key(x_api_key: str = Header(...)):
|
|
|
|
| 39 |
if x_api_key not in VALID_API_KEYS:
|
| 40 |
+
raise HTTPException(status_code=401, detail="Invalid API Key")
|
| 41 |
+
return x_api_key
|
| 42 |
+
|
| 43 |
+
# --- Event Handlers ---
|
| 44 |
+
@app.on_event("startup")
|
| 45 |
+
async def startup_event():
|
| 46 |
+
"""Initializes the database connection pool on server startup."""
|
| 47 |
+
get_db_connections()
|
| 48 |
+
get_graph_driver()
|
| 49 |
+
logger.info("MCP server started and database connections initialized.")
|
| 50 |
+
|
| 51 |
+
@app.on_event("shutdown")
|
| 52 |
+
def shutdown_event():
|
| 53 |
+
"""Closes the database connection pool on server shutdown."""
|
| 54 |
+
close_db_connections()
|
| 55 |
+
close_graph_driver()
|
| 56 |
+
logger.info("MCP server shutting down and database connections closed.")
|
| 57 |
+
|
| 58 |
+
# --- API Endpoints ---
|
| 59 |
+
@app.get("/health")
|
| 60 |
+
def health_check():
|
| 61 |
+
return {"status": "ok"}
|
| 62 |
+
|
| 63 |
+
@app.post("/mcp/discovery/get_relevant_schemas", dependencies=[Depends(verify_api_key)])
|
| 64 |
+
async def discover_schemas(request: SchemaQuery):
|
| 65 |
+
try:
|
| 66 |
+
schemas = await get_relevant_schemas(request.query)
|
| 67 |
+
return {"status": "success", "schemas": schemas}
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Schema discovery failed: {e}", exc_info=True)
|
| 70 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 71 |
+
|
| 72 |
+
@app.post("/mcp/graph/find_join_path", dependencies=[Depends(verify_api_key)])
|
| 73 |
+
async def get_join_path(request: JoinPathRequest):
|
| 74 |
+
try:
|
| 75 |
+
path = find_join_path(request.table1, request.table2)
|
| 76 |
+
return {"status": "success", "path": path}
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error(f"Join path finding failed: {e}", exc_info=True)
|
| 79 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
@app.post("/mcp/intelligence/execute_query", dependencies=[Depends(verify_api_key)])
|
| 82 |
+
async def execute_query(request: SQLQuery):
|
| 83 |
+
try:
|
| 84 |
+
results = await execute_federated_query(request.sql)
|
| 85 |
+
return {"status": "success", "results": results}
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logger.error(f"Query execution failed: {e}", exc_info=True)
|
| 88 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
streamlit/requirements.txt
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
streamlit==1.28.0
|
| 2 |
requests==2.31.0
|
| 3 |
pandas==2.1.0
|
| 4 |
-
python-dotenv==1.0.0
|
|
|
|
| 1 |
streamlit==1.28.0
|
| 2 |
requests==2.31.0
|
| 3 |
pandas==2.1.0
|
|
|