"""MCP Server for RAG system.""" import json from typing import Any, Dict, List from fastapi import FastAPI, HTTPException from pydantic import BaseModel import uvicorn from core.ingest import DocumentProcessor from core.index import IndexManager from core.retrieval import RAGComparator from core.eval import RAGEvaluator # Initialize FastAPI app app = FastAPI(title="Hierarchical RAG MCP Server", version="1.0.0") # Global state index_manager = None rag_comparator = None evaluator = None # Request/Response Models class InitRequest(BaseModel): persist_directory: str = "./data/chroma" embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" class UploadRequest(BaseModel): filepaths: List[str] hierarchy: str mask_pii: bool = False class IndexRequest(BaseModel): filepaths: List[str] hierarchy: str chunk_size: int = 512 chunk_overlap: int = 50 mask_pii: bool = False collection_name: str = "rag_documents" class QueryRequest(BaseModel): query: str n_results: int = 5 pipeline: str = "both" # base, hier, or both level1: str = None level2: str = None level3: str = None doc_type: str = None auto_infer: bool = True class EvaluateRequest(BaseModel): queries: List[str] relevant_ids: List[List[str]] k_values: List[int] = [1, 3, 5] # Endpoints @app.post("/initialize") async def initialize(request: InitRequest) -> Dict[str, Any]: """Initialize the RAG system.""" global index_manager, evaluator try: index_manager = IndexManager( persist_directory=request.persist_directory, embedding_model_name=request.embedding_model ) evaluator = RAGEvaluator(embedding_model_name=request.embedding_model) return { "status": "success", "message": "System initialized successfully" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/upload") async def upload_documents(request: UploadRequest) -> Dict[str, Any]: """Validate uploaded documents.""" try: from pathlib import Path valid_extensions = {'.pdf', '.txt'} valid_files = [] invalid_files = [] for filepath in request.filepaths: ext = Path(filepath).suffix.lower() if ext in valid_extensions: valid_files.append(filepath) else: invalid_files.append(filepath) return { "status": "success", "total_uploaded": len(request.filepaths), "valid_files": valid_files, "invalid_files": invalid_files, "hierarchy": request.hierarchy } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/index") async def build_index(request: IndexRequest) -> Dict[str, Any]: """Build RAG index from documents.""" global index_manager, rag_comparator if not index_manager: raise HTTPException(status_code=400, detail="System not initialized") try: # Process documents processor = DocumentProcessor( hierarchy_name=request.hierarchy, chunk_size=request.chunk_size, chunk_overlap=request.chunk_overlap, mask_pii=request.mask_pii ) all_chunks = processor.process_documents(request.filepaths) if not all_chunks: return { "status": "error", "message": "No chunks extracted from documents" } # Index documents stats = index_manager.index_documents(all_chunks, request.collection_name) # Initialize RAG comparator vector_store = index_manager.get_store(request.collection_name) import os rag_comparator = RAGComparator( vector_store=vector_store, llm_model=os.getenv("LLM_MODEL", "gpt-3.5-turbo"), api_key=os.getenv("OPENAI_API_KEY") ) return { "status": "success", "chunks_indexed": stats.get("chunks_added", 0), "collection": request.collection_name, "hierarchy": request.hierarchy } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/query") async def query_rag(request: QueryRequest) -> Dict[str, Any]: """Query the RAG system.""" global rag_comparator if not rag_comparator: raise HTTPException(status_code=400, detail="RAG system not initialized") try: if request.pipeline.lower() == "both": result = rag_comparator.compare( query=request.query, n_results=request.n_results, level1=request.level1, level2=request.level2, level3=request.level3, doc_type=request.doc_type, auto_infer=request.auto_infer ) return result elif request.pipeline.lower() == "base": result = rag_comparator.base_rag.query(request.query, request.n_results) return result else: # hier result = rag_comparator.hier_rag.query( query=request.query, n_results=request.n_results, level1=request.level1, level2=request.level2, level3=request.level3, doc_type=request.doc_type, auto_infer=request.auto_infer ) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/evaluate") async def evaluate_rag(request: EvaluateRequest) -> Dict[str, Any]: """Evaluate RAG system performance.""" global rag_comparator, evaluator if not rag_comparator or not evaluator: raise HTTPException(status_code=400, detail="System not initialized") try: results = [] for i, (query, relevant_ids) in enumerate(zip(request.queries, request.relevant_ids)): # Run comparison comparison = rag_comparator.compare(query=query, n_results=5) # Evaluate base RAG base_eval = evaluator.evaluate_rag_pipeline( comparison['base_rag'], relevant_ids, k_values=request.k_values ) # Evaluate hier RAG hier_eval = evaluator.evaluate_rag_pipeline( comparison['hier_rag'], relevant_ids, k_values=request.k_values ) results.append({ "query": query, "base_rag": base_eval, "hier_rag": hier_eval, "speedup": comparison['speedup'] }) return { "status": "success", "results": results } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check() -> Dict[str, str]: """Health check endpoint.""" return {"status": "healthy"} @app.get("/info") async def system_info() -> Dict[str, Any]: """Get system information.""" global index_manager, rag_comparator info = { "initialized": index_manager is not None, "rag_ready": rag_comparator is not None } if index_manager: info["collections"] = index_manager.list_collections() return info # Run server if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)