File size: 7,757 Bytes
c54dcef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""MCP Server for RAG system."""

import json
from typing import Any, Dict, List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn

from core.ingest import DocumentProcessor
from core.index import IndexManager
from core.retrieval import RAGComparator
from core.eval import RAGEvaluator

# Initialize FastAPI app
app = FastAPI(title="Hierarchical RAG MCP Server", version="1.0.0")

# Global state
index_manager = None
rag_comparator = None
evaluator = None


# Request/Response Models
class InitRequest(BaseModel):
    persist_directory: str = "./data/chroma"
    embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"


class UploadRequest(BaseModel):
    filepaths: List[str]
    hierarchy: str
    mask_pii: bool = False


class IndexRequest(BaseModel):
    filepaths: List[str]
    hierarchy: str
    chunk_size: int = 512
    chunk_overlap: int = 50
    mask_pii: bool = False
    collection_name: str = "rag_documents"


class QueryRequest(BaseModel):
    query: str
    n_results: int = 5
    pipeline: str = "both"  # base, hier, or both
    level1: str = None
    level2: str = None
    level3: str = None
    doc_type: str = None
    auto_infer: bool = True


class EvaluateRequest(BaseModel):
    queries: List[str]
    relevant_ids: List[List[str]]
    k_values: List[int] = [1, 3, 5]


# Endpoints
@app.post("/initialize")
async def initialize(request: InitRequest) -> Dict[str, Any]:
    """Initialize the RAG system."""
    global index_manager, evaluator
    
    try:
        index_manager = IndexManager(
            persist_directory=request.persist_directory,
            embedding_model_name=request.embedding_model
        )
        
        evaluator = RAGEvaluator(embedding_model_name=request.embedding_model)
        
        return {
            "status": "success",
            "message": "System initialized successfully"
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/upload")
async def upload_documents(request: UploadRequest) -> Dict[str, Any]:
    """Validate uploaded documents."""
    try:
        from pathlib import Path
        
        valid_extensions = {'.pdf', '.txt'}
        valid_files = []
        invalid_files = []
        
        for filepath in request.filepaths:
            ext = Path(filepath).suffix.lower()
            if ext in valid_extensions:
                valid_files.append(filepath)
            else:
                invalid_files.append(filepath)
        
        return {
            "status": "success",
            "total_uploaded": len(request.filepaths),
            "valid_files": valid_files,
            "invalid_files": invalid_files,
            "hierarchy": request.hierarchy
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/index")
async def build_index(request: IndexRequest) -> Dict[str, Any]:
    """Build RAG index from documents."""
    global index_manager, rag_comparator
    
    if not index_manager:
        raise HTTPException(status_code=400, detail="System not initialized")
    
    try:
        # Process documents
        processor = DocumentProcessor(
            hierarchy_name=request.hierarchy,
            chunk_size=request.chunk_size,
            chunk_overlap=request.chunk_overlap,
            mask_pii=request.mask_pii
        )
        
        all_chunks = processor.process_documents(request.filepaths)
        
        if not all_chunks:
            return {
                "status": "error",
                "message": "No chunks extracted from documents"
            }
        
        # Index documents
        stats = index_manager.index_documents(all_chunks, request.collection_name)
        
        # Initialize RAG comparator
        vector_store = index_manager.get_store(request.collection_name)
        import os
        
        rag_comparator = RAGComparator(
            vector_store=vector_store,
            llm_model=os.getenv("LLM_MODEL", "gpt-3.5-turbo"),
            api_key=os.getenv("OPENAI_API_KEY")
        )
        
        return {
            "status": "success",
            "chunks_indexed": stats.get("chunks_added", 0),
            "collection": request.collection_name,
            "hierarchy": request.hierarchy
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/query")
async def query_rag(request: QueryRequest) -> Dict[str, Any]:
    """Query the RAG system."""
    global rag_comparator
    
    if not rag_comparator:
        raise HTTPException(status_code=400, detail="RAG system not initialized")
    
    try:
        if request.pipeline.lower() == "both":
            result = rag_comparator.compare(
                query=request.query,
                n_results=request.n_results,
                level1=request.level1,
                level2=request.level2,
                level3=request.level3,
                doc_type=request.doc_type,
                auto_infer=request.auto_infer
            )
            return result
        elif request.pipeline.lower() == "base":
            result = rag_comparator.base_rag.query(request.query, request.n_results)
            return result
        else:  # hier
            result = rag_comparator.hier_rag.query(
                query=request.query,
                n_results=request.n_results,
                level1=request.level1,
                level2=request.level2,
                level3=request.level3,
                doc_type=request.doc_type,
                auto_infer=request.auto_infer
            )
            return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/evaluate")
async def evaluate_rag(request: EvaluateRequest) -> Dict[str, Any]:
    """Evaluate RAG system performance."""
    global rag_comparator, evaluator
    
    if not rag_comparator or not evaluator:
        raise HTTPException(status_code=400, detail="System not initialized")
    
    try:
        results = []
        
        for i, (query, relevant_ids) in enumerate(zip(request.queries, request.relevant_ids)):
            # Run comparison
            comparison = rag_comparator.compare(query=query, n_results=5)
            
            # Evaluate base RAG
            base_eval = evaluator.evaluate_rag_pipeline(
                comparison['base_rag'],
                relevant_ids,
                k_values=request.k_values
            )
            
            # Evaluate hier RAG
            hier_eval = evaluator.evaluate_rag_pipeline(
                comparison['hier_rag'],
                relevant_ids,
                k_values=request.k_values
            )
            
            results.append({
                "query": query,
                "base_rag": base_eval,
                "hier_rag": hier_eval,
                "speedup": comparison['speedup']
            })
        
        return {
            "status": "success",
            "results": results
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/health")
async def health_check() -> Dict[str, str]:
    """Health check endpoint."""
    return {"status": "healthy"}


@app.get("/info")
async def system_info() -> Dict[str, Any]:
    """Get system information."""
    global index_manager, rag_comparator
    
    info = {
        "initialized": index_manager is not None,
        "rag_ready": rag_comparator is not None
    }
    
    if index_manager:
        info["collections"] = index_manager.list_collections()
    
    return info


# Run server
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)