Spaces:
Sleeping
Sleeping
| """MCP Server for RAG system.""" | |
| import json | |
| from typing import Any, Dict, List | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from core.ingest import DocumentProcessor | |
| from core.index import IndexManager | |
| from core.retrieval import RAGComparator | |
| from core.eval import RAGEvaluator | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Hierarchical RAG MCP Server", version="1.0.0") | |
| # Global state | |
| index_manager = None | |
| rag_comparator = None | |
| evaluator = None | |
| # Request/Response Models | |
| class InitRequest(BaseModel): | |
| persist_directory: str = "./data/chroma" | |
| embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" | |
| class UploadRequest(BaseModel): | |
| filepaths: List[str] | |
| hierarchy: str | |
| mask_pii: bool = False | |
| class IndexRequest(BaseModel): | |
| filepaths: List[str] | |
| hierarchy: str | |
| chunk_size: int = 512 | |
| chunk_overlap: int = 50 | |
| mask_pii: bool = False | |
| collection_name: str = "rag_documents" | |
| class QueryRequest(BaseModel): | |
| query: str | |
| n_results: int = 5 | |
| pipeline: str = "both" # base, hier, or both | |
| level1: str = None | |
| level2: str = None | |
| level3: str = None | |
| doc_type: str = None | |
| auto_infer: bool = True | |
| class EvaluateRequest(BaseModel): | |
| queries: List[str] | |
| relevant_ids: List[List[str]] | |
| k_values: List[int] = [1, 3, 5] | |
| # Endpoints | |
| async def initialize(request: InitRequest) -> Dict[str, Any]: | |
| """Initialize the RAG system.""" | |
| global index_manager, evaluator | |
| try: | |
| index_manager = IndexManager( | |
| persist_directory=request.persist_directory, | |
| embedding_model_name=request.embedding_model | |
| ) | |
| evaluator = RAGEvaluator(embedding_model_name=request.embedding_model) | |
| return { | |
| "status": "success", | |
| "message": "System initialized successfully" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_documents(request: UploadRequest) -> Dict[str, Any]: | |
| """Validate uploaded documents.""" | |
| try: | |
| from pathlib import Path | |
| valid_extensions = {'.pdf', '.txt'} | |
| valid_files = [] | |
| invalid_files = [] | |
| for filepath in request.filepaths: | |
| ext = Path(filepath).suffix.lower() | |
| if ext in valid_extensions: | |
| valid_files.append(filepath) | |
| else: | |
| invalid_files.append(filepath) | |
| return { | |
| "status": "success", | |
| "total_uploaded": len(request.filepaths), | |
| "valid_files": valid_files, | |
| "invalid_files": invalid_files, | |
| "hierarchy": request.hierarchy | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def build_index(request: IndexRequest) -> Dict[str, Any]: | |
| """Build RAG index from documents.""" | |
| global index_manager, rag_comparator | |
| if not index_manager: | |
| raise HTTPException(status_code=400, detail="System not initialized") | |
| try: | |
| # Process documents | |
| processor = DocumentProcessor( | |
| hierarchy_name=request.hierarchy, | |
| chunk_size=request.chunk_size, | |
| chunk_overlap=request.chunk_overlap, | |
| mask_pii=request.mask_pii | |
| ) | |
| all_chunks = processor.process_documents(request.filepaths) | |
| if not all_chunks: | |
| return { | |
| "status": "error", | |
| "message": "No chunks extracted from documents" | |
| } | |
| # Index documents | |
| stats = index_manager.index_documents(all_chunks, request.collection_name) | |
| # Initialize RAG comparator | |
| vector_store = index_manager.get_store(request.collection_name) | |
| import os | |
| rag_comparator = RAGComparator( | |
| vector_store=vector_store, | |
| llm_model=os.getenv("LLM_MODEL", "gpt-3.5-turbo"), | |
| api_key=os.getenv("OPENAI_API_KEY") | |
| ) | |
| return { | |
| "status": "success", | |
| "chunks_indexed": stats.get("chunks_added", 0), | |
| "collection": request.collection_name, | |
| "hierarchy": request.hierarchy | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def query_rag(request: QueryRequest) -> Dict[str, Any]: | |
| """Query the RAG system.""" | |
| global rag_comparator | |
| if not rag_comparator: | |
| raise HTTPException(status_code=400, detail="RAG system not initialized") | |
| try: | |
| if request.pipeline.lower() == "both": | |
| result = rag_comparator.compare( | |
| query=request.query, | |
| n_results=request.n_results, | |
| level1=request.level1, | |
| level2=request.level2, | |
| level3=request.level3, | |
| doc_type=request.doc_type, | |
| auto_infer=request.auto_infer | |
| ) | |
| return result | |
| elif request.pipeline.lower() == "base": | |
| result = rag_comparator.base_rag.query(request.query, request.n_results) | |
| return result | |
| else: # hier | |
| result = rag_comparator.hier_rag.query( | |
| query=request.query, | |
| n_results=request.n_results, | |
| level1=request.level1, | |
| level2=request.level2, | |
| level3=request.level3, | |
| doc_type=request.doc_type, | |
| auto_infer=request.auto_infer | |
| ) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def evaluate_rag(request: EvaluateRequest) -> Dict[str, Any]: | |
| """Evaluate RAG system performance.""" | |
| global rag_comparator, evaluator | |
| if not rag_comparator or not evaluator: | |
| raise HTTPException(status_code=400, detail="System not initialized") | |
| try: | |
| results = [] | |
| for i, (query, relevant_ids) in enumerate(zip(request.queries, request.relevant_ids)): | |
| # Run comparison | |
| comparison = rag_comparator.compare(query=query, n_results=5) | |
| # Evaluate base RAG | |
| base_eval = evaluator.evaluate_rag_pipeline( | |
| comparison['base_rag'], | |
| relevant_ids, | |
| k_values=request.k_values | |
| ) | |
| # Evaluate hier RAG | |
| hier_eval = evaluator.evaluate_rag_pipeline( | |
| comparison['hier_rag'], | |
| relevant_ids, | |
| k_values=request.k_values | |
| ) | |
| results.append({ | |
| "query": query, | |
| "base_rag": base_eval, | |
| "hier_rag": hier_eval, | |
| "speedup": comparison['speedup'] | |
| }) | |
| return { | |
| "status": "success", | |
| "results": results | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check() -> Dict[str, str]: | |
| """Health check endpoint.""" | |
| return {"status": "healthy"} | |
| async def system_info() -> Dict[str, Any]: | |
| """Get system information.""" | |
| global index_manager, rag_comparator | |
| info = { | |
| "initialized": index_manager is not None, | |
| "rag_ready": rag_comparator is not None | |
| } | |
| if index_manager: | |
| info["collections"] = index_manager.list_collections() | |
| return info | |
| # Run server | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |