File size: 3,054 Bytes
9d411a7 9930ba9 9d411a7 9930ba9 9d411a7 9930ba9 9d411a7 9930ba9 9d411a7 9930ba9 9d411a7 9930ba9 9d411a7 84473fd 9d411a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | from fastapi import FastAPI, Header, HTTPException, Depends
from typing import List, Dict, Any
import os
import logging
from pydantic import BaseModel
# --- Core Logic Imports ---
# These imports assume your project structure places the core logic correctly.
from core.database import get_db_connections, close_db_connections
from core.discovery import get_relevant_schemas
from core.graph import find_join_path, get_graph_driver, close_graph_driver
from core.intelligence import execute_federated_query
# --- App Configuration ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="MCP Server", version="2.0")
VALID_API_KEYS = os.getenv("MCP_API_KEYS", "dev-key-123").split(",")
# --- Pydantic Models ---
class ToolRequest(BaseModel):
tool: str
params: Dict[str, Any]
class SchemaQuery(BaseModel):
query: str
class JoinPathRequest(BaseModel):
table1: str
table2: str
class SQLQuery(BaseModel):
sql: str
# --- Dependency for Auth ---
async def verify_api_key(x_api_key: str = Header(...)):
if x_api_key not in VALID_API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API Key")
return x_api_key
# --- Event Handlers ---
@app.on_event("startup")
async def startup_event():
"""Initializes the database connection pool on server startup."""
get_db_connections()
get_graph_driver()
logger.info("MCP server started and database connections initialized.")
@app.on_event("shutdown")
def shutdown_event():
"""Closes the database connection pool on server shutdown."""
close_db_connections()
close_graph_driver()
logger.info("MCP server shutting down and database connections closed.")
# --- API Endpoints ---
@app.get("/health")
def health_check():
return {"status": "ok"}
@app.post("/mcp/discovery/get_relevant_schemas", dependencies=[Depends(verify_api_key)])
async def discover_schemas(request: SchemaQuery):
try:
schemas = await get_relevant_schemas(request.query)
return {"status": "success", "schemas": schemas}
except Exception as e:
logger.error(f"Schema discovery failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/mcp/graph/find_join_path", dependencies=[Depends(verify_api_key)])
async def get_join_path(request: JoinPathRequest):
try:
path = find_join_path(request.table1, request.table2)
return {"status": "success", "path": path}
except Exception as e:
logger.error(f"Join path finding failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/mcp/intelligence/execute_query", dependencies=[Depends(verify_api_key)])
async def execute_query(request: SQLQuery):
try:
results = await execute_federated_query(request.sql)
return {"status": "success", "results": results}
except Exception as e:
logger.error(f"Query execution failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) |