Spaces:
Sleeping
Sleeping
| """ | |
| Advanced Pydantic Validation Patterns for Multi-Agent Systems. | |
| This module demonstrates advanced validation techniques that are particularly | |
| useful in multi-agent travel planning systems. | |
| """ | |
| from datetime import datetime, timezone, timedelta | |
| from decimal import Decimal | |
| from typing import List, Optional, Dict, Any, Union, Literal | |
| from enum import Enum | |
| import re | |
| from pydantic import ( | |
| BaseModel, Field, field_validator, model_validator, | |
| ConfigDict, computed_field, validator | |
| ) | |
| # ============================================================================= | |
| # ADVANCED FIELD VALIDATION PATTERNS | |
| # ============================================================================= | |
| class AdvancedFlightOption(BaseModel): | |
| """ | |
| Advanced flight model demonstrating sophisticated validation patterns. | |
| This model shows: | |
| - Conditional validation based on other fields | |
| - Complex regex patterns | |
| - Custom type coercion | |
| - Computed fields | |
| - Advanced error handling | |
| """ | |
| model_config = ConfigDict( | |
| str_strip_whitespace=True, | |
| validate_assignment=True, | |
| use_enum_values=True, | |
| # Allow extra fields but warn about them | |
| extra='forbid', | |
| # Custom error messages | |
| json_schema_extra={ | |
| "examples": [ | |
| { | |
| "airline": "Delta Air Lines", | |
| "flight_number": "DL1234", | |
| "departure_city": "New York", | |
| "arrival_city": "Los Angeles", | |
| "departure_time": "2024-01-15T08:30:00Z", | |
| "arrival_time": "2024-01-15T11:45:00Z", | |
| "price": 299.99, | |
| "duration_minutes": 195, | |
| "stops": 0, | |
| "aircraft_type": "Boeing 737" | |
| } | |
| ] | |
| } | |
| ) | |
| # Basic fields | |
| airline: str = Field(..., min_length=2, max_length=50) | |
| flight_number: str = Field(..., min_length=3, max_length=10) | |
| departure_city: str = Field(..., min_length=2, max_length=50) | |
| arrival_city: str = Field(..., min_length=2, max_length=50) | |
| departure_time: datetime = Field(...) | |
| arrival_time: datetime = Field(...) | |
| price: Decimal = Field(..., ge=0, decimal_places=2) | |
| duration_minutes: int = Field(..., ge=30, le=1440) | |
| stops: int = Field(..., ge=0, le=3) | |
| # Optional fields with conditional validation | |
| aircraft_type: Optional[str] = Field(None, max_length=20) | |
| terminal_departure: Optional[str] = Field(None, max_length=10) | |
| terminal_arrival: Optional[str] = Field(None, max_length=10) | |
| baggage_allowance: Optional[str] = Field(None, max_length=100) | |
| # Advanced fields | |
| booking_class: Optional[Literal["Economy", "Premium Economy", "Business", "First"]] = None | |
| meal_service: Optional[bool] = None | |
| wifi_available: Optional[bool] = None | |
| seat_pitch: Optional[int] = Field(None, ge=28, le=40) # Inches | |
| # ============================================================================= | |
| # ADVANCED FIELD VALIDATORS | |
| # ============================================================================= | |
| def validate_flight_number_advanced(cls, v): | |
| """ | |
| Advanced flight number validation with airline code matching. | |
| This validator: | |
| 1. Checks format with regex | |
| 2. Validates airline code consistency | |
| 3. Handles special cases | |
| """ | |
| if not v: | |
| raise ValueError('Flight number cannot be empty') | |
| # Convert to uppercase | |
| v = v.upper().strip() | |
| # Regex pattern for flight numbers | |
| pattern = r'^([A-Z]{2,3})(\d{1,4})([A-Z]?)$' | |
| match = re.match(pattern, v) | |
| if not match: | |
| raise ValueError( | |
| f'Invalid flight number format: {v}. ' | |
| f'Expected format: AA123 or AAA1234 or AA123A' | |
| ) | |
| airline_code, number, suffix = match.groups() | |
| # Validate airline code length | |
| if len(airline_code) < 2 or len(airline_code) > 3: | |
| raise ValueError(f'Airline code must be 2-3 characters: {airline_code}') | |
| # Validate number part | |
| if not (1 <= len(number) <= 4): | |
| raise ValueError(f'Flight number must be 1-4 digits: {number}') | |
| return v | |
| def validate_city_names_advanced(cls, v): | |
| """ | |
| Advanced city name validation. | |
| This validator: | |
| 1. Normalizes city names | |
| 2. Handles common variations | |
| 3. Validates against known patterns | |
| """ | |
| if not v or not v.strip(): | |
| raise ValueError('City name cannot be empty') | |
| v = v.strip().title() | |
| # Handle common city name variations | |
| city_mappings = { | |
| 'New York City': 'New York', | |
| 'Los Angeles': 'Los Angeles', | |
| 'San Francisco': 'San Francisco', | |
| 'Washington Dc': 'Washington DC', | |
| 'Washington D.C.': 'Washington DC' | |
| } | |
| v = city_mappings.get(v, v) | |
| # Validate city name format | |
| if not re.match(r'^[A-Za-z\s\-\'\.]+$', v): | |
| raise ValueError(f'City name contains invalid characters: {v}') | |
| # Check for minimum length after normalization | |
| if len(v) < 2: | |
| raise ValueError('City name too short after normalization') | |
| return v | |
| def validate_datetime_advanced(cls, v): | |
| """ | |
| Advanced datetime validation. | |
| This validator: | |
| 1. Handles multiple datetime formats | |
| 2. Ensures timezone awareness | |
| 3. Validates reasonable date ranges | |
| """ | |
| if isinstance(v, str): | |
| # Handle common datetime formats | |
| formats = [ | |
| '%Y-%m-%dT%H:%M:%SZ', | |
| '%Y-%m-%dT%H:%M:%S%z', | |
| '%Y-%m-%d %H:%M:%S', | |
| '%Y-%m-%dT%H:%M:%S.%fZ' | |
| ] | |
| parsed = None | |
| for fmt in formats: | |
| try: | |
| parsed = datetime.strptime(v, fmt) | |
| break | |
| except ValueError: | |
| continue | |
| if parsed is None: | |
| # Try ISO format | |
| try: | |
| parsed = datetime.fromisoformat(v.replace('Z', '+00:00')) | |
| except ValueError: | |
| raise ValueError(f'Invalid datetime format: {v}') | |
| v = parsed | |
| # Ensure timezone awareness | |
| if v.tzinfo is None: | |
| raise ValueError('Datetime must be timezone-aware') | |
| # Validate reasonable date range (not too far in past/future) | |
| now = datetime.now(timezone.utc) | |
| min_date = now - timedelta(days=365) # 1 year ago | |
| max_date = now + timedelta(days=365) # 1 year from now | |
| if v < min_date: | |
| raise ValueError(f'Datetime too far in the past: {v}') | |
| if v > max_date: | |
| raise ValueError(f'Datetime too far in the future: {v}') | |
| return v | |
| def validate_price_advanced(cls, v): | |
| """ | |
| Advanced price validation with business rules. | |
| This validator: | |
| 1. Handles multiple input types | |
| 2. Applies business rules | |
| 3. Normalizes to Decimal | |
| """ | |
| # Handle different input types | |
| if isinstance(v, str): | |
| # Remove currency symbols and whitespace | |
| v = re.sub(r'[$โฌยฃยฅ,\s]', '', v) | |
| # Handle special cases | |
| if v.lower() in ['free', 'n/a', 'na', '']: | |
| raise ValueError('Price cannot be free or N/A') | |
| try: | |
| v = Decimal(v) | |
| except Exception: | |
| raise ValueError(f'Invalid price format: {v}') | |
| elif isinstance(v, (int, float)): | |
| v = Decimal(str(v)) | |
| elif not isinstance(v, Decimal): | |
| raise ValueError(f'Price must be a number, got: {type(v)}') | |
| # Business rule validations | |
| if v < 0: | |
| raise ValueError('Price cannot be negative') | |
| if v > Decimal('50000'): # $50,000 | |
| raise ValueError('Price seems unreasonably high (>$50,000)') | |
| # Round to 2 decimal places | |
| v = v.quantize(Decimal('0.01')) | |
| return v | |
| def validate_stops_advanced(cls, v): | |
| """ | |
| Advanced stops validation with business logic. | |
| """ | |
| if isinstance(v, str): | |
| # Handle string representations | |
| v = v.lower().strip() | |
| if v in ['none', 'direct', '0']: | |
| v = 0 | |
| elif v in ['one', '1']: | |
| v = 1 | |
| elif v in ['two', '2']: | |
| v = 2 | |
| elif v in ['three', '3']: | |
| v = 3 | |
| else: | |
| try: | |
| v = int(v) | |
| except ValueError: | |
| raise ValueError(f'Invalid stops value: {v}') | |
| if not isinstance(v, int): | |
| raise ValueError(f'Stops must be an integer, got: {type(v)}') | |
| if v < 0: | |
| raise ValueError('Stops cannot be negative') | |
| if v > 3: | |
| raise ValueError('Too many stops (maximum 3)') | |
| return v | |
| # ============================================================================= | |
| # CONDITIONAL VALIDATION | |
| # ============================================================================= | |
| def validate_seat_pitch_conditional(cls, v, info): | |
| """ | |
| Conditional validation based on booking class. | |
| This validator shows how to access other field values | |
| during validation. | |
| """ | |
| if v is None: | |
| return v | |
| # Get booking class from other fields | |
| booking_class = info.data.get('booking_class') | |
| if booking_class: | |
| # Different seat pitch requirements by class | |
| min_pitch_by_class = { | |
| 'Economy': 28, | |
| 'Premium Economy': 32, | |
| 'Business': 36, | |
| 'First': 40 | |
| } | |
| min_required = min_pitch_by_class.get(booking_class, 28) | |
| if v < min_required: | |
| raise ValueError( | |
| f'Seat pitch {v}" is too small for {booking_class} class ' | |
| f'(minimum {min_required}")' | |
| ) | |
| return v | |
| # ============================================================================= | |
| # MODEL VALIDATORS (Cross-field validation) | |
| # ============================================================================= | |
| def validate_flight_consistency(self): | |
| """ | |
| Advanced cross-field validation. | |
| This validator checks multiple business rules that involve | |
| multiple fields. | |
| """ | |
| # 1. Time consistency | |
| if self.arrival_time <= self.departure_time: | |
| raise ValueError('Arrival time must be after departure time') | |
| # 2. Duration consistency | |
| actual_duration = (self.arrival_time - self.departure_time).total_seconds() / 60 | |
| duration_diff = abs(actual_duration - self.duration_minutes) | |
| if duration_diff > 30: # 30 minute tolerance | |
| raise ValueError( | |
| f'Duration mismatch: stated {self.duration_minutes} min, ' | |
| f'actual {actual_duration:.0f} min (difference: {duration_diff:.0f} min)' | |
| ) | |
| # 3. Direct flight validation | |
| if self.stops == 0 and self.duration_minutes > 600: # 10 hours | |
| raise ValueError( | |
| f'Direct flight duration too long: {self.duration_minutes} min ' | |
| f'(>10 hours)' | |
| ) | |
| # 4. Price reasonableness based on duration and stops | |
| price_per_hour = self.price / Decimal(str(self.duration_hours)) | |
| if self.stops == 0: # Direct flight | |
| if price_per_hour > Decimal('1000'): # $1000/hour | |
| raise ValueError( | |
| f'Direct flight price per hour too high: ${price_per_hour:.2f}/hour' | |
| ) | |
| else: # Connecting flight | |
| if price_per_hour > Decimal('800'): # $800/hour | |
| raise ValueError( | |
| f'Connecting flight price per hour too high: ${price_per_hour:.2f}/hour' | |
| ) | |
| # 5. Terminal validation (if provided) | |
| if self.terminal_departure and self.terminal_arrival: | |
| if self.terminal_departure == self.terminal_arrival and self.stops > 0: | |
| raise ValueError( | |
| 'Connecting flight cannot have same departure and arrival terminal' | |
| ) | |
| return self | |
| # ============================================================================= | |
| # COMPUTED FIELDS | |
| # ============================================================================= | |
| def duration_hours(self) -> float: | |
| """Duration in hours.""" | |
| return self.duration_minutes / 60 | |
| def is_direct_flight(self) -> bool: | |
| """Check if direct flight.""" | |
| return self.stops == 0 | |
| def price_per_hour(self) -> Decimal: | |
| """Price per hour of flight.""" | |
| if self.duration_minutes == 0: | |
| return Decimal('0') | |
| return self.price / Decimal(str(self.duration_hours)) | |
| def route_summary(self) -> str: | |
| """Route summary string.""" | |
| return f"{self.departure_city} โ {self.arrival_city}" | |
| def flight_summary(self) -> str: | |
| """Complete flight summary.""" | |
| return ( | |
| f"{self.airline} {self.flight_number}: {self.route_summary} " | |
| f"({self.duration_hours:.1f}h, ${self.price}, {self.stops} stops)" | |
| ) | |
| # ============================================================================= | |
| # UTILITY METHODS | |
| # ============================================================================= | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary with computed fields.""" | |
| return self.model_dump(include_computed=True) | |
| def to_json(self) -> str: | |
| """Convert to JSON with computed fields.""" | |
| return self.model_dump_json(include_computed=True) | |
| def is_cheaper_than(self, other: 'AdvancedFlightOption') -> bool: | |
| """Compare prices with another flight.""" | |
| return self.price < other.price | |
| def is_faster_than(self, other: 'AdvancedFlightOption') -> bool: | |
| """Compare durations with another flight.""" | |
| return self.duration_minutes < other.duration_minutes | |
| def __str__(self) -> str: | |
| """String representation.""" | |
| return self.flight_summary | |
| # ============================================================================= | |
| # MULTI-AGENT COMMUNICATION MODELS | |
| # ============================================================================= | |
| class AgentRequest(BaseModel): | |
| """Request model for agent communication.""" | |
| request_id: str = Field(..., min_length=1, max_length=50) | |
| sender_agent: str = Field(..., min_length=2, max_length=50) | |
| recipient_agent: str = Field(..., min_length=2, max_length=50) | |
| request_type: str = Field(..., min_length=2, max_length=50) | |
| data: Dict[str, Any] = Field(default_factory=dict) | |
| timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) | |
| priority: Literal["low", "medium", "high", "urgent"] = "medium" | |
| timeout_seconds: int = Field(30, ge=5, le=300) | |
| def validate_request_id(cls, v): | |
| """Validate request ID format.""" | |
| if not re.match(r'^[A-Za-z0-9_-]+$', v): | |
| raise ValueError('Request ID can only contain letters, numbers, hyphens, and underscores') | |
| return v | |
| def parse_timestamp(cls, v): | |
| """Parse timestamp from string.""" | |
| if isinstance(v, str): | |
| return datetime.fromisoformat(v.replace('Z', '+00:00')) | |
| return v | |
| class AgentResponse(BaseModel): | |
| """Response model for agent communication.""" | |
| request_id: str = Field(..., min_length=1, max_length=50) | |
| sender_agent: str = Field(..., min_length=2, max_length=50) | |
| recipient_agent: str = Field(..., min_length=2, max_length=50) | |
| response_type: str = Field(..., min_length=2, max_length=50) | |
| success: bool = Field(...) | |
| data: Optional[Dict[str, Any]] = None | |
| error_message: Optional[str] = None | |
| timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) | |
| processing_time_ms: Optional[int] = Field(None, ge=0) | |
| def validate_response_consistency(self): | |
| """Validate response consistency.""" | |
| if self.success and self.error_message: | |
| raise ValueError('Successful response cannot have error message') | |
| if not self.success and not self.error_message: | |
| raise ValueError('Failed response must have error message') | |
| return self | |
| # ============================================================================= | |
| # EXAMPLE USAGE AND TESTING | |
| # ============================================================================= | |
| def demonstrate_advanced_validation(): | |
| """Demonstrate advanced validation features.""" | |
| print("๐ Advanced Pydantic Validation Demo") | |
| print("=" * 50) | |
| # Valid flight data | |
| valid_data = { | |
| "airline": "Delta Air Lines", | |
| "flight_number": "DL1234", | |
| "departure_city": "New York", | |
| "arrival_city": "Los Angeles", | |
| "departure_time": "2024-01-15T08:30:00Z", | |
| "arrival_time": "2024-01-15T11:45:00Z", | |
| "price": "299.99", | |
| "duration_minutes": 195, | |
| "stops": 0, | |
| "booking_class": "Economy", | |
| "seat_pitch": 30 | |
| } | |
| try: | |
| flight = AdvancedFlightOption(**valid_data) | |
| print(f"โ Valid flight created: {flight.flight_summary}") | |
| print(f" Computed fields:") | |
| print(f" - Duration: {flight.duration_hours:.1f} hours") | |
| print(f" - Price per hour: ${flight.price_per_hour:.2f}") | |
| print(f" - Is direct: {flight.is_direct_flight}") | |
| print(f" - Route: {flight.route_summary}") | |
| except ValidationError as e: | |
| print(f"โ Validation failed: {e}") | |
| # Test invalid data | |
| invalid_examples = [ | |
| { | |
| "name": "Invalid flight number", | |
| "data": {**valid_data, "flight_number": "1234"} # Missing airline code | |
| }, | |
| { | |
| "name": "Invalid price", | |
| "data": {**valid_data, "price": "free"} # Invalid price | |
| }, | |
| { | |
| "name": "Invalid seat pitch for class", | |
| "data": {**valid_data, "booking_class": "First", "seat_pitch": 30} # Too small for First | |
| }, | |
| { | |
| "name": "Duration mismatch", | |
| "data": {**valid_data, "duration_minutes": 600} # Doesn't match actual duration | |
| } | |
| ] | |
| for example in invalid_examples: | |
| print(f"\nโ {example['name']}:") | |
| try: | |
| AdvancedFlightOption(**example['data']) | |
| print(" Unexpected: This should have failed!") | |
| except ValidationError as e: | |
| print(f" Caught error: {e}") | |
| if __name__ == "__main__": | |
| demonstrate_advanced_validation() | |