""" Payload Validation Middleware for AegisLM Provides middleware to validate job payloads before processing. """ from typing import Any, Dict, List, Optional from enum import Enum from fastapi import Request, HTTPException from pydantic import BaseModel, Field, validator class EvaluationMode(str, Enum): """Evaluation modes.""" LIGHTWEIGHT = "lightweight" FULL = "full" class AttackType(str, Enum): """Valid attack types.""" INJECTION = "injection" JAILBREAK = "jailbreak" BIAS_TRIGGER = "bias_trigger" CONTEXT_POISON = "context_poison" ROLE_CONFUSION = "role_confusion" CHAINING = "chaining" class MutationType(str, Enum): """Valid mutation types.""" SYNONYM = "synonym" ROLE_SWAP = "role_swap" CONTEXT_OBFUSCATION = "context_obfuscation" MULTI_HOP = "multi_hop" PARAPHRASE = "paraphrase" class JobPayloadSchema(BaseModel): """Schema for job payload validation.""" model_name: str = Field(..., min_length=1, max_length=255) model_version: str = Field(..., min_length=1, max_length=100) dataset_name: str = Field(..., min_length=1, max_length=255) dataset_version: str = Field(..., min_length=1, max_length=100) # Evaluation settings evaluation_mode: EvaluationMode = EvaluationMode.FULL temperature: float = Field(default=0.7, ge=0.0, le=2.0) max_tokens: int = Field(default=512, ge=1, le=4096) # Attack settings attack_types: Optional[List[str]] = None mutation_enabled: bool = True mutation_depth: int = Field(default=1, ge=0, le=5) # Batch settings batch_size: int = Field(default=10, ge=1, le=100) max_samples: Optional[int] = Field(default=None, ge=1) @validator("attack_types") def validate_attack_types(cls, v): if v is not None: valid_attacks = [a.value for a in AttackType] for attack in v: if attack not in valid_attacks: raise ValueError(f"Invalid attack type: {attack}") return v @validator("dataset_name") def validate_dataset_name(cls, v): # Check against known datasets allowed_datasets = ["advbench", "truthfulqa", "aegislm-harmful-queries"] if v not in allowed_datasets: # Allow for custom datasets but warn pass return v class PayloadValidator: """ Validates job payloads for security and integrity. Checks: - Required fields present - Field values within acceptable ranges - Dataset and model versions exist - Attack types are valid - Weights sum to 1.0 (if provided) """ # Valid dataset names ALLOWED_DATASETS = ["advbench", "truthfulqa", "aegislm-harmful-queries"] # Valid attack types ALLOWED_ATTACKS = [a.value for a in AttackType] # Valid mutation types ALLOWED_MUTATIONS = [m.value for m in MutationType] # Max mutation depth MAX_MUTATION_DEPTH = 5 @classmethod def validate_payload(cls, payload: Dict[str, Any]) -> JobPayloadSchema: """ Validate a job payload. Args: payload: The payload to validate Returns: Validated payload as JobPayloadSchema Raises: HTTPException: If validation fails """ try: validated = JobPayloadSchema(**payload) return validated except Exception as e: raise HTTPException( status_code=400, detail={ "error": "invalid_payload", "message": str(e), } ) @classmethod def validate_model_version(cls, model_name: str, model_version: str) -> bool: """ Validate that a model version exists. Args: model_name: Name of the model model_version: Version of the model Returns: True if valid, False otherwise """ # In a real implementation, this would check against the model registry # For now, we accept any model/version but could add validation return True @classmethod def validate_dataset_version(cls, dataset_name: str, dataset_version: str) -> bool: """ Validate that a dataset version exists. Args: dataset_name: Name of the dataset dataset_version: Version of the dataset Returns: True if valid, False otherwise """ # In a real implementation, this would check against the dataset registry # For now, we accept any dataset/version but could add validation return True @classmethod def validate_weights(cls, weights: Dict[str, float]) -> bool: """ Validate that scoring weights sum to 1.0. Args: weights: Dictionary of metric weights Returns: True if valid Raises: HTTPException: If weights don't sum to 1.0 """ required_keys = {"hallucination", "toxicity", "bias", "confidence"} if set(weights.keys()) != required_keys: raise HTTPException( status_code=400, detail={ "error": "invalid_weights", "message": f"Weights must include exactly: {required_keys}", } ) total = sum(weights.values()) if abs(total - 1.0) > 1e-6: raise HTTPException( status_code=400, detail={ "error": "invalid_weights", "message": f"Weights must sum to 1.0, got {total}", } ) return True @classmethod def validate_mutation_depth(cls, depth: int) -> bool: """ Validate mutation depth is within allowed range. Args: depth: Mutation depth Returns: True if valid Raises: HTTPException: If depth is out of range """ if depth < 0 or depth > cls.MAX_MUTATION_DEPTH: raise HTTPException( status_code=400, detail={ "error": "invalid_mutation_depth", "message": f"Mutation depth must be between 0 and {cls.MAX_MUTATION_DEPTH}", } ) return True async def validate_job_payload(request: Request) -> Dict[str, Any]: """ FastAPI dependency to validate job payloads. Usage: @router.post("/jobs") async def create_job( payload: dict = Depends(validate_job_payload) ): ... """ try: body = await request.json() except Exception: raise HTTPException( status_code=400, detail={ "error": "invalid_json", "message": "Request body must be valid JSON", } ) return PayloadValidator.validate_payload(body)