""" Secure Multi-Tenant RAG MCP Server """ import sys import os import uuid import chromadb from chromadb.config import Settings from chromadb.utils import embedding_functions from mcp.server.fastmcp import FastMCP from typing import List, Dict, Any, Optional from core.mcp_telemetry import log_usage, log_trace, log_metric import time # Initialize FastMCP Server mcp = FastMCP("Secure RAG", host="0.0.0.0") # Initialize ChromaDB (Persistent) # Store in src/mcp-rag-secure/chroma_db current_dir = os.path.dirname(os.path.abspath(__file__)) persist_directory = os.path.join(current_dir, "chroma_db") client = chromadb.PersistentClient(path=persist_directory) # Use default embedding function (all-MiniLM-L6-v2 usually) # Explicitly use SentenceTransformer if installed, else default try: from sentence_transformers import SentenceTransformer # Custom embedding function wrapper class SentenceTransformerEmbeddingFunction(embedding_functions.EmbeddingFunction): def __init__(self, model_name="all-MiniLM-L6-v2"): self.model = SentenceTransformer(model_name) def __call__(self, input: List[str]) -> List[List[float]]: return self.model.encode(input).tolist() emb_fn = SentenceTransformerEmbeddingFunction() except ImportError: emb_fn = embedding_functions.DefaultEmbeddingFunction() # Create collection collection = client.get_or_create_collection( name="secure_rag", embedding_function=emb_fn ) @mcp.tool() def ingest_document(tenant_id: str, content: str, metadata: Dict[str, Any] = None) -> str: """ Ingest a document into the RAG system with strict tenant isolation. """ log_usage("mcp-rag-secure", "ingest_document") if not metadata: metadata = {} # Enforce tenant_id in metadata metadata["tenant_id"] = tenant_id doc_id = str(uuid.uuid4()) collection.add( documents=[content], metadatas=[metadata], ids=[doc_id] ) return f"Document ingested with ID: {doc_id}" @mcp.tool() def query_knowledge_base(tenant_id: str, query: str, k: int = 3) -> List[Dict[str, Any]]: """ Query the knowledge base. Results are strictly filtered by tenant_id. """ start_time = time.time() trace_id = str(uuid.uuid4()) span_id = str(uuid.uuid4()) log_usage("mcp-rag-secure", "query_knowledge_base") try: results = collection.query( query_texts=[query], n_results=k, where={"tenant_id": tenant_id} # Critical security filter ) formatted_results = [] if results["documents"]: for i, doc in enumerate(results["documents"][0]): meta = results["metadatas"][0][i] formatted_results.append({ "content": doc, "metadata": meta, "score": results["distances"][0][i] if results["distances"] else None }) duration = (time.time() - start_time) * 1000 log_trace("mcp-rag-secure", trace_id, span_id, "query_knowledge_base", duration, "ok") return formatted_results except Exception as e: duration = (time.time() - start_time) * 1000 log_trace("mcp-rag-secure", trace_id, span_id, "query_knowledge_base", duration, "error") raise e @mcp.tool() def delete_tenant_data(tenant_id: str) -> str: """ Delete all data associated with a specific tenant. """ collection.delete( where={"tenant_id": tenant_id} ) return f"All data for tenant {tenant_id} has been deleted." if __name__ == "__main__": import os if os.environ.get("MCP_TRANSPORT") == "sse": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run(mcp.sse_app(), host="0.0.0.0", port=port) else: mcp.run()