aegislm / backend /api /middleware /payload_validation.py
ACA050's picture
Upload 50 files
1a4aa87 verified
"""
Payload Validation Middleware for AegisLM
Provides middleware to validate job payloads before processing.
"""
from typing import Any, Dict, List, Optional
from enum import Enum
from fastapi import Request, HTTPException
from pydantic import BaseModel, Field, validator
class EvaluationMode(str, Enum):
"""Evaluation modes."""
LIGHTWEIGHT = "lightweight"
FULL = "full"
class AttackType(str, Enum):
"""Valid attack types."""
INJECTION = "injection"
JAILBREAK = "jailbreak"
BIAS_TRIGGER = "bias_trigger"
CONTEXT_POISON = "context_poison"
ROLE_CONFUSION = "role_confusion"
CHAINING = "chaining"
class MutationType(str, Enum):
"""Valid mutation types."""
SYNONYM = "synonym"
ROLE_SWAP = "role_swap"
CONTEXT_OBFUSCATION = "context_obfuscation"
MULTI_HOP = "multi_hop"
PARAPHRASE = "paraphrase"
class JobPayloadSchema(BaseModel):
"""Schema for job payload validation."""
model_name: str = Field(..., min_length=1, max_length=255)
model_version: str = Field(..., min_length=1, max_length=100)
dataset_name: str = Field(..., min_length=1, max_length=255)
dataset_version: str = Field(..., min_length=1, max_length=100)
# Evaluation settings
evaluation_mode: EvaluationMode = EvaluationMode.FULL
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: int = Field(default=512, ge=1, le=4096)
# Attack settings
attack_types: Optional[List[str]] = None
mutation_enabled: bool = True
mutation_depth: int = Field(default=1, ge=0, le=5)
# Batch settings
batch_size: int = Field(default=10, ge=1, le=100)
max_samples: Optional[int] = Field(default=None, ge=1)
@validator("attack_types")
def validate_attack_types(cls, v):
if v is not None:
valid_attacks = [a.value for a in AttackType]
for attack in v:
if attack not in valid_attacks:
raise ValueError(f"Invalid attack type: {attack}")
return v
@validator("dataset_name")
def validate_dataset_name(cls, v):
# Check against known datasets
allowed_datasets = ["advbench", "truthfulqa", "aegislm-harmful-queries"]
if v not in allowed_datasets:
# Allow for custom datasets but warn
pass
return v
class PayloadValidator:
"""
Validates job payloads for security and integrity.
Checks:
- Required fields present
- Field values within acceptable ranges
- Dataset and model versions exist
- Attack types are valid
- Weights sum to 1.0 (if provided)
"""
# Valid dataset names
ALLOWED_DATASETS = ["advbench", "truthfulqa", "aegislm-harmful-queries"]
# Valid attack types
ALLOWED_ATTACKS = [a.value for a in AttackType]
# Valid mutation types
ALLOWED_MUTATIONS = [m.value for m in MutationType]
# Max mutation depth
MAX_MUTATION_DEPTH = 5
@classmethod
def validate_payload(cls, payload: Dict[str, Any]) -> JobPayloadSchema:
"""
Validate a job payload.
Args:
payload: The payload to validate
Returns:
Validated payload as JobPayloadSchema
Raises:
HTTPException: If validation fails
"""
try:
validated = JobPayloadSchema(**payload)
return validated
except Exception as e:
raise HTTPException(
status_code=400,
detail={
"error": "invalid_payload",
"message": str(e),
}
)
@classmethod
def validate_model_version(cls, model_name: str, model_version: str) -> bool:
"""
Validate that a model version exists.
Args:
model_name: Name of the model
model_version: Version of the model
Returns:
True if valid, False otherwise
"""
# In a real implementation, this would check against the model registry
# For now, we accept any model/version but could add validation
return True
@classmethod
def validate_dataset_version(cls, dataset_name: str, dataset_version: str) -> bool:
"""
Validate that a dataset version exists.
Args:
dataset_name: Name of the dataset
dataset_version: Version of the dataset
Returns:
True if valid, False otherwise
"""
# In a real implementation, this would check against the dataset registry
# For now, we accept any dataset/version but could add validation
return True
@classmethod
def validate_weights(cls, weights: Dict[str, float]) -> bool:
"""
Validate that scoring weights sum to 1.0.
Args:
weights: Dictionary of metric weights
Returns:
True if valid
Raises:
HTTPException: If weights don't sum to 1.0
"""
required_keys = {"hallucination", "toxicity", "bias", "confidence"}
if set(weights.keys()) != required_keys:
raise HTTPException(
status_code=400,
detail={
"error": "invalid_weights",
"message": f"Weights must include exactly: {required_keys}",
}
)
total = sum(weights.values())
if abs(total - 1.0) > 1e-6:
raise HTTPException(
status_code=400,
detail={
"error": "invalid_weights",
"message": f"Weights must sum to 1.0, got {total}",
}
)
return True
@classmethod
def validate_mutation_depth(cls, depth: int) -> bool:
"""
Validate mutation depth is within allowed range.
Args:
depth: Mutation depth
Returns:
True if valid
Raises:
HTTPException: If depth is out of range
"""
if depth < 0 or depth > cls.MAX_MUTATION_DEPTH:
raise HTTPException(
status_code=400,
detail={
"error": "invalid_mutation_depth",
"message": f"Mutation depth must be between 0 and {cls.MAX_MUTATION_DEPTH}",
}
)
return True
async def validate_job_payload(request: Request) -> Dict[str, Any]:
"""
FastAPI dependency to validate job payloads.
Usage:
@router.post("/jobs")
async def create_job(
payload: dict = Depends(validate_job_payload)
):
...
"""
try:
body = await request.json()
except Exception:
raise HTTPException(
status_code=400,
detail={
"error": "invalid_json",
"message": "Request body must be valid JSON",
}
)
return PayloadValidator.validate_payload(body)