Spaces:
Sleeping
Sleeping
| """ | |
| Comprehensive Trip Validation System | |
| This module implements sophisticated validation patterns for ensuring AI systems | |
| produce feasible, high-quality travel recommendations that meet user expectations | |
| and business requirements. | |
| Key Features: | |
| 1. Multi-level validation (individual, cross-agent, overall) | |
| 2. Business rule validation (travel times, sequences, budget) | |
| 3. Quality scoring for complete itineraries | |
| 4. User satisfaction factor validation | |
| 5. Smart conflict detection and resolution | |
| 6. Detailed validation reporting | |
| """ | |
| import asyncio | |
| import uuid | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timedelta | |
| from enum import Enum | |
| from typing import Dict, List, Optional, Any, Tuple, Set, Union | |
| from decimal import Decimal | |
| import math | |
| from pydantic import BaseModel, Field, validator | |
| from ..models.flight_models import FlightOption | |
| from ..models.hotel_models import HotelOption | |
| from ..models.poi_models import POI | |
| # Import validation rules (will be imported after class definitions to avoid circular imports) | |
| class ValidationLevel(str, Enum): | |
| """Levels of validation depth.""" | |
| INDIVIDUAL = "individual" # Single agent output validation | |
| CROSS_AGENT = "cross_agent" # Compatibility between agents | |
| OVERALL = "overall" # Complete trip quality | |
| BUSINESS_RULES = "business_rules" # Business logic validation | |
| USER_SATISFACTION = "user_satisfaction" # User experience validation | |
| class ValidationSeverity(str, Enum): | |
| """Severity levels for validation issues.""" | |
| CRITICAL = "critical" # Trip cannot proceed | |
| WARNING = "warning" # Significant issue, should be addressed | |
| INFO = "info" # Minor issue or suggestion | |
| PASS = "pass" # No issues found | |
| class ConflictType(str, Enum): | |
| """Types of conflicts that can occur.""" | |
| TEMPORAL = "temporal" # Time conflicts | |
| SPATIAL = "spatial" # Location conflicts | |
| BUDGET = "budget" # Budget conflicts | |
| PREFERENCE = "preference" # User preference conflicts | |
| LOGISTICAL = "logistical" # Practical/logistical conflicts | |
| QUALITY = "quality" # Quality standard conflicts | |
| class ValidationResult: | |
| """Result of a validation check.""" | |
| validation_id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| level: ValidationLevel = ValidationLevel.INDIVIDUAL | |
| severity: ValidationSeverity = ValidationSeverity.PASS | |
| rule_name: str = "" | |
| message: str = "" | |
| details: Dict[str, Any] = field(default_factory=dict) | |
| suggestions: List[str] = field(default_factory=list) | |
| timestamp: datetime = field(default_factory=datetime.now) | |
| agent_id: Optional[str] = None | |
| affected_elements: List[str] = field(default_factory=list) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary for serialization.""" | |
| return { | |
| "validation_id": self.validation_id, | |
| "level": self.level.value, | |
| "severity": self.severity.value, | |
| "rule_name": self.rule_name, | |
| "message": self.message, | |
| "details": self.details, | |
| "suggestions": self.suggestions, | |
| "timestamp": self.timestamp.isoformat(), | |
| "agent_id": self.agent_id, | |
| "affected_elements": self.affected_elements | |
| } | |
| class ConflictResolution: | |
| """Represents a conflict and its resolution options.""" | |
| conflict_id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| conflict_type: ConflictType = ConflictType.TEMPORAL | |
| description: str = "" | |
| affected_agents: List[str] = field(default_factory=list) | |
| resolution_options: List[Dict[str, Any]] = field(default_factory=list) | |
| recommended_resolution: Optional[Dict[str, Any]] = None | |
| impact_assessment: Dict[str, Any] = field(default_factory=dict) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary for serialization.""" | |
| return { | |
| "conflict_id": self.conflict_id, | |
| "conflict_type": self.conflict_type.value, | |
| "description": self.description, | |
| "affected_agents": self.affected_agents, | |
| "resolution_options": self.resolution_options, | |
| "recommended_resolution": self.recommended_resolution, | |
| "impact_assessment": self.impact_assessment | |
| } | |
| class QualityScore: | |
| """Quality score for a trip component or overall trip.""" | |
| component: str = "" # flight, hotel, poi, overall | |
| score: float = 0.0 # 0.0 to 1.0 | |
| max_score: float = 1.0 | |
| factors: Dict[str, float] = field(default_factory=dict) | |
| explanation: str = "" | |
| recommendations: List[str] = field(default_factory=list) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary for serialization.""" | |
| return { | |
| "component": self.component, | |
| "score": self.score, | |
| "max_score": self.max_score, | |
| "factors": self.factors, | |
| "explanation": self.explanation, | |
| "recommendations": self.recommendations | |
| } | |
| class ValidationRule(ABC): | |
| """Abstract base class for validation rules.""" | |
| def __init__(self, rule_name: str, severity: ValidationSeverity = ValidationSeverity.WARNING): | |
| self.rule_name = rule_name | |
| self.severity = severity | |
| async def validate(self, context: Dict[str, Any]) -> List[ValidationResult]: | |
| """Validate the given context and return results.""" | |
| pass | |
| def get_rule_description(self) -> str: | |
| """Get human-readable description of the rule.""" | |
| pass | |
| class TripValidator: | |
| """ | |
| Comprehensive trip validator that ensures high-quality, feasible travel recommendations. | |
| This class implements multi-level validation patterns that ensure AI systems | |
| produce reliable, user-satisfying outcomes. | |
| """ | |
| def __init__(self): | |
| self.validation_rules: Dict[ValidationLevel, List[ValidationRule]] = { | |
| ValidationLevel.INDIVIDUAL: [], | |
| ValidationLevel.CROSS_AGENT: [], | |
| ValidationLevel.OVERALL: [], | |
| ValidationLevel.BUSINESS_RULES: [], | |
| ValidationLevel.USER_SATISFACTION: [] | |
| } | |
| self.validation_history: List[ValidationResult] = [] | |
| self.conflict_resolutions: List[ConflictResolution] = [] | |
| # Initialize validation rules | |
| self._initialize_validation_rules() | |
| def _initialize_validation_rules(self): | |
| """Initialize all validation rules.""" | |
| # Import validation rules dynamically to avoid circular imports | |
| from .validation_rules import ( | |
| FlightValidationRule, HotelValidationRule, POIValidationRule, | |
| FlightHotelCompatibilityRule, HotelPOICompatibilityRule, FlightPOICompatibilityRule, | |
| TripCoherenceRule, BudgetAdherenceRule, QualityThresholdRule, | |
| TravelTimeRule, CheckInOutSequenceRule, RealisticScheduleRule, AirportTransferRule, | |
| PreferenceAlignmentRule, TravelStyleRule, EnergyLevelRule, AccessibilityRule | |
| ) | |
| # Individual agent validation rules | |
| self.validation_rules[ValidationLevel.INDIVIDUAL].extend([ | |
| FlightValidationRule(), | |
| HotelValidationRule(), | |
| POIValidationRule() | |
| ]) | |
| # Cross-agent validation rules | |
| self.validation_rules[ValidationLevel.CROSS_AGENT].extend([ | |
| FlightHotelCompatibilityRule(), | |
| HotelPOICompatibilityRule(), | |
| FlightPOICompatibilityRule() | |
| ]) | |
| # Overall trip validation rules | |
| self.validation_rules[ValidationLevel.OVERALL].extend([ | |
| TripCoherenceRule(), | |
| BudgetAdherenceRule(), | |
| QualityThresholdRule() | |
| ]) | |
| # Business rules validation | |
| self.validation_rules[ValidationLevel.BUSINESS_RULES].extend([ | |
| TravelTimeRule(), | |
| CheckInOutSequenceRule(), | |
| RealisticScheduleRule(), | |
| AirportTransferRule() | |
| ]) | |
| # User satisfaction validation | |
| self.validation_rules[ValidationLevel.USER_SATISFACTION].extend([ | |
| PreferenceAlignmentRule(), | |
| TravelStyleRule(), | |
| EnergyLevelRule(), | |
| AccessibilityRule() | |
| ]) | |
| async def validate_trip(self, trip_context: Dict[str, Any], | |
| agent_outputs: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Perform comprehensive trip validation across all levels. | |
| This is the main validation entry point that orchestrates all validation | |
| levels and provides comprehensive feedback. | |
| """ | |
| validation_results = { | |
| "trip_id": trip_context.get("trip_id", "unknown"), | |
| "validation_timestamp": datetime.now().isoformat(), | |
| "overall_valid": True, | |
| "overall_score": 0.0, | |
| "validation_levels": {}, | |
| "conflicts": [], | |
| "quality_scores": {}, | |
| "recommendations": [], | |
| "validation_summary": {} | |
| } | |
| try: | |
| # Prepare validation context | |
| validation_context = { | |
| "trip_context": trip_context, | |
| "agent_outputs": agent_outputs, | |
| "validation_results": validation_results | |
| } | |
| # Run validation at each level | |
| for level in ValidationLevel: | |
| level_results = await self._validate_level(level, validation_context) | |
| validation_results["validation_levels"][level.value] = level_results | |
| # Check if any critical issues found | |
| critical_issues = [r for r in level_results if r.severity == ValidationSeverity.CRITICAL] | |
| if critical_issues: | |
| validation_results["overall_valid"] = False | |
| # Perform conflict detection and resolution | |
| conflicts = await self._detect_conflicts(validation_context) | |
| validation_results["conflicts"] = [c.to_dict() for c in conflicts] | |
| # Calculate quality scores | |
| quality_scores = await self._calculate_quality_scores(validation_context) | |
| validation_results["quality_scores"] = {k: v.to_dict() for k, v in quality_scores.items()} | |
| # Calculate overall score | |
| validation_results["overall_score"] = self._calculate_overall_score(quality_scores) | |
| # Generate recommendations | |
| validation_results["recommendations"] = await self._generate_recommendations(validation_context) | |
| # Create validation summary | |
| validation_results["validation_summary"] = self._create_validation_summary(validation_results) | |
| # Store in history | |
| self.validation_history.extend(self._flatten_validation_results(validation_results)) | |
| except Exception as e: | |
| validation_results["validation_error"] = str(e) | |
| validation_results["overall_valid"] = False | |
| # Ensure validation summary is always created | |
| if "validation_summary" not in validation_results: | |
| validation_results["validation_summary"] = self._create_validation_summary(validation_results) | |
| return validation_results | |
| async def _validate_level(self, level: ValidationLevel, context: Dict[str, Any]) -> List[ValidationResult]: | |
| """Validate at a specific level.""" | |
| results = [] | |
| for rule in self.validation_rules[level]: | |
| try: | |
| rule_results = await rule.validate(context) | |
| results.extend(rule_results) | |
| except Exception as e: | |
| # Create error result | |
| error_result = ValidationResult( | |
| level=level, | |
| severity=ValidationSeverity.CRITICAL, | |
| rule_name=rule.rule_name, | |
| message=f"Validation rule failed: {str(e)}", | |
| details={"error": str(e)} | |
| ) | |
| results.append(error_result) | |
| return results | |
| async def _detect_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]: | |
| """Detect conflicts between different trip components.""" | |
| conflicts = [] | |
| # Temporal conflicts | |
| temporal_conflicts = await self._detect_temporal_conflicts(context) | |
| conflicts.extend(temporal_conflicts) | |
| # Spatial conflicts | |
| spatial_conflicts = await self._detect_spatial_conflicts(context) | |
| conflicts.extend(spatial_conflicts) | |
| # Budget conflicts | |
| budget_conflicts = await self._detect_budget_conflicts(context) | |
| conflicts.extend(budget_conflicts) | |
| # Preference conflicts | |
| preference_conflicts = await self._detect_preference_conflicts(context) | |
| conflicts.extend(preference_conflicts) | |
| return conflicts | |
| async def _detect_temporal_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]: | |
| """Detect time-related conflicts.""" | |
| conflicts = [] | |
| trip_context = context["trip_context"] | |
| agent_outputs = context["agent_outputs"] | |
| # Check flight-hotel timing | |
| if "flight" in agent_outputs and "hotel" in agent_outputs: | |
| flight = agent_outputs["flight"] | |
| hotel = agent_outputs["hotel"] | |
| if hasattr(flight, 'arrival_time') and hasattr(hotel, 'check_in_time'): | |
| arrival_time = flight.arrival_time | |
| check_in_time = hotel.check_in_time | |
| # Check if there's enough time between arrival and check-in | |
| time_diff = (check_in_time - arrival_time).total_seconds() / 3600 # hours | |
| if time_diff < 2: # Less than 2 hours | |
| conflict = ConflictResolution( | |
| conflict_type=ConflictType.TEMPORAL, | |
| description=f"Insufficient time between flight arrival ({arrival_time}) and hotel check-in ({check_in_time})", | |
| affected_agents=["flight_agent", "hotel_agent"], | |
| resolution_options=[ | |
| { | |
| "option": "delay_hotel_checkin", | |
| "description": "Change hotel check-in to later time", | |
| "impact": "May incur additional fees" | |
| }, | |
| { | |
| "option": "change_flight", | |
| "description": "Select earlier flight", | |
| "impact": "May increase cost or reduce options" | |
| } | |
| ], | |
| impact_assessment={ | |
| "severity": "medium", | |
| "user_impact": "May cause stress and inconvenience", | |
| "cost_impact": "Potential additional fees" | |
| } | |
| ) | |
| conflicts.append(conflict) | |
| return conflicts | |
| async def _detect_spatial_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]: | |
| """Detect location-related conflicts.""" | |
| conflicts = [] | |
| agent_outputs = context["agent_outputs"] | |
| # Check hotel-POI distance conflicts | |
| if "hotel" in agent_outputs and "pois" in agent_outputs: | |
| hotel = agent_outputs["hotel"] | |
| pois = agent_outputs["pois"] | |
| if hasattr(hotel, 'distance_city_center_km'): | |
| hotel_distance = hotel.distance_city_center_km | |
| # Check if POIs are too far from hotel | |
| for poi in pois: | |
| if hasattr(poi, 'distance_city_center_km'): | |
| poi_distance = poi.distance_city_center_km | |
| total_distance = abs(hotel_distance - poi_distance) | |
| if total_distance > 20: # More than 20km apart | |
| conflict = ConflictResolution( | |
| conflict_type=ConflictType.SPATIAL, | |
| description=f"Hotel and {poi.name} are too far apart ({total_distance:.1f}km)", | |
| affected_agents=["hotel_agent", "poi_agent"], | |
| resolution_options=[ | |
| { | |
| "option": "change_hotel", | |
| "description": "Select hotel closer to activities", | |
| "impact": "May change price or quality" | |
| }, | |
| { | |
| "option": "change_poi", | |
| "description": "Select activities closer to hotel", | |
| "impact": "May reduce activity options" | |
| } | |
| ], | |
| impact_assessment={ | |
| "severity": "low", | |
| "user_impact": "Increased travel time and costs", | |
| "cost_impact": "Additional transportation costs" | |
| } | |
| ) | |
| conflicts.append(conflict) | |
| return conflicts | |
| async def _detect_budget_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]: | |
| """Detect budget-related conflicts.""" | |
| conflicts = [] | |
| trip_context = context["trip_context"] | |
| agent_outputs = context["agent_outputs"] | |
| total_budget = float(trip_context.get("total_budget", 0)) | |
| total_cost = 0.0 | |
| # Calculate total cost from all agents | |
| for agent_id, output in agent_outputs.items(): | |
| if agent_id == "flight" and hasattr(output, 'price'): | |
| total_cost += float(output.price) | |
| elif agent_id == "hotel" and hasattr(output, 'price_per_night'): | |
| # Assume 3 nights for demo | |
| nights = 3 | |
| total_cost += float(output.price_per_night) * nights | |
| elif agent_id == "pois" and isinstance(output, list): | |
| # Calculate POI costs | |
| for poi in output: | |
| if hasattr(poi, 'adult_price') and poi.adult_price: | |
| total_cost += float(poi.adult_price) | |
| # Check if over budget | |
| if total_cost > total_budget: | |
| overage = total_cost - total_budget | |
| overage_percentage = (overage / total_budget) * 100 | |
| conflict = ConflictResolution( | |
| conflict_type=ConflictType.BUDGET, | |
| description=f"Trip exceeds budget by ${overage:.2f} ({overage_percentage:.1f}%)", | |
| affected_agents=list(agent_outputs.keys()), | |
| resolution_options=[ | |
| { | |
| "option": "reduce_flight_cost", | |
| "description": "Select cheaper flight options", | |
| "impact": "May increase travel time or reduce comfort" | |
| }, | |
| { | |
| "option": "reduce_hotel_cost", | |
| "description": "Select cheaper hotel", | |
| "impact": "May reduce amenities or location quality" | |
| }, | |
| { | |
| "option": "reduce_activities", | |
| "description": "Remove or replace expensive activities", | |
| "impact": "May reduce trip experience quality" | |
| } | |
| ], | |
| impact_assessment={ | |
| "severity": "high" if overage_percentage > 20 else "medium", | |
| "user_impact": "Financial stress or trip cancellation", | |
| "cost_impact": f"${overage:.2f} over budget" | |
| } | |
| ) | |
| conflicts.append(conflict) | |
| return conflicts | |
| async def _detect_preference_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]: | |
| """Detect user preference conflicts.""" | |
| conflicts = [] | |
| trip_context = context["trip_context"] | |
| agent_outputs = context["agent_outputs"] | |
| preferences = trip_context.get("preferences", {}) | |
| # Check preference alignment | |
| if preferences.get("prefer_direct_flights", False): | |
| if "flight" in agent_outputs: | |
| flight = agent_outputs["flight"] | |
| if hasattr(flight, 'stops') and flight.stops > 0: | |
| conflict = ConflictResolution( | |
| conflict_type=ConflictType.PREFERENCE, | |
| description="User prefers direct flights but selected flight has stops", | |
| affected_agents=["flight_agent"], | |
| resolution_options=[ | |
| { | |
| "option": "find_direct_flight", | |
| "description": "Search for direct flight alternatives", | |
| "impact": "May increase cost or reduce availability" | |
| }, | |
| { | |
| "option": "accept_stops", | |
| "description": "Accept stops for cost savings", | |
| "impact": "User preference not fully met" | |
| } | |
| ], | |
| impact_assessment={ | |
| "severity": "low", | |
| "user_impact": "Preference not fully satisfied", | |
| "cost_impact": "May increase flight cost" | |
| } | |
| ) | |
| conflicts.append(conflict) | |
| return conflicts | |
| async def _calculate_quality_scores(self, context: Dict[str, Any]) -> Dict[str, QualityScore]: | |
| """Calculate quality scores for trip components.""" | |
| quality_scores = {} | |
| agent_outputs = context["agent_outputs"] | |
| # Calculate flight quality score | |
| if "flight" in agent_outputs: | |
| flight_score = await self._calculate_flight_quality(agent_outputs["flight"]) | |
| quality_scores["flight"] = flight_score | |
| # Calculate hotel quality score | |
| if "hotel" in agent_outputs: | |
| hotel_score = await self._calculate_hotel_quality(agent_outputs["hotel"]) | |
| quality_scores["hotel"] = hotel_score | |
| # Calculate POI quality score | |
| if "pois" in agent_outputs: | |
| poi_score = await self._calculate_poi_quality(agent_outputs["pois"]) | |
| quality_scores["pois"] = poi_score | |
| # Calculate overall trip quality score | |
| overall_score = await self._calculate_overall_quality(context) | |
| quality_scores["overall"] = overall_score | |
| return quality_scores | |
| async def _calculate_flight_quality(self, flight: FlightOption) -> QualityScore: | |
| """Calculate quality score for flight.""" | |
| factors = {} | |
| total_score = 0.0 | |
| max_score = 0.0 | |
| # Price factor (lower is better) | |
| price_factor = 0.3 | |
| max_score += price_factor | |
| if hasattr(flight, 'price'): | |
| # Normalize price (assume $500 is average, $1000+ is expensive) | |
| normalized_price = min(float(flight.price) / 1000, 2.0) # Cap at 2.0 | |
| price_score = max(0, 1.0 - (normalized_price - 0.5)) # Best at $500 | |
| factors["price"] = price_score * price_factor | |
| total_score += factors["price"] | |
| # Duration factor (shorter is better) | |
| duration_factor = 0.3 | |
| max_score += duration_factor | |
| if hasattr(flight, 'duration_hours'): | |
| # Normalize duration (assume 6h is average, 12h+ is long) | |
| normalized_duration = min(flight.duration_hours / 12, 2.0) | |
| duration_score = max(0, 1.0 - (normalized_duration - 0.5)) | |
| factors["duration"] = duration_score * duration_factor | |
| total_score += factors["duration"] | |
| # Stops factor (fewer is better) | |
| stops_factor = 0.2 | |
| max_score += stops_factor | |
| if hasattr(flight, 'stops'): | |
| stops_score = max(0, 1.0 - (flight.stops * 0.5)) # 0 stops = 1.0, 2 stops = 0.0 | |
| factors["stops"] = stops_score * stops_factor | |
| total_score += factors["stops"] | |
| # Airline factor (based on airline quality) | |
| airline_factor = 0.2 | |
| max_score += airline_factor | |
| if hasattr(flight, 'airline'): | |
| # Simple airline quality mapping | |
| premium_airlines = ["Singapore Airlines", "Emirates", "Qatar Airways", "Cathay Pacific"] | |
| good_airlines = ["United", "Delta", "American", "Lufthansa", "British Airways"] | |
| if flight.airline in premium_airlines: | |
| airline_score = 1.0 | |
| elif flight.airline in good_airlines: | |
| airline_score = 0.8 | |
| else: | |
| airline_score = 0.6 | |
| factors["airline"] = airline_score * airline_factor | |
| total_score += factors["airline"] | |
| # Calculate final score | |
| final_score = total_score / max_score if max_score > 0 else 0.0 | |
| # Generate explanation | |
| price_str = f"${flight.price}" if hasattr(flight, 'price') else 'N/A' | |
| duration_str = f"{flight.duration_hours}h" if hasattr(flight, 'duration_hours') else 'N/A' | |
| stops_str = str(flight.stops) if hasattr(flight, 'stops') else 'N/A' | |
| airline_str = flight.airline if hasattr(flight, 'airline') else 'N/A' | |
| explanation = f"Flight quality score based on price ({price_str}), duration ({duration_str}), stops ({stops_str}), and airline ({airline_str})" | |
| return QualityScore( | |
| component="flight", | |
| score=final_score, | |
| max_score=1.0, | |
| factors=factors, | |
| explanation=explanation, | |
| recommendations=self._get_flight_recommendations(final_score, factors) | |
| ) | |
| async def _calculate_hotel_quality(self, hotel: HotelOption) -> QualityScore: | |
| """Calculate quality score for hotel.""" | |
| factors = {} | |
| total_score = 0.0 | |
| max_score = 0.0 | |
| # Rating factor | |
| rating_factor = 0.4 | |
| max_score += rating_factor | |
| if hasattr(hotel, 'rating'): | |
| rating_score = hotel.rating / 5.0 # Normalize to 0-1 | |
| factors["rating"] = rating_score * rating_factor | |
| total_score += factors["rating"] | |
| # Location factor (closer to center is better) | |
| location_factor = 0.3 | |
| max_score += location_factor | |
| if hasattr(hotel, 'distance_city_center_km'): | |
| distance = hotel.distance_city_center_km | |
| location_score = max(0, 1.0 - (distance / 10.0)) # Best within 1km, poor beyond 10km | |
| factors["location"] = location_score * location_factor | |
| total_score += factors["location"] | |
| # Price factor | |
| price_factor = 0.3 | |
| max_score += price_factor | |
| if hasattr(hotel, 'price_per_night'): | |
| # Normalize price (assume $150/night is average) | |
| normalized_price = min(float(hotel.price_per_night) / 300, 2.0) | |
| price_score = max(0, 1.0 - (normalized_price - 0.5)) | |
| factors["price"] = price_score * price_factor | |
| total_score += factors["price"] | |
| # Calculate final score | |
| final_score = total_score / max_score if max_score > 0 else 0.0 | |
| rating_str = f"{hotel.rating}/5" if hasattr(hotel, 'rating') else 'N/A' | |
| location_str = f"{hotel.distance_city_center_km}km from center" if hasattr(hotel, 'distance_city_center_km') else 'N/A' | |
| price_str = f"${hotel.price_per_night}/night" if hasattr(hotel, 'price_per_night') else 'N/A' | |
| explanation = f"Hotel quality based on rating ({rating_str}), location ({location_str}), and price ({price_str})" | |
| return QualityScore( | |
| component="hotel", | |
| score=final_score, | |
| max_score=1.0, | |
| factors=factors, | |
| explanation=explanation, | |
| recommendations=self._get_hotel_recommendations(final_score, factors) | |
| ) | |
| async def _calculate_poi_quality(self, pois: List[POI]) -> QualityScore: | |
| """Calculate quality score for POIs.""" | |
| if not pois: | |
| return QualityScore( | |
| component="pois", | |
| score=0.0, | |
| explanation="No activities selected", | |
| recommendations=["Add activities to improve trip experience"] | |
| ) | |
| factors = {} | |
| total_score = 0.0 | |
| max_score = 0.0 | |
| # Diversity factor (variety of activities) | |
| diversity_factor = 0.3 | |
| max_score += diversity_factor | |
| categories = set() | |
| for poi in pois: | |
| if hasattr(poi, 'category'): | |
| categories.add(poi.category.value if hasattr(poi.category, 'value') else str(poi.category)) | |
| diversity_score = min(len(categories) / 3.0, 1.0) # Best with 3+ categories | |
| factors["diversity"] = diversity_score * diversity_factor | |
| total_score += factors["diversity"] | |
| # Rating factor (average rating of POIs) | |
| rating_factor = 0.4 | |
| max_score += rating_factor | |
| total_rating = 0 | |
| rated_pois = 0 | |
| for poi in pois: | |
| if hasattr(poi, 'rating'): | |
| total_rating += poi.rating | |
| rated_pois += 1 | |
| if rated_pois > 0: | |
| avg_rating = total_rating / rated_pois | |
| rating_score = avg_rating / 5.0 # Normalize to 0-1 | |
| factors["rating"] = rating_score * rating_factor | |
| total_score += factors["rating"] | |
| # Value factor (mix of free and paid activities) | |
| value_factor = 0.3 | |
| max_score += value_factor | |
| free_activities = 0 | |
| for poi in pois: | |
| if hasattr(poi, 'adult_price'): | |
| if poi.adult_price == 0 or poi.adult_price is None: | |
| free_activities += 1 | |
| value_score = min(free_activities / len(pois), 1.0) if pois else 0.0 | |
| factors["value"] = value_score * value_factor | |
| total_score += factors["value"] | |
| # Calculate final score | |
| final_score = total_score / max_score if max_score > 0 else 0.0 | |
| avg_rating_str = f"{avg_rating:.1f}/5" if rated_pois > 0 else 'N/A' | |
| free_percentage = f"{(free_activities/len(pois)*100):.0f}%" if pois else "0%" | |
| explanation = f"Activity quality based on diversity ({len(categories)} categories), average rating ({avg_rating_str}), and value ({free_percentage} free activities)" | |
| return QualityScore( | |
| component="pois", | |
| score=final_score, | |
| max_score=1.0, | |
| factors=factors, | |
| explanation=explanation, | |
| recommendations=self._get_poi_recommendations(final_score, factors) | |
| ) | |
| async def _calculate_overall_quality(self, context: Dict[str, Any]) -> QualityScore: | |
| """Calculate overall trip quality score.""" | |
| trip_context = context["trip_context"] | |
| agent_outputs = context["agent_outputs"] | |
| factors = {} | |
| total_score = 0.0 | |
| max_score = 0.0 | |
| # Component balance factor (all components present) | |
| balance_factor = 0.3 | |
| max_score += balance_factor | |
| components = ["flight", "hotel", "pois"] | |
| present_components = sum(1 for comp in components if comp in agent_outputs) | |
| balance_score = present_components / len(components) | |
| factors["balance"] = balance_score * balance_factor | |
| total_score += factors["balance"] | |
| # Budget utilization factor (good budget allocation) | |
| budget_factor = 0.3 | |
| max_score += budget_factor | |
| total_budget = float(trip_context.get("total_budget", 1)) | |
| total_cost = 0.0 | |
| # Calculate total cost from all agents | |
| for agent_id, output in agent_outputs.items(): | |
| if agent_id == "flight" and hasattr(output, 'price'): | |
| total_cost += float(output.price) | |
| elif agent_id == "hotel" and hasattr(output, 'price_per_night'): | |
| # Assume 3 nights for demo | |
| nights = 3 | |
| total_cost += float(output.price_per_night) * nights | |
| elif agent_id == "pois" and isinstance(output, list): | |
| # Calculate POI costs | |
| for poi in output: | |
| if hasattr(poi, 'adult_price') and poi.adult_price: | |
| total_cost += float(poi.adult_price) | |
| budget_utilization = min(total_cost / total_budget, 1.0) if total_budget > 0 else 0 | |
| # Optimal utilization is around 80-90% | |
| if 0.7 <= budget_utilization <= 0.95: | |
| budget_score = 1.0 | |
| elif budget_utilization < 0.7: | |
| budget_score = budget_utilization / 0.7 # Under-utilization penalty | |
| else: | |
| budget_score = max(0, 1.0 - (budget_utilization - 0.95) / 0.05) # Over-budget penalty | |
| factors["budget"] = budget_score * budget_factor | |
| total_score += factors["budget"] | |
| # Preference alignment factor | |
| preference_factor = 0.4 | |
| max_score += preference_factor | |
| preferences = trip_context.get("preferences", {}) | |
| alignment_score = self._calculate_preference_alignment(preferences, agent_outputs) | |
| factors["preferences"] = alignment_score * preference_factor | |
| total_score += factors["preferences"] | |
| # Calculate final score | |
| final_score = total_score / max_score if max_score > 0 else 0.0 | |
| explanation = f"Overall trip quality based on component balance ({present_components}/{len(components)}), " | |
| explanation += f"budget utilization ({(budget_utilization*100):.0f}%), " | |
| explanation += f"and preference alignment ({(alignment_score*100):.0f}%)" | |
| return QualityScore( | |
| component="overall", | |
| score=final_score, | |
| max_score=1.0, | |
| factors=factors, | |
| explanation=explanation, | |
| recommendations=self._get_overall_recommendations(final_score, factors) | |
| ) | |
| def _calculate_preference_alignment(self, preferences: Dict[str, Any], agent_outputs: Dict[str, Any]) -> float: | |
| """Calculate how well the trip aligns with user preferences.""" | |
| alignment_score = 0.0 | |
| total_checks = 0 | |
| # Check flight preferences | |
| if preferences.get("prefer_direct_flights", False) and "flight" in agent_outputs: | |
| flight = agent_outputs["flight"] | |
| if hasattr(flight, 'stops') and flight.stops == 0: | |
| alignment_score += 1.0 | |
| total_checks += 1 | |
| # Check hotel preferences | |
| if preferences.get("prefer_central_hotels", False) and "hotel" in agent_outputs: | |
| hotel = agent_outputs["hotel"] | |
| if hasattr(hotel, 'distance_city_center_km') and hotel.distance_city_center_km <= 5: | |
| alignment_score += 1.0 | |
| total_checks += 1 | |
| # Check activity preferences | |
| if preferences.get("prefer_morning_activities", False) and "pois" in agent_outputs: | |
| # In real implementation, check POI opening hours | |
| alignment_score += 0.5 # Placeholder | |
| total_checks += 1 | |
| return alignment_score / total_checks if total_checks > 0 else 0.5 | |
| def _calculate_overall_score(self, quality_scores: Dict[str, QualityScore]) -> float: | |
| """Calculate overall validation score.""" | |
| if not quality_scores: | |
| return 0.0 | |
| # Weight different components | |
| weights = { | |
| "flight": 0.3, | |
| "hotel": 0.3, | |
| "pois": 0.2, | |
| "overall": 0.2 | |
| } | |
| weighted_score = 0.0 | |
| total_weight = 0.0 | |
| for component, score in quality_scores.items(): | |
| if component in weights: | |
| weighted_score += score.score * weights[component] | |
| total_weight += weights[component] | |
| return weighted_score / total_weight if total_weight > 0 else 0.0 | |
| async def _generate_recommendations(self, context: Dict[str, Any]) -> List[str]: | |
| """Generate actionable recommendations for improving the trip.""" | |
| recommendations = [] | |
| # Get validation results | |
| validation_results = context["validation_results"] | |
| # Check for critical issues | |
| critical_issues = [] | |
| for level_results in validation_results["validation_levels"].values(): | |
| critical_issues.extend([r for r in level_results if r.severity == ValidationSeverity.CRITICAL]) | |
| if critical_issues: | |
| recommendations.append("🚨 Address critical issues before proceeding with trip booking") | |
| # Check quality scores | |
| quality_scores = validation_results.get("quality_scores", {}) | |
| for component, score in quality_scores.items(): | |
| if score["score"] < 0.6: # Low quality | |
| recommendations.extend(score["recommendations"]) | |
| # Check conflicts | |
| conflicts = validation_results.get("conflicts", []) | |
| for conflict in conflicts: | |
| if conflict["impact_assessment"]["severity"] in ["high", "medium"]: | |
| recommendations.append(f"⚠️ Resolve {conflict['conflict_type']} conflict: {conflict['description']}") | |
| # General recommendations | |
| overall_score = validation_results.get("overall_score", 0.0) | |
| if overall_score < 0.7: | |
| recommendations.append("💡 Consider adjusting trip components to improve overall quality") | |
| return recommendations | |
| def _create_validation_summary(self, validation_results: Dict[str, Any]) -> Dict[str, Any]: | |
| """Create a summary of validation results.""" | |
| summary = { | |
| "total_validations": 0, | |
| "critical_issues": 0, | |
| "warnings": 0, | |
| "passed": 0, | |
| "conflicts_found": len(validation_results.get("conflicts", [])), | |
| "quality_grade": self._calculate_quality_grade(validation_results.get("overall_score", 0.0)), | |
| "recommendation_count": len(validation_results.get("recommendations", [])) | |
| } | |
| # Count validation results by severity | |
| for level_results in validation_results.get("validation_levels", {}).values(): | |
| for result in level_results: | |
| summary["total_validations"] += 1 | |
| if result.severity == ValidationSeverity.CRITICAL: | |
| summary["critical_issues"] += 1 | |
| elif result.severity == ValidationSeverity.WARNING: | |
| summary["warnings"] += 1 | |
| elif result.severity == ValidationSeverity.PASS: | |
| summary["passed"] += 1 | |
| return summary | |
| def _calculate_quality_grade(self, score: float) -> str: | |
| """Calculate quality grade based on score.""" | |
| if score >= 0.9: | |
| return "A+" | |
| elif score >= 0.8: | |
| return "A" | |
| elif score >= 0.7: | |
| return "B" | |
| elif score >= 0.6: | |
| return "C" | |
| elif score >= 0.5: | |
| return "D" | |
| else: | |
| return "F" | |
| def _flatten_validation_results(self, validation_results: Dict[str, Any]) -> List[ValidationResult]: | |
| """Flatten validation results into a single list.""" | |
| flattened = [] | |
| for level_results in validation_results.get("validation_levels", {}).values(): | |
| for result_data in level_results: | |
| result = ValidationResult( | |
| level=ValidationLevel(result_data["level"]), | |
| severity=ValidationSeverity(result_data["severity"]), | |
| rule_name=result_data["rule_name"], | |
| message=result_data["message"], | |
| details=result_data["details"], | |
| suggestions=result_data["suggestions"], | |
| agent_id=result_data.get("agent_id"), | |
| affected_elements=result_data.get("affected_elements", []) | |
| ) | |
| flattened.append(result) | |
| return flattened | |
| def _get_flight_recommendations(self, score: float, factors: Dict[str, float]) -> List[str]: | |
| """Get recommendations for improving flight quality.""" | |
| recommendations = [] | |
| if score < 0.6: | |
| if factors.get("price", 0) < 0.5: | |
| recommendations.append("Consider more expensive flights for better quality") | |
| if factors.get("duration", 0) < 0.5: | |
| recommendations.append("Look for shorter flight options") | |
| if factors.get("stops", 0) < 0.5: | |
| recommendations.append("Consider direct flights to reduce travel time") | |
| return recommendations | |
| def _get_hotel_recommendations(self, score: float, factors: Dict[str, float]) -> List[str]: | |
| """Get recommendations for improving hotel quality.""" | |
| recommendations = [] | |
| if score < 0.6: | |
| if factors.get("rating", 0) < 0.5: | |
| recommendations.append("Consider hotels with higher ratings") | |
| if factors.get("location", 0) < 0.5: | |
| recommendations.append("Look for hotels closer to city center") | |
| if factors.get("price", 0) < 0.5: | |
| recommendations.append("Consider adjusting budget for better hotel") | |
| return recommendations | |
| def _get_poi_recommendations(self, score: float, factors: Dict[str, float]) -> List[str]: | |
| """Get recommendations for improving POI quality.""" | |
| recommendations = [] | |
| if score < 0.6: | |
| if factors.get("diversity", 0) < 0.5: | |
| recommendations.append("Add more diverse activity types") | |
| if factors.get("rating", 0) < 0.5: | |
| recommendations.append("Select higher-rated activities") | |
| if factors.get("value", 0) < 0.5: | |
| recommendations.append("Include more free or low-cost activities") | |
| return recommendations | |
| def _get_overall_recommendations(self, score: float, factors: Dict[str, float]) -> List[str]: | |
| """Get recommendations for improving overall trip quality.""" | |
| recommendations = [] | |
| if score < 0.6: | |
| if factors.get("balance", 0) < 0.5: | |
| recommendations.append("Ensure all trip components are included") | |
| if factors.get("budget", 0) < 0.5: | |
| recommendations.append("Optimize budget allocation across components") | |
| if factors.get("preferences", 0) < 0.5: | |
| recommendations.append("Better align trip with user preferences") | |
| return recommendations | |