Spaces:
Sleeping
Sleeping
| """ | |
| Test Suite for DB Service | |
| Comprehensive tests for the plug-and-play DB Service including: | |
| - Configuration | |
| - Permissions (USER/ADMIN/SYSTEM) | |
| - Filtering (user ownership, soft deletes) | |
| - CRUD operations | |
| - Database initialization | |
| """ | |
| import pytest | |
| import os | |
| from datetime import datetime | |
| from sqlalchemy import select | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker | |
| from services.db_service import ( | |
| DBServiceConfig, | |
| QueryService, | |
| init_database, | |
| reset_database, | |
| get_registered_models, | |
| ) | |
| from core.models import ( | |
| Base, User, GeminiJob, PaymentTransaction, Contact, | |
| RateLimit, ApiKeyUsage, ClientUser, AuditLog | |
| ) | |
| # Test database URL | |
| TEST_DB_URL = "sqlite+aiosqlite:///:memory:" | |
| async def engine(): | |
| """Create test database engine.""" | |
| engine = create_async_engine(TEST_DB_URL, echo=False) | |
| yield engine | |
| await engine.dispose() | |
| async def session(engine): | |
| """Create test database session.""" | |
| async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) | |
| async with async_session() as session: | |
| yield session | |
| async def setup_db(engine): | |
| """Setup test database with configuration.""" | |
| # Register configuration | |
| DBServiceConfig.register( | |
| db_base=Base, | |
| all_models=[User, GeminiJob, PaymentTransaction, Contact, | |
| RateLimit, ApiKeyUsage, ClientUser, AuditLog], | |
| user_filter_column="user_id", | |
| user_id_column="id", | |
| soft_delete_column="deleted_at", | |
| special_user_model=User, | |
| user_read_scoped=[User, GeminiJob, PaymentTransaction, Contact], | |
| user_create_scoped=[GeminiJob, PaymentTransaction, Contact], | |
| user_update_scoped=[User, GeminiJob], | |
| user_delete_scoped=[GeminiJob, Contact], | |
| admin_read_only=[RateLimit, ApiKeyUsage, ClientUser, AuditLog], | |
| admin_create_only=[RateLimit, ApiKeyUsage, ClientUser, AuditLog], | |
| admin_update_only=[RateLimit, ApiKeyUsage, ClientUser, PaymentTransaction], | |
| admin_delete_only=[RateLimit, ApiKeyUsage, User], | |
| system_read_scoped=[User, GeminiJob, PaymentTransaction, RateLimit, | |
| ApiKeyUsage, ClientUser, AuditLog], | |
| system_create_scoped=[User, ClientUser, AuditLog, PaymentTransaction, | |
| ApiKeyUsage, GeminiJob, RateLimit], | |
| system_update_scoped=[User, GeminiJob, PaymentTransaction, ApiKeyUsage, | |
| RateLimit, ClientUser], | |
| system_delete_scoped=[GeminiJob, RateLimit, ApiKeyUsage], | |
| ) | |
| # Initialize database | |
| await init_database(engine) | |
| yield | |
| # Cleanup | |
| await reset_database(engine) | |
| async def regular_user(session): | |
| """Create a regular test user.""" | |
| import uuid | |
| user = User( | |
| user_id=str(uuid.uuid4()), | |
| email="user@example.com", | |
| name="Test User", | |
| credits=100 | |
| ) | |
| session.add(user) | |
| await session.commit() | |
| await session.refresh(user) | |
| return user | |
| async def admin_user(session): | |
| """Create an admin test user.""" | |
| import uuid | |
| user = User( | |
| user_id=str(uuid.uuid4()), | |
| email=os.getenv("ADMIN_EMAILS", "admin@example.com").split(",")[0], | |
| name="Admin User", | |
| credits=1000 | |
| ) | |
| session.add(user) | |
| await session.commit() | |
| await session.refresh(user) | |
| return user | |
| async def other_user(session): | |
| """Create another test user.""" | |
| import uuid | |
| user = User( | |
| user_id=str(uuid.uuid4()), | |
| email="other@example.com", | |
| name="Other User", | |
| credits=50 | |
| ) | |
| session.add(user) | |
| await session.commit() | |
| await session.refresh(user) | |
| return user | |
| # ============================================================================ | |
| # Configuration Tests | |
| # ============================================================================ | |
| class TestConfiguration: | |
| """Test DB Service configuration.""" | |
| def test_config_registered(self): | |
| """Test that configuration is registered.""" | |
| assert DBServiceConfig.is_registered() | |
| assert DBServiceConfig.db_base == Base | |
| assert len(DBServiceConfig.all_models) == 8 | |
| def test_get_registered_models(self): | |
| """Test getting registered models.""" | |
| models = get_registered_models() | |
| assert len(models) == 8 | |
| assert User in models | |
| assert GeminiJob in models | |
| def test_column_names(self): | |
| """Test configured column names.""" | |
| assert DBServiceConfig.user_filter_column == "user_id" | |
| assert DBServiceConfig.soft_delete_column == "deleted_at" | |
| assert DBServiceConfig.special_user_model == User | |
| # ============================================================================ | |
| # Permission Tests | |
| # ============================================================================ | |
| class TestPermissions: | |
| """Test USER/ADMIN/SYSTEM permission hierarchy.""" | |
| async def test_user_can_read_own_data(self, session, regular_user): | |
| """Test that users can read their own data.""" | |
| import uuid | |
| job = GeminiJob( | |
| job_id=str(uuid.uuid4()), | |
| user_id=regular_user.id, | |
| job_type="text", | |
| input_data={"prompt": "Test"}, | |
| status="queued" | |
| ) | |
| session.add(job) | |
| await session.commit() | |
| qs = QueryService(regular_user, session) | |
| jobs = await qs.select().execute(select(GeminiJob)) | |
| assert len(jobs) == 1 | |
| assert jobs[0].id == job.id | |
| async def test_user_cannot_read_others_data(self, session, regular_user, other_user): | |
| """Test that users cannot read other users' data.""" | |
| # Create job for other user | |
| import uuid | |
| job = GeminiJob( | |
| job_id=str(uuid.uuid4()), | |
| user_id=other_user.id, | |
| job_type="text", | |
| input_data={"prompt": "Other"}, | |
| status="queued" | |
| ) | |
| session.add(job) | |
| await session.commit() | |
| # Regular user tries to read | |
| qs = QueryService(regular_user, session) | |
| jobs = await qs.select().execute(select(GeminiJob)) | |
| assert len(jobs) == 0 # Should not see other user's jobs | |
| async def test_admin_can_read_all_data(self, session, admin_user, regular_user): | |
| """Test that admins can read all users' data.""" | |
| # Create jobs for different users | |
| import uuid | |
| job1 = GeminiJob( | |
| job_id=str(uuid.uuid4()), | |
| user_id=regular_user.id, | |
| job_type="text", | |
| input_data={"prompt": "User Job"}, | |
| status="queued" | |
| ) | |
| job2 = GeminiJob( | |
| job_id=str(uuid.uuid4()), | |
| user_id=admin_user.id, | |
| job_type="text", | |
| input_data={"prompt": "Admin Job"}, | |
| status="queued" | |
| ) | |
| session.add_all([job1, job2]) | |
| await session.commit() | |
| qs = QueryService(admin_user, session) | |
| jobs = await qs.select().execute(select(GeminiJob)) | |
| assert len(jobs) == 2 # Admin sees all jobs | |
| async def test_user_cannot_access_admin_only_models(self, session, regular_user): | |
| """Test that regular users cannot access admin-only models.""" | |
| qs = QueryService(regular_user, session) | |
| with pytest.raises(Exception) as exc_info: | |
| await qs.select().execute(select(RateLimit)) | |
| assert "403" in str(exc_info.value) or "administrator" in str(exc_info.value).lower() | |
| async def test_admin_can_access_admin_only_models(self, session, admin_user): | |
| """Test that admins can access admin-only models.""" | |
| from datetime import datetime, timedelta | |
| now = datetime.now() | |
| rate_limit = RateLimit( | |
| identifier="test", | |
| endpoint="/api/test", | |
| attempts=10, | |
| window_start=now, | |
| expires_at=now + timedelta(hours=1) | |
| ) | |
| session.add(rate_limit) | |
| await session.commit() | |
| qs = QueryService(admin_user, session) | |
| limits = await qs.select().execute(select(RateLimit)) | |
| assert len(limits) == 1 | |
| async def test_system_can_create_user(self, session, regular_user): | |
| """Test that system operations can create users.""" | |
| qs = QueryService(regular_user, session, is_system=True) | |
| # System should be able to bypass permissions | |
| # (actual create would use direct SQLAlchemy, but permission check passes) | |
| assert qs.is_system is True | |
| # ============================================================================ | |
| # Soft Delete Tests | |
| # ============================================================================ | |
| class TestSoftDeletes: | |
| """Test soft delete functionality.""" | |
| async def test_soft_delete_marks_record(self, session, regular_user): | |
| """Test that soft delete sets deleted_at.""" | |
| import uuid | |
| job = GeminiJob( | |
| job_id=str(uuid.uuid4()), | |
| user_id=regular_user.id, | |
| job_type="text", | |
| input_data={"prompt": "Delete Me"}, | |
| status="queued" | |
| ) | |
| session.add(job) | |
| await session.commit() | |
| qs = QueryService(regular_user, session) | |
| await qs.delete().soft_delete_one(job) | |
| assert job.deleted_at is not None | |
| async def test_soft_deleted_not_in_query(self, session, regular_user): | |
| """Test that soft-deleted records don't appear in queries.""" | |
| import uuid | |
| job = GeminiJob( | |
| job_id=str(uuid.uuid4()), | |
| user_id=regular_user.id, | |
| job_type="text", | |
| input_data={"prompt": "Delete Me"}, | |
| status="queued" | |
| ) | |
| session.add(job) | |
| await session.commit() | |
| qs = QueryService(regular_user, session) | |
| # Before delete | |
| jobs = await qs.select().execute(select(GeminiJob)) | |
| assert len(jobs) == 1 | |
| # After delete | |
| await qs.delete().soft_delete_one(job) | |
| jobs = await qs.select().execute(select(GeminiJob)) | |
| assert len(jobs) == 0 # Should not appear | |
| async def test_admin_can_restore(self, session, admin_user, regular_user): | |
| """Test that admins can restore deleted records.""" | |
| import uuid | |
| job = GeminiJob( | |
| job_id=str(uuid.uuid4()), | |
| user_id=regular_user.id, | |
| job_type="text", | |
| input_data={"prompt": "Restore Me"}, | |
| status="queued" | |
| ) | |
| session.add(job) | |
| await session.commit() | |
| job_id = job.id | |
| qs = QueryService(admin_user, session) | |
| # Delete | |
| await qs.delete().soft_delete_one(job) | |
| assert job.deleted_at is not None | |
| # Restore | |
| await qs.delete().restore_one(job) | |
| assert job.deleted_at is None | |
| async def test_user_cannot_restore(self, session, regular_user): | |
| """Test that regular users cannot restore records.""" | |
| import uuid | |
| job = GeminiJob( | |
| job_id=str(uuid.uuid4()), | |
| user_id=regular_user.id, | |
| job_type="text", | |
| input_data={"prompt": "Deleted"}, | |
| status="queued" | |
| ) | |
| session.add(job) | |
| await session.commit() | |
| qs = QueryService(regular_user, session) | |
| await qs.delete().soft_delete_one(job) | |
| with pytest.raises(Exception) as exc_info: | |
| await qs.delete().restore_one(job) | |
| assert "403" in str(exc_info.value) or "administrator" in str(exc_info.value).lower() | |
| # ============================================================================ | |
| # Database Initialization Tests | |
| # ============================================================================ | |
| class TestDatabaseInitialization: | |
| """Test database initialization utilities.""" | |
| async def test_init_database_creates_tables(self, engine): | |
| """Test that init_database creates all tables.""" | |
| await init_database(engine) | |
| # Verify tables exist by querying | |
| async with AsyncSession(engine) as session: | |
| result = await session.execute(select(User)) | |
| assert result.scalars().all() == [] # Empty but table exists | |
| async def test_reset_database_clears_data(self, engine, session, regular_user): | |
| """Test that reset_database clears all data.""" | |
| import uuid | |
| # Add some data | |
| user = User(user_id=str(uuid.uuid4()), email="test@example.com", name="Test", credits=10) | |
| session.add(user) | |
| await session.commit() | |
| # Reset | |
| await reset_database(engine) | |
| # Verify data cleared | |
| async with AsyncSession(engine) as new_session: | |
| result = await new_session.execute(select(User)) | |
| assert len(result.scalars().all()) == 0 | |
| # ============================================================================ | |
| # Run Tests | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "--tb=short"]) | |