# ============================================================= # File: backend/mcp_servers/rag_server.py # ============================================================= from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware import sys import os # Fix Python module paths current_dir = os.path.dirname(__file__) sys.path.insert(0, current_dir) from typing import Any, Dict, List from embeddings import embed_text from database import insert_document_chunks, search_vectors from models.rag import IngestRequest, SearchRequest rag_app = FastAPI(title="RAG MCP Server") # Enable CORS rag_app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Wrapper functions to match expected interface def db_insert(tenant_id: str, content: str, vector: list): """Wrapper for insert_document_chunks to match expected interface.""" return insert_document_chunks(tenant_id, content, vector) def db_search(tenant_id: str, vector: list, limit: int = 5): """Wrapper for search_vectors to match expected interface.""" results = search_vectors(tenant_id, vector, limit) return [{"text": text} for text in results] @rag_app.post("/ingest") async def ingest(req: IngestRequest): vector = embed_text(req.content) db_insert(req.tenant_id, req.content, vector) return {"status": "ok"} def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float: import math if not vec_a or not vec_b: return 0.0 numerator = sum(a * b for a, b in zip(vec_a, vec_b)) denom = math.sqrt(sum(a * a for a in vec_a)) * math.sqrt(sum(b * b for b in vec_b)) if denom == 0: return 0.0 return numerator / denom def rank_chunks(chunks: List[Dict[str, Any]], query_embedding: List[float]): ranked = [] for chunk in chunks: chunk_vector = embed_text(chunk.get("text", "")) relevance = cosine_similarity(chunk_vector, query_embedding) chunk["relevance"] = relevance ranked.append(chunk) return sorted(ranked, key=lambda x: x["relevance"], reverse=True) @rag_app.post("/search") async def search(req: SearchRequest): vector = embed_text(req.query) results = db_search(req.tenant_id, vector) ranked = rank_chunks(results, vector) filtered = [chunk for chunk in ranked if chunk["relevance"] >= 0.55][:3] return { "results": filtered, "metadata": { "total_retrieved": len(results), "returned": len(filtered), "threshold": 0.55 } } if __name__ == "__main__": import uvicorn uvicorn.run(rag_app, host="0.0.0.0", port=8001)