| | from fastapi import FastAPI, Header, HTTPException, Depends |
| | from typing import List, Dict, Any |
| | import os |
| | import logging |
| | from pydantic import BaseModel |
| |
|
| | |
| | |
| | from core.database import get_db_connections, close_db_connections |
| | from core.discovery import get_relevant_schemas |
| | from core.graph import find_join_path, get_graph_driver, close_graph_driver |
| | from core.intelligence import execute_federated_query |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = FastAPI(title="MCP Server", version="2.0") |
| |
|
| | VALID_API_KEYS = os.getenv("MCP_API_KEYS", "dev-key-123").split(",") |
| |
|
| | |
| | class ToolRequest(BaseModel): |
| | tool: str |
| | params: Dict[str, Any] |
| |
|
| | class SchemaQuery(BaseModel): |
| | query: str |
| |
|
| | class JoinPathRequest(BaseModel): |
| | table1: str |
| | table2: str |
| |
|
| | class SQLQuery(BaseModel): |
| | sql: str |
| |
|
| | |
| | async def verify_api_key(x_api_key: str = Header(...)): |
| | if x_api_key not in VALID_API_KEYS: |
| | raise HTTPException(status_code=401, detail="Invalid API Key") |
| | return x_api_key |
| |
|
| | |
| | @app.on_event("startup") |
| | async def startup_event(): |
| | """Initializes the database connection pool on server startup.""" |
| | get_db_connections() |
| | get_graph_driver() |
| | logger.info("MCP server started and database connections initialized.") |
| |
|
| | @app.on_event("shutdown") |
| | def shutdown_event(): |
| | """Closes the database connection pool on server shutdown.""" |
| | close_db_connections() |
| | close_graph_driver() |
| | logger.info("MCP server shutting down and database connections closed.") |
| |
|
| | |
| | @app.get("/health") |
| | def health_check(): |
| | return {"status": "ok"} |
| |
|
| | @app.post("/mcp/discovery/get_relevant_schemas", dependencies=[Depends(verify_api_key)]) |
| | async def discover_schemas(request: SchemaQuery): |
| | try: |
| | schemas = await get_relevant_schemas(request.query) |
| | return {"status": "success", "schemas": schemas} |
| | except Exception as e: |
| | logger.error(f"Schema discovery failed: {e}", exc_info=True) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.post("/mcp/graph/find_join_path", dependencies=[Depends(verify_api_key)]) |
| | async def get_join_path(request: JoinPathRequest): |
| | try: |
| | path = find_join_path(request.table1, request.table2) |
| | return {"status": "success", "path": path} |
| | except Exception as e: |
| | logger.error(f"Join path finding failed: {e}", exc_info=True) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.post("/mcp/intelligence/execute_query", dependencies=[Depends(verify_api_key)]) |
| | async def execute_query(request: SQLQuery): |
| | try: |
| | results = await execute_federated_query(request.sql) |
| | return {"status": "success", "results": results} |
| | except Exception as e: |
| | logger.error(f"Query execution failed: {e}", exc_info=True) |
| | raise HTTPException(status_code=500, detail=str(e)) |