File size: 7,133 Bytes
e44e5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b86e5c9
e44e5dd
 
b86e5c9
e44e5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from __future__ import annotations

import logging
import os
from contextlib import asynccontextmanager
from typing import Awaitable, Callable, Dict, Optional

from fastapi import FastAPI, Query
import uvicorn

from backend.mcp_server.admin.rules import admin_add_rule, admin_delete_rule, admin_get_rules
from backend.mcp_server.admin.violations import log_violation as admin_log_violation
from backend.mcp_server.rag.delete import rag_delete
from backend.mcp_server.rag.ingest import rag_ingest
from backend.mcp_server.rag.list import rag_list
from backend.mcp_server.rag.search import rag_search
from backend.mcp_server.web.search import web_search

ToolHandler = Callable[[Dict], Awaitable[Dict] | Dict]

logger = logging.getLogger("integrachat.mcp.server")
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter(
        "[%(asctime)s] %(levelname)s %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    print("formatter", formatter)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    print("logger", logger)
logger.setLevel(logging.INFO)


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifespan context manager for startup and shutdown events."""
    # Startup
    try:
        routes = []
        for route in app.routes:
            if hasattr(route, "path") and hasattr(route, "methods"):
                routes.append(f"{', '.join(route.methods)} {route.path}")
        logger.info("Registered routes: %s", ", ".join(sorted(routes)))
    except Exception as e:
        logger.warning("Could not log routes during startup: %s", e)
    yield
    # Shutdown (if needed in the future)


app = FastAPI(title="IntegraChat MCP", version="1.0.0", lifespan=lifespan)


def _register_tool(tool_name: str, handler: ToolHandler) -> None:
    """
    Register the given tool handler under both a namespaced route
    (/rag/search) and an optional root route (/search) so the server works
    whether clients point to /rag or directly to the namespace port.
    """
    namespace, action = tool_name.split(".", 1)
    namespaced_path = f"/{namespace}/{action}"
    root_path = f"/{action}"

    @app.post(namespaced_path)
    async def namespaced_endpoint(payload: Dict) -> Dict:
        return await handler(payload)  # type: ignore[arg-type]

    @app.post(root_path)
    async def root_endpoint(payload: Dict) -> Dict:
        return await handler(payload)  # type: ignore[arg-type]


# Add GET endpoint support for /rag/list (register BEFORE POST to avoid conflicts)
@app.get("/rag/list")
async def rag_list_get(
    tenant_id: str = Query(..., description="Tenant ID"),
    limit: Optional[int] = Query(1000, description="Maximum number of documents to return"),
    offset: Optional[int] = Query(0, description="Number of documents to skip")
) -> Dict:
    """GET endpoint for listing RAG documents."""
    logger.info("GET /rag/list called with tenant_id=%s, limit=%s, offset=%s", tenant_id, limit, offset)
    payload = {
        "tenant_id": tenant_id,
        "limit": limit,
        "offset": offset
    }
    result = await rag_list(payload)  # type: ignore[arg-type]
    return result

@app.get("/list")
async def rag_list_get_root(
    tenant_id: str = Query(..., description="Tenant ID"),
    limit: Optional[int] = Query(1000, description="Maximum number of documents to return"),
    offset: Optional[int] = Query(0, description="Number of documents to skip")
) -> Dict:
    """GET endpoint for listing RAG documents (root path)."""
    logger.info("GET /list called with tenant_id=%s, limit=%s, offset=%s", tenant_id, limit, offset)
    payload = {
        "tenant_id": tenant_id,
        "limit": limit,
        "offset": offset
    }
    result = await rag_list(payload)  # type: ignore[arg-type]
    return result

# Add DELETE endpoint support for /rag/delete/{document_id}
@app.delete("/rag/delete/{document_id}")
async def rag_delete_document(
    document_id: int,
    tenant_id: str = Query(..., description="Tenant ID")
) -> Dict:
    """DELETE endpoint for deleting a specific document."""
    try:
        logger.info("DELETE /rag/delete/%s called with tenant_id=%s", document_id, tenant_id)
        payload = {
            "tenant_id": tenant_id,
            "document_id": document_id
        }
        result = await rag_delete(payload)  # type: ignore[arg-type]
        logger.info("DELETE /rag/delete/%s result: %s", document_id, result)
        return result
    except Exception as e:
        logger.error("Error in DELETE /rag/delete/%s: %s", document_id, e, exc_info=True)
        raise

@app.delete("/delete/{document_id}")
async def rag_delete_document_root(
    document_id: int,
    tenant_id: str = Query(..., description="Tenant ID")
) -> Dict:
    """DELETE endpoint for deleting a specific document (root path)."""
    logger.info("DELETE /delete/%s called with tenant_id=%s", document_id, tenant_id)
    payload = {
        "tenant_id": tenant_id,
        "document_id": document_id
    }
    result = await rag_delete(payload)  # type: ignore[arg-type]
    return result

# Add DELETE endpoint support for /rag/delete-all
@app.delete("/rag/delete-all")
async def rag_delete_all(
    tenant_id: str = Query(..., description="Tenant ID")
) -> Dict:
    """DELETE endpoint for deleting all documents."""
    try:
        logger.info("DELETE /rag/delete-all called with tenant_id=%s", tenant_id)
        payload = {
            "tenant_id": tenant_id,
            "delete_all": True
        }
        result = await rag_delete(payload)  # type: ignore[arg-type]
        return result
    except Exception as e:
        logger.error("Error in DELETE /rag/delete-all: %s", e, exc_info=True)
        raise

@app.delete("/delete-all")
async def rag_delete_all_root(
    tenant_id: str = Query(..., description="Tenant ID")
) -> Dict:
    """DELETE endpoint for deleting all documents (root path)."""
    try:
        logger.info("DELETE /delete-all called with tenant_id=%s", tenant_id)
        payload = {
            "tenant_id": tenant_id,
            "delete_all": True
        }
        result = await rag_delete(payload)  # type: ignore[arg-type]
        return result
    except Exception as e:
        logger.error("Error in DELETE /delete-all: %s", e, exc_info=True)
        raise

_register_tool("rag.search", rag_search)
_register_tool("rag.ingest", rag_ingest)
_register_tool("rag.delete", rag_delete)
_register_tool("rag.list", rag_list)

_register_tool("web.search", web_search)

_register_tool("admin.getRules", admin_get_rules)
_register_tool("admin.addRule", admin_add_rule)
_register_tool("admin.deleteRule", admin_delete_rule)
_register_tool("admin.logViolation", admin_log_violation)


@app.get("/health")
async def health() -> Dict[str, str]:
    return {"status": "ok", "service": "mcp"}


def main():
    host = os.getenv("MCP_HOST", "0.0.0.0")
    port = int(os.getenv("MCP_PORT", "8001"))
    logger.info("Starting IntegraChat MCP HTTP server on %s:%s", host, port)
    uvicorn.run("backend.mcp_server.server:app", host=host, port=port)


if __name__ == "__main__":
    main()