Spaces:
Sleeping
Sleeping
| """ | |
| Comprehensive Tests for SQLAlchemy Models | |
| Tests cover all 8 models: | |
| 1. User - Authentication, credits, soft delete | |
| 2. ClientUser - Device tracking, IP mapping | |
| 3. AuditLog - Event logging | |
| 4. GeminiJob - Job queue, priority, credit tracking | |
| 5. PaymentTransaction - Payment processing | |
| 6. Contact - Support tickets | |
| 7. RateLimit - Rate limiting | |
| 8. ApiKeyUsage - API key rotation | |
| Tests CRUD operations, relationships, soft deletes, constraints, and indexes. | |
| """ | |
| import pytest | |
| from datetime import datetime, timedelta | |
| from sqlalchemy import select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| # ============================================================================ | |
| # 1. User Model Tests | |
| # ============================================================================ | |
| class TestUserModel: | |
| """Test User model CRUD and features.""" | |
| async def test_create_user(self, db_session): | |
| """Create a new user.""" | |
| from core.models import User | |
| user = User( | |
| user_id="usr_test_001", | |
| email="test@example.com", | |
| google_id="google_123", | |
| name="Test User", | |
| credits=50 | |
| ) | |
| db_session.add(user) | |
| await db_session.commit() | |
| await db_session.refresh(user) | |
| assert user.id is not None | |
| assert user.email == "test@example.com" | |
| assert user.credits == 50 | |
| assert user.token_version == 1 # Default | |
| async def test_user_unique_email(self, db_session): | |
| """Email must be unique.""" | |
| from core.models import User | |
| from sqlalchemy.exc import IntegrityError | |
| user1 = User(user_id="usr_001", email="duplicate@example.com") | |
| db_session.add(user1) | |
| await db_session.commit() | |
| user2 = User(user_id="usr_002", email="duplicate@example.com") | |
| db_session.add(user2) | |
| with pytest.raises(IntegrityError): | |
| await db_session.commit() | |
| async def test_user_token_versioning(self, db_session): | |
| """Test token version increment for logout.""" | |
| from core.models import User | |
| user = User(user_id="usr_tv", email="tv@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| assert user.token_version == 1 | |
| # Increment version (simulate logout) | |
| user.token_version += 1 | |
| await db_session.commit() | |
| assert user.token_version == 2 | |
| async def test_user_soft_delete(self, db_session): | |
| """Test soft delete functionality.""" | |
| from core.models import User | |
| user = User(user_id="usr_del", email="delete@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| # Soft delete | |
| user.deleted_at = datetime.utcnow() | |
| await db_session.commit() | |
| assert user.deleted_at is not None | |
| assert user.id is not None # Still in database | |
| # ============================================================================ | |
| # 2. ClientUser Model Tests | |
| # ============================================================================ | |
| class TestClientUserModel: | |
| """Test ClientUser model for device tracking.""" | |
| async def test_create_client_user(self, db_session): | |
| """Create client user mapping.""" | |
| from core.models import User, ClientUser | |
| user = User(user_id="usr_cu", email="cu@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| client_user = ClientUser( | |
| user_id=user.id, | |
| client_user_id="temp_client_123", | |
| ip_address="192.168.1.1", | |
| device_fingerprint="abc123" | |
| ) | |
| db_session.add(client_user) | |
| await db_session.commit() | |
| assert client_user.id is not None | |
| assert client_user.user_id == user.id | |
| async def test_client_user_anonymous(self, db_session): | |
| """Client user can exist without server user (anonymous).""" | |
| from core.models import ClientUser | |
| client_user = ClientUser( | |
| user_id=None, # Anonymous | |
| client_user_id="anon_123", | |
| ip_address="10.0.0.1" | |
| ) | |
| db_session.add(client_user) | |
| await db_session.commit() | |
| assert client_user.id is not None | |
| assert client_user.user_id is None | |
| async def test_client_user_relationship(self, db_session): | |
| """Test relationship to User.""" | |
| from core.models import User, ClientUser | |
| user = User(user_id="usr_rel", email="rel@example.com") | |
| client1 = ClientUser(client_user_id="c1", ip_address="1.1.1.1") | |
| client2 = ClientUser(client_user_id="c2", ip_address="2.2.2.2") | |
| user.client_users.append(client1) | |
| user.client_users.append(client2) | |
| db_session.add(user) | |
| await db_session.commit() | |
| # Query user's client mappings | |
| result = await db_session.execute( | |
| select(ClientUser).where(ClientUser.user_id == user.id) | |
| ) | |
| clients = result.scalars().all() | |
| assert len(clients) == 2 | |
| # ============================================================================ | |
| # 3. AuditLog Model Tests | |
| # ============================================================================ | |
| class TestAuditLogModel: | |
| """Test AuditLog model.""" | |
| async def test_create_client_audit_log(self, db_session): | |
| """Create client-side audit log.""" | |
| from core.models import AuditLog | |
| log = AuditLog( | |
| log_type="client", | |
| client_user_id="temp_123", | |
| action="page_view", | |
| status="success", | |
| details={"page": "/home"} | |
| ) | |
| db_session.add(log) | |
| await db_session.commit() | |
| assert log.id is not None | |
| assert log.log_type == "client" | |
| async def test_create_server_audit_log(self, db_session): | |
| """Create server-side audit log.""" | |
| from core.models import User, AuditLog | |
| user = User(user_id="usr_audit", email="audit@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| log = AuditLog( | |
| log_type="server", | |
| user_id=user.id, | |
| action="credit_deduction", | |
| status="success", | |
| details={"amount": 5} | |
| ) | |
| db_session.add(log) | |
| await db_session.commit() | |
| assert log.user_id == user.id | |
| # ============================================================================ | |
| # 4. GeminiJob Model Tests | |
| # ============================================================================ | |
| class TestGeminiJobModel: | |
| """Test GeminiJob model.""" | |
| async def test_create_job(self, db_session): | |
| """Create a Gemini job.""" | |
| from core.models import User, GeminiJob | |
| user = User(user_id="usr_job", email="job@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| job = GeminiJob( | |
| job_id="job_001", | |
| user_id=user.id, | |
| job_type="video", | |
| status="queued", | |
| priority="fast", | |
| credits_reserved=10 | |
| ) | |
| db_session.add(job) | |
| await db_session.commit() | |
| assert job.id is not None | |
| assert job.status == "queued" | |
| assert job.credits_reserved == 10 | |
| async def test_job_status_transitions(self, db_session): | |
| """Test job status lifecycle.""" | |
| from core.models import User, GeminiJob | |
| user = User(user_id="usr_status", email="status@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| job = GeminiJob( | |
| job_id="job_lifecycle", | |
| user_id=user.id, | |
| job_type="video", | |
| status="queued" | |
| ) | |
| db_session.add(job) | |
| await db_session.commit() | |
| # Transition to processing | |
| job.status = "processing" | |
| job.started_at = datetime.utcnow() | |
| await db_session.commit() | |
| # Transition to completed | |
| job.status = "completed" | |
| job.completed_at = datetime.utcnow() | |
| await db_session.commit() | |
| assert job.status == "completed" | |
| assert job.started_at is not None | |
| assert job.completed_at is not None | |
| async def test_job_priority_system(self, db_session): | |
| """Test job priority tiers.""" | |
| from core.models import User, GeminiJob | |
| user = User(user_id="usr_priority", email="priority@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| fast_job = GeminiJob(job_id="job_fast", user_id=user.id, job_type="video", priority="fast") | |
| medium_job = GeminiJob(job_id="job_medium", user_id=user.id, job_type="video", priority="medium") | |
| slow_job = GeminiJob(job_id="job_slow", user_id=user.id, job_type="video", priority="slow") | |
| db_session.add_all([fast_job, medium_job, slow_job]) | |
| await db_session.commit() | |
| # Query by priority | |
| result = await db_session.execute( | |
| GeminiJob.user_id == user.id # Filter by this user only | |
| select(GeminiJob).where(GeminiJob.priority == "fast") | |
| ) | |
| jobs = result.scalars().all() | |
| assert len(jobs) == 1 | |
| assert jobs[0].job_id == "job_fast" | |
| # ============================================================================ | |
| # 5. PaymentTransaction Model Tests | |
| # ============================================================================ | |
| class TestPaymentTransactionModel: | |
| """Test PaymentTransaction model.""" | |
| async def test_create_payment(self, db_session): | |
| """Create payment transaction.""" | |
| from core.models import User, PaymentTransaction | |
| user = User(user_id="usr_pay", email="pay@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| payment = PaymentTransaction( | |
| transaction_id="txn_001", | |
| user_id=user.id, | |
| gateway="razorpay", | |
| package_id="starter", | |
| credits_amount=100, | |
| amount_paise=9900, | |
| status="created" | |
| ) | |
| db_session.add(payment) | |
| await db_session.commit() | |
| assert payment.id is not None | |
| assert payment.amount_paise == 9900 | |
| async def test_payment_status_transitions(self, db_session): | |
| """Test payment status changes.""" | |
| from core.models import User, PaymentTransaction | |
| user = User(user_id="usr_paystat", email="paystat@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| payment = PaymentTransaction( | |
| transaction_id="txn_002", | |
| user_id=user.id, | |
| gateway="razorpay", | |
| package_id="pro", | |
| credits_amount=1000, | |
| amount_paise=49900, | |
| status="created" | |
| ) | |
| db_session.add(payment) | |
| await db_session.commit() | |
| # Payment completed | |
| payment.status = "paid" | |
| payment.paid_at = datetime.utcnow() | |
| payment.gateway_payment_id = "pay_abc123" | |
| await db_session.commit() | |
| assert payment.status == "paid" | |
| assert payment.paid_at is not None | |
| # ============================================================================ | |
| # 6. Contact Model Tests | |
| # ============================================================================ | |
| class TestContactModel: | |
| """Test Contact model.""" | |
| async def test_create_contact(self, db_session): | |
| """Create contact form submission.""" | |
| from core.models import User, Contact | |
| user = User(user_id="usr_contact", email="contact@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| contact = Contact( | |
| user_id=user.id, | |
| email=user.email, | |
| subject="Help with credits", | |
| message="I need assistance with my credit balance.", | |
| ip_address="192.168.1.100" | |
| ) | |
| db_session.add(contact) | |
| await db_session.commit() | |
| assert contact.id is not None | |
| assert contact.subject == "Help with credits" | |
| # ============================================================================ | |
| # 7. RateLimit Model Tests | |
| # ============================================================================ | |
| class TestRateLimitModel: | |
| """Test RateLimit model.""" | |
| async def test_create_rate_limit(self, db_session): | |
| """Create rate limit entry.""" | |
| from core.models import RateLimit | |
| now = datetime.utcnow() | |
| rate_limit = RateLimit( | |
| identifier="192.168.1.1", | |
| endpoint="/auth/google", | |
| attempts=1, | |
| window_start=now, | |
| expires_at=now + timedelta(minutes=15) | |
| ) | |
| db_session.add(rate_limit) | |
| await db_session.commit() | |
| assert rate_limit.id is not None | |
| assert rate_limit.attempts == 1 | |
| async def test_rate_limit_increment(self, db_session): | |
| """Increment rate limit attempts.""" | |
| from core.models import RateLimit | |
| now = datetime.utcnow() | |
| rate_limit = RateLimit( | |
| identifier="10.0.0.1", | |
| endpoint="/auth/refresh", | |
| attempts=1, | |
| window_start=now, | |
| expires_at=now + timedelta(minutes=15) | |
| ) | |
| db_session.add(rate_limit) | |
| await db_session.commit() | |
| # Increment attempts | |
| rate_limit.attempts += 1 | |
| await db_session.commit() | |
| assert rate_limit.attempts == 2 | |
| # ============================================================================ | |
| # 8. ApiKeyUsage Model Tests | |
| # ============================================================================ | |
| class TestApiKeyUsageModel: | |
| """Test ApiKeyUsage model.""" | |
| async def test_create_api_key_usage(self, db_session): | |
| """Create API key usage tracking.""" | |
| from core.models import ApiKeyUsage | |
| usage = ApiKeyUsage( | |
| key_index=0, | |
| total_requests=0, | |
| success_count=0, | |
| failure_count=0 | |
| ) | |
| db_session.add(usage) | |
| await db_session.commit() | |
| assert usage.id is not None | |
| assert usage.key_index == 0 | |
| async def test_api_key_usage_tracking(self, db_session): | |
| """Track API key usage stats.""" | |
| from core.models import ApiKeyUsage | |
| usage = ApiKeyUsage(key_index=1) | |
| db_session.add(usage) | |
| await db_session.commit() | |
| # Simulate successful request | |
| usage.total_requests += 1 | |
| usage.success_count += 1 | |
| usage.last_used_at = datetime.utcnow() | |
| await db_session.commit() | |
| assert usage.total_requests == 1 | |
| assert usage.success_count == 1 | |
| # Simulate failed request | |
| usage.total_requests += 1 | |
| usage.failure_count += 1 | |
| usage.last_error = "Quota exceeded" | |
| await db_session.commit() | |
| assert usage.total_requests == 2 | |
| assert usage.failure_count == 1 | |
| # ============================================================================ | |
| # Relationship Tests | |
| # ============================================================================ | |
| class TestModelRelationships: | |
| """Test relationships between models.""" | |
| async def test_user_jobs_relationship(self, db_session): | |
| """User can have multiple jobs.""" | |
| from core.models import User, GeminiJob | |
| user = User(user_id="usr_jobs", email="jobs@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| job1 = GeminiJob(job_id="job_rel_1", user_id=user.id, job_type="video") | |
| job2 = GeminiJob(job_id="job_rel_2", user_id=user.id, job_type="image") | |
| db_session.add_all([job1, job2]) | |
| await db_session.commit() | |
| # Query user's jobs | |
| result = await db_session.execute( | |
| select(GeminiJob).where(GeminiJob.user_id == user.id) | |
| ) | |
| jobs = result.scalars().all() | |
| assert len(jobs) == 2 | |
| async def test_user_payments_relationship(self, db_session): | |
| """User can have multiple payments.""" | |
| from core.models import User, PaymentTransaction | |
| user = User(user_id="usr_payments", email="payments@example.com") | |
| db_session.add(user) | |
| await db_session.commit() | |
| payment1 = PaymentTransaction( | |
| transaction_id="txn_1", | |
| user_id=user.id, | |
| gateway="razorpay", | |
| package_id="starter", | |
| credits_amount=100, | |
| amount_paise=9900 | |
| ) | |
| payment2 = PaymentTransaction( | |
| transaction_id="txn_2", | |
| user_id=user.id, | |
| gateway="razorpay", | |
| package_id="pro", | |
| credits_amount=1000, | |
| amount_paise=49900 | |
| ) | |
| db_session.add_all([payment1, payment2]) | |
| await db_session.commit() | |
| # Query user's payments | |
| result = await db_session.execute( | |
| select(PaymentTransaction).where(PaymentTransaction.user_id == user.id) | |
| ) | |
| payments = result.scalars().all() | |
| assert len(payments) == 2 | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |