wanderlust.ai / src /wanderlust_ai /models /advanced_validation.py
BlakeL's picture
Upload 115 files
3f9f85b verified
"""
Advanced Pydantic Validation Patterns for Multi-Agent Systems.
This module demonstrates advanced validation techniques that are particularly
useful in multi-agent travel planning systems.
"""
from datetime import datetime, timezone, timedelta
from decimal import Decimal
from typing import List, Optional, Dict, Any, Union, Literal
from enum import Enum
import re
from pydantic import (
BaseModel, Field, field_validator, model_validator,
ConfigDict, computed_field, validator
)
# =============================================================================
# ADVANCED FIELD VALIDATION PATTERNS
# =============================================================================
class AdvancedFlightOption(BaseModel):
"""
Advanced flight model demonstrating sophisticated validation patterns.
This model shows:
- Conditional validation based on other fields
- Complex regex patterns
- Custom type coercion
- Computed fields
- Advanced error handling
"""
model_config = ConfigDict(
str_strip_whitespace=True,
validate_assignment=True,
use_enum_values=True,
# Allow extra fields but warn about them
extra='forbid',
# Custom error messages
json_schema_extra={
"examples": [
{
"airline": "Delta Air Lines",
"flight_number": "DL1234",
"departure_city": "New York",
"arrival_city": "Los Angeles",
"departure_time": "2024-01-15T08:30:00Z",
"arrival_time": "2024-01-15T11:45:00Z",
"price": 299.99,
"duration_minutes": 195,
"stops": 0,
"aircraft_type": "Boeing 737"
}
]
}
)
# Basic fields
airline: str = Field(..., min_length=2, max_length=50)
flight_number: str = Field(..., min_length=3, max_length=10)
departure_city: str = Field(..., min_length=2, max_length=50)
arrival_city: str = Field(..., min_length=2, max_length=50)
departure_time: datetime = Field(...)
arrival_time: datetime = Field(...)
price: Decimal = Field(..., ge=0, decimal_places=2)
duration_minutes: int = Field(..., ge=30, le=1440)
stops: int = Field(..., ge=0, le=3)
# Optional fields with conditional validation
aircraft_type: Optional[str] = Field(None, max_length=20)
terminal_departure: Optional[str] = Field(None, max_length=10)
terminal_arrival: Optional[str] = Field(None, max_length=10)
baggage_allowance: Optional[str] = Field(None, max_length=100)
# Advanced fields
booking_class: Optional[Literal["Economy", "Premium Economy", "Business", "First"]] = None
meal_service: Optional[bool] = None
wifi_available: Optional[bool] = None
seat_pitch: Optional[int] = Field(None, ge=28, le=40) # Inches
# =============================================================================
# ADVANCED FIELD VALIDATORS
# =============================================================================
@field_validator('flight_number')
@classmethod
def validate_flight_number_advanced(cls, v):
"""
Advanced flight number validation with airline code matching.
This validator:
1. Checks format with regex
2. Validates airline code consistency
3. Handles special cases
"""
if not v:
raise ValueError('Flight number cannot be empty')
# Convert to uppercase
v = v.upper().strip()
# Regex pattern for flight numbers
pattern = r'^([A-Z]{2,3})(\d{1,4})([A-Z]?)$'
match = re.match(pattern, v)
if not match:
raise ValueError(
f'Invalid flight number format: {v}. '
f'Expected format: AA123 or AAA1234 or AA123A'
)
airline_code, number, suffix = match.groups()
# Validate airline code length
if len(airline_code) < 2 or len(airline_code) > 3:
raise ValueError(f'Airline code must be 2-3 characters: {airline_code}')
# Validate number part
if not (1 <= len(number) <= 4):
raise ValueError(f'Flight number must be 1-4 digits: {number}')
return v
@field_validator('departure_city', 'arrival_city')
@classmethod
def validate_city_names_advanced(cls, v):
"""
Advanced city name validation.
This validator:
1. Normalizes city names
2. Handles common variations
3. Validates against known patterns
"""
if not v or not v.strip():
raise ValueError('City name cannot be empty')
v = v.strip().title()
# Handle common city name variations
city_mappings = {
'New York City': 'New York',
'Los Angeles': 'Los Angeles',
'San Francisco': 'San Francisco',
'Washington Dc': 'Washington DC',
'Washington D.C.': 'Washington DC'
}
v = city_mappings.get(v, v)
# Validate city name format
if not re.match(r'^[A-Za-z\s\-\'\.]+$', v):
raise ValueError(f'City name contains invalid characters: {v}')
# Check for minimum length after normalization
if len(v) < 2:
raise ValueError('City name too short after normalization')
return v
@field_validator('departure_time', 'arrival_time')
@classmethod
def validate_datetime_advanced(cls, v):
"""
Advanced datetime validation.
This validator:
1. Handles multiple datetime formats
2. Ensures timezone awareness
3. Validates reasonable date ranges
"""
if isinstance(v, str):
# Handle common datetime formats
formats = [
'%Y-%m-%dT%H:%M:%SZ',
'%Y-%m-%dT%H:%M:%S%z',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%dT%H:%M:%S.%fZ'
]
parsed = None
for fmt in formats:
try:
parsed = datetime.strptime(v, fmt)
break
except ValueError:
continue
if parsed is None:
# Try ISO format
try:
parsed = datetime.fromisoformat(v.replace('Z', '+00:00'))
except ValueError:
raise ValueError(f'Invalid datetime format: {v}')
v = parsed
# Ensure timezone awareness
if v.tzinfo is None:
raise ValueError('Datetime must be timezone-aware')
# Validate reasonable date range (not too far in past/future)
now = datetime.now(timezone.utc)
min_date = now - timedelta(days=365) # 1 year ago
max_date = now + timedelta(days=365) # 1 year from now
if v < min_date:
raise ValueError(f'Datetime too far in the past: {v}')
if v > max_date:
raise ValueError(f'Datetime too far in the future: {v}')
return v
@field_validator('price')
@classmethod
def validate_price_advanced(cls, v):
"""
Advanced price validation with business rules.
This validator:
1. Handles multiple input types
2. Applies business rules
3. Normalizes to Decimal
"""
# Handle different input types
if isinstance(v, str):
# Remove currency symbols and whitespace
v = re.sub(r'[$โ‚ฌยฃยฅ,\s]', '', v)
# Handle special cases
if v.lower() in ['free', 'n/a', 'na', '']:
raise ValueError('Price cannot be free or N/A')
try:
v = Decimal(v)
except Exception:
raise ValueError(f'Invalid price format: {v}')
elif isinstance(v, (int, float)):
v = Decimal(str(v))
elif not isinstance(v, Decimal):
raise ValueError(f'Price must be a number, got: {type(v)}')
# Business rule validations
if v < 0:
raise ValueError('Price cannot be negative')
if v > Decimal('50000'): # $50,000
raise ValueError('Price seems unreasonably high (>$50,000)')
# Round to 2 decimal places
v = v.quantize(Decimal('0.01'))
return v
@field_validator('stops')
@classmethod
def validate_stops_advanced(cls, v):
"""
Advanced stops validation with business logic.
"""
if isinstance(v, str):
# Handle string representations
v = v.lower().strip()
if v in ['none', 'direct', '0']:
v = 0
elif v in ['one', '1']:
v = 1
elif v in ['two', '2']:
v = 2
elif v in ['three', '3']:
v = 3
else:
try:
v = int(v)
except ValueError:
raise ValueError(f'Invalid stops value: {v}')
if not isinstance(v, int):
raise ValueError(f'Stops must be an integer, got: {type(v)}')
if v < 0:
raise ValueError('Stops cannot be negative')
if v > 3:
raise ValueError('Too many stops (maximum 3)')
return v
# =============================================================================
# CONDITIONAL VALIDATION
# =============================================================================
@field_validator('seat_pitch')
@classmethod
def validate_seat_pitch_conditional(cls, v, info):
"""
Conditional validation based on booking class.
This validator shows how to access other field values
during validation.
"""
if v is None:
return v
# Get booking class from other fields
booking_class = info.data.get('booking_class')
if booking_class:
# Different seat pitch requirements by class
min_pitch_by_class = {
'Economy': 28,
'Premium Economy': 32,
'Business': 36,
'First': 40
}
min_required = min_pitch_by_class.get(booking_class, 28)
if v < min_required:
raise ValueError(
f'Seat pitch {v}" is too small for {booking_class} class '
f'(minimum {min_required}")'
)
return v
# =============================================================================
# MODEL VALIDATORS (Cross-field validation)
# =============================================================================
@model_validator(mode='after')
def validate_flight_consistency(self):
"""
Advanced cross-field validation.
This validator checks multiple business rules that involve
multiple fields.
"""
# 1. Time consistency
if self.arrival_time <= self.departure_time:
raise ValueError('Arrival time must be after departure time')
# 2. Duration consistency
actual_duration = (self.arrival_time - self.departure_time).total_seconds() / 60
duration_diff = abs(actual_duration - self.duration_minutes)
if duration_diff > 30: # 30 minute tolerance
raise ValueError(
f'Duration mismatch: stated {self.duration_minutes} min, '
f'actual {actual_duration:.0f} min (difference: {duration_diff:.0f} min)'
)
# 3. Direct flight validation
if self.stops == 0 and self.duration_minutes > 600: # 10 hours
raise ValueError(
f'Direct flight duration too long: {self.duration_minutes} min '
f'(>10 hours)'
)
# 4. Price reasonableness based on duration and stops
price_per_hour = self.price / Decimal(str(self.duration_hours))
if self.stops == 0: # Direct flight
if price_per_hour > Decimal('1000'): # $1000/hour
raise ValueError(
f'Direct flight price per hour too high: ${price_per_hour:.2f}/hour'
)
else: # Connecting flight
if price_per_hour > Decimal('800'): # $800/hour
raise ValueError(
f'Connecting flight price per hour too high: ${price_per_hour:.2f}/hour'
)
# 5. Terminal validation (if provided)
if self.terminal_departure and self.terminal_arrival:
if self.terminal_departure == self.terminal_arrival and self.stops > 0:
raise ValueError(
'Connecting flight cannot have same departure and arrival terminal'
)
return self
# =============================================================================
# COMPUTED FIELDS
# =============================================================================
@computed_field
@property
def duration_hours(self) -> float:
"""Duration in hours."""
return self.duration_minutes / 60
@computed_field
@property
def is_direct_flight(self) -> bool:
"""Check if direct flight."""
return self.stops == 0
@computed_field
@property
def price_per_hour(self) -> Decimal:
"""Price per hour of flight."""
if self.duration_minutes == 0:
return Decimal('0')
return self.price / Decimal(str(self.duration_hours))
@computed_field
@property
def route_summary(self) -> str:
"""Route summary string."""
return f"{self.departure_city} โ†’ {self.arrival_city}"
@computed_field
@property
def flight_summary(self) -> str:
"""Complete flight summary."""
return (
f"{self.airline} {self.flight_number}: {self.route_summary} "
f"({self.duration_hours:.1f}h, ${self.price}, {self.stops} stops)"
)
# =============================================================================
# UTILITY METHODS
# =============================================================================
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary with computed fields."""
return self.model_dump(include_computed=True)
def to_json(self) -> str:
"""Convert to JSON with computed fields."""
return self.model_dump_json(include_computed=True)
def is_cheaper_than(self, other: 'AdvancedFlightOption') -> bool:
"""Compare prices with another flight."""
return self.price < other.price
def is_faster_than(self, other: 'AdvancedFlightOption') -> bool:
"""Compare durations with another flight."""
return self.duration_minutes < other.duration_minutes
def __str__(self) -> str:
"""String representation."""
return self.flight_summary
# =============================================================================
# MULTI-AGENT COMMUNICATION MODELS
# =============================================================================
class AgentRequest(BaseModel):
"""Request model for agent communication."""
request_id: str = Field(..., min_length=1, max_length=50)
sender_agent: str = Field(..., min_length=2, max_length=50)
recipient_agent: str = Field(..., min_length=2, max_length=50)
request_type: str = Field(..., min_length=2, max_length=50)
data: Dict[str, Any] = Field(default_factory=dict)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
priority: Literal["low", "medium", "high", "urgent"] = "medium"
timeout_seconds: int = Field(30, ge=5, le=300)
@field_validator('request_id')
@classmethod
def validate_request_id(cls, v):
"""Validate request ID format."""
if not re.match(r'^[A-Za-z0-9_-]+$', v):
raise ValueError('Request ID can only contain letters, numbers, hyphens, and underscores')
return v
@field_validator('timestamp', mode='before')
@classmethod
def parse_timestamp(cls, v):
"""Parse timestamp from string."""
if isinstance(v, str):
return datetime.fromisoformat(v.replace('Z', '+00:00'))
return v
class AgentResponse(BaseModel):
"""Response model for agent communication."""
request_id: str = Field(..., min_length=1, max_length=50)
sender_agent: str = Field(..., min_length=2, max_length=50)
recipient_agent: str = Field(..., min_length=2, max_length=50)
response_type: str = Field(..., min_length=2, max_length=50)
success: bool = Field(...)
data: Optional[Dict[str, Any]] = None
error_message: Optional[str] = None
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
processing_time_ms: Optional[int] = Field(None, ge=0)
@model_validator(mode='after')
def validate_response_consistency(self):
"""Validate response consistency."""
if self.success and self.error_message:
raise ValueError('Successful response cannot have error message')
if not self.success and not self.error_message:
raise ValueError('Failed response must have error message')
return self
# =============================================================================
# EXAMPLE USAGE AND TESTING
# =============================================================================
def demonstrate_advanced_validation():
"""Demonstrate advanced validation features."""
print("๐Ÿš€ Advanced Pydantic Validation Demo")
print("=" * 50)
# Valid flight data
valid_data = {
"airline": "Delta Air Lines",
"flight_number": "DL1234",
"departure_city": "New York",
"arrival_city": "Los Angeles",
"departure_time": "2024-01-15T08:30:00Z",
"arrival_time": "2024-01-15T11:45:00Z",
"price": "299.99",
"duration_minutes": 195,
"stops": 0,
"booking_class": "Economy",
"seat_pitch": 30
}
try:
flight = AdvancedFlightOption(**valid_data)
print(f"โœ… Valid flight created: {flight.flight_summary}")
print(f" Computed fields:")
print(f" - Duration: {flight.duration_hours:.1f} hours")
print(f" - Price per hour: ${flight.price_per_hour:.2f}")
print(f" - Is direct: {flight.is_direct_flight}")
print(f" - Route: {flight.route_summary}")
except ValidationError as e:
print(f"โŒ Validation failed: {e}")
# Test invalid data
invalid_examples = [
{
"name": "Invalid flight number",
"data": {**valid_data, "flight_number": "1234"} # Missing airline code
},
{
"name": "Invalid price",
"data": {**valid_data, "price": "free"} # Invalid price
},
{
"name": "Invalid seat pitch for class",
"data": {**valid_data, "booking_class": "First", "seat_pitch": 30} # Too small for First
},
{
"name": "Duration mismatch",
"data": {**valid_data, "duration_minutes": 600} # Doesn't match actual duration
}
]
for example in invalid_examples:
print(f"\nโŒ {example['name']}:")
try:
AdvancedFlightOption(**example['data'])
print(" Unexpected: This should have failed!")
except ValidationError as e:
print(f" Caught error: {e}")
if __name__ == "__main__":
demonstrate_advanced_validation()