| """
|
| Worker Authentication for AegisLM
|
|
|
| Provides worker token authentication for distributed evaluation workers.
|
| Each worker must authenticate with a signed token before processing jobs.
|
| """
|
|
|
| import hashlib
|
| import hmac
|
| import uuid
|
| from datetime import datetime, timedelta
|
| from typing import Optional
|
|
|
| import jwt
|
| from fastapi import Depends, HTTPException, Request, status
|
| from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
| from pydantic import BaseModel
|
|
|
| from security.secret_manager import get_jwt_secret, get_jwt_algorithm, get_worker_secret
|
|
|
|
|
|
|
| worker_scheme = HTTPBearer(auto_error=False)
|
|
|
|
|
| class WorkerTokenPayload(BaseModel):
|
| """Worker token payload structure."""
|
| worker_id: str
|
| worker_uuid: str
|
| hostname: str
|
| iat: datetime
|
| exp: datetime
|
|
|
|
|
| class WorkerAuth:
|
| """
|
| Worker authentication service.
|
|
|
| Validates worker tokens and ensures only registered workers
|
| can process evaluation jobs.
|
| """
|
|
|
| @staticmethod
|
| def create_worker_token(
|
| worker_id: str,
|
| worker_uuid: uuid.UUID,
|
| hostname: str,
|
| expires_delta: Optional[timedelta] = None,
|
| ) -> str:
|
| """
|
| Create a JWT token for a worker.
|
|
|
| Args:
|
| worker_id: Unique worker identifier
|
| worker_uuid: Worker UUID
|
| hostname: Worker hostname
|
| expires_delta: Token expiration time delta
|
|
|
| Returns:
|
| JWT token string
|
| """
|
| if expires_delta is None:
|
| expires_delta = timedelta(hours=24)
|
|
|
| now = datetime.utcnow()
|
| expire = now + expires_delta
|
|
|
| payload = {
|
| "worker_id": worker_id,
|
| "worker_uuid": str(worker_uuid),
|
| "hostname": hostname,
|
| "iat": now,
|
| "exp": expire,
|
| }
|
|
|
| return jwt.encode(
|
| payload,
|
| get_jwt_secret(),
|
| algorithm=get_jwt_algorithm(),
|
| )
|
|
|
| @staticmethod
|
| def decode_worker_token(token: str) -> WorkerTokenPayload:
|
| """
|
| Decode and validate a worker token.
|
|
|
| Args:
|
| token: JWT token string
|
|
|
| Returns:
|
| WorkerTokenPayload
|
|
|
| Raises:
|
| HTTPException: If token is invalid or expired
|
| """
|
| try:
|
| payload = jwt.decode(
|
| token,
|
| get_jwt_secret(),
|
| algorithms=[get_jwt_algorithm()],
|
| )
|
| return WorkerTokenPayload(
|
| worker_id=payload["worker_id"],
|
| worker_uuid=payload["worker_uuid"],
|
| hostname=payload["hostname"],
|
| iat=datetime.fromtimestamp(payload["iat"]),
|
| exp=datetime.fromtimestamp(payload["exp"]),
|
| )
|
| except jwt.ExpiredSignatureError:
|
| raise HTTPException(
|
| status_code=status.HTTP_401_UNAUTHORIZED,
|
| detail="Worker token has expired",
|
| headers={"WWW-Authenticate": "Worker-Token"},
|
| )
|
| except jwt.InvalidTokenError:
|
| raise HTTPException(
|
| status_code=status.HTTP_401_UNAUTHORIZED,
|
| detail="Invalid worker token",
|
| headers={"WWW-Authenticate": "Worker-Token"},
|
| )
|
|
|
| @staticmethod
|
| def verify_worker_signature(
|
| worker_id: str,
|
| timestamp: str,
|
| signature: str,
|
| ) -> bool:
|
| """
|
| Verify worker signature using HMAC.
|
|
|
| Args:
|
| worker_id: Worker identifier
|
| timestamp: Request timestamp
|
| signature: HMAC signature
|
|
|
| Returns:
|
| True if signature is valid
|
| """
|
| worker_secret = get_worker_secret()
|
| message = f"{worker_id}:{timestamp}"
|
|
|
| expected_signature = hmac.new(
|
| worker_secret.encode(),
|
| message.encode(),
|
| hashlib.sha256,
|
| ).hexdigest()
|
|
|
| return hmac.compare_digest(signature, expected_signature)
|
|
|
| @staticmethod
|
| def create_worker_signature(
|
| worker_id: str,
|
| timestamp: str,
|
| ) -> str:
|
| """
|
| Create HMAC signature for worker request.
|
|
|
| Args:
|
| worker_id: Worker identifier
|
| timestamp: Request timestamp
|
|
|
| Returns:
|
| HMAC signature
|
| """
|
| worker_secret = get_worker_secret()
|
| message = f"{worker_id}:{timestamp}"
|
|
|
| return hmac.new(
|
| worker_secret.encode(),
|
| message.encode(),
|
| hashlib.sha256,
|
| ).hexdigest()
|
|
|
|
|
| async def get_current_worker(
|
| request: Request,
|
| credentials: Optional[HTTPAuthorizationCredentials] = Depends(worker_scheme),
|
| ) -> WorkerTokenPayload:
|
| """
|
| FastAPI dependency to get the current authenticated worker.
|
|
|
| Validates the worker token from the Authorization header.
|
| """
|
| if not credentials:
|
|
|
| worker_token = request.headers.get("X-Worker-Token")
|
| if not worker_token:
|
| raise HTTPException(
|
| status_code=status.HTTP_401_UNAUTHORIZED,
|
| detail="Worker authentication required",
|
| headers={"WWW-Authenticate": "Worker-Token"},
|
| )
|
| token = worker_token
|
| else:
|
| token = credentials.credentials
|
|
|
|
|
| return WorkerAuth.decode_worker_token(token)
|
|
|
|
|
| async def get_current_worker_optional(
|
| request: Request,
|
| credentials: Optional[HTTPAuthorizationCredentials] = Depends(worker_scheme),
|
| ) -> Optional[WorkerTokenPayload]:
|
| """
|
| FastAPI dependency to get the current authenticated worker, optionally.
|
|
|
| Returns None if no valid authentication is provided.
|
| """
|
| token = None
|
|
|
|
|
| if credentials:
|
| token = credentials.credentials
|
| else:
|
|
|
| worker_token = request.headers.get("X-Worker-Token")
|
| if worker_token:
|
| token = worker_token
|
|
|
| if not token:
|
| return None
|
|
|
| try:
|
| return WorkerAuth.decode_worker_token(token)
|
| except HTTPException:
|
| return None
|
|
|
|
|
| def require_worker_auth():
|
| """
|
| Decorator/factory to require worker authentication.
|
|
|
| Usage:
|
| @router.get("/worker/status")
|
| @require_worker_auth()
|
| async def get_worker_status(worker: WorkerTokenPayload = Depends(get_current_worker)):
|
| ...
|
| """
|
| async def checker(worker: WorkerTokenPayload = Depends(get_current_worker)):
|
| return worker
|
| return checker
|
|
|
|
|
|
|
|
|
|
|
|
|
| def create_registration_token(
|
| worker_id: str,
|
| expires_delta: Optional[timedelta] = None,
|
| ) -> str:
|
| """
|
| Create a registration token for new workers.
|
|
|
| Args:
|
| worker_id: Worker identifier
|
| expires_delta: Token expiration time delta
|
|
|
| Returns:
|
| Registration token
|
| """
|
| if expires_delta is None:
|
| expires_delta = timedelta(hours=1)
|
|
|
| now = datetime.utcnow()
|
| expire = now + expires_delta
|
|
|
| payload = {
|
| "type": "registration",
|
| "worker_id": worker_id,
|
| "iat": now,
|
| "exp": expire,
|
| }
|
|
|
| return jwt.encode(
|
| payload,
|
| get_worker_secret(),
|
| algorithm="HS256",
|
| )
|
|
|
|
|
| def verify_registration_token(token: str) -> Optional[str]:
|
| """
|
| Verify a registration token.
|
|
|
| Args:
|
| token: Registration token
|
|
|
| Returns:
|
| Worker ID if valid, None otherwise
|
| """
|
| try:
|
| payload = jwt.decode(
|
| token,
|
| get_worker_secret(),
|
| algorithms=["HS256"],
|
| )
|
|
|
| if payload.get("type") != "registration":
|
| return None
|
|
|
| return payload.get("worker_id")
|
| except jwt.InvalidTokenError:
|
| return None
|
|
|