"""
Performance and security tests for the AI Backend with RAG + Authentication
"""
import pytest
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
from fastapi.testclient import TestClient
import jwt
from ..main import app
from ..auth.auth import TokenData
from ..config.settings import settings
from ..db import crud
from ..rag.pipeline import search_documents
@pytest.fixture
def client():
"""Test client fixture"""
return TestClient(app)
@pytest.mark.asyncio
async def test_authentication_performance():
"""Test authentication endpoint performance"""
# Mock the database operations
mock_db = AsyncMock()
mock_user = MagicMock()
mock_user.id = uuid4()
mock_user.email = "performance_test@example.com"
mock_user.hashed_password = "hashed_password"
mock_user.is_active = True
with patch('..config.database.get_db_session', return_value=mock_db):
with patch('..db.crud.get_user_by_email', return_value=mock_user):
with patch('..auth.auth.verify_password', return_value=True):
with patch('..auth.auth.create_user_token', return_value="fake_jwt_token"):
start_time = time.time()
# Test login performance
for _ in range(10): # Test multiple calls to get average
response = await crud.get_user_by_email(mock_db, "performance_test@example.com")
end_time = time.time()
elapsed = end_time - start_time
# Authentication should complete within reasonable time (under 100ms for 10 calls)
assert elapsed < 0.5 # 500ms for 10 operations is acceptable
@pytest.mark.asyncio
async def test_search_performance():
"""Test search endpoint performance"""
# Mock the embedding result
mock_embedding = [0.1, 0.2, 0.3] + [0.0] * (1536 - 3) # 1536 dimensions
# Mock the search results
mock_search_results = [
{
"id": f"point_id_{i}",
"document_id": str(uuid4()),
"score": 0.9 - (i * 0.1), # Decreasing scores
"payload": {"chunk_text": f"This is relevant context #{i}.", "user_id": str(uuid4())}
} for i in range(5) # 5 results
]
with patch('..embeddings.gemini_client.generate_embedding', return_value=mock_embedding):
with patch('..qdrant.operations.VectorOperations.search_vectors', return_value=mock_search_results):
user_id = uuid4()
start_time = time.time()
# Test search performance
for _ in range(5): # Test multiple searches
result = await search_documents(
query="Performance test query",
user_id=user_id,
top_k=5
)
assert result is not None
assert len(result) <= 5 # Should not exceed top_k
end_time = time.time()
elapsed = end_time - start_time
# Search should complete within reasonable time (under 500ms for 5 searches)
assert elapsed < 2.0 # 2 seconds for 5 searches is acceptable
@pytest.mark.asyncio
async def test_user_isolation_in_search():
"""Test that users can only access their own documents in search results"""
# Mock the embedding result
mock_embedding = [0.4, 0.5, 0.6] + [0.0] * (1536 - 3) # 1536 dimensions
# Mock search results for user A
user_a_id = uuid4()
user_b_id = uuid4()
mock_search_results_user_a = [
{
"id": "point_id_1",
"document_id": str(uuid4()),
"score": 0.9,
"payload": {"chunk_text": "User A's document", "user_id": str(user_a_id)}
}
]
mock_search_results_user_b = [
{
"id": "point_id_2",
"document_id": str(uuid4()),
"score": 0.85,
"payload": {"chunk_text": "User B's document", "user_id": str(user_b_id)}
}
]
with patch('..embeddings.gemini_client.generate_embedding', return_value=mock_embedding):
with patch('..qdrant.operations.VectorOperations.search_vectors') as mock_search:
# Mock search for user A - should only return user A's documents
mock_search.return_value = mock_search_results_user_a
result_a = await search_documents(
query="Test query",
user_id=user_a_id,
top_k=5
)
# Verify all results belong to user A
for result in result_a:
assert result["payload"]["user_id"] == str(user_a_id)
# Mock search for user B - should only return user B's documents
mock_search.return_value = mock_search_results_user_b
result_b = await search_documents(
query="Test query",
user_id=user_b_id,
top_k=5
)
# Verify all results belong to user B
for result in result_b:
assert result["payload"]["user_id"] == str(user_b_id)
@pytest.mark.asyncio
async def test_jwt_token_security():
"""Test JWT token security and validation"""
from ..auth.auth import create_access_token, decode_access_token
user_id = uuid4()
data = {"sub": "test_user", "user_id": str(user_id)}
# Create a token
token = create_access_token(data)
assert token is not None
# Decode and verify the token
decoded = decode_access_token(token)
assert decoded is not None
assert decoded.username == "test_user"
assert decoded.user_id == str(user_id)
# Test invalid token
invalid_token = "invalid.token.string"
decoded_invalid = decode_access_token(invalid_token)
assert decoded_invalid is None
# Test token with wrong secret
wrong_secret_token = jwt.encode(data, "wrong_secret", algorithm=settings.jwt_algorithm)
decoded_wrong = decode_access_token(wrong_secret_token)
assert decoded_wrong is None
@pytest.mark.asyncio
async def test_rate_limiting_simulation():
"""Simulate rate limiting functionality"""
# While we can't easily test the actual rate limiting middleware in unit tests,
# we can verify that the rate limiting functions exist and are properly configured
from ..embeddings.gemini_client import rate_limit, generate_embedding_with_rate_limit
# Verify the rate limit decorator exists and is callable
assert callable(rate_limit)
# Test that the rate-limited function exists
assert callable(generate_embedding_with_rate_limit)
@pytest.mark.asyncio
async def test_password_hashing_security():
"""Test password hashing security"""
from ..auth.auth import get_password_hash, verify_password
password = "secure_test_password_123!"
# Hash the password
hashed = get_password_hash(password)
assert hashed is not None
assert hashed != password # Should not be plain text
assert len(hashed) > 0 # Should have content
assert "$2b$" in hashed # Should be bcrypt hash
# Verify the password works
assert verify_password(password, hashed) == True
# Verify wrong password fails
assert verify_password("wrong_password", hashed) == False
# Verify same password produces different hashes (salt)
hashed2 = get_password_hash(password)
assert hashed != hashed2 # Due to salting
@pytest.mark.asyncio
async def test_api_response_times():
"""Test that API responses meet performance requirements"""
# Mock the token decoding
mock_token_data = TokenData(username="perf_test@example.com", user_id=str(uuid4()))
# Mock the database operations
mock_db = AsyncMock()
mock_user = MagicMock()
mock_user.id = uuid4()
mock_user.email = "perf_test@example.com"
mock_user.full_name = "Performance Test User"
mock_user.is_active = True
mock_user.created_at = MagicMock()
with patch('..config.database.get_db_session', return_value=mock_db):
with patch('..auth.auth.get_current_user', return_value=mock_token_data):
with patch('..db.crud.get_user_by_id', return_value=mock_user):
# Test /auth/me endpoint response time
start_time = time.time()
# Simulate the operation that would happen in the endpoint
_ = await crud.get_user_by_id(mock_db, mock_token_data.user_id)
end_time = time.time()
# Operation should complete quickly (under 100ms)
assert (end_time - start_time) < 0.1
@pytest.mark.asyncio
async def test_document_content_security():
"""Test that document content doesn't contain dangerous patterns"""
from ..routes.documents import save_document
from ..models.documents import DocumentCreate
# Test document with potentially dangerous content
dangerous_content = ""
# The validation should catch this
try:
# Simulate the validation that happens in the endpoint
dangerous_patterns = ['