""" Worker Authentication for AegisLM Provides worker token authentication for distributed evaluation workers. Each worker must authenticate with a signed token before processing jobs. """ import hashlib import hmac import uuid from datetime import datetime, timedelta from typing import Optional import jwt from fastapi import Depends, HTTPException, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel from security.secret_manager import get_jwt_secret, get_jwt_algorithm, get_worker_secret # Security scheme worker_scheme = HTTPBearer(auto_error=False) class WorkerTokenPayload(BaseModel): """Worker token payload structure.""" worker_id: str worker_uuid: str hostname: str iat: datetime exp: datetime class WorkerAuth: """ Worker authentication service. Validates worker tokens and ensures only registered workers can process evaluation jobs. """ @staticmethod def create_worker_token( worker_id: str, worker_uuid: uuid.UUID, hostname: str, expires_delta: Optional[timedelta] = None, ) -> str: """ Create a JWT token for a worker. Args: worker_id: Unique worker identifier worker_uuid: Worker UUID hostname: Worker hostname expires_delta: Token expiration time delta Returns: JWT token string """ if expires_delta is None: expires_delta = timedelta(hours=24) now = datetime.utcnow() expire = now + expires_delta payload = { "worker_id": worker_id, "worker_uuid": str(worker_uuid), "hostname": hostname, "iat": now, "exp": expire, } return jwt.encode( payload, get_jwt_secret(), algorithm=get_jwt_algorithm(), ) @staticmethod def decode_worker_token(token: str) -> WorkerTokenPayload: """ Decode and validate a worker token. Args: token: JWT token string Returns: WorkerTokenPayload Raises: HTTPException: If token is invalid or expired """ try: payload = jwt.decode( token, get_jwt_secret(), algorithms=[get_jwt_algorithm()], ) return WorkerTokenPayload( worker_id=payload["worker_id"], worker_uuid=payload["worker_uuid"], hostname=payload["hostname"], iat=datetime.fromtimestamp(payload["iat"]), exp=datetime.fromtimestamp(payload["exp"]), ) except jwt.ExpiredSignatureError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Worker token has expired", headers={"WWW-Authenticate": "Worker-Token"}, ) except jwt.InvalidTokenError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid worker token", headers={"WWW-Authenticate": "Worker-Token"}, ) @staticmethod def verify_worker_signature( worker_id: str, timestamp: str, signature: str, ) -> bool: """ Verify worker signature using HMAC. Args: worker_id: Worker identifier timestamp: Request timestamp signature: HMAC signature Returns: True if signature is valid """ worker_secret = get_worker_secret() message = f"{worker_id}:{timestamp}" expected_signature = hmac.new( worker_secret.encode(), message.encode(), hashlib.sha256, ).hexdigest() return hmac.compare_digest(signature, expected_signature) @staticmethod def create_worker_signature( worker_id: str, timestamp: str, ) -> str: """ Create HMAC signature for worker request. Args: worker_id: Worker identifier timestamp: Request timestamp Returns: HMAC signature """ worker_secret = get_worker_secret() message = f"{worker_id}:{timestamp}" return hmac.new( worker_secret.encode(), message.encode(), hashlib.sha256, ).hexdigest() async def get_current_worker( request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(worker_scheme), ) -> WorkerTokenPayload: """ FastAPI dependency to get the current authenticated worker. Validates the worker token from the Authorization header. """ if not credentials: # Also check for Worker-Token header worker_token = request.headers.get("X-Worker-Token") if not worker_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Worker authentication required", headers={"WWW-Authenticate": "Worker-Token"}, ) token = worker_token else: token = credentials.credentials # Decode and validate token return WorkerAuth.decode_worker_token(token) async def get_current_worker_optional( request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(worker_scheme), ) -> Optional[WorkerTokenPayload]: """ FastAPI dependency to get the current authenticated worker, optionally. Returns None if no valid authentication is provided. """ token = None # Check Authorization header if credentials: token = credentials.credentials else: # Check X-Worker-Token header worker_token = request.headers.get("X-Worker-Token") if worker_token: token = worker_token if not token: return None try: return WorkerAuth.decode_worker_token(token) except HTTPException: return None def require_worker_auth(): """ Decorator/factory to require worker authentication. Usage: @router.get("/worker/status") @require_worker_auth() async def get_worker_status(worker: WorkerTokenPayload = Depends(get_current_worker)): ... """ async def checker(worker: WorkerTokenPayload = Depends(get_current_worker)): return worker return checker # ============================================================================= # Worker registration token (for initial registration) # ============================================================================= def create_registration_token( worker_id: str, expires_delta: Optional[timedelta] = None, ) -> str: """ Create a registration token for new workers. Args: worker_id: Worker identifier expires_delta: Token expiration time delta Returns: Registration token """ if expires_delta is None: expires_delta = timedelta(hours=1) now = datetime.utcnow() expire = now + expires_delta payload = { "type": "registration", "worker_id": worker_id, "iat": now, "exp": expire, } return jwt.encode( payload, get_worker_secret(), algorithm="HS256", ) def verify_registration_token(token: str) -> Optional[str]: """ Verify a registration token. Args: token: Registration token Returns: Worker ID if valid, None otherwise """ try: payload = jwt.decode( token, get_worker_secret(), algorithms=["HS256"], ) if payload.get("type") != "registration": return None return payload.get("worker_id") except jwt.InvalidTokenError: return None