nothingworry's picture
imporve RAG
9d50a01
raw
history blame
10.5 kB
"""
Supabase/PostgreSQL database utilities shared by all MCP tools.
This module provides:
1. Direct PostgreSQL connections (via psycopg2) for pgvector operations
2. A Supabase client for REST-style administrative needs
"""
from __future__ import annotations
import os
from typing import Optional, List, Dict, Any
import psycopg2
import psycopg2.extras
from dotenv import load_dotenv
from supabase import Client, create_client
# Load environment variables
load_dotenv()
# -----------------------------------
# Environment variables
# -----------------------------------
DATABASE_URL = os.getenv("POSTGRESQL_URL") # Direct PostgreSQL connection
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY") # MUST be service role key
# Global Supabase client instance
_supabase_client: Optional[Client] = None
# -----------------------------------
# PostgreSQL Connection (for pgvector)
# -----------------------------------
def get_connection():
"""
Establish a direct PostgreSQL connection for pgvector operations.
"""
if not DATABASE_URL:
raise ValueError(
"PostgreSQL connection string not configured. "
"Set POSTGRESQL_URL in your .env file."
)
return psycopg2.connect(DATABASE_URL)
# -----------------------------------
# Database Schema Initialization
# -----------------------------------
def initialize_database():
"""
Initialize the database schema:
- Enable pgvector extension
- Create documents table with vector support
"""
try:
conn = get_connection()
cur = conn.cursor()
# Enable pgvector extension
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
print("βœ… pgvector extension enabled")
# Create documents table
cur.execute(
"""
CREATE TABLE IF NOT EXISTS documents (
id BIGSERIAL PRIMARY KEY,
tenant_id TEXT NOT NULL,
chunk_text TEXT NOT NULL,
embedding vector(384) NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
"""
)
print("βœ… documents table created")
# Create index for vector similarity search
cur.execute(
"""
CREATE INDEX IF NOT EXISTS documents_embedding_idx
ON documents
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
"""
)
print("βœ… vector index created")
# Create index for tenant_id for faster filtering
cur.execute(
"""
CREATE INDEX IF NOT EXISTS documents_tenant_id_idx
ON documents (tenant_id);
"""
)
print("βœ… tenant_id index created")
conn.commit()
cur.close()
conn.close()
print("βœ… Database schema initialized successfully")
except Exception as e:
print(f"❌ Database initialization error: {e}")
# Don't raise - allow the app to continue even if table exists
if "already exists" not in str(e).lower():
raise
# -----------------------------------
# Document + Embedding Operations
# -----------------------------------
def insert_document_chunks(tenant_id: str, text: str, embedding: list):
"""
Insert document chunk + embedding.
"""
try:
# Normalize tenant_id to ensure consistency
tenant_id = tenant_id.strip()
conn = get_connection()
cur = conn.cursor()
cur.execute(
"""
INSERT INTO documents (tenant_id, chunk_text, embedding)
VALUES (%s, %s, %s);
""",
(tenant_id, text, embedding),
)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print("DB INSERT ERROR:", e)
raise
def search_vectors(tenant_id: str, vector: list, limit: int = 5) -> List[Dict[str, Any]]:
"""
Perform semantic vector search using pgvector.
Results are filtered by tenant_id to ensure data isolation.
"""
try:
# Validate tenant_id
if not tenant_id or not tenant_id.strip():
print("DB SEARCH ERROR: tenant_id is empty")
return []
tenant_id_normalized = tenant_id.strip()
conn = get_connection()
cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
# Query with normalized tenant_id filtering
cur.execute(
"""
SELECT
chunk_text,
tenant_id,
1 - (embedding <=> %s::vector(384)) AS similarity
FROM documents
WHERE TRIM(tenant_id) = %s
ORDER BY embedding <=> %s::vector(384)
LIMIT %s;
""",
(vector, tenant_id_normalized, vector, limit),
)
rows = cur.fetchall()
# Verify all results belong to the requested tenant (safety check)
results: List[Dict[str, Any]] = []
for row in rows:
row_tenant_id = row.get("tenant_id", "")
if row_tenant_id and row_tenant_id.strip() != tenant_id_normalized:
print(
f"WARNING: Found document with tenant_id '{row_tenant_id}' when searching for '{tenant_id_normalized}' - skipping"
)
continue
results.append(
{
"text": row["chunk_text"],
"similarity": float(row.get("similarity", 0.0)),
}
)
cur.close()
conn.close()
return results
except Exception as e:
print(f"DB SEARCH ERROR (tenant_id={tenant_id}): {e}")
import traceback
traceback.print_exc()
return []
def list_all_documents(
tenant_id: str, limit: int = 1000, offset: int = 0
) -> Dict[str, Any]:
"""
List all documents for a tenant with pagination.
tenant_id comparison is normalized via TRIM() to handle historical data.
"""
try:
tenant_id_normalized = tenant_id.strip()
conn = get_connection()
cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
cur.execute(
"""
SELECT
id,
chunk_text,
created_at
FROM documents
WHERE TRIM(tenant_id) = %s
ORDER BY created_at DESC
LIMIT %s OFFSET %s;
""",
(tenant_id_normalized, limit, offset),
)
rows = cur.fetchall()
cur.execute(
"""
SELECT COUNT(*) as total
FROM documents
WHERE TRIM(tenant_id) = %s;
""",
(tenant_id_normalized,),
)
total_row = cur.fetchone()
total = total_row["total"] if total_row else 0
cur.close()
conn.close()
results: List[Dict[str, Any]] = []
for row in rows:
results.append(
{
"id": row["id"],
"text": row["chunk_text"],
"created_at": row["created_at"].isoformat()
if row["created_at"]
else None,
}
)
return {
"documents": results,
"total": total,
"limit": limit,
"offset": offset,
}
except Exception as e:
print("DB LIST ERROR:", e)
return {"documents": [], "total": 0, "limit": limit, "offset": offset}
def delete_document(tenant_id: str, document_id: int) -> bool:
"""
Delete a specific document by ID for a tenant.
Returns True if document was deleted, False otherwise.
"""
try:
tenant_id_normalized = tenant_id.strip()
conn = get_connection()
cur = conn.cursor()
cur.execute(
"""
DELETE FROM documents
WHERE id = %s AND TRIM(tenant_id) = %s;
""",
(document_id, tenant_id_normalized),
)
deleted = cur.rowcount > 0
if deleted:
print(f"DB DELETE: Deleted document {document_id} for tenant '{tenant_id_normalized}'")
else:
print(f"DB DELETE: Document {document_id} not found for tenant '{tenant_id_normalized}'")
conn.commit()
cur.close()
conn.close()
return deleted
except Exception as e:
print(f"DB DELETE ERROR (document_id={document_id}, tenant_id={tenant_id}): {e}")
import traceback
traceback.print_exc()
return False
def delete_all_documents(tenant_id: str) -> int:
"""
Delete all documents for a tenant.
Returns the number of documents deleted.
Handles tenant_id normalization to match documents stored with different formatting.
"""
try:
tenant_id_normalized = tenant_id.strip()
conn = get_connection()
cur = conn.cursor()
cur.execute(
"""
DELETE FROM documents
WHERE TRIM(tenant_id) = %s;
""",
(tenant_id_normalized,),
)
deleted_count = cur.rowcount
print(f"DB DELETE ALL: Deleted {deleted_count} document(s) for tenant '{tenant_id_normalized}'")
conn.commit()
cur.close()
conn.close()
return deleted_count
except Exception as e:
print(f"DB DELETE ALL ERROR (tenant_id={tenant_id}): {e}")
import traceback
traceback.print_exc()
return 0
# -----------------------------------
# Supabase Client (for REST operations)
# -----------------------------------
def get_supabase_client() -> Client:
"""
Get or create Supabase client.
"""
global _supabase_client
if _supabase_client is None:
if not SUPABASE_URL or not SUPABASE_KEY:
raise ValueError(
"Supabase credentials missing. "
"Set SUPABASE_URL and SUPABASE_SERVICE_KEY."
)
_supabase_client = create_client(SUPABASE_URL, SUPABASE_KEY)
return _supabase_client
def reset_client():
global _supabase_client
_supabase_client = None
# Table names
TABLES = {
"tenants": "tenants",
"documents": "documents",
"embeddings": "tenant_embeddings",
"redflag_rules": "redflag_rules",
"analytics": "analytics_events",
"tool_usage": "tool_usage_stats",
}