nothingworry's picture
imporve RAG
9d50a01
raw
history blame
7.81 kB
import os
import httpx
from dotenv import load_dotenv
load_dotenv()
class RAGClient:
"""
Communicates with the RAG MCP server.
"""
def __init__(self):
self.base_url = os.getenv("RAG_MCP_URL", "http://localhost:8001")
if not self.base_url:
raise ValueError("RAG_MCP_URL environment variable is not set")
self.search_endpoint = f"{self.base_url}/search"
self.ingest_endpoint = f"{self.base_url}/ingest"
async def search(self, query: str, tenant_id: str):
"""
Sends the query to the RAG server and returns document chunks.
Unwraps MCP server responses automatically.
"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.search_endpoint,
json={
"query": query,
"tenant_id": tenant_id
}
)
if response.status_code != 200:
return []
data = response.json()
if isinstance(data, dict) and data.get("status") == "error":
print("RAG Client Error:", data.get("message"))
return []
if isinstance(data, dict) and "data" in data:
payload = data["data"]
return payload.get("results", []) if isinstance(payload, dict) else payload
return data.get("results", []) if isinstance(data, dict) else data
except Exception as e:
print("RAG Client Error:", e)
return []
async def ingest(self, content: str, tenant_id: str):
"""
Sends content to the RAG server for ingestion.
Returns the unwrapped data from the MCP server response.
"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.ingest_endpoint,
json={
"tenant_id": tenant_id,
"content": content
}
)
if response.status_code != 200:
return {"error": f"HTTP {response.status_code}"}
data = response.json()
# MCP server wraps response in a 'data' field
# Extract the actual result data
if isinstance(data, dict) and "data" in data:
result = data["data"]
# Map chunks_ingested to chunks_stored for consistency
if "chunks_ingested" in result:
result["chunks_stored"] = result.pop("chunks_ingested")
return result
# If not wrapped, return as-is (backward compatibility)
return data
except Exception as e:
print("RAG Ingest Error:", e)
return {"error": str(e)}
async def list_documents(self, tenant_id: str, limit: int = 1000, offset: int = 0):
"""
List all documents for a tenant.
Returns the unwrapped data from the MCP server response.
"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/list",
params={
"tenant_id": tenant_id,
"limit": limit,
"offset": offset
}
)
if response.status_code != 200:
return {"documents": [], "total": 0, "limit": limit, "offset": offset}
data = response.json()
# MCP server wraps response in a 'data' field
# Extract the actual result data
if isinstance(data, dict) and "data" in data:
return data["data"]
# If not wrapped, return as-is (backward compatibility)
return data
except Exception as e:
print("RAG List Error:", e)
return {"documents": [], "total": 0, "limit": limit, "offset": offset}
async def delete_document(self, tenant_id: str, document_id: int):
"""
Delete a specific document by ID for a tenant.
Returns the unwrapped data from the MCP server response.
"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.delete(
f"{self.base_url}/delete/{document_id}",
params={"tenant_id": tenant_id}
)
if response.status_code == 404:
return {"error": f"Document {document_id} not found or access denied"}
if response.status_code != 200:
error_text = response.text
try:
error_json = response.json()
error_text = error_json.get("detail", error_text)
except:
pass
return {"error": f"HTTP {response.status_code}: {error_text}"}
data = response.json()
# Check if MCP server returned an error response
if isinstance(data, dict) and data.get("status") == "error":
error_msg = data.get("message", "Unknown error")
return {"error": error_msg}
# MCP server wraps response in a 'data' field
# Extract the actual result data
if isinstance(data, dict) and "data" in data:
return data["data"]
# If not wrapped, return as-is (backward compatibility)
return data
except httpx.ConnectError as e:
print(f"RAG Delete Error: Cannot connect to RAG MCP server at {self.base_url}")
return {"error": f"Cannot connect to RAG MCP server. Is it running at {self.base_url}?"}
except Exception as e:
print(f"RAG Delete Error: {e}")
return {"error": str(e)}
async def delete_all_documents(self, tenant_id: str):
"""
Delete all documents for a tenant.
Returns the unwrapped data from the MCP server response.
"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.delete(
f"{self.base_url}/delete-all",
params={"tenant_id": tenant_id}
)
if response.status_code != 200:
error_text = response.text
try:
error_json = response.json()
error_text = error_json.get("detail", error_text)
except:
pass
return {"error": f"HTTP {response.status_code}: {error_text}"}
data = response.json()
# Check if MCP server returned an error response
if isinstance(data, dict) and data.get("status") == "error":
error_msg = data.get("message", "Unknown error")
return {"error": error_msg}
# MCP server wraps response in a 'data' field
# Extract the actual result data
if isinstance(data, dict) and "data" in data:
return data["data"]
# If not wrapped, return as-is (backward compatibility)
return data
except httpx.ConnectError as e:
print(f"RAG Delete All Error: Cannot connect to RAG MCP server at {self.base_url}")
return {"error": f"Cannot connect to RAG MCP server. Is it running at {self.base_url}?"}
except Exception as e:
print(f"RAG Delete All Error: {e}")
return {"error": str(e)}