""" Authentication Middleware for AegisLM Provides FastAPI middleware for authentication and authorization. """ import hashlib import jwt import uuid from datetime import datetime, timedelta from typing import Optional from dataclasses import dataclass from fastapi import Depends, HTTPException, Request, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from backend.db.models import User, Tenant, APIKey from backend.db.session import get_db_session from security.rbac import RBACContext, Role from security.tenant_scope import TenantScope, set_tenant_scope # Security scheme security = HTTPBearer() # JWT configuration - loaded from secret manager def _get_jwt_secret() -> str: """Get JWT secret from secret manager.""" from security.secret_manager import get_jwt_secret return get_jwt_secret() def _get_jwt_algorithm() -> str: """Get JWT algorithm from secret manager.""" from security.secret_manager import get_secret_manager return get_secret_manager().get_jwt_algorithm() def _get_jwt_expiration_hours() -> int: """Get JWT expiration hours from secret manager.""" from security.secret_manager import get_secret_manager return get_secret_manager().get_jwt_expiration_hours() @dataclass class AuthenticatedUser: """Represents an authenticated user.""" user_id: uuid.UUID tenant_id: uuid.UUID email: str role: Role is_api_client: bool = False def hash_api_key(api_key: str) -> str: """Hash an API key for storage/comparison.""" return hashlib.sha256(api_key.encode()).hexdigest() def create_jwt_token( user_id: uuid.UUID, tenant_id: uuid.UUID, email: str, role: str, expires_delta: Optional[timedelta] = None, ) -> str: """ Create a JWT token for a user. Args: user_id: User ID tenant_id: Tenant ID email: User email role: User role expires_delta: Token expiration time delta Returns: JWT token string """ if expires_delta is None: expires_delta = timedelta(hours=_get_jwt_expiration_hours()) expire = datetime.utcnow() + expires_delta payload = { "sub": str(user_id), "tenant_id": str(tenant_id), "email": email, "role": role, "exp": expire, "iat": datetime.utcnow(), } return jwt.encode(payload, _get_jwt_secret(), algorithm=_get_jwt_algorithm()) def decode_jwt_token(token: str) -> dict: """ Decode and validate a JWT token. Args: token: JWT token string Returns: Decoded token payload Raises: HTTPException: If token is invalid or expired """ try: payload = jwt.decode( token, _get_jwt_secret(), algorithms=[_get_jwt_algorithm()] ) return payload except jwt.ExpiredSignatureError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired", headers={"WWW-Authenticate": "Bearer"}, ) except jwt.InvalidTokenError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", headers={"WWW-Authenticate": "Bearer"}, ) async def get_db() -> AsyncSession: """Get database session.""" async for session in get_db_session(): yield session async def get_current_user_from_token( token: str, ) -> tuple[uuid.UUID, uuid.UUID, str, Role]: """ Validate JWT token and extract user info WITHOUT database access. This function validates the token signature and expiration ONLY. It does NOT query the database. Returns: Tuple of (user_id, tenant_id, email, role) Raises: HTTPException: If token is invalid or expired """ # Step 1: Validate JWT signature and expiration (NO DB ACCESS) payload = decode_jwt_token(token) # Step 2: Extract claims from validated token user_id = uuid.UUID(payload["sub"]) tenant_id = uuid.UUID(payload["tenant_id"]) email = payload["email"] role = Role(payload["role"]) return user_id, tenant_id, email, role async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db), ) -> AuthenticatedUser: """ FastAPI dependency to get the current authenticated user. CRITICAL: This validates the JWT signature FIRST (no DB), then only queries DB if token is valid. This ensures 401 is returned BEFORE any database access for unauthenticated requests. """ # Step 1: Validate token WITHOUT database (fails fast for invalid tokens) token = credentials.credentials user_id, tenant_id, email, role = await get_current_user_from_token(token) # Step 2: Only query DB AFTER token validation succeeds query = select(User).where( User.id == user_id, User.tenant_id == tenant_id, User.active == True, ) result = await db.execute(query) user = result.scalar_one_or_none() if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive", headers={"WWW-Authenticate": "Bearer"}, ) return AuthenticatedUser( user_id=user.id, tenant_id=user.tenant_id, email=user.email, role=Role(user.role), is_api_client=False, ) async def get_current_user_optional( request: Request, db: AsyncSession = Depends(get_db), ) -> Optional[AuthenticatedUser]: """ FastAPI dependency to get the current authenticated user, optionally. Returns None if no valid authentication is provided. """ # Check for Bearer token auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): token = auth_header[7:] try: payload = decode_jwt_token(token) user_id = uuid.UUID(payload["sub"]) tenant_id = uuid.UUID(payload["tenant_id"]) query = select(User).where( User.id == user_id, User.tenant_id == tenant_id, User.active == True, ) result = await db.execute(query) user = result.scalar_one_or_none() if user: return AuthenticatedUser( user_id=user.id, tenant_id=user.tenant_id, email=user.email, role=Role(user.role), is_api_client=False, ) except Exception: pass # Check for API key api_key = request.headers.get("X-API-Key") if api_key: return await verify_api_key(db, api_key) return None async def verify_api_key( db: AsyncSession, api_key: str, ) -> Optional[AuthenticatedUser]: """ Verify an API key and return the associated user. Args: db: Database session api_key: API key to verify Returns: AuthenticatedUser if valid, None otherwise """ key_hash = hash_api_key(api_key) query = select(APIKey).where( APIKey.key_hash == key_hash, APIKey.active == True, ) result = await db.execute(query) api_key_obj = result.scalar_one_or_none() if api_key_obj is None: return None # Update last used api_key_obj.last_used = datetime.utcnow() await db.commit() # Get the tenant query = select(Tenant).where(Tenant.id == api_key_obj.tenant_id) result = await db.execute(query) tenant = result.scalar_one_or_none() if tenant is None or not tenant.active: return None return AuthenticatedUser( user_id=api_key_obj.id, # Use API key ID as user_id for API clients tenant_id=api_key_obj.tenant_id, email=f"api:{api_key_obj.owner}", role=Role.API_CLIENT, is_api_client=True, ) async def get_current_tenant( user: AuthenticatedUser = Depends(get_current_user), ) -> uuid.UUID: """Get the current tenant ID from the authenticated user.""" return user.tenant_id class AuthMiddleware: """ Authentication middleware for FastAPI. Provides request authentication and sets up tenant context. """ @staticmethod def hash_api_key(api_key: str) -> str: """Hash an API key for storage/comparison.""" return hashlib.sha256(api_key.encode()).hexdigest() @staticmethod async def verify_api_key(db: AsyncSession, api_key: str) -> Optional[AuthenticatedUser]: """ Verify an API key and return the associated user. """ return await verify_api_key(db, api_key) @staticmethod async def authenticate_request( request: Request, db: AsyncSession, ) -> Optional[AuthenticatedUser]: """ Authenticate a request using either JWT or API key. Checks Authorization header for Bearer token or X-API-Key. """ # Check for Bearer token auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): token = auth_header[7:] try: payload = decode_jwt_token(token) user_id = uuid.UUID(payload["sub"]) tenant_id = uuid.UUID(payload["tenant_id"]) query = select(User).where( User.id == user_id, User.tenant_id == tenant_id, User.active == True, ) result = await db.execute(query) user = result.scalar_one_or_none() if user: return AuthenticatedUser( user_id=user.id, tenant_id=user.tenant_id, email=user.email, role=Role(user.role), is_api_client=False, ) except Exception: pass # Check for API key api_key = request.headers.get("X-API-Key") if api_key: return await verify_api_key(db, api_key) return None class TenantContextMiddleware: """ Middleware to set up tenant context for each request. This ensures all database queries are properly scoped to the current tenant. """ def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): if scope["type"] != "http": await self.app(scope, receive, send) return # TODO: Extract tenant from request and set context # This would typically be done after authentication await self.app(scope, receive, send) async def require_role(required_role: Role): """ FastAPI dependency to require a specific role. Usage: @router.get("/admin/users") async def list_users(user: AuthenticatedUser = Depends(require_role(Role.ADMIN))): ... """ async def role_checker(user: AuthenticatedUser = Depends(get_current_user)): if user.role != required_role: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Requires {required_role.value} role" ) return user return role_checker async def require_permission(permission: str): """ FastAPI dependency to require a specific permission. Usage: @router.post("/jobs") async def create_job( user: AuthenticatedUser = Depends(require_permission("create_job")) ): ... """ async def permission_checker(user: AuthenticatedUser = Depends(get_current_user)): # Import here to avoid circular imports from security.rbac import RBAC, Permission # Convert string permission to enum try: perm = Permission(permission) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid permission: {permission}" ) # Check if user has permission if not RBAC.has_permission(user.role, perm): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Requires {permission} permission" ) return user return permission_checker def create_rbac_context(user: AuthenticatedUser) -> RBACContext: """ Create an RBAC context from an authenticated user. This can be stored in request state for easy access. """ return RBACContext( user_id=user.user_id, tenant_id=user.tenant_id, role=user.role, ) async def setup_tenant_context( user: AuthenticatedUser = Depends(get_current_user), ) -> TenantScope: """ Set up tenant context for the current request. This ensures all subsequent database queries are scoped to the tenant. """ scope = TenantScope(tenant_id=user.tenant_id) set_tenant_scope(scope) return scope