hierarchical-rag-eval / mcp_server.py
hh786's picture
Deployment of Hierarchical RAG system
c54dcef
"""MCP Server for RAG system."""
import json
from typing import Any, Dict, List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from core.ingest import DocumentProcessor
from core.index import IndexManager
from core.retrieval import RAGComparator
from core.eval import RAGEvaluator
# Initialize FastAPI app
app = FastAPI(title="Hierarchical RAG MCP Server", version="1.0.0")
# Global state
index_manager = None
rag_comparator = None
evaluator = None
# Request/Response Models
class InitRequest(BaseModel):
persist_directory: str = "./data/chroma"
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
class UploadRequest(BaseModel):
filepaths: List[str]
hierarchy: str
mask_pii: bool = False
class IndexRequest(BaseModel):
filepaths: List[str]
hierarchy: str
chunk_size: int = 512
chunk_overlap: int = 50
mask_pii: bool = False
collection_name: str = "rag_documents"
class QueryRequest(BaseModel):
query: str
n_results: int = 5
pipeline: str = "both" # base, hier, or both
level1: str = None
level2: str = None
level3: str = None
doc_type: str = None
auto_infer: bool = True
class EvaluateRequest(BaseModel):
queries: List[str]
relevant_ids: List[List[str]]
k_values: List[int] = [1, 3, 5]
# Endpoints
@app.post("/initialize")
async def initialize(request: InitRequest) -> Dict[str, Any]:
"""Initialize the RAG system."""
global index_manager, evaluator
try:
index_manager = IndexManager(
persist_directory=request.persist_directory,
embedding_model_name=request.embedding_model
)
evaluator = RAGEvaluator(embedding_model_name=request.embedding_model)
return {
"status": "success",
"message": "System initialized successfully"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload")
async def upload_documents(request: UploadRequest) -> Dict[str, Any]:
"""Validate uploaded documents."""
try:
from pathlib import Path
valid_extensions = {'.pdf', '.txt'}
valid_files = []
invalid_files = []
for filepath in request.filepaths:
ext = Path(filepath).suffix.lower()
if ext in valid_extensions:
valid_files.append(filepath)
else:
invalid_files.append(filepath)
return {
"status": "success",
"total_uploaded": len(request.filepaths),
"valid_files": valid_files,
"invalid_files": invalid_files,
"hierarchy": request.hierarchy
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/index")
async def build_index(request: IndexRequest) -> Dict[str, Any]:
"""Build RAG index from documents."""
global index_manager, rag_comparator
if not index_manager:
raise HTTPException(status_code=400, detail="System not initialized")
try:
# Process documents
processor = DocumentProcessor(
hierarchy_name=request.hierarchy,
chunk_size=request.chunk_size,
chunk_overlap=request.chunk_overlap,
mask_pii=request.mask_pii
)
all_chunks = processor.process_documents(request.filepaths)
if not all_chunks:
return {
"status": "error",
"message": "No chunks extracted from documents"
}
# Index documents
stats = index_manager.index_documents(all_chunks, request.collection_name)
# Initialize RAG comparator
vector_store = index_manager.get_store(request.collection_name)
import os
rag_comparator = RAGComparator(
vector_store=vector_store,
llm_model=os.getenv("LLM_MODEL", "gpt-3.5-turbo"),
api_key=os.getenv("OPENAI_API_KEY")
)
return {
"status": "success",
"chunks_indexed": stats.get("chunks_added", 0),
"collection": request.collection_name,
"hierarchy": request.hierarchy
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query")
async def query_rag(request: QueryRequest) -> Dict[str, Any]:
"""Query the RAG system."""
global rag_comparator
if not rag_comparator:
raise HTTPException(status_code=400, detail="RAG system not initialized")
try:
if request.pipeline.lower() == "both":
result = rag_comparator.compare(
query=request.query,
n_results=request.n_results,
level1=request.level1,
level2=request.level2,
level3=request.level3,
doc_type=request.doc_type,
auto_infer=request.auto_infer
)
return result
elif request.pipeline.lower() == "base":
result = rag_comparator.base_rag.query(request.query, request.n_results)
return result
else: # hier
result = rag_comparator.hier_rag.query(
query=request.query,
n_results=request.n_results,
level1=request.level1,
level2=request.level2,
level3=request.level3,
doc_type=request.doc_type,
auto_infer=request.auto_infer
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/evaluate")
async def evaluate_rag(request: EvaluateRequest) -> Dict[str, Any]:
"""Evaluate RAG system performance."""
global rag_comparator, evaluator
if not rag_comparator or not evaluator:
raise HTTPException(status_code=400, detail="System not initialized")
try:
results = []
for i, (query, relevant_ids) in enumerate(zip(request.queries, request.relevant_ids)):
# Run comparison
comparison = rag_comparator.compare(query=query, n_results=5)
# Evaluate base RAG
base_eval = evaluator.evaluate_rag_pipeline(
comparison['base_rag'],
relevant_ids,
k_values=request.k_values
)
# Evaluate hier RAG
hier_eval = evaluator.evaluate_rag_pipeline(
comparison['hier_rag'],
relevant_ids,
k_values=request.k_values
)
results.append({
"query": query,
"base_rag": base_eval,
"hier_rag": hier_eval,
"speedup": comparison['speedup']
})
return {
"status": "success",
"results": results
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check() -> Dict[str, str]:
"""Health check endpoint."""
return {"status": "healthy"}
@app.get("/info")
async def system_info() -> Dict[str, Any]:
"""Get system information."""
global index_manager, rag_comparator
info = {
"initialized": index_manager is not None,
"rag_ready": rag_comparator is not None
}
if index_manager:
info["collections"] = index_manager.list_collections()
return info
# Run server
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)