File size: 3,054 Bytes
9d411a7
 
9930ba9
9d411a7
 
9930ba9
9d411a7
 
 
 
 
 
9930ba9
9d411a7
 
 
9930ba9
9d411a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9930ba9
9d411a7
 
9930ba9
9d411a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84473fd
9d411a7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from fastapi import FastAPI, Header, HTTPException, Depends
from typing import List, Dict, Any
import os
import logging
from pydantic import BaseModel

# --- Core Logic Imports ---
# These imports assume your project structure places the core logic correctly.
from core.database import get_db_connections, close_db_connections
from core.discovery import get_relevant_schemas
from core.graph import find_join_path, get_graph_driver, close_graph_driver
from core.intelligence import execute_federated_query

# --- App Configuration ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="MCP Server", version="2.0")

VALID_API_KEYS = os.getenv("MCP_API_KEYS", "dev-key-123").split(",")

# --- Pydantic Models ---
class ToolRequest(BaseModel):
    tool: str
    params: Dict[str, Any]

class SchemaQuery(BaseModel):
    query: str

class JoinPathRequest(BaseModel):
    table1: str
    table2: str

class SQLQuery(BaseModel):
    sql: str

# --- Dependency for Auth ---
async def verify_api_key(x_api_key: str = Header(...)):
    if x_api_key not in VALID_API_KEYS:
        raise HTTPException(status_code=401, detail="Invalid API Key")
    return x_api_key

# --- Event Handlers ---
@app.on_event("startup")
async def startup_event():
    """Initializes the database connection pool on server startup."""
    get_db_connections()
    get_graph_driver()
    logger.info("MCP server started and database connections initialized.")

@app.on_event("shutdown")
def shutdown_event():
    """Closes the database connection pool on server shutdown."""
    close_db_connections()
    close_graph_driver()
    logger.info("MCP server shutting down and database connections closed.")

# --- API Endpoints ---
@app.get("/health")
def health_check():
    return {"status": "ok"}

@app.post("/mcp/discovery/get_relevant_schemas", dependencies=[Depends(verify_api_key)])
async def discover_schemas(request: SchemaQuery):
    try:
        schemas = await get_relevant_schemas(request.query)
        return {"status": "success", "schemas": schemas}
    except Exception as e:
        logger.error(f"Schema discovery failed: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/mcp/graph/find_join_path", dependencies=[Depends(verify_api_key)])
async def get_join_path(request: JoinPathRequest):
    try:
        path = find_join_path(request.table1, request.table2)
        return {"status": "success", "path": path}
    except Exception as e:
        logger.error(f"Join path finding failed: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/mcp/intelligence/execute_query", dependencies=[Depends(verify_api_key)])
async def execute_query(request: SQLQuery):
    try:
        results = await execute_federated_query(request.sql)
        return {"status": "success", "results": results}
    except Exception as e:
        logger.error(f"Query execution failed: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))