|
|
|
|
|
"""
|
|
|
Secure Multi-Tenant RAG MCP Server
|
|
|
"""
|
|
|
import sys
|
|
|
import os
|
|
|
import uuid
|
|
|
import chromadb
|
|
|
from chromadb.config import Settings
|
|
|
from chromadb.utils import embedding_functions
|
|
|
from mcp.server.fastmcp import FastMCP
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
from core.mcp_telemetry import log_usage, log_trace, log_metric
|
|
|
import time
|
|
|
|
|
|
|
|
|
mcp = FastMCP("Secure RAG", host="0.0.0.0")
|
|
|
|
|
|
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
persist_directory = os.path.join(current_dir, "chroma_db")
|
|
|
|
|
|
client = chromadb.PersistentClient(path=persist_directory)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
class SentenceTransformerEmbeddingFunction(embedding_functions.EmbeddingFunction):
|
|
|
def __init__(self, model_name="all-MiniLM-L6-v2"):
|
|
|
self.model = SentenceTransformer(model_name)
|
|
|
def __call__(self, input: List[str]) -> List[List[float]]:
|
|
|
return self.model.encode(input).tolist()
|
|
|
|
|
|
emb_fn = SentenceTransformerEmbeddingFunction()
|
|
|
except ImportError:
|
|
|
emb_fn = embedding_functions.DefaultEmbeddingFunction()
|
|
|
|
|
|
|
|
|
collection = client.get_or_create_collection(
|
|
|
name="secure_rag",
|
|
|
embedding_function=emb_fn
|
|
|
)
|
|
|
|
|
|
@mcp.tool()
|
|
|
def ingest_document(tenant_id: str, content: str, metadata: Dict[str, Any] = None) -> str:
|
|
|
"""
|
|
|
Ingest a document into the RAG system with strict tenant isolation.
|
|
|
"""
|
|
|
log_usage("mcp-rag-secure", "ingest_document")
|
|
|
if not metadata:
|
|
|
metadata = {}
|
|
|
|
|
|
|
|
|
metadata["tenant_id"] = tenant_id
|
|
|
|
|
|
doc_id = str(uuid.uuid4())
|
|
|
|
|
|
collection.add(
|
|
|
documents=[content],
|
|
|
metadatas=[metadata],
|
|
|
ids=[doc_id]
|
|
|
)
|
|
|
return f"Document ingested with ID: {doc_id}"
|
|
|
|
|
|
@mcp.tool()
|
|
|
def query_knowledge_base(tenant_id: str, query: str, k: int = 3) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
Query the knowledge base. Results are strictly filtered by tenant_id.
|
|
|
"""
|
|
|
start_time = time.time()
|
|
|
trace_id = str(uuid.uuid4())
|
|
|
span_id = str(uuid.uuid4())
|
|
|
log_usage("mcp-rag-secure", "query_knowledge_base")
|
|
|
|
|
|
try:
|
|
|
results = collection.query(
|
|
|
query_texts=[query],
|
|
|
n_results=k,
|
|
|
where={"tenant_id": tenant_id}
|
|
|
)
|
|
|
|
|
|
formatted_results = []
|
|
|
if results["documents"]:
|
|
|
for i, doc in enumerate(results["documents"][0]):
|
|
|
meta = results["metadatas"][0][i]
|
|
|
formatted_results.append({
|
|
|
"content": doc,
|
|
|
"metadata": meta,
|
|
|
"score": results["distances"][0][i] if results["distances"] else None
|
|
|
})
|
|
|
|
|
|
duration = (time.time() - start_time) * 1000
|
|
|
log_trace("mcp-rag-secure", trace_id, span_id, "query_knowledge_base", duration, "ok")
|
|
|
return formatted_results
|
|
|
except Exception as e:
|
|
|
duration = (time.time() - start_time) * 1000
|
|
|
log_trace("mcp-rag-secure", trace_id, span_id, "query_knowledge_base", duration, "error")
|
|
|
raise e
|
|
|
|
|
|
@mcp.tool()
|
|
|
def delete_tenant_data(tenant_id: str) -> str:
|
|
|
"""
|
|
|
Delete all data associated with a specific tenant.
|
|
|
"""
|
|
|
collection.delete(
|
|
|
where={"tenant_id": tenant_id}
|
|
|
)
|
|
|
return f"All data for tenant {tenant_id} has been deleted."
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import os
|
|
|
if os.environ.get("MCP_TRANSPORT") == "sse":
|
|
|
import uvicorn
|
|
|
port = int(os.environ.get("PORT", 7860))
|
|
|
uvicorn.run(mcp.sse_app(), host="0.0.0.0", port=port)
|
|
|
else:
|
|
|
mcp.run()
|
|
|
|