""" Distributed State Management for AI Travel Agents This module implements sophisticated distributed state management patterns for AI systems, enabling agents to share context, maintain consistency, and handle real-time updates like experienced travel planners working together. Key Features: 1. Shared TripContext with versioning and rollback 2. Context propagation with dependency tracking 3. Constraint satisfaction and validation 4. Real-time updates with conflict resolution 5. Persistent state for session recovery 6. Event-driven architecture for scalability """ import asyncio import json import uuid from abc import ABC, abstractmethod from datetime import datetime, timedelta from enum import Enum from typing import Dict, List, Optional, Any, Set, Tuple, Union, Callable from dataclasses import dataclass, field, asdict from decimal import Decimal import threading from collections import defaultdict, deque import pickle import hashlib from pydantic import BaseModel, Field, validator from ..models.flight_models import FlightOption from ..models.hotel_models import HotelOption from ..models.poi_models import POI class ContextEventType(str, Enum): """Types of context events that can trigger updates.""" BUDGET_UPDATE = "budget_update" PREFERENCE_CHANGE = "preference_change" CONSTRAINT_ADDED = "constraint_added" CONSTRAINT_REMOVED = "constraint_removed" AGENT_DECISION = "agent_decision" AVAILABILITY_CHANGE = "availability_change" PRICE_CHANGE = "price_change" PLAN_CHANGE = "plan_change" ROLLBACK_REQUEST = "rollback_request" class ConstraintType(str, Enum): """Types of constraints that can be enforced.""" BUDGET = "budget" TIME = "time" LOCATION = "location" AVAILABILITY = "availability" PREFERENCE = "preference" DEPENDENCY = "dependency" QUALITY = "quality" FLEXIBILITY = "flexibility" class ConstraintSeverity(str, Enum): """Severity levels for constraints.""" HARD = "hard" # Must be satisfied SOFT = "soft" # Preferred but not required WARNING = "warning" # Advisory only @dataclass class Constraint: """Represents a constraint in the trip planning context.""" constraint_id: str constraint_type: ConstraintType severity: ConstraintSeverity source_agent: str target_agents: List[str] constraint_data: Dict[str, Any] created_at: datetime = field(default_factory=datetime.now) expires_at: Optional[datetime] = None version: int = 1 is_active: bool = True def to_dict(self) -> Dict[str, Any]: """Convert constraint to dictionary for serialization.""" return { "constraint_id": self.constraint_id, "constraint_type": self.constraint_type.value, "severity": self.severity.value, "source_agent": self.source_agent, "target_agents": self.target_agents, "constraint_data": self.constraint_data, "created_at": self.created_at.isoformat(), "expires_at": self.expires_at.isoformat() if self.expires_at else None, "version": self.version, "is_active": self.is_active } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Constraint': """Create constraint from dictionary.""" return cls( constraint_id=data["constraint_id"], constraint_type=ConstraintType(data["constraint_type"]), severity=ConstraintSeverity(data["severity"]), source_agent=data["source_agent"], target_agents=data["target_agents"], constraint_data=data["constraint_data"], created_at=datetime.fromisoformat(data["created_at"]), expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None, version=data["version"], is_active=data["is_active"] ) @dataclass class ContextEvent: """Represents an event that triggers context updates.""" event_id: str = field(default_factory=lambda: str(uuid.uuid4())) event_type: ContextEventType = ContextEventType.PLAN_CHANGE source_agent: str = "system" target_agents: List[str] = field(default_factory=list) event_data: Dict[str, Any] = field(default_factory=dict) timestamp: datetime = field(default_factory=datetime.now) priority: int = 1 # 1 = highest, 5 = lowest requires_acknowledgment: bool = False correlation_id: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert event to dictionary for serialization.""" return { "event_id": self.event_id, "event_type": self.event_type.value, "source_agent": self.source_agent, "target_agents": self.target_agents, "event_data": self.event_data, "timestamp": self.timestamp.isoformat(), "priority": self.priority, "requires_acknowledgment": self.requires_acknowledgment, "correlation_id": self.correlation_id } @dataclass class ContextSnapshot: """Represents a versioned snapshot of the trip context.""" snapshot_id: str = field(default_factory=lambda: str(uuid.uuid4())) version: int = 1 timestamp: datetime = field(default_factory=datetime.now) context_data: Dict[str, Any] = field(default_factory=dict) constraints: List[Constraint] = field(default_factory=list) agent_states: Dict[str, Dict[str, Any]] = field(default_factory=dict) checksum: str = "" def __post_init__(self): """Calculate checksum after initialization.""" if not self.checksum: self.checksum = self._calculate_checksum() def _calculate_checksum(self) -> str: """Calculate checksum for data integrity verification.""" data_str = json.dumps({ "version": self.version, "context_data": self.context_data, "constraints": [c.to_dict() for c in self.constraints], "agent_states": self.agent_states }, sort_keys=True, default=str) return hashlib.md5(data_str.encode()).hexdigest() def to_dict(self) -> Dict[str, Any]: """Convert snapshot to dictionary for serialization.""" return { "snapshot_id": self.snapshot_id, "version": self.version, "timestamp": self.timestamp.isoformat(), "context_data": self.context_data, "constraints": [c.to_dict() for c in self.constraints], "agent_states": self.agent_states, "checksum": self.checksum } class ContextObserver(ABC): """Abstract base class for context observers.""" @abstractmethod async def on_context_update(self, event: ContextEvent, context_data: Dict[str, Any]) -> None: """Handle context update events.""" pass @abstractmethod async def on_constraint_added(self, constraint: Constraint) -> None: """Handle new constraint addition.""" pass @abstractmethod async def on_constraint_removed(self, constraint_id: str) -> None: """Handle constraint removal.""" pass class TripContext: """ Distributed trip context that maintains shared state across all agents. This class implements sophisticated state management patterns including: - Versioned state snapshots for rollback - Event-driven updates with propagation - Constraint satisfaction and validation - Real-time updates with conflict resolution - Persistent state for session recovery """ def __init__(self, trip_id: str, initial_budget: Decimal = Decimal('2000.00')): self.trip_id = trip_id self.version = 1 self.created_at = datetime.now() self.last_updated = datetime.now() # Core context data self.context_data: Dict[str, Any] = { "trip_id": trip_id, "total_budget": float(initial_budget), "remaining_budget": float(initial_budget), "trip_type": "leisure", "priority": "experience_first", "origin": "", "destination": "", "travel_dates": {}, "passengers": 1, "guests": 1, "interests": [], "preferences": {}, "special_requirements": [] } # State management self.constraints: Dict[str, Constraint] = {} self.agent_states: Dict[str, Dict[str, Any]] = {} self.snapshots: deque = deque(maxlen=50) # Keep last 50 snapshots self.event_history: deque = deque(maxlen=1000) # Keep last 1000 events # Observers and propagation self.observers: Dict[str, ContextObserver] = {} self.propagation_rules: Dict[str, List[str]] = {} self.dependency_graph: Dict[str, Set[str]] = defaultdict(set) # Real-time updates self.update_queue: asyncio.Queue = asyncio.Queue() self.update_lock = asyncio.Lock() self.pending_updates: Set[str] = set() # Persistence self.persistence_enabled = True self.auto_save_interval = 30 # seconds self._persistence_timer: Optional[asyncio.Task] = None # Create initial snapshot self._create_snapshot("initial_state") # Start background tasks self._start_background_tasks() def _start_background_tasks(self): """Start background tasks for state management.""" if not hasattr(self, '_background_tasks'): self._background_tasks = set() # Only start background tasks if we're in an async context try: loop = asyncio.get_event_loop() if loop.is_running(): # Start update processor task = asyncio.create_task(self._process_updates()) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) # Start auto-save if persistence enabled if self.persistence_enabled: task = asyncio.create_task(self._auto_save_loop()) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) else: print("No running event loop - background tasks will be started when async context is available") except RuntimeError: print("No event loop available - background tasks will be started when async context is available") async def _process_updates(self): """Process queued context updates.""" while True: try: event = await self.update_queue.get() await self._handle_event(event) self.update_queue.task_done() except Exception as e: print(f"Error processing update: {e}") await asyncio.sleep(1) async def _auto_save_loop(self): """Auto-save context state at regular intervals.""" while True: try: await asyncio.sleep(self.auto_save_interval) await self.save_state() except Exception as e: print(f"Error in auto-save: {e}") async def _handle_event(self, event: ContextEvent): """Handle a context event with proper locking and validation.""" async with self.update_lock: try: # Validate event if not self._validate_event(event): print(f"Invalid event: {event.event_id}") return # Apply event await self._apply_event(event) # Propagate to observers await self._propagate_event(event) # Record in history self.event_history.append(event) # Create snapshot if significant change if self._is_significant_change(event): self._create_snapshot(f"event_{event.event_id}") except Exception as e: print(f"Error handling event {event.event_id}: {e}") def _validate_event(self, event: ContextEvent) -> bool: """Validate that an event can be applied.""" # Check if event is not too old if datetime.now() - event.timestamp > timedelta(hours=1): return False # Check if event conflicts with existing constraints if event.event_type == ContextEventType.CONSTRAINT_ADDED: return self._validate_constraint_addition(event.event_data) return True def _validate_constraint_addition(self, constraint_data: Dict[str, Any]) -> bool: """Validate that a new constraint doesn't conflict with existing ones.""" # Check for hard constraint conflicts existing_hard_constraints = [ c for c in self.constraints.values() if c.severity == ConstraintSeverity.HARD and c.is_active ] # Add validation logic here based on constraint type return True async def _apply_event(self, event: ContextEvent): """Apply an event to the context state.""" if event.event_type == ContextEventType.BUDGET_UPDATE: await self._apply_budget_update(event) elif event.event_type == ContextEventType.PREFERENCE_CHANGE: await self._apply_preference_change(event) elif event.event_type == ContextEventType.CONSTRAINT_ADDED: await self._apply_constraint_addition(event) elif event.event_type == ContextEventType.CONSTRAINT_REMOVED: await self._apply_constraint_removal(event) elif event.event_type == ContextEventType.AGENT_DECISION: await self._apply_agent_decision(event) elif event.event_type == ContextEventType.PRICE_CHANGE: await self._apply_price_change(event) elif event.event_type == ContextEventType.AVAILABILITY_CHANGE: await self._apply_availability_change(event) elif event.event_type == ContextEventType.ROLLBACK_REQUEST: await self._apply_rollback(event) self.last_updated = datetime.now() self.version += 1 async def _apply_budget_update(self, event: ContextEvent): """Apply budget update event.""" budget_data = event.event_data.get("budget_data", {}) if "total_budget" in budget_data: self.context_data["total_budget"] = budget_data["total_budget"] if "remaining_budget" in budget_data: self.context_data["remaining_budget"] = budget_data["remaining_budget"] # Update budget constraints await self._update_budget_constraints() async def _apply_preference_change(self, event: ContextEvent): """Apply preference change event.""" preference_data = event.event_data.get("preference_data", {}) for key, value in preference_data.items(): self.context_data["preferences"][key] = value # Update preference constraints await self._update_preference_constraints() async def _apply_constraint_addition(self, event: ContextEvent): """Apply constraint addition event.""" constraint_data = event.event_data.get("constraint_data", {}) constraint = Constraint.from_dict(constraint_data) self.constraints[constraint.constraint_id] = constraint # Update dependency graph for target_agent in constraint.target_agents: self.dependency_graph[constraint.source_agent].add(target_agent) async def _apply_constraint_removal(self, event: ContextEvent): """Apply constraint removal event.""" constraint_id = event.event_data.get("constraint_id") if constraint_id in self.constraints: constraint = self.constraints[constraint_id] constraint.is_active = False # Update dependency graph for target_agent in constraint.target_agents: self.dependency_graph[constraint.source_agent].discard(target_agent) async def _apply_agent_decision(self, event: ContextEvent): """Apply agent decision event.""" agent_id = event.source_agent decision_data = event.event_data.get("decision_data", {}) # Update agent state if agent_id not in self.agent_states: self.agent_states[agent_id] = {} self.agent_states[agent_id].update(decision_data) self.agent_states[agent_id]["last_decision"] = event.timestamp.isoformat() # Update context data based on decision await self._update_context_from_decision(agent_id, decision_data) async def _apply_price_change(self, event: ContextEvent): """Apply real-time price change event.""" price_data = event.event_data.get("price_data", {}) # Update relevant agent states for agent_id, agent_state in self.agent_states.items(): if "selected_options" in agent_state: for option in agent_state["selected_options"]: if option.get("id") == price_data.get("option_id"): option["price"] = price_data["new_price"] option["price_updated_at"] = event.timestamp.isoformat() # Trigger constraint re-evaluation await self._reevaluate_price_constraints(price_data) async def _apply_availability_change(self, event: ContextEvent): """Apply availability change event.""" availability_data = event.event_data.get("availability_data", {}) # Update relevant agent states for agent_id, agent_state in self.agent_states.items(): if "selected_options" in agent_state: for option in agent_state["selected_options"]: if option.get("id") == availability_data.get("option_id"): option["available"] = availability_data["available"] option["availability_updated_at"] = event.timestamp.isoformat() # Trigger constraint re-evaluation await self._reevaluate_availability_constraints(availability_data) async def _apply_rollback(self, event: ContextEvent): """Apply rollback to a previous state.""" target_version = event.event_data.get("target_version") if target_version: await self.rollback_to_version(target_version) async def _update_context_from_decision(self, agent_id: str, decision_data: Dict[str, Any]): """Update context data based on agent decision.""" if agent_id == "flight_agent": if "selected_flight" in decision_data: flight = decision_data["selected_flight"] self.context_data["selected_flight"] = flight.to_dict() if hasattr(flight, 'to_dict') else flight # Update travel dates if not set if not self.context_data.get("travel_dates"): self.context_data["travel_dates"] = { "departure": flight.departure_time.isoformat() if hasattr(flight, 'departure_time') else None, "return": flight.arrival_time.isoformat() if hasattr(flight, 'arrival_time') else None } elif agent_id == "hotel_agent": if "selected_hotel" in decision_data: hotel = decision_data["selected_hotel"] self.context_data["selected_hotel"] = hotel.to_dict() if hasattr(hotel, 'to_dict') else hotel elif agent_id == "poi_agent": if "selected_pois" in decision_data: pois = decision_data["selected_pois"] self.context_data["selected_pois"] = [ poi.to_dict() if hasattr(poi, 'to_dict') else poi for poi in pois ] async def _update_budget_constraints(self): """Update budget-related constraints.""" remaining_budget = self.context_data.get("remaining_budget", 0) # Update existing budget constraints for constraint in self.constraints.values(): if constraint.constraint_type == ConstraintType.BUDGET and constraint.is_active: if "max_budget" in constraint.constraint_data: constraint.constraint_data["max_budget"] = remaining_budget * 0.8 # 80% of remaining async def _update_preference_constraints(self): """Update preference-related constraints.""" preferences = self.context_data.get("preferences", {}) # Update existing preference constraints for constraint in self.constraints.values(): if constraint.constraint_type == ConstraintType.PREFERENCE and constraint.is_active: constraint.constraint_data.update(preferences) async def _reevaluate_price_constraints(self, price_data: Dict[str, Any]): """Re-evaluate constraints after price changes.""" affected_agents = set() # Find agents affected by price change for agent_id, agent_state in self.agent_states.items(): if "selected_options" in agent_state: for option in agent_state["selected_options"]: if option.get("id") == price_data.get("option_id"): affected_agents.add(agent_id) # Trigger constraint re-evaluation for affected agents for agent_id in affected_agents: await self._notify_constraint_change(agent_id, "price_change") async def _reevaluate_availability_constraints(self, availability_data: Dict[str, Any]): """Re-evaluate constraints after availability changes.""" affected_agents = set() # Find agents affected by availability change for agent_id, agent_state in self.agent_states.items(): if "selected_options" in agent_state: for option in agent_state["selected_options"]: if option.get("id") == availability_data.get("option_id"): affected_agents.add(agent_id) # Trigger constraint re-evaluation for affected agents for agent_id in affected_agents: await self._notify_constraint_change(agent_id, "availability_change") async def _notify_constraint_change(self, agent_id: str, change_type: str): """Notify an agent about constraint changes.""" if agent_id in self.observers: observer = self.observers[agent_id] event = ContextEvent( event_type=ContextEventType.CONSTRAINT_ADDED, source_agent="system", target_agents=[agent_id], event_data={"change_type": change_type} ) await observer.on_context_update(event, self.context_data) async def _propagate_event(self, event: ContextEvent): """Propagate event to relevant observers.""" # Determine which agents should receive the event target_agents = event.target_agents or [] # Add agents based on propagation rules if event.source_agent in self.propagation_rules: target_agents.extend(self.propagation_rules[event.source_agent]) # Add agents based on dependency graph if event.source_agent in self.dependency_graph: target_agents.extend(self.dependency_graph[event.source_agent]) # Remove duplicates target_agents = list(set(target_agents)) # Notify observers for agent_id in target_agents: if agent_id in self.observers: observer = self.observers[agent_id] try: await observer.on_context_update(event, self.context_data) except Exception as e: print(f"Error notifying observer {agent_id}: {e}") def _is_significant_change(self, event: ContextEvent) -> bool: """Determine if an event represents a significant change requiring a snapshot.""" significant_events = { ContextEventType.BUDGET_UPDATE, ContextEventType.PREFERENCE_CHANGE, ContextEventType.CONSTRAINT_ADDED, ContextEventType.CONSTRAINT_REMOVED, ContextEventType.AGENT_DECISION } return event.event_type in significant_events def _create_snapshot(self, reason: str): """Create a versioned snapshot of the current state.""" snapshot = ContextSnapshot( version=self.version, context_data=self.context_data.copy(), constraints=[c for c in self.constraints.values() if c.is_active], agent_states=self.agent_states.copy() ) self.snapshots.append(snapshot) # Trigger persistence if enabled if self.persistence_enabled: # Schedule persistence task without blocking try: loop = asyncio.get_event_loop() if loop.is_running(): loop.create_task(self._persist_snapshot(snapshot)) else: # If no event loop is running, just log it print(f"Scheduling snapshot persistence for {snapshot.snapshot_id}") except RuntimeError: # No event loop available, just log it print(f"Scheduling snapshot persistence for {snapshot.snapshot_id}") async def _persist_snapshot(self, snapshot: ContextSnapshot): """Persist snapshot to storage.""" try: # In a real implementation, this would save to database/file # For now, we'll just log it print(f"Persisting snapshot {snapshot.snapshot_id} (version {snapshot.version})") except Exception as e: print(f"Error persisting snapshot: {e}") # Public API Methods async def update_context(self, event: ContextEvent) -> bool: """Update context with an event.""" try: await self.update_queue.put(event) return True except Exception as e: print(f"Error queuing update: {e}") return False async def add_constraint(self, constraint: Constraint) -> bool: """Add a new constraint to the context.""" event = ContextEvent( event_type=ContextEventType.CONSTRAINT_ADDED, source_agent=constraint.source_agent, target_agents=constraint.target_agents, event_data={"constraint_data": constraint.to_dict()}, priority=1 if constraint.severity == ConstraintSeverity.HARD else 3 ) return await self.update_context(event) async def remove_constraint(self, constraint_id: str, source_agent: str) -> bool: """Remove a constraint from the context.""" if constraint_id not in self.constraints: return False event = ContextEvent( event_type=ContextEventType.CONSTRAINT_REMOVED, source_agent=source_agent, event_data={"constraint_id": constraint_id} ) return await self.update_context(event) def get_constraints_for_agent(self, agent_id: str, constraint_type: Optional[ConstraintType] = None) -> List[Constraint]: """Get constraints relevant to a specific agent.""" relevant_constraints = [] for constraint in self.constraints.values(): if not constraint.is_active: continue if agent_id in constraint.target_agents or constraint.source_agent == agent_id: if constraint_type is None or constraint.constraint_type == constraint_type: relevant_constraints.append(constraint) return relevant_constraints def validate_agent_decision(self, agent_id: str, decision_data: Dict[str, Any]) -> Tuple[bool, List[str]]: """Validate an agent's decision against current constraints.""" violations = [] constraints = self.get_constraints_for_agent(agent_id) for constraint in constraints: if constraint.severity == ConstraintSeverity.HARD: if not self._check_constraint_violation(constraint, decision_data): violations.append(f"Hard constraint violated: {constraint.constraint_id}") return len(violations) == 0, violations def _check_constraint_violation(self, constraint: Constraint, decision_data: Dict[str, Any]) -> bool: """Check if decision violates a specific constraint.""" if constraint.constraint_type == ConstraintType.BUDGET: cost = decision_data.get("cost", 0) max_budget = constraint.constraint_data.get("max_budget", float('inf')) return cost <= max_budget elif constraint.constraint_type == ConstraintType.TIME: # Add time constraint validation logic return True elif constraint.constraint_type == ConstraintType.LOCATION: # Add location constraint validation logic return True return True async def rollback_to_version(self, target_version: int) -> bool: """Rollback context to a previous version.""" try: # Find target snapshot target_snapshot = None for snapshot in reversed(self.snapshots): if snapshot.version <= target_version: target_snapshot = snapshot break if not target_snapshot: return False # Restore state self.context_data = target_snapshot.context_data.copy() self.agent_states = target_snapshot.agent_states.copy() # Restore constraints self.constraints.clear() for constraint_data in target_snapshot.constraints: constraint = Constraint.from_dict(constraint_data.to_dict()) self.constraints[constraint.constraint_id] = constraint # Update version self.version = target_snapshot.version self.last_updated = datetime.now() # Create rollback event event = ContextEvent( event_type=ContextEventType.ROLLBACK_REQUEST, source_agent="system", event_data={"target_version": target_version, "rollback_reason": "user_request"} ) await self._propagate_event(event) return True except Exception as e: print(f"Error during rollback: {e}") return False def register_observer(self, agent_id: str, observer: ContextObserver): """Register an agent as a context observer.""" self.observers[agent_id] = observer def unregister_observer(self, agent_id: str): """Unregister an agent observer.""" if agent_id in self.observers: del self.observers[agent_id] def set_propagation_rule(self, source_agent: str, target_agents: List[str]): """Set propagation rules for context updates.""" self.propagation_rules[source_agent] = target_agents async def save_state(self) -> bool: """Save current state to persistent storage.""" try: if not self.persistence_enabled: return False # Create current snapshot snapshot = ContextSnapshot( version=self.version, context_data=self.context_data, constraints=list(self.constraints.values()), agent_states=self.agent_states ) # In a real implementation, save to database/file # For now, we'll just log it print(f"Saving state for trip {self.trip_id} (version {self.version})") return True except Exception as e: print(f"Error saving state: {e}") return False async def load_state(self, trip_id: str) -> bool: """Load state from persistent storage.""" try: # In a real implementation, load from database/file # For now, we'll just log it print(f"Loading state for trip {trip_id}") return True except Exception as e: print(f"Error loading state: {e}") return False def get_context_summary(self) -> Dict[str, Any]: """Get a summary of the current context state.""" return { "trip_id": self.trip_id, "version": self.version, "created_at": self.created_at.isoformat(), "last_updated": self.last_updated.isoformat(), "total_budget": self.context_data.get("total_budget", 0), "remaining_budget": self.context_data.get("remaining_budget", 0), "active_constraints": len([c for c in self.constraints.values() if c.is_active]), "registered_agents": list(self.observers.keys()), "snapshots_count": len(self.snapshots), "events_count": len(self.event_history) } async def cleanup(self): """Cleanup resources and stop background tasks.""" # Cancel background tasks for task in self._background_tasks: task.cancel() # Wait for tasks to complete await asyncio.gather(*self._background_tasks, return_exceptions=True) # Save final state if self.persistence_enabled: await self.save_state() class DistributedStateManager: """ Manages multiple TripContext instances and provides coordination between them. This class implements patterns for: - Multi-trip coordination - Cross-trip constraint propagation - Global state synchronization - Resource management """ def __init__(self): self.trip_contexts: Dict[str, TripContext] = {} self.global_constraints: Dict[str, Constraint] = {} self.coordination_rules: Dict[str, List[str]] = {} self.global_observers: List[ContextObserver] = [] async def create_trip_context(self, trip_id: str, initial_budget: Decimal = Decimal('2000.00')) -> TripContext: """Create a new trip context.""" context = TripContext(trip_id, initial_budget) self.trip_contexts[trip_id] = context # Apply global constraints for constraint in self.global_constraints.values(): await context.add_constraint(constraint) return context async def remove_trip_context(self, trip_id: str) -> bool: """Remove a trip context and cleanup resources.""" if trip_id not in self.trip_contexts: return False context = self.trip_contexts[trip_id] await context.cleanup() del self.trip_contexts[trip_id] return True def get_trip_context(self, trip_id: str) -> Optional[TripContext]: """Get a trip context by ID.""" return self.trip_contexts.get(trip_id) async def add_global_constraint(self, constraint: Constraint) -> bool: """Add a constraint that applies to all trips.""" self.global_constraints[constraint.constraint_id] = constraint # Apply to all existing contexts for context in self.trip_contexts.values(): await context.add_constraint(constraint) return True def set_coordination_rule(self, source_trip: str, target_trips: List[str]): """Set coordination rules between trips.""" self.coordination_rules[source_trip] = target_trips async def synchronize_contexts(self, trip_id: str, event: ContextEvent): """Synchronize context updates across coordinated trips.""" if trip_id in self.coordination_rules: target_trips = self.coordination_rules[trip_id] for target_trip_id in target_trips: if target_trip_id in self.trip_contexts: target_context = self.trip_contexts[target_trip_id] await target_context.update_context(event) def get_global_summary(self) -> Dict[str, Any]: """Get a summary of all trip contexts.""" return { "active_trips": len(self.trip_contexts), "global_constraints": len(self.global_constraints), "coordination_rules": len(self.coordination_rules), "trip_summaries": { trip_id: context.get_context_summary() for trip_id, context in self.trip_contexts.items() } }