| """ |
| Authentication dependencies for AegisLM SaaS Backend. |
| |
| Production-ready dependency injection for authentication, |
| authorization, and user context management. |
| """ |
|
|
| from typing import Optional |
| from fastapi import Depends, HTTPException, status, Request |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from sqlalchemy import select |
|
|
| from core.database import get_db |
| from services.auth_service import AuthService |
| from db_models.user import User |
| from db_models.api_key import ApiKey |
|
|
|
|
| |
| security = HTTPBearer(auto_error=False) |
|
|
|
|
| async def get_current_user( |
| request: Request, |
| credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), |
| db: AsyncSession = Depends(get_db) |
| ) -> User: |
| """ |
| Get current authenticated user from JWT token or session cookie. |
| |
| Args: |
| request: HTTP request object |
| credentials: HTTP Authorization credentials (optional) |
| db: Database session |
| |
| Returns: |
| User: Current authenticated user |
| |
| Raises: |
| HTTPException: If authentication fails |
| """ |
| auth_service = AuthService(db) |
| user = None |
| |
| |
| if credentials: |
| user = await auth_service.verify_token(credentials.credentials) |
| |
| |
| if not user: |
| session_id = request.cookies.get("session_id") |
| if session_id: |
| session_data = await auth_service.get_session(session_id) |
| if session_data: |
| user = await auth_service.get_user_by_id(session_data.get("user_id")) |
| |
| if not user: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid authentication credentials", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
| |
| if not user.is_active: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="User account is deactivated", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
| |
| return user |
|
|
|
|
| async def get_current_active_user( |
| current_user: User = Depends(get_current_user) |
| ) -> User: |
| """ |
| Get current active user. |
| |
| Args: |
| current_user: Current authenticated user |
| |
| Returns: |
| User: Current active user |
| |
| Raises: |
| HTTPException: If user is not active |
| """ |
| if not current_user.is_active: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="User account is not active" |
| ) |
| |
| return current_user |
|
|
|
|
| async def get_current_verified_user( |
| current_user: User = Depends(get_current_active_user) |
| ) -> User: |
| """ |
| Get current verified user. |
| |
| Args: |
| current_user: Current active user |
| |
| Returns: |
| User: Current verified user |
| |
| Raises: |
| HTTPException: If user is not verified |
| """ |
| if not current_user.is_verified: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="User account is not verified" |
| ) |
| |
| return current_user |
|
|
|
|
| async def get_api_key_user( |
| api_key: str, |
| db: AsyncSession = Depends(get_db) |
| ) -> User: |
| """ |
| Get user from API key. |
| |
| Args: |
| api_key: API key string |
| db: Database session |
| |
| Returns: |
| User: User associated with API key |
| |
| Raises: |
| HTTPException: If API key is invalid |
| """ |
| |
| result = await db.execute( |
| select(ApiKey).where(ApiKey.key == api_key) |
| ) |
| api_key_obj = result.scalar_one_or_none() |
| |
| if not api_key_obj: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid API key" |
| ) |
| |
| |
| if not api_key_obj.is_valid(): |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="API key is expired or inactive" |
| ) |
| |
| |
| user = await db.get(User, api_key_obj.user_id) |
| if not user: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="User not found" |
| ) |
| |
| if not user.is_active: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="User account is deactivated" |
| ) |
| |
| |
| api_key_obj.update_usage() |
| await db.commit() |
| |
| return user |
|
|
|
|
| async def get_optional_current_user( |
| request: Request, |
| credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), |
| db: AsyncSession = Depends(get_db) |
| ) -> Optional[User]: |
| """ |
| Get current user if authenticated, otherwise None. |
| Supports both JWT and session-based authentication. |
| |
| Args: |
| request: HTTP request object |
| credentials: HTTP Authorization credentials (optional) |
| db: Database session |
| |
| Returns: |
| Optional[User]: Current user or None |
| """ |
| auth_service = AuthService(db) |
| user = None |
| |
| |
| if credentials: |
| try: |
| user = await auth_service.verify_token(credentials.credentials) |
| except Exception: |
| pass |
| |
| |
| if not user: |
| try: |
| session_id = request.cookies.get("session_id") |
| if session_id: |
| session_data = await auth_service.get_session(session_id) |
| if session_data: |
| user = await auth_service.get_user_by_id(session_data.get("user_id")) |
| except Exception: |
| pass |
| |
| if user and user.is_active: |
| return user |
| |
| return None |
|
|
|
|
| class RequirePermission: |
| """ |
| Permission requirement dependency. |
| |
| This class can be used to create permission-based dependencies. |
| """ |
| |
| def __init__(self, required_permission: str): |
| self.required_permission = required_permission |
| |
| def __call__(self, current_user: User = Depends(get_current_active_user)) -> User: |
| """ |
| Check if user has required permission. |
| |
| Args: |
| current_user: Current authenticated user |
| |
| Returns: |
| User: Current user if has permission |
| |
| Raises: |
| HTTPException: If user lacks permission |
| """ |
| if not current_user.has_permission(self.required_permission): |
| raise HTTPException( |
| status_code=status.HTTP_403_FORBIDDEN, |
| detail=f"Insufficient permissions. Required: {self.required_permission}" |
| ) |
| |
| return current_user |
|
|
|
|
| |
| RequireVerifiedUser = Depends(get_current_verified_user) |
| RequireActiveUser = Depends(get_current_active_user) |
|
|
|
|
| |
| async def get_rate_limit_key( |
| current_user: Optional[User] = Depends(get_optional_current_user), |
| api_key_user: Optional[User] = None |
| ) -> str: |
| """ |
| Get rate limit key for user. |
| |
| Args: |
| current_user: Current authenticated user |
| api_key_user: User from API key |
| |
| Returns: |
| str: Rate limit key |
| """ |
| if current_user: |
| return f"user:{current_user.id}" |
| elif api_key_user: |
| return f"api_key:{api_key_user.id}" |
| else: |
| |
| return "anonymous" |
|
|