| """
|
| Payload Validation Middleware for AegisLM
|
|
|
| Provides middleware to validate job payloads before processing.
|
| """
|
|
|
| from typing import Any, Dict, List, Optional
|
| from enum import Enum
|
|
|
| from fastapi import Request, HTTPException
|
| from pydantic import BaseModel, Field, validator
|
|
|
|
|
| class EvaluationMode(str, Enum):
|
| """Evaluation modes."""
|
| LIGHTWEIGHT = "lightweight"
|
| FULL = "full"
|
|
|
|
|
| class AttackType(str, Enum):
|
| """Valid attack types."""
|
| INJECTION = "injection"
|
| JAILBREAK = "jailbreak"
|
| BIAS_TRIGGER = "bias_trigger"
|
| CONTEXT_POISON = "context_poison"
|
| ROLE_CONFUSION = "role_confusion"
|
| CHAINING = "chaining"
|
|
|
|
|
| class MutationType(str, Enum):
|
| """Valid mutation types."""
|
| SYNONYM = "synonym"
|
| ROLE_SWAP = "role_swap"
|
| CONTEXT_OBFUSCATION = "context_obfuscation"
|
| MULTI_HOP = "multi_hop"
|
| PARAPHRASE = "paraphrase"
|
|
|
|
|
| class JobPayloadSchema(BaseModel):
|
| """Schema for job payload validation."""
|
|
|
| model_name: str = Field(..., min_length=1, max_length=255)
|
| model_version: str = Field(..., min_length=1, max_length=100)
|
| dataset_name: str = Field(..., min_length=1, max_length=255)
|
| dataset_version: str = Field(..., min_length=1, max_length=100)
|
|
|
|
|
| evaluation_mode: EvaluationMode = EvaluationMode.FULL
|
| temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
| max_tokens: int = Field(default=512, ge=1, le=4096)
|
|
|
|
|
| attack_types: Optional[List[str]] = None
|
| mutation_enabled: bool = True
|
| mutation_depth: int = Field(default=1, ge=0, le=5)
|
|
|
|
|
| batch_size: int = Field(default=10, ge=1, le=100)
|
| max_samples: Optional[int] = Field(default=None, ge=1)
|
|
|
| @validator("attack_types")
|
| def validate_attack_types(cls, v):
|
| if v is not None:
|
| valid_attacks = [a.value for a in AttackType]
|
| for attack in v:
|
| if attack not in valid_attacks:
|
| raise ValueError(f"Invalid attack type: {attack}")
|
| return v
|
|
|
| @validator("dataset_name")
|
| def validate_dataset_name(cls, v):
|
|
|
| allowed_datasets = ["advbench", "truthfulqa", "aegislm-harmful-queries"]
|
| if v not in allowed_datasets:
|
|
|
| pass
|
| return v
|
|
|
|
|
| class PayloadValidator:
|
| """
|
| Validates job payloads for security and integrity.
|
|
|
| Checks:
|
| - Required fields present
|
| - Field values within acceptable ranges
|
| - Dataset and model versions exist
|
| - Attack types are valid
|
| - Weights sum to 1.0 (if provided)
|
| """
|
|
|
|
|
| ALLOWED_DATASETS = ["advbench", "truthfulqa", "aegislm-harmful-queries"]
|
|
|
|
|
| ALLOWED_ATTACKS = [a.value for a in AttackType]
|
|
|
|
|
| ALLOWED_MUTATIONS = [m.value for m in MutationType]
|
|
|
|
|
| MAX_MUTATION_DEPTH = 5
|
|
|
| @classmethod
|
| def validate_payload(cls, payload: Dict[str, Any]) -> JobPayloadSchema:
|
| """
|
| Validate a job payload.
|
|
|
| Args:
|
| payload: The payload to validate
|
|
|
| Returns:
|
| Validated payload as JobPayloadSchema
|
|
|
| Raises:
|
| HTTPException: If validation fails
|
| """
|
| try:
|
| validated = JobPayloadSchema(**payload)
|
| return validated
|
| except Exception as e:
|
| raise HTTPException(
|
| status_code=400,
|
| detail={
|
| "error": "invalid_payload",
|
| "message": str(e),
|
| }
|
| )
|
|
|
| @classmethod
|
| def validate_model_version(cls, model_name: str, model_version: str) -> bool:
|
| """
|
| Validate that a model version exists.
|
|
|
| Args:
|
| model_name: Name of the model
|
| model_version: Version of the model
|
|
|
| Returns:
|
| True if valid, False otherwise
|
| """
|
|
|
|
|
| return True
|
|
|
| @classmethod
|
| def validate_dataset_version(cls, dataset_name: str, dataset_version: str) -> bool:
|
| """
|
| Validate that a dataset version exists.
|
|
|
| Args:
|
| dataset_name: Name of the dataset
|
| dataset_version: Version of the dataset
|
|
|
| Returns:
|
| True if valid, False otherwise
|
| """
|
|
|
|
|
| return True
|
|
|
| @classmethod
|
| def validate_weights(cls, weights: Dict[str, float]) -> bool:
|
| """
|
| Validate that scoring weights sum to 1.0.
|
|
|
| Args:
|
| weights: Dictionary of metric weights
|
|
|
| Returns:
|
| True if valid
|
|
|
| Raises:
|
| HTTPException: If weights don't sum to 1.0
|
| """
|
| required_keys = {"hallucination", "toxicity", "bias", "confidence"}
|
|
|
| if set(weights.keys()) != required_keys:
|
| raise HTTPException(
|
| status_code=400,
|
| detail={
|
| "error": "invalid_weights",
|
| "message": f"Weights must include exactly: {required_keys}",
|
| }
|
| )
|
|
|
| total = sum(weights.values())
|
| if abs(total - 1.0) > 1e-6:
|
| raise HTTPException(
|
| status_code=400,
|
| detail={
|
| "error": "invalid_weights",
|
| "message": f"Weights must sum to 1.0, got {total}",
|
| }
|
| )
|
|
|
| return True
|
|
|
| @classmethod
|
| def validate_mutation_depth(cls, depth: int) -> bool:
|
| """
|
| Validate mutation depth is within allowed range.
|
|
|
| Args:
|
| depth: Mutation depth
|
|
|
| Returns:
|
| True if valid
|
|
|
| Raises:
|
| HTTPException: If depth is out of range
|
| """
|
| if depth < 0 or depth > cls.MAX_MUTATION_DEPTH:
|
| raise HTTPException(
|
| status_code=400,
|
| detail={
|
| "error": "invalid_mutation_depth",
|
| "message": f"Mutation depth must be between 0 and {cls.MAX_MUTATION_DEPTH}",
|
| }
|
| )
|
|
|
| return True
|
|
|
|
|
| async def validate_job_payload(request: Request) -> Dict[str, Any]:
|
| """
|
| FastAPI dependency to validate job payloads.
|
|
|
| Usage:
|
| @router.post("/jobs")
|
| async def create_job(
|
| payload: dict = Depends(validate_job_payload)
|
| ):
|
| ...
|
| """
|
| try:
|
| body = await request.json()
|
| except Exception:
|
| raise HTTPException(
|
| status_code=400,
|
| detail={
|
| "error": "invalid_json",
|
| "message": "Request body must be valid JSON",
|
| }
|
| )
|
|
|
| return PayloadValidator.validate_payload(body)
|
|
|