Spaces:
No application file
No application file
| from fastapi import FastAPI, Header, HTTPException, Depends | |
| from typing import List, Dict, Any | |
| import os | |
| import logging | |
| from pydantic import BaseModel | |
| # --- Core Logic Imports --- | |
| # These imports assume your project structure places the core logic correctly. | |
| from core.database import get_db_connections, close_db_connections | |
| from core.discovery import get_relevant_schemas | |
| from core.graph import find_join_path, get_graph_driver, close_graph_driver | |
| from core.intelligence import execute_federated_query | |
| # --- App Configuration --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="MCP Server", version="2.0") | |
| VALID_API_KEYS = os.getenv("MCP_API_KEYS", "dev-key-123").split(",") | |
| # --- Pydantic Models --- | |
| class ToolRequest(BaseModel): | |
| tool: str | |
| params: Dict[str, Any] | |
| class SchemaQuery(BaseModel): | |
| query: str | |
| class JoinPathRequest(BaseModel): | |
| table1: str | |
| table2: str | |
| class SQLQuery(BaseModel): | |
| sql: str | |
| # --- Dependency for Auth --- | |
| async def verify_api_key(x_api_key: str = Header(...)): | |
| if x_api_key not in VALID_API_KEYS: | |
| raise HTTPException(status_code=401, detail="Invalid API Key") | |
| return x_api_key | |
| # --- Event Handlers --- | |
| async def startup_event(): | |
| """Initializes the database connection pool on server startup.""" | |
| get_db_connections() | |
| get_graph_driver() | |
| logger.info("MCP server started and database connections initialized.") | |
| def shutdown_event(): | |
| """Closes the database connection pool on server shutdown.""" | |
| close_db_connections() | |
| close_graph_driver() | |
| logger.info("MCP server shutting down and database connections closed.") | |
| # --- API Endpoints --- | |
| def health_check(): | |
| return {"status": "ok"} | |
| async def discover_schemas(request: SchemaQuery): | |
| try: | |
| schemas = await get_relevant_schemas(request.query) | |
| return {"status": "success", "schemas": schemas} | |
| except Exception as e: | |
| logger.error(f"Schema discovery failed: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_join_path(request: JoinPathRequest): | |
| try: | |
| path = find_join_path(request.table1, request.table2) | |
| return {"status": "success", "path": path} | |
| except Exception as e: | |
| logger.error(f"Join path finding failed: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def execute_query(request: SQLQuery): | |
| try: | |
| results = await execute_federated_query(request.sql) | |
| return {"status": "success", "results": results} | |
| except Exception as e: | |
| logger.error(f"Query execution failed: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) |