Spaces:
Sleeping
Sleeping
| """ | |
| Comprehensive Tests for Auth Router | |
| Tests cover: | |
| 1. POST /auth/check-registration endpoint | |
| 2. POST /auth/google endpoint (Google Sign-In) | |
| 3. GET /auth/me endpoint (Get current user info) | |
| 4. POST /auth/refresh endpoint (Token refresh) | |
| 5. POST /auth/logout endpoint (User logout) | |
| Uses mocked Google Auth service and database. | |
| """ | |
| import pytest | |
| from datetime import datetime | |
| from unittest.mock import patch, MagicMock, AsyncMock | |
| from fastapi.testclient import TestClient | |
| # ============================================================================ | |
| # 1. POST /auth/check-registration Tests | |
| # ============================================================================ | |
| class TestCheckRegistration: | |
| """Test POST /auth/check-registration endpoint.""" | |
| def test_check_registration_not_registered(self): | |
| """Unregistered temp user returns is_registered=False.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| app = FastAPI() | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = None # No ClientUser found | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=True): | |
| response = client.post( | |
| "/auth/check-registration", | |
| json={"user_id": "temp_user_123"} | |
| ) | |
| assert response.status_code == 200 | |
| assert response.json()["is_registered"] == False | |
| def test_check_registration_is_registered(self): | |
| """Registered temp user returns is_registered=True.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| app = FastAPI() | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| # Mock ClientUser exists | |
| mock_client_user = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = mock_client_user | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=True): | |
| response = client.post( | |
| "/auth/check-registration", | |
| json={"user_id": "temp_user_123"} | |
| ) | |
| assert response.status_code == 200 | |
| assert response.json()["is_registered"] == True | |
| def test_check_registration_rate_limited(self): | |
| """Rate limit blocks excessive requests.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| app = FastAPI() | |
| async def mock_get_db(): | |
| yield AsyncMock() | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=False): | |
| response = client.post( | |
| "/auth/check-registration", | |
| json={"user_id": "temp_user_123"} | |
| ) | |
| assert response.status_code == 429 | |
| assert "too many" in response.json()["detail"].lower() | |
| # ============================================================================ | |
| # 2. POST /auth/google Tests | |
| # ============================================================================ | |
| class TestGoogleAuth: | |
| """Test POST /auth/google endpoint.""" | |
| def test_google_auth_new_user(self): | |
| """New user sign-in creates user account.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| from core.models import User | |
| app = FastAPI() | |
| # Mock Google user info | |
| mock_google_user = MagicMock() | |
| mock_google_user.google_id = "123456" | |
| mock_google_user.email = "newuser@example.com" | |
| mock_google_user.name = "New User" | |
| mock_google_user.picture = "https://example.com/pic.jpg" | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| # First query: user doesn't exist | |
| mock_result = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = None | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.get_google_auth_service') as mock_service, \ | |
| patch('routers.auth.check_rate_limit', return_value=True), \ | |
| patch('routers.auth.AuditService.log_event', return_value=AsyncMock()), \ | |
| patch('services.backup_service.get_backup_service'): | |
| mock_service.return_value.verify_token.return_value = mock_google_user | |
| response = client.post( | |
| "/auth/google", | |
| json={"id_token": "fake-google-token"} | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["success"] == True | |
| assert "access_token" in data | |
| assert data["email"] == "newuser@example.com" | |
| assert data["is_new_user"] == True | |
| def test_google_auth_existing_user(self): | |
| """Existing user sign-in returns user data.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| from core.models import User | |
| app = FastAPI() | |
| # Mock existing user | |
| mock_user = MagicMock(spec=User) | |
| mock_user.id = 1 | |
| mock_user.user_id = "usr_existing" | |
| mock_user.email = "existing@example.com" | |
| mock_user.google_id = "123456" | |
| mock_user.name = "Existing User" | |
| mock_user.credits = 100 | |
| mock_user.token_version = 1 | |
| mock_user.profile_picture = "https://example.com/pic.jpg" | |
| # Mock Google user info | |
| mock_google_user = MagicMock() | |
| mock_google_user.google_id = "123456" | |
| mock_google_user.email = "existing@example.com" | |
| mock_google_user.name = "Existing User" | |
| mock_google_user.picture = "https://example.com/pic.jpg" | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = mock_user | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.get_google_auth_service') as mock_service, \ | |
| patch('routers.auth.check_rate_limit', return_value=True), \ | |
| patch('routers.auth.AuditService.log_event', return_value=AsyncMock()), \ | |
| patch('services.backup_service.get_backup_service'): | |
| mock_service.return_value.verify_token.return_value = mock_google_user | |
| response = client.post( | |
| "/auth/google", | |
| json={"id_token": "fake-google-token"} | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["success"] == True | |
| assert data["user_id"] == "usr_existing" | |
| assert data["is_new_user"] == False | |
| assert data["credits"] == 100 | |
| def test_google_auth_web_client_cookie(self): | |
| """Web client receives refresh token as HttpOnly cookie.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| from core.models import User | |
| app = FastAPI() | |
| mock_user = MagicMock(spec=User) | |
| mock_user.id = 1 | |
| mock_user.user_id = "usr_web" | |
| mock_user.email = "web@example.com" | |
| mock_user.name = "Web User" | |
| mock_user.credits = 50 | |
| mock_user.token_version = 1 | |
| mock_google_user = MagicMock() | |
| mock_google_user.google_id = "web123" | |
| mock_google_user.email = "web@example.com" | |
| mock_google_user.name = "Web User" | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = mock_user | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.get_google_auth_service') as mock_service, \ | |
| patch('routers.auth.check_rate_limit', return_value=True), \ | |
| patch('routers.auth.AuditService.log_event', return_value=AsyncMock()), \ | |
| patch('services.backup_service.get_backup_service'), \ | |
| patch('routers.auth.detect_client_type', return_value="web"): | |
| mock_service.return_value.verify_token.return_value = mock_google_user | |
| response = client.post( | |
| "/auth/google", | |
| json={"id_token": "fake-google-token"}, | |
| headers={"User-Agent": "Mozilla/5.0"} | |
| ) | |
| assert response.status_code == 200 | |
| # Check cookie was set | |
| assert "refresh_token" in response.cookies | |
| # Refresh token should NOT be in JSON body for web | |
| data = response.json() | |
| assert "refresh_token" not in data | |
| def test_google_auth_mobile_client_json(self): | |
| """Mobile client receives refresh token in JSON body.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| from core.models import User | |
| app = FastAPI() | |
| mock_user = MagicMock(spec=User) | |
| mock_user.id = 1 | |
| mock_user.user_id = "usr_mobile" | |
| mock_user.email = "mobile@example.com" | |
| mock_user.name = "Mobile User" | |
| mock_user.credits = 50 | |
| mock_user.token_version = 1 | |
| mock_google_user = MagicMock() | |
| mock_google_user.google_id = "mobile123" | |
| mock_google_user.email = "mobile@example.com" | |
| mock_google_user.name = "Mobile User" | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = mock_user | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.get_google_auth_service') as mock_service, \ | |
| patch('routers.auth.check_rate_limit', return_value=True), \ | |
| patch('routers.auth.AuditService.log_event', return_value=AsyncMock()), \ | |
| patch('services.backup_service.get_backup_service'), \ | |
| patch('routers.auth.detect_client_type', return_value="mobile"): | |
| mock_service.return_value.verify_token.return_value = mock_google_user | |
| response = client.post( | |
| "/auth/google", | |
| json={"id_token": "fake-google-token"}, | |
| headers={"User-Agent": "MyApp/1.0"} | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Refresh token SHOULD be in JSON body for mobile | |
| assert "refresh_token" in data | |
| def test_google_auth_invalid_token(self): | |
| """Invalid Google token returns 401.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| from services.auth_service.google_provider import InvalidTokenError | |
| app = FastAPI() | |
| async def mock_get_db(): | |
| yield AsyncMock() | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.get_google_auth_service') as mock_service, \ | |
| patch('routers.auth.check_rate_limit', return_value=True): | |
| mock_service.return_value.verify_token.side_effect = InvalidTokenError("Invalid token") | |
| response = client.post( | |
| "/auth/google", | |
| json={"id_token": "invalid-token"} | |
| ) | |
| assert response.status_code == 401 | |
| assert "invalid" in response.json()["detail"].lower() | |
| def test_google_auth_rate_limited(self): | |
| """Rate limit blocks excessive requests.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| app = FastAPI() | |
| async def mock_get_db(): | |
| yield AsyncMock() | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=False): | |
| response = client.post( | |
| "/auth/google", | |
| json={"id_token": "any-token"} | |
| ) | |
| assert response.status_code == 429 | |
| # ============================================================================ | |
| # 3. GET /auth/me Tests | |
| # ============================================================================ | |
| class TestGetCurrentUserInfo: | |
| """Test GET /auth/me endpoint.""" | |
| def test_get_me_requires_auth(self): | |
| """GET /me requires authentication.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| app = FastAPI() | |
| app.include_router(router) | |
| client = TestClient(app) | |
| response = client.get("/auth/me") | |
| # Should fail with auth error | |
| assert response.status_code in [401, 403, 422] | |
| def test_get_me_returns_user_info(self): | |
| """GET /me returns authenticated user info.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.dependencies import get_current_user | |
| from core.models import User | |
| app = FastAPI() | |
| # Mock authenticated user | |
| mock_user = MagicMock(spec=User) | |
| mock_user.user_id = "usr_123" | |
| mock_user.email = "user@example.com" | |
| mock_user.name = "Test User" | |
| mock_user.credits = 75 | |
| mock_user.profile_picture = "https://example.com/pic.jpg" | |
| app.dependency_overrides[get_current_user] = lambda: mock_user | |
| app.include_router(router) | |
| client = TestClient(app) | |
| response = client.get("/auth/me") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["user_id"] == "usr_123" | |
| assert data["email"] == "user@example.com" | |
| assert data["name"] == "Test User" | |
| assert data["credits"] == 75 | |
| # ============================================================================ | |
| # 4. POST /auth/refresh Tests | |
| # ============================================================================ | |
| class TestTokenRefresh: | |
| """Test POST /auth/refresh endpoint.""" | |
| def test_refresh_with_valid_token_in_body(self): | |
| """Refresh with valid token in body returns new tokens.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| from core.models import User | |
| from services.auth_service.jwt_provider import create_refresh_token | |
| app = FastAPI() | |
| # Create a valid refresh token | |
| refresh_token = create_refresh_token("usr_123", "user@example.com", token_version=1) | |
| mock_user = MagicMock(spec=User) | |
| mock_user.user_id = "usr_123" | |
| mock_user.email = "user@example.com" | |
| mock_user.token_version = 1 | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = mock_user | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=True): | |
| response = client.post( | |
| "/auth/refresh", | |
| json={"token": refresh_token} | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["success"] == True | |
| assert "access_token" in data | |
| assert "refresh_token" in data # New refresh token (rotation) | |
| def test_refresh_with_cookie(self): | |
| """Refresh with cookie returns new tokens and rotates cookie.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| from core.models import User | |
| from services.auth_service.jwt_provider import create_refresh_token | |
| app = FastAPI() | |
| refresh_token = create_refresh_token("usr_456", "user2@example.com", token_version=1) | |
| mock_user = MagicMock(spec=User) | |
| mock_user.user_id = "usr_456" | |
| mock_user.email = "user2@example.com" | |
| mock_user.token_version = 1 | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = mock_user | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=True): | |
| # Set refresh token in cookie | |
| client.cookies.set("refresh_token", refresh_token) | |
| response = client.post( | |
| "/auth/refresh", | |
| json={} # Empty body, token from cookie | |
| ) | |
| assert response.status_code == 200 | |
| # Cookie should be rotated | |
| assert "refresh_token" in response.cookies | |
| def test_refresh_missing_token(self): | |
| """Refresh without token returns 401.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| app = FastAPI() | |
| async def mock_get_db(): | |
| yield AsyncMock() | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=True): | |
| response = client.post( | |
| "/auth/refresh", | |
| json={} # No token | |
| ) | |
| assert response.status_code == 401 | |
| assert "missing" in response.json()["detail"].lower() | |
| def test_refresh_wrong_token_type(self): | |
| """Refresh with access token (not refresh) returns 401.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| from services.auth_service.jwt_provider import create_access_token | |
| app = FastAPI() | |
| # Create access token instead of refresh | |
| access_token = create_access_token("usr_123", "user@example.com") | |
| async def mock_get_db(): | |
| yield AsyncMock() | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=True): | |
| response = client.post( | |
| "/auth/refresh", | |
| json={"token": access_token} | |
| ) | |
| assert response.status_code == 401 | |
| assert "invalid token type" in response.json()["detail"].lower() | |
| def test_refresh_invalidated_token(self): | |
| """Refresh with old token version returns 401.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| from core.models import User | |
| from services.auth_service.jwt_provider import create_refresh_token | |
| app = FastAPI() | |
| # Create token with version 1 | |
| refresh_token = create_refresh_token("usr_123", "user@example.com", token_version=1) | |
| # Mock user with version 2 (token was invalidated) | |
| mock_user = MagicMock(spec=User) | |
| mock_user.user_id = "usr_123" | |
| mock_user.token_version = 2 # Higher version | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = mock_user | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=True): | |
| response = client.post( | |
| "/auth/refresh", | |
| json={"token": refresh_token} | |
| ) | |
| assert response.status_code == 401 | |
| assert "invalidated" in response.json()["detail"].lower() | |
| def test_refresh_rate_limited(self): | |
| """Rate limit blocks excessive refresh attempts.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| app = FastAPI() | |
| async def mock_get_db(): | |
| yield AsyncMock() | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.check_rate_limit', return_value=False): | |
| response = client.post( | |
| "/auth/refresh", | |
| json={"token": "any-token"} | |
| ) | |
| assert response.status_code == 429 | |
| # ============================================================================ | |
| # 5. POST /auth/logout Tests | |
| # ============================================================================ | |
| class TestLogout: | |
| """Test POST /auth/logout endpoint.""" | |
| def test_logout_requires_auth(self): | |
| """Logout requires authentication.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| app = FastAPI() | |
| app.include_router(router) | |
| client = TestClient(app) | |
| response = client.post("/auth/logout") | |
| assert response.status_code in [401, 403, 422] | |
| def test_logout_increments_token_version(self): | |
| """Logout increments user's token version.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.dependencies import get_current_user | |
| from core.database import get_db | |
| from core.models import User | |
| app = FastAPI() | |
| mock_user = MagicMock(spec=User) | |
| mock_user.id = 1 | |
| mock_user.user_id = "usr_123" | |
| mock_user.token_version = 1 | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| yield mock_db | |
| app.dependency_overrides[get_current_user] = lambda: mock_user | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.AuditService.log_event', return_value=AsyncMock()), \ | |
| patch('services.backup_service.get_backup_service'): | |
| response = client.post("/auth/logout") | |
| assert response.status_code == 200 | |
| # Token version should be incremented | |
| assert mock_user.token_version == 2 | |
| def test_logout_deletes_cookie(self): | |
| """Logout deletes refresh token cookie.""" | |
| from routers.auth import router | |
| from fastapi import FastAPI | |
| from core.dependencies import get_current_user | |
| from core.database import get_db | |
| from core.models import User | |
| app = FastAPI() | |
| mock_user = MagicMock(spec=User) | |
| mock_user.id = 1 | |
| mock_user.user_id = "usr_123" | |
| mock_user.token_version = 1 | |
| async def mock_get_db(): | |
| yield AsyncMock() | |
| app.dependency_overrides[get_current_user] = lambda: mock_user | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| with patch('routers.auth.AuditService.log_event', return_value=AsyncMock()), \ | |
| patch('services.backup_service.get_backup_service'): | |
| response = client.post("/auth/logout") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["success"] == True | |
| assert "logged out" in data["message"].lower() | |
| # ============================================================================ | |
| # Helper Function Tests | |
| # ============================================================================ | |
| class TestHelperFunctions: | |
| """Test helper functions in auth router.""" | |
| def test_detect_client_type_web(self): | |
| """detect_client_type identifies web browsers.""" | |
| from routers.auth import detect_client_type | |
| from fastapi import Request | |
| mock_request = MagicMock(spec=Request) | |
| mock_request.headers.get.return_value = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/91.0" | |
| client_type = detect_client_type(mock_request) | |
| assert client_type == "web" | |
| def test_detect_client_type_mobile(self): | |
| """detect_client_type identifies mobile apps.""" | |
| from routers.auth import detect_client_type | |
| from fastapi import Request | |
| mock_request = MagicMock(spec=Request) | |
| mock_request.headers.get.return_value = "MyApp/1.0 iOS" | |
| client_type = detect_client_type(mock_request) | |
| assert client_type == "mobile" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |