IntegraChat / check_rag_database.py
nothingworry's picture
all the thing
78b6d7b
raw
history blame
5.02 kB
"""
Diagnostic script to check RAG database tenant isolation
This script directly queries the database to verify tenant_id isolation.
"""
import sys
from pathlib import Path
# Add backend to path
backend_dir = Path(__file__).parent / "backend"
sys.path.insert(0, str(backend_dir))
def check_database():
"""Check database directly for tenant isolation"""
print("\n" + "="*60)
print("RAG Database Tenant Isolation Check")
print("="*60)
try:
from mcp_server.common.database import get_connection
import psycopg2.extras
conn = get_connection()
cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
# Check all tenant_ids in database
print("\n1. Checking all tenant_ids in database...")
cur.execute("SELECT DISTINCT tenant_id, COUNT(*) as count FROM documents GROUP BY tenant_id")
rows = cur.fetchall()
if not rows:
print(" ⚠️ No documents found in database")
cur.close()
conn.close()
return
print(f" Found {len(rows)} unique tenant(s):")
for row in rows:
print(f" - tenant_id: '{row['tenant_id']}' ({row['count']} documents)")
# Check for tenant1 documents
print("\n2. Checking documents for 'verify_tenant1'...")
cur.execute(
"SELECT id, tenant_id, LEFT(chunk_text, 50) as preview FROM documents WHERE tenant_id = %s LIMIT 5",
("verify_tenant1",)
)
tenant1_docs = cur.fetchall()
print(f" Found {len(tenant1_docs)} documents for verify_tenant1")
for doc in tenant1_docs:
preview = doc['preview'].replace('\n', ' ')
print(f" - ID: {doc['id']}, tenant_id: '{doc['tenant_id']}', preview: {preview[:50]}...")
# Check for tenant2 documents
print("\n3. Checking documents for 'verify_tenant2'...")
cur.execute(
"SELECT id, tenant_id, LEFT(chunk_text, 50) as preview FROM documents WHERE tenant_id = %s LIMIT 5",
("verify_tenant2",)
)
tenant2_docs = cur.fetchall()
print(f" Found {len(tenant2_docs)} documents for verify_tenant2")
for doc in tenant2_docs:
preview = doc['preview'].replace('\n', ' ')
print(f" - ID: {doc['id']}, tenant_id: '{doc['tenant_id']}', preview: {preview[:50]}...")
# Test search_vectors function directly
print("\n4. Testing search_vectors function directly...")
from mcp_server.common.embeddings import embed_text
from mcp_server.common.database import search_vectors
# Search for tenant1's secret as tenant1
query = "TENANT1_SECRET"
query_vector = embed_text(query)
results_tenant1 = search_vectors("verify_tenant1", query_vector, limit=5)
print(f" Searching for '{query}' as verify_tenant1: {len(results_tenant1)} results")
for i, result in enumerate(results_tenant1[:2], 1):
text_preview = result['text'][:80].replace('\n', ' ')
print(f" Result {i}: {text_preview}...")
# Search for tenant1's secret as tenant2 (should NOT find)
results_tenant2 = search_vectors("verify_tenant2", query_vector, limit=5)
print(f" Searching for '{query}' as verify_tenant2: {len(results_tenant2)} results")
if results_tenant2:
print(" ⚠️ WARNING: tenant2 found tenant1's secret!")
for i, result in enumerate(results_tenant2[:2], 1):
text_preview = result['text'][:80].replace('\n', ' ')
print(f" Result {i}: {text_preview}...")
else:
print(" ✅ PASSED: tenant2 cannot see tenant1's secret")
# Check for any documents with wrong tenant_id
print("\n5. Checking for data integrity issues...")
cur.execute("""
SELECT tenant_id, COUNT(*) as count
FROM documents
WHERE tenant_id IN ('verify_tenant1', 'verify_tenant2')
GROUP BY tenant_id
""")
integrity_check = cur.fetchall()
print(" Tenant document counts:")
for row in integrity_check:
print(f" - {row['tenant_id']}: {row['count']} documents")
cur.close()
conn.close()
print("\n" + "="*60)
if results_tenant2 and "TENANT1_SECRET" in str(results_tenant2):
print("❌ ISOLATION FAILED: tenant2 can see tenant1's documents")
else:
print("✅ Database isolation appears to be working correctly")
print("="*60)
except ImportError as e:
print(f"\n❌ Import error: {e}")
print(" Make sure you're running from the project root directory")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
check_database()