nothingworry's picture
Reasoning traces, smarter tools, deterministic backend tests.
ef83e66
raw
history blame
2.72 kB
# =============================================================
# File: backend/mcp_servers/rag_server.py
# =============================================================
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import sys
import os
# Fix Python module paths
current_dir = os.path.dirname(__file__)
sys.path.insert(0, current_dir)
from typing import Any, Dict, List
from embeddings import embed_text
from database import insert_document_chunks, search_vectors
from models.rag import IngestRequest, SearchRequest
rag_app = FastAPI(title="RAG MCP Server")
# Enable CORS
rag_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Wrapper functions to match expected interface
def db_insert(tenant_id: str, content: str, vector: list):
"""Wrapper for insert_document_chunks to match expected interface."""
return insert_document_chunks(tenant_id, content, vector)
def db_search(tenant_id: str, vector: list, limit: int = 5):
"""Wrapper for search_vectors to match expected interface."""
results = search_vectors(tenant_id, vector, limit)
return [{"text": text} for text in results]
@rag_app.post("/ingest")
async def ingest(req: IngestRequest):
vector = embed_text(req.content)
db_insert(req.tenant_id, req.content, vector)
return {"status": "ok"}
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
import math
if not vec_a or not vec_b:
return 0.0
numerator = sum(a * b for a, b in zip(vec_a, vec_b))
denom = math.sqrt(sum(a * a for a in vec_a)) * math.sqrt(sum(b * b for b in vec_b))
if denom == 0:
return 0.0
return numerator / denom
def rank_chunks(chunks: List[Dict[str, Any]], query_embedding: List[float]):
ranked = []
for chunk in chunks:
chunk_vector = embed_text(chunk.get("text", ""))
relevance = cosine_similarity(chunk_vector, query_embedding)
chunk["relevance"] = relevance
ranked.append(chunk)
return sorted(ranked, key=lambda x: x["relevance"], reverse=True)
@rag_app.post("/search")
async def search(req: SearchRequest):
vector = embed_text(req.query)
results = db_search(req.tenant_id, vector)
ranked = rank_chunks(results, vector)
filtered = [chunk for chunk in ranked if chunk["relevance"] >= 0.55][:3]
return {
"results": filtered,
"metadata": {
"total_retrieved": len(results),
"returned": len(filtered),
"threshold": 0.55
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(rag_app, host="0.0.0.0", port=8001)