Spaces:
Build error
Build error
| import logging | |
| from typing import Any, Dict | |
| from datetime import datetime | |
| from fastapi import Request, HTTPException, status | |
| from sqlalchemy import select | |
| from google_auth_service.fastapi_hooks import AuthHooks | |
| from core.database import async_session_maker | |
| from core.dependencies import check_rate_limit | |
| from services.audit_service import AuditService | |
| from core.models import ClientUser, User | |
| logger = logging.getLogger(__name__) | |
| class CoreAuthHooks(AuthHooks): | |
| """ | |
| Custom authentication hooks for API Gateway. | |
| Handles: Rate Limiting, Audit Logging, Client User Linking, and Backups. | |
| """ | |
| async def before_login(self, request: Request): | |
| """Rate Limit Check""" | |
| ip = request.client.host | |
| async with async_session_maker() as db: | |
| if not await check_rate_limit(db, ip, "/auth/google", 10, 1): | |
| raise HTTPException( | |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
| detail="Too many authentication attempts" | |
| ) | |
| async def on_login_success(self, user: Any, tokens: Dict[str, str], request: Request, is_new_user: bool = False): | |
| """Audit Log, Link Client, Trigger Backup""" | |
| ip = request.client.host | |
| # Try to retrieve body (FastAPI/Starlette caches .json() result) | |
| login_data = {} | |
| try: | |
| login_data = await request.json() | |
| except Exception: | |
| pass | |
| temp_user_id = login_data.get("temp_user_id") | |
| async with async_session_maker() as db: | |
| # 1. Link Client User if temp_user_id provided | |
| if temp_user_id: | |
| # Check if this client mapping exists | |
| client_query = select(ClientUser).where( | |
| ClientUser.user_id == user.id, | |
| ClientUser.client_user_id == temp_user_id | |
| ) | |
| client_result = await db.execute(client_query) | |
| existing_client = client_result.scalar_one_or_none() | |
| if not existing_client: | |
| # Create new client user mapping | |
| client_user = ClientUser( | |
| user_id=user.id, | |
| client_user_id=temp_user_id, | |
| ip_address=ip, | |
| last_seen_at=datetime.utcnow() | |
| ) | |
| db.add(client_user) | |
| else: | |
| # Update last seen | |
| existing_client.last_seen_at = datetime.utcnow() | |
| # Commit is needed for ClientUser changes | |
| await db.commit() | |
| # 2. Log Success | |
| await AuditService.log_event( | |
| db=db, | |
| log_type="server", | |
| user_id=user.id, | |
| client_user_id=temp_user_id, | |
| action="google_auth", | |
| status="success", | |
| request=request | |
| ) | |
| await db.commit() | |
| # 3. Trigger Backup | |
| from services.backup_service import get_backup_service | |
| backup_service = get_backup_service() | |
| await backup_service.backup_async() | |
| async def on_login_error(self, error: Exception, request: Request): | |
| """Audit Log Failure""" | |
| async with async_session_maker() as db: | |
| await AuditService.log_event( | |
| db=db, | |
| log_type="server", | |
| action="google_auth", | |
| status="failed", | |
| error_message=str(error), | |
| request=request | |
| ) | |
| async def on_logout(self, user: Any, request: Request): | |
| """Log Logout, Backup""" | |
| async with async_session_maker() as db: | |
| if user: | |
| # Need user.id (int) or user_id (str)? | |
| # User object from library `get` is a Dict in test, but `User` model in prod? | |
| # Wait, `get` returns what `UserStore.save` returns. | |
| # apigateway's UserStore will return SQLAlchemy model `User`. | |
| # So user.id is valid. | |
| await AuditService.log_event( | |
| db=db, | |
| log_type="server", | |
| user_id=user.id, | |
| action="logout", | |
| status="success", | |
| request=request | |
| ) | |
| from services.backup_service import get_backup_service | |
| backup_service = get_backup_service() | |
| await backup_service.backup_async() | |