from fastapi import FastAPI, Header, HTTPException, Depends from typing import List, Dict, Any import os import logging from pydantic import BaseModel # --- Core Logic Imports --- # These imports assume your project structure places the core logic correctly. from core.database import get_db_connections, close_db_connections from core.discovery import get_relevant_schemas from core.graph import find_join_path, get_graph_driver, close_graph_driver from core.intelligence import execute_federated_query # --- App Configuration --- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="MCP Server", version="2.0") VALID_API_KEYS = os.getenv("MCP_API_KEYS", "dev-key-123").split(",") # --- Pydantic Models --- class ToolRequest(BaseModel): tool: str params: Dict[str, Any] class SchemaQuery(BaseModel): query: str class JoinPathRequest(BaseModel): table1: str table2: str class SQLQuery(BaseModel): sql: str # --- Dependency for Auth --- async def verify_api_key(x_api_key: str = Header(...)): if x_api_key not in VALID_API_KEYS: raise HTTPException(status_code=401, detail="Invalid API Key") return x_api_key # --- Event Handlers --- @app.on_event("startup") async def startup_event(): """Initializes the database connection pool on server startup.""" get_db_connections() get_graph_driver() logger.info("MCP server started and database connections initialized.") @app.on_event("shutdown") def shutdown_event(): """Closes the database connection pool on server shutdown.""" close_db_connections() close_graph_driver() logger.info("MCP server shutting down and database connections closed.") # --- API Endpoints --- @app.get("/health") def health_check(): return {"status": "ok"} @app.post("/mcp/discovery/get_relevant_schemas", dependencies=[Depends(verify_api_key)]) async def discover_schemas(request: SchemaQuery): try: schemas = await get_relevant_schemas(request.query) return {"status": "success", "schemas": schemas} except Exception as e: logger.error(f"Schema discovery failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/mcp/graph/find_join_path", dependencies=[Depends(verify_api_key)]) async def get_join_path(request: JoinPathRequest): try: path = find_join_path(request.table1, request.table2) return {"status": "success", "path": path} except Exception as e: logger.error(f"Join path finding failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/mcp/intelligence/execute_query", dependencies=[Depends(verify_api_key)]) async def execute_query(request: SQLQuery): try: results = await execute_federated_query(request.sql) return {"status": "success", "results": results} except Exception as e: logger.error(f"Query execution failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e))