wanderlust.ai / src /wanderlust_ai /core /trip_validation.py
BlakeL's picture
Upload 115 files
3f9f85b verified
"""
Comprehensive Trip Validation System
This module implements sophisticated validation patterns for ensuring AI systems
produce feasible, high-quality travel recommendations that meet user expectations
and business requirements.
Key Features:
1. Multi-level validation (individual, cross-agent, overall)
2. Business rule validation (travel times, sequences, budget)
3. Quality scoring for complete itineraries
4. User satisfaction factor validation
5. Smart conflict detection and resolution
6. Detailed validation reporting
"""
import asyncio
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Dict, List, Optional, Any, Tuple, Set, Union
from decimal import Decimal
import math
from pydantic import BaseModel, Field, validator
from ..models.flight_models import FlightOption
from ..models.hotel_models import HotelOption
from ..models.poi_models import POI
# Import validation rules (will be imported after class definitions to avoid circular imports)
class ValidationLevel(str, Enum):
"""Levels of validation depth."""
INDIVIDUAL = "individual" # Single agent output validation
CROSS_AGENT = "cross_agent" # Compatibility between agents
OVERALL = "overall" # Complete trip quality
BUSINESS_RULES = "business_rules" # Business logic validation
USER_SATISFACTION = "user_satisfaction" # User experience validation
class ValidationSeverity(str, Enum):
"""Severity levels for validation issues."""
CRITICAL = "critical" # Trip cannot proceed
WARNING = "warning" # Significant issue, should be addressed
INFO = "info" # Minor issue or suggestion
PASS = "pass" # No issues found
class ConflictType(str, Enum):
"""Types of conflicts that can occur."""
TEMPORAL = "temporal" # Time conflicts
SPATIAL = "spatial" # Location conflicts
BUDGET = "budget" # Budget conflicts
PREFERENCE = "preference" # User preference conflicts
LOGISTICAL = "logistical" # Practical/logistical conflicts
QUALITY = "quality" # Quality standard conflicts
@dataclass
class ValidationResult:
"""Result of a validation check."""
validation_id: str = field(default_factory=lambda: str(uuid.uuid4()))
level: ValidationLevel = ValidationLevel.INDIVIDUAL
severity: ValidationSeverity = ValidationSeverity.PASS
rule_name: str = ""
message: str = ""
details: Dict[str, Any] = field(default_factory=dict)
suggestions: List[str] = field(default_factory=list)
timestamp: datetime = field(default_factory=datetime.now)
agent_id: Optional[str] = None
affected_elements: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"validation_id": self.validation_id,
"level": self.level.value,
"severity": self.severity.value,
"rule_name": self.rule_name,
"message": self.message,
"details": self.details,
"suggestions": self.suggestions,
"timestamp": self.timestamp.isoformat(),
"agent_id": self.agent_id,
"affected_elements": self.affected_elements
}
@dataclass
class ConflictResolution:
"""Represents a conflict and its resolution options."""
conflict_id: str = field(default_factory=lambda: str(uuid.uuid4()))
conflict_type: ConflictType = ConflictType.TEMPORAL
description: str = ""
affected_agents: List[str] = field(default_factory=list)
resolution_options: List[Dict[str, Any]] = field(default_factory=list)
recommended_resolution: Optional[Dict[str, Any]] = None
impact_assessment: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"conflict_id": self.conflict_id,
"conflict_type": self.conflict_type.value,
"description": self.description,
"affected_agents": self.affected_agents,
"resolution_options": self.resolution_options,
"recommended_resolution": self.recommended_resolution,
"impact_assessment": self.impact_assessment
}
@dataclass
class QualityScore:
"""Quality score for a trip component or overall trip."""
component: str = "" # flight, hotel, poi, overall
score: float = 0.0 # 0.0 to 1.0
max_score: float = 1.0
factors: Dict[str, float] = field(default_factory=dict)
explanation: str = ""
recommendations: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"component": self.component,
"score": self.score,
"max_score": self.max_score,
"factors": self.factors,
"explanation": self.explanation,
"recommendations": self.recommendations
}
class ValidationRule(ABC):
"""Abstract base class for validation rules."""
def __init__(self, rule_name: str, severity: ValidationSeverity = ValidationSeverity.WARNING):
self.rule_name = rule_name
self.severity = severity
@abstractmethod
async def validate(self, context: Dict[str, Any]) -> List[ValidationResult]:
"""Validate the given context and return results."""
pass
@abstractmethod
def get_rule_description(self) -> str:
"""Get human-readable description of the rule."""
pass
class TripValidator:
"""
Comprehensive trip validator that ensures high-quality, feasible travel recommendations.
This class implements multi-level validation patterns that ensure AI systems
produce reliable, user-satisfying outcomes.
"""
def __init__(self):
self.validation_rules: Dict[ValidationLevel, List[ValidationRule]] = {
ValidationLevel.INDIVIDUAL: [],
ValidationLevel.CROSS_AGENT: [],
ValidationLevel.OVERALL: [],
ValidationLevel.BUSINESS_RULES: [],
ValidationLevel.USER_SATISFACTION: []
}
self.validation_history: List[ValidationResult] = []
self.conflict_resolutions: List[ConflictResolution] = []
# Initialize validation rules
self._initialize_validation_rules()
def _initialize_validation_rules(self):
"""Initialize all validation rules."""
# Import validation rules dynamically to avoid circular imports
from .validation_rules import (
FlightValidationRule, HotelValidationRule, POIValidationRule,
FlightHotelCompatibilityRule, HotelPOICompatibilityRule, FlightPOICompatibilityRule,
TripCoherenceRule, BudgetAdherenceRule, QualityThresholdRule,
TravelTimeRule, CheckInOutSequenceRule, RealisticScheduleRule, AirportTransferRule,
PreferenceAlignmentRule, TravelStyleRule, EnergyLevelRule, AccessibilityRule
)
# Individual agent validation rules
self.validation_rules[ValidationLevel.INDIVIDUAL].extend([
FlightValidationRule(),
HotelValidationRule(),
POIValidationRule()
])
# Cross-agent validation rules
self.validation_rules[ValidationLevel.CROSS_AGENT].extend([
FlightHotelCompatibilityRule(),
HotelPOICompatibilityRule(),
FlightPOICompatibilityRule()
])
# Overall trip validation rules
self.validation_rules[ValidationLevel.OVERALL].extend([
TripCoherenceRule(),
BudgetAdherenceRule(),
QualityThresholdRule()
])
# Business rules validation
self.validation_rules[ValidationLevel.BUSINESS_RULES].extend([
TravelTimeRule(),
CheckInOutSequenceRule(),
RealisticScheduleRule(),
AirportTransferRule()
])
# User satisfaction validation
self.validation_rules[ValidationLevel.USER_SATISFACTION].extend([
PreferenceAlignmentRule(),
TravelStyleRule(),
EnergyLevelRule(),
AccessibilityRule()
])
async def validate_trip(self, trip_context: Dict[str, Any],
agent_outputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Perform comprehensive trip validation across all levels.
This is the main validation entry point that orchestrates all validation
levels and provides comprehensive feedback.
"""
validation_results = {
"trip_id": trip_context.get("trip_id", "unknown"),
"validation_timestamp": datetime.now().isoformat(),
"overall_valid": True,
"overall_score": 0.0,
"validation_levels": {},
"conflicts": [],
"quality_scores": {},
"recommendations": [],
"validation_summary": {}
}
try:
# Prepare validation context
validation_context = {
"trip_context": trip_context,
"agent_outputs": agent_outputs,
"validation_results": validation_results
}
# Run validation at each level
for level in ValidationLevel:
level_results = await self._validate_level(level, validation_context)
validation_results["validation_levels"][level.value] = level_results
# Check if any critical issues found
critical_issues = [r for r in level_results if r.severity == ValidationSeverity.CRITICAL]
if critical_issues:
validation_results["overall_valid"] = False
# Perform conflict detection and resolution
conflicts = await self._detect_conflicts(validation_context)
validation_results["conflicts"] = [c.to_dict() for c in conflicts]
# Calculate quality scores
quality_scores = await self._calculate_quality_scores(validation_context)
validation_results["quality_scores"] = {k: v.to_dict() for k, v in quality_scores.items()}
# Calculate overall score
validation_results["overall_score"] = self._calculate_overall_score(quality_scores)
# Generate recommendations
validation_results["recommendations"] = await self._generate_recommendations(validation_context)
# Create validation summary
validation_results["validation_summary"] = self._create_validation_summary(validation_results)
# Store in history
self.validation_history.extend(self._flatten_validation_results(validation_results))
except Exception as e:
validation_results["validation_error"] = str(e)
validation_results["overall_valid"] = False
# Ensure validation summary is always created
if "validation_summary" not in validation_results:
validation_results["validation_summary"] = self._create_validation_summary(validation_results)
return validation_results
async def _validate_level(self, level: ValidationLevel, context: Dict[str, Any]) -> List[ValidationResult]:
"""Validate at a specific level."""
results = []
for rule in self.validation_rules[level]:
try:
rule_results = await rule.validate(context)
results.extend(rule_results)
except Exception as e:
# Create error result
error_result = ValidationResult(
level=level,
severity=ValidationSeverity.CRITICAL,
rule_name=rule.rule_name,
message=f"Validation rule failed: {str(e)}",
details={"error": str(e)}
)
results.append(error_result)
return results
async def _detect_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]:
"""Detect conflicts between different trip components."""
conflicts = []
# Temporal conflicts
temporal_conflicts = await self._detect_temporal_conflicts(context)
conflicts.extend(temporal_conflicts)
# Spatial conflicts
spatial_conflicts = await self._detect_spatial_conflicts(context)
conflicts.extend(spatial_conflicts)
# Budget conflicts
budget_conflicts = await self._detect_budget_conflicts(context)
conflicts.extend(budget_conflicts)
# Preference conflicts
preference_conflicts = await self._detect_preference_conflicts(context)
conflicts.extend(preference_conflicts)
return conflicts
async def _detect_temporal_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]:
"""Detect time-related conflicts."""
conflicts = []
trip_context = context["trip_context"]
agent_outputs = context["agent_outputs"]
# Check flight-hotel timing
if "flight" in agent_outputs and "hotel" in agent_outputs:
flight = agent_outputs["flight"]
hotel = agent_outputs["hotel"]
if hasattr(flight, 'arrival_time') and hasattr(hotel, 'check_in_time'):
arrival_time = flight.arrival_time
check_in_time = hotel.check_in_time
# Check if there's enough time between arrival and check-in
time_diff = (check_in_time - arrival_time).total_seconds() / 3600 # hours
if time_diff < 2: # Less than 2 hours
conflict = ConflictResolution(
conflict_type=ConflictType.TEMPORAL,
description=f"Insufficient time between flight arrival ({arrival_time}) and hotel check-in ({check_in_time})",
affected_agents=["flight_agent", "hotel_agent"],
resolution_options=[
{
"option": "delay_hotel_checkin",
"description": "Change hotel check-in to later time",
"impact": "May incur additional fees"
},
{
"option": "change_flight",
"description": "Select earlier flight",
"impact": "May increase cost or reduce options"
}
],
impact_assessment={
"severity": "medium",
"user_impact": "May cause stress and inconvenience",
"cost_impact": "Potential additional fees"
}
)
conflicts.append(conflict)
return conflicts
async def _detect_spatial_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]:
"""Detect location-related conflicts."""
conflicts = []
agent_outputs = context["agent_outputs"]
# Check hotel-POI distance conflicts
if "hotel" in agent_outputs and "pois" in agent_outputs:
hotel = agent_outputs["hotel"]
pois = agent_outputs["pois"]
if hasattr(hotel, 'distance_city_center_km'):
hotel_distance = hotel.distance_city_center_km
# Check if POIs are too far from hotel
for poi in pois:
if hasattr(poi, 'distance_city_center_km'):
poi_distance = poi.distance_city_center_km
total_distance = abs(hotel_distance - poi_distance)
if total_distance > 20: # More than 20km apart
conflict = ConflictResolution(
conflict_type=ConflictType.SPATIAL,
description=f"Hotel and {poi.name} are too far apart ({total_distance:.1f}km)",
affected_agents=["hotel_agent", "poi_agent"],
resolution_options=[
{
"option": "change_hotel",
"description": "Select hotel closer to activities",
"impact": "May change price or quality"
},
{
"option": "change_poi",
"description": "Select activities closer to hotel",
"impact": "May reduce activity options"
}
],
impact_assessment={
"severity": "low",
"user_impact": "Increased travel time and costs",
"cost_impact": "Additional transportation costs"
}
)
conflicts.append(conflict)
return conflicts
async def _detect_budget_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]:
"""Detect budget-related conflicts."""
conflicts = []
trip_context = context["trip_context"]
agent_outputs = context["agent_outputs"]
total_budget = float(trip_context.get("total_budget", 0))
total_cost = 0.0
# Calculate total cost from all agents
for agent_id, output in agent_outputs.items():
if agent_id == "flight" and hasattr(output, 'price'):
total_cost += float(output.price)
elif agent_id == "hotel" and hasattr(output, 'price_per_night'):
# Assume 3 nights for demo
nights = 3
total_cost += float(output.price_per_night) * nights
elif agent_id == "pois" and isinstance(output, list):
# Calculate POI costs
for poi in output:
if hasattr(poi, 'adult_price') and poi.adult_price:
total_cost += float(poi.adult_price)
# Check if over budget
if total_cost > total_budget:
overage = total_cost - total_budget
overage_percentage = (overage / total_budget) * 100
conflict = ConflictResolution(
conflict_type=ConflictType.BUDGET,
description=f"Trip exceeds budget by ${overage:.2f} ({overage_percentage:.1f}%)",
affected_agents=list(agent_outputs.keys()),
resolution_options=[
{
"option": "reduce_flight_cost",
"description": "Select cheaper flight options",
"impact": "May increase travel time or reduce comfort"
},
{
"option": "reduce_hotel_cost",
"description": "Select cheaper hotel",
"impact": "May reduce amenities or location quality"
},
{
"option": "reduce_activities",
"description": "Remove or replace expensive activities",
"impact": "May reduce trip experience quality"
}
],
impact_assessment={
"severity": "high" if overage_percentage > 20 else "medium",
"user_impact": "Financial stress or trip cancellation",
"cost_impact": f"${overage:.2f} over budget"
}
)
conflicts.append(conflict)
return conflicts
async def _detect_preference_conflicts(self, context: Dict[str, Any]) -> List[ConflictResolution]:
"""Detect user preference conflicts."""
conflicts = []
trip_context = context["trip_context"]
agent_outputs = context["agent_outputs"]
preferences = trip_context.get("preferences", {})
# Check preference alignment
if preferences.get("prefer_direct_flights", False):
if "flight" in agent_outputs:
flight = agent_outputs["flight"]
if hasattr(flight, 'stops') and flight.stops > 0:
conflict = ConflictResolution(
conflict_type=ConflictType.PREFERENCE,
description="User prefers direct flights but selected flight has stops",
affected_agents=["flight_agent"],
resolution_options=[
{
"option": "find_direct_flight",
"description": "Search for direct flight alternatives",
"impact": "May increase cost or reduce availability"
},
{
"option": "accept_stops",
"description": "Accept stops for cost savings",
"impact": "User preference not fully met"
}
],
impact_assessment={
"severity": "low",
"user_impact": "Preference not fully satisfied",
"cost_impact": "May increase flight cost"
}
)
conflicts.append(conflict)
return conflicts
async def _calculate_quality_scores(self, context: Dict[str, Any]) -> Dict[str, QualityScore]:
"""Calculate quality scores for trip components."""
quality_scores = {}
agent_outputs = context["agent_outputs"]
# Calculate flight quality score
if "flight" in agent_outputs:
flight_score = await self._calculate_flight_quality(agent_outputs["flight"])
quality_scores["flight"] = flight_score
# Calculate hotel quality score
if "hotel" in agent_outputs:
hotel_score = await self._calculate_hotel_quality(agent_outputs["hotel"])
quality_scores["hotel"] = hotel_score
# Calculate POI quality score
if "pois" in agent_outputs:
poi_score = await self._calculate_poi_quality(agent_outputs["pois"])
quality_scores["pois"] = poi_score
# Calculate overall trip quality score
overall_score = await self._calculate_overall_quality(context)
quality_scores["overall"] = overall_score
return quality_scores
async def _calculate_flight_quality(self, flight: FlightOption) -> QualityScore:
"""Calculate quality score for flight."""
factors = {}
total_score = 0.0
max_score = 0.0
# Price factor (lower is better)
price_factor = 0.3
max_score += price_factor
if hasattr(flight, 'price'):
# Normalize price (assume $500 is average, $1000+ is expensive)
normalized_price = min(float(flight.price) / 1000, 2.0) # Cap at 2.0
price_score = max(0, 1.0 - (normalized_price - 0.5)) # Best at $500
factors["price"] = price_score * price_factor
total_score += factors["price"]
# Duration factor (shorter is better)
duration_factor = 0.3
max_score += duration_factor
if hasattr(flight, 'duration_hours'):
# Normalize duration (assume 6h is average, 12h+ is long)
normalized_duration = min(flight.duration_hours / 12, 2.0)
duration_score = max(0, 1.0 - (normalized_duration - 0.5))
factors["duration"] = duration_score * duration_factor
total_score += factors["duration"]
# Stops factor (fewer is better)
stops_factor = 0.2
max_score += stops_factor
if hasattr(flight, 'stops'):
stops_score = max(0, 1.0 - (flight.stops * 0.5)) # 0 stops = 1.0, 2 stops = 0.0
factors["stops"] = stops_score * stops_factor
total_score += factors["stops"]
# Airline factor (based on airline quality)
airline_factor = 0.2
max_score += airline_factor
if hasattr(flight, 'airline'):
# Simple airline quality mapping
premium_airlines = ["Singapore Airlines", "Emirates", "Qatar Airways", "Cathay Pacific"]
good_airlines = ["United", "Delta", "American", "Lufthansa", "British Airways"]
if flight.airline in premium_airlines:
airline_score = 1.0
elif flight.airline in good_airlines:
airline_score = 0.8
else:
airline_score = 0.6
factors["airline"] = airline_score * airline_factor
total_score += factors["airline"]
# Calculate final score
final_score = total_score / max_score if max_score > 0 else 0.0
# Generate explanation
price_str = f"${flight.price}" if hasattr(flight, 'price') else 'N/A'
duration_str = f"{flight.duration_hours}h" if hasattr(flight, 'duration_hours') else 'N/A'
stops_str = str(flight.stops) if hasattr(flight, 'stops') else 'N/A'
airline_str = flight.airline if hasattr(flight, 'airline') else 'N/A'
explanation = f"Flight quality score based on price ({price_str}), duration ({duration_str}), stops ({stops_str}), and airline ({airline_str})"
return QualityScore(
component="flight",
score=final_score,
max_score=1.0,
factors=factors,
explanation=explanation,
recommendations=self._get_flight_recommendations(final_score, factors)
)
async def _calculate_hotel_quality(self, hotel: HotelOption) -> QualityScore:
"""Calculate quality score for hotel."""
factors = {}
total_score = 0.0
max_score = 0.0
# Rating factor
rating_factor = 0.4
max_score += rating_factor
if hasattr(hotel, 'rating'):
rating_score = hotel.rating / 5.0 # Normalize to 0-1
factors["rating"] = rating_score * rating_factor
total_score += factors["rating"]
# Location factor (closer to center is better)
location_factor = 0.3
max_score += location_factor
if hasattr(hotel, 'distance_city_center_km'):
distance = hotel.distance_city_center_km
location_score = max(0, 1.0 - (distance / 10.0)) # Best within 1km, poor beyond 10km
factors["location"] = location_score * location_factor
total_score += factors["location"]
# Price factor
price_factor = 0.3
max_score += price_factor
if hasattr(hotel, 'price_per_night'):
# Normalize price (assume $150/night is average)
normalized_price = min(float(hotel.price_per_night) / 300, 2.0)
price_score = max(0, 1.0 - (normalized_price - 0.5))
factors["price"] = price_score * price_factor
total_score += factors["price"]
# Calculate final score
final_score = total_score / max_score if max_score > 0 else 0.0
rating_str = f"{hotel.rating}/5" if hasattr(hotel, 'rating') else 'N/A'
location_str = f"{hotel.distance_city_center_km}km from center" if hasattr(hotel, 'distance_city_center_km') else 'N/A'
price_str = f"${hotel.price_per_night}/night" if hasattr(hotel, 'price_per_night') else 'N/A'
explanation = f"Hotel quality based on rating ({rating_str}), location ({location_str}), and price ({price_str})"
return QualityScore(
component="hotel",
score=final_score,
max_score=1.0,
factors=factors,
explanation=explanation,
recommendations=self._get_hotel_recommendations(final_score, factors)
)
async def _calculate_poi_quality(self, pois: List[POI]) -> QualityScore:
"""Calculate quality score for POIs."""
if not pois:
return QualityScore(
component="pois",
score=0.0,
explanation="No activities selected",
recommendations=["Add activities to improve trip experience"]
)
factors = {}
total_score = 0.0
max_score = 0.0
# Diversity factor (variety of activities)
diversity_factor = 0.3
max_score += diversity_factor
categories = set()
for poi in pois:
if hasattr(poi, 'category'):
categories.add(poi.category.value if hasattr(poi.category, 'value') else str(poi.category))
diversity_score = min(len(categories) / 3.0, 1.0) # Best with 3+ categories
factors["diversity"] = diversity_score * diversity_factor
total_score += factors["diversity"]
# Rating factor (average rating of POIs)
rating_factor = 0.4
max_score += rating_factor
total_rating = 0
rated_pois = 0
for poi in pois:
if hasattr(poi, 'rating'):
total_rating += poi.rating
rated_pois += 1
if rated_pois > 0:
avg_rating = total_rating / rated_pois
rating_score = avg_rating / 5.0 # Normalize to 0-1
factors["rating"] = rating_score * rating_factor
total_score += factors["rating"]
# Value factor (mix of free and paid activities)
value_factor = 0.3
max_score += value_factor
free_activities = 0
for poi in pois:
if hasattr(poi, 'adult_price'):
if poi.adult_price == 0 or poi.adult_price is None:
free_activities += 1
value_score = min(free_activities / len(pois), 1.0) if pois else 0.0
factors["value"] = value_score * value_factor
total_score += factors["value"]
# Calculate final score
final_score = total_score / max_score if max_score > 0 else 0.0
avg_rating_str = f"{avg_rating:.1f}/5" if rated_pois > 0 else 'N/A'
free_percentage = f"{(free_activities/len(pois)*100):.0f}%" if pois else "0%"
explanation = f"Activity quality based on diversity ({len(categories)} categories), average rating ({avg_rating_str}), and value ({free_percentage} free activities)"
return QualityScore(
component="pois",
score=final_score,
max_score=1.0,
factors=factors,
explanation=explanation,
recommendations=self._get_poi_recommendations(final_score, factors)
)
async def _calculate_overall_quality(self, context: Dict[str, Any]) -> QualityScore:
"""Calculate overall trip quality score."""
trip_context = context["trip_context"]
agent_outputs = context["agent_outputs"]
factors = {}
total_score = 0.0
max_score = 0.0
# Component balance factor (all components present)
balance_factor = 0.3
max_score += balance_factor
components = ["flight", "hotel", "pois"]
present_components = sum(1 for comp in components if comp in agent_outputs)
balance_score = present_components / len(components)
factors["balance"] = balance_score * balance_factor
total_score += factors["balance"]
# Budget utilization factor (good budget allocation)
budget_factor = 0.3
max_score += budget_factor
total_budget = float(trip_context.get("total_budget", 1))
total_cost = 0.0
# Calculate total cost from all agents
for agent_id, output in agent_outputs.items():
if agent_id == "flight" and hasattr(output, 'price'):
total_cost += float(output.price)
elif agent_id == "hotel" and hasattr(output, 'price_per_night'):
# Assume 3 nights for demo
nights = 3
total_cost += float(output.price_per_night) * nights
elif agent_id == "pois" and isinstance(output, list):
# Calculate POI costs
for poi in output:
if hasattr(poi, 'adult_price') and poi.adult_price:
total_cost += float(poi.adult_price)
budget_utilization = min(total_cost / total_budget, 1.0) if total_budget > 0 else 0
# Optimal utilization is around 80-90%
if 0.7 <= budget_utilization <= 0.95:
budget_score = 1.0
elif budget_utilization < 0.7:
budget_score = budget_utilization / 0.7 # Under-utilization penalty
else:
budget_score = max(0, 1.0 - (budget_utilization - 0.95) / 0.05) # Over-budget penalty
factors["budget"] = budget_score * budget_factor
total_score += factors["budget"]
# Preference alignment factor
preference_factor = 0.4
max_score += preference_factor
preferences = trip_context.get("preferences", {})
alignment_score = self._calculate_preference_alignment(preferences, agent_outputs)
factors["preferences"] = alignment_score * preference_factor
total_score += factors["preferences"]
# Calculate final score
final_score = total_score / max_score if max_score > 0 else 0.0
explanation = f"Overall trip quality based on component balance ({present_components}/{len(components)}), "
explanation += f"budget utilization ({(budget_utilization*100):.0f}%), "
explanation += f"and preference alignment ({(alignment_score*100):.0f}%)"
return QualityScore(
component="overall",
score=final_score,
max_score=1.0,
factors=factors,
explanation=explanation,
recommendations=self._get_overall_recommendations(final_score, factors)
)
def _calculate_preference_alignment(self, preferences: Dict[str, Any], agent_outputs: Dict[str, Any]) -> float:
"""Calculate how well the trip aligns with user preferences."""
alignment_score = 0.0
total_checks = 0
# Check flight preferences
if preferences.get("prefer_direct_flights", False) and "flight" in agent_outputs:
flight = agent_outputs["flight"]
if hasattr(flight, 'stops') and flight.stops == 0:
alignment_score += 1.0
total_checks += 1
# Check hotel preferences
if preferences.get("prefer_central_hotels", False) and "hotel" in agent_outputs:
hotel = agent_outputs["hotel"]
if hasattr(hotel, 'distance_city_center_km') and hotel.distance_city_center_km <= 5:
alignment_score += 1.0
total_checks += 1
# Check activity preferences
if preferences.get("prefer_morning_activities", False) and "pois" in agent_outputs:
# In real implementation, check POI opening hours
alignment_score += 0.5 # Placeholder
total_checks += 1
return alignment_score / total_checks if total_checks > 0 else 0.5
def _calculate_overall_score(self, quality_scores: Dict[str, QualityScore]) -> float:
"""Calculate overall validation score."""
if not quality_scores:
return 0.0
# Weight different components
weights = {
"flight": 0.3,
"hotel": 0.3,
"pois": 0.2,
"overall": 0.2
}
weighted_score = 0.0
total_weight = 0.0
for component, score in quality_scores.items():
if component in weights:
weighted_score += score.score * weights[component]
total_weight += weights[component]
return weighted_score / total_weight if total_weight > 0 else 0.0
async def _generate_recommendations(self, context: Dict[str, Any]) -> List[str]:
"""Generate actionable recommendations for improving the trip."""
recommendations = []
# Get validation results
validation_results = context["validation_results"]
# Check for critical issues
critical_issues = []
for level_results in validation_results["validation_levels"].values():
critical_issues.extend([r for r in level_results if r.severity == ValidationSeverity.CRITICAL])
if critical_issues:
recommendations.append("🚨 Address critical issues before proceeding with trip booking")
# Check quality scores
quality_scores = validation_results.get("quality_scores", {})
for component, score in quality_scores.items():
if score["score"] < 0.6: # Low quality
recommendations.extend(score["recommendations"])
# Check conflicts
conflicts = validation_results.get("conflicts", [])
for conflict in conflicts:
if conflict["impact_assessment"]["severity"] in ["high", "medium"]:
recommendations.append(f"⚠️ Resolve {conflict['conflict_type']} conflict: {conflict['description']}")
# General recommendations
overall_score = validation_results.get("overall_score", 0.0)
if overall_score < 0.7:
recommendations.append("💡 Consider adjusting trip components to improve overall quality")
return recommendations
def _create_validation_summary(self, validation_results: Dict[str, Any]) -> Dict[str, Any]:
"""Create a summary of validation results."""
summary = {
"total_validations": 0,
"critical_issues": 0,
"warnings": 0,
"passed": 0,
"conflicts_found": len(validation_results.get("conflicts", [])),
"quality_grade": self._calculate_quality_grade(validation_results.get("overall_score", 0.0)),
"recommendation_count": len(validation_results.get("recommendations", []))
}
# Count validation results by severity
for level_results in validation_results.get("validation_levels", {}).values():
for result in level_results:
summary["total_validations"] += 1
if result.severity == ValidationSeverity.CRITICAL:
summary["critical_issues"] += 1
elif result.severity == ValidationSeverity.WARNING:
summary["warnings"] += 1
elif result.severity == ValidationSeverity.PASS:
summary["passed"] += 1
return summary
def _calculate_quality_grade(self, score: float) -> str:
"""Calculate quality grade based on score."""
if score >= 0.9:
return "A+"
elif score >= 0.8:
return "A"
elif score >= 0.7:
return "B"
elif score >= 0.6:
return "C"
elif score >= 0.5:
return "D"
else:
return "F"
def _flatten_validation_results(self, validation_results: Dict[str, Any]) -> List[ValidationResult]:
"""Flatten validation results into a single list."""
flattened = []
for level_results in validation_results.get("validation_levels", {}).values():
for result_data in level_results:
result = ValidationResult(
level=ValidationLevel(result_data["level"]),
severity=ValidationSeverity(result_data["severity"]),
rule_name=result_data["rule_name"],
message=result_data["message"],
details=result_data["details"],
suggestions=result_data["suggestions"],
agent_id=result_data.get("agent_id"),
affected_elements=result_data.get("affected_elements", [])
)
flattened.append(result)
return flattened
def _get_flight_recommendations(self, score: float, factors: Dict[str, float]) -> List[str]:
"""Get recommendations for improving flight quality."""
recommendations = []
if score < 0.6:
if factors.get("price", 0) < 0.5:
recommendations.append("Consider more expensive flights for better quality")
if factors.get("duration", 0) < 0.5:
recommendations.append("Look for shorter flight options")
if factors.get("stops", 0) < 0.5:
recommendations.append("Consider direct flights to reduce travel time")
return recommendations
def _get_hotel_recommendations(self, score: float, factors: Dict[str, float]) -> List[str]:
"""Get recommendations for improving hotel quality."""
recommendations = []
if score < 0.6:
if factors.get("rating", 0) < 0.5:
recommendations.append("Consider hotels with higher ratings")
if factors.get("location", 0) < 0.5:
recommendations.append("Look for hotels closer to city center")
if factors.get("price", 0) < 0.5:
recommendations.append("Consider adjusting budget for better hotel")
return recommendations
def _get_poi_recommendations(self, score: float, factors: Dict[str, float]) -> List[str]:
"""Get recommendations for improving POI quality."""
recommendations = []
if score < 0.6:
if factors.get("diversity", 0) < 0.5:
recommendations.append("Add more diverse activity types")
if factors.get("rating", 0) < 0.5:
recommendations.append("Select higher-rated activities")
if factors.get("value", 0) < 0.5:
recommendations.append("Include more free or low-cost activities")
return recommendations
def _get_overall_recommendations(self, score: float, factors: Dict[str, float]) -> List[str]:
"""Get recommendations for improving overall trip quality."""
recommendations = []
if score < 0.6:
if factors.get("balance", 0) < 0.5:
recommendations.append("Ensure all trip components are included")
if factors.get("budget", 0) < 0.5:
recommendations.append("Optimize budget allocation across components")
if factors.get("preferences", 0) < 0.5:
recommendations.append("Better align trip with user preferences")
return recommendations