aegislm / security /worker_auth.py
ACA050's picture
Upload 57 files
f2c6053 verified
"""
Worker Authentication for AegisLM
Provides worker token authentication for distributed evaluation workers.
Each worker must authenticate with a signed token before processing jobs.
"""
import hashlib
import hmac
import uuid
from datetime import datetime, timedelta
from typing import Optional
import jwt
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
from security.secret_manager import get_jwt_secret, get_jwt_algorithm, get_worker_secret
# Security scheme
worker_scheme = HTTPBearer(auto_error=False)
class WorkerTokenPayload(BaseModel):
"""Worker token payload structure."""
worker_id: str
worker_uuid: str
hostname: str
iat: datetime
exp: datetime
class WorkerAuth:
"""
Worker authentication service.
Validates worker tokens and ensures only registered workers
can process evaluation jobs.
"""
@staticmethod
def create_worker_token(
worker_id: str,
worker_uuid: uuid.UUID,
hostname: str,
expires_delta: Optional[timedelta] = None,
) -> str:
"""
Create a JWT token for a worker.
Args:
worker_id: Unique worker identifier
worker_uuid: Worker UUID
hostname: Worker hostname
expires_delta: Token expiration time delta
Returns:
JWT token string
"""
if expires_delta is None:
expires_delta = timedelta(hours=24)
now = datetime.utcnow()
expire = now + expires_delta
payload = {
"worker_id": worker_id,
"worker_uuid": str(worker_uuid),
"hostname": hostname,
"iat": now,
"exp": expire,
}
return jwt.encode(
payload,
get_jwt_secret(),
algorithm=get_jwt_algorithm(),
)
@staticmethod
def decode_worker_token(token: str) -> WorkerTokenPayload:
"""
Decode and validate a worker token.
Args:
token: JWT token string
Returns:
WorkerTokenPayload
Raises:
HTTPException: If token is invalid or expired
"""
try:
payload = jwt.decode(
token,
get_jwt_secret(),
algorithms=[get_jwt_algorithm()],
)
return WorkerTokenPayload(
worker_id=payload["worker_id"],
worker_uuid=payload["worker_uuid"],
hostname=payload["hostname"],
iat=datetime.fromtimestamp(payload["iat"]),
exp=datetime.fromtimestamp(payload["exp"]),
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Worker token has expired",
headers={"WWW-Authenticate": "Worker-Token"},
)
except jwt.InvalidTokenError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid worker token",
headers={"WWW-Authenticate": "Worker-Token"},
)
@staticmethod
def verify_worker_signature(
worker_id: str,
timestamp: str,
signature: str,
) -> bool:
"""
Verify worker signature using HMAC.
Args:
worker_id: Worker identifier
timestamp: Request timestamp
signature: HMAC signature
Returns:
True if signature is valid
"""
worker_secret = get_worker_secret()
message = f"{worker_id}:{timestamp}"
expected_signature = hmac.new(
worker_secret.encode(),
message.encode(),
hashlib.sha256,
).hexdigest()
return hmac.compare_digest(signature, expected_signature)
@staticmethod
def create_worker_signature(
worker_id: str,
timestamp: str,
) -> str:
"""
Create HMAC signature for worker request.
Args:
worker_id: Worker identifier
timestamp: Request timestamp
Returns:
HMAC signature
"""
worker_secret = get_worker_secret()
message = f"{worker_id}:{timestamp}"
return hmac.new(
worker_secret.encode(),
message.encode(),
hashlib.sha256,
).hexdigest()
async def get_current_worker(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(worker_scheme),
) -> WorkerTokenPayload:
"""
FastAPI dependency to get the current authenticated worker.
Validates the worker token from the Authorization header.
"""
if not credentials:
# Also check for Worker-Token header
worker_token = request.headers.get("X-Worker-Token")
if not worker_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Worker authentication required",
headers={"WWW-Authenticate": "Worker-Token"},
)
token = worker_token
else:
token = credentials.credentials
# Decode and validate token
return WorkerAuth.decode_worker_token(token)
async def get_current_worker_optional(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(worker_scheme),
) -> Optional[WorkerTokenPayload]:
"""
FastAPI dependency to get the current authenticated worker, optionally.
Returns None if no valid authentication is provided.
"""
token = None
# Check Authorization header
if credentials:
token = credentials.credentials
else:
# Check X-Worker-Token header
worker_token = request.headers.get("X-Worker-Token")
if worker_token:
token = worker_token
if not token:
return None
try:
return WorkerAuth.decode_worker_token(token)
except HTTPException:
return None
def require_worker_auth():
"""
Decorator/factory to require worker authentication.
Usage:
@router.get("/worker/status")
@require_worker_auth()
async def get_worker_status(worker: WorkerTokenPayload = Depends(get_current_worker)):
...
"""
async def checker(worker: WorkerTokenPayload = Depends(get_current_worker)):
return worker
return checker
# =============================================================================
# Worker registration token (for initial registration)
# =============================================================================
def create_registration_token(
worker_id: str,
expires_delta: Optional[timedelta] = None,
) -> str:
"""
Create a registration token for new workers.
Args:
worker_id: Worker identifier
expires_delta: Token expiration time delta
Returns:
Registration token
"""
if expires_delta is None:
expires_delta = timedelta(hours=1)
now = datetime.utcnow()
expire = now + expires_delta
payload = {
"type": "registration",
"worker_id": worker_id,
"iat": now,
"exp": expire,
}
return jwt.encode(
payload,
get_worker_secret(),
algorithm="HS256",
)
def verify_registration_token(token: str) -> Optional[str]:
"""
Verify a registration token.
Args:
token: Registration token
Returns:
Worker ID if valid, None otherwise
"""
try:
payload = jwt.decode(
token,
get_worker_secret(),
algorithms=["HS256"],
)
if payload.get("type") != "registration":
return None
return payload.get("worker_id")
except jwt.InvalidTokenError:
return None