Spaces:
Running
Running
| """ | |
| Distributed State Management for AI Travel Agents | |
| This module implements sophisticated distributed state management patterns for AI systems, | |
| enabling agents to share context, maintain consistency, and handle real-time updates | |
| like experienced travel planners working together. | |
| Key Features: | |
| 1. Shared TripContext with versioning and rollback | |
| 2. Context propagation with dependency tracking | |
| 3. Constraint satisfaction and validation | |
| 4. Real-time updates with conflict resolution | |
| 5. Persistent state for session recovery | |
| 6. Event-driven architecture for scalability | |
| """ | |
| import asyncio | |
| import json | |
| import uuid | |
| from abc import ABC, abstractmethod | |
| from datetime import datetime, timedelta | |
| from enum import Enum | |
| from typing import Dict, List, Optional, Any, Set, Tuple, Union, Callable | |
| from dataclasses import dataclass, field, asdict | |
| from decimal import Decimal | |
| import threading | |
| from collections import defaultdict, deque | |
| import pickle | |
| import hashlib | |
| from pydantic import BaseModel, Field, validator | |
| from ..models.flight_models import FlightOption | |
| from ..models.hotel_models import HotelOption | |
| from ..models.poi_models import POI | |
| class ContextEventType(str, Enum): | |
| """Types of context events that can trigger updates.""" | |
| BUDGET_UPDATE = "budget_update" | |
| PREFERENCE_CHANGE = "preference_change" | |
| CONSTRAINT_ADDED = "constraint_added" | |
| CONSTRAINT_REMOVED = "constraint_removed" | |
| AGENT_DECISION = "agent_decision" | |
| AVAILABILITY_CHANGE = "availability_change" | |
| PRICE_CHANGE = "price_change" | |
| PLAN_CHANGE = "plan_change" | |
| ROLLBACK_REQUEST = "rollback_request" | |
| class ConstraintType(str, Enum): | |
| """Types of constraints that can be enforced.""" | |
| BUDGET = "budget" | |
| TIME = "time" | |
| LOCATION = "location" | |
| AVAILABILITY = "availability" | |
| PREFERENCE = "preference" | |
| DEPENDENCY = "dependency" | |
| QUALITY = "quality" | |
| FLEXIBILITY = "flexibility" | |
| class ConstraintSeverity(str, Enum): | |
| """Severity levels for constraints.""" | |
| HARD = "hard" # Must be satisfied | |
| SOFT = "soft" # Preferred but not required | |
| WARNING = "warning" # Advisory only | |
| class Constraint: | |
| """Represents a constraint in the trip planning context.""" | |
| constraint_id: str | |
| constraint_type: ConstraintType | |
| severity: ConstraintSeverity | |
| source_agent: str | |
| target_agents: List[str] | |
| constraint_data: Dict[str, Any] | |
| created_at: datetime = field(default_factory=datetime.now) | |
| expires_at: Optional[datetime] = None | |
| version: int = 1 | |
| is_active: bool = True | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert constraint to dictionary for serialization.""" | |
| return { | |
| "constraint_id": self.constraint_id, | |
| "constraint_type": self.constraint_type.value, | |
| "severity": self.severity.value, | |
| "source_agent": self.source_agent, | |
| "target_agents": self.target_agents, | |
| "constraint_data": self.constraint_data, | |
| "created_at": self.created_at.isoformat(), | |
| "expires_at": self.expires_at.isoformat() if self.expires_at else None, | |
| "version": self.version, | |
| "is_active": self.is_active | |
| } | |
| def from_dict(cls, data: Dict[str, Any]) -> 'Constraint': | |
| """Create constraint from dictionary.""" | |
| return cls( | |
| constraint_id=data["constraint_id"], | |
| constraint_type=ConstraintType(data["constraint_type"]), | |
| severity=ConstraintSeverity(data["severity"]), | |
| source_agent=data["source_agent"], | |
| target_agents=data["target_agents"], | |
| constraint_data=data["constraint_data"], | |
| created_at=datetime.fromisoformat(data["created_at"]), | |
| expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None, | |
| version=data["version"], | |
| is_active=data["is_active"] | |
| ) | |
| class ContextEvent: | |
| """Represents an event that triggers context updates.""" | |
| event_id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| event_type: ContextEventType = ContextEventType.PLAN_CHANGE | |
| source_agent: str = "system" | |
| target_agents: List[str] = field(default_factory=list) | |
| event_data: Dict[str, Any] = field(default_factory=dict) | |
| timestamp: datetime = field(default_factory=datetime.now) | |
| priority: int = 1 # 1 = highest, 5 = lowest | |
| requires_acknowledgment: bool = False | |
| correlation_id: Optional[str] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert event to dictionary for serialization.""" | |
| return { | |
| "event_id": self.event_id, | |
| "event_type": self.event_type.value, | |
| "source_agent": self.source_agent, | |
| "target_agents": self.target_agents, | |
| "event_data": self.event_data, | |
| "timestamp": self.timestamp.isoformat(), | |
| "priority": self.priority, | |
| "requires_acknowledgment": self.requires_acknowledgment, | |
| "correlation_id": self.correlation_id | |
| } | |
| class ContextSnapshot: | |
| """Represents a versioned snapshot of the trip context.""" | |
| snapshot_id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| version: int = 1 | |
| timestamp: datetime = field(default_factory=datetime.now) | |
| context_data: Dict[str, Any] = field(default_factory=dict) | |
| constraints: List[Constraint] = field(default_factory=list) | |
| agent_states: Dict[str, Dict[str, Any]] = field(default_factory=dict) | |
| checksum: str = "" | |
| def __post_init__(self): | |
| """Calculate checksum after initialization.""" | |
| if not self.checksum: | |
| self.checksum = self._calculate_checksum() | |
| def _calculate_checksum(self) -> str: | |
| """Calculate checksum for data integrity verification.""" | |
| data_str = json.dumps({ | |
| "version": self.version, | |
| "context_data": self.context_data, | |
| "constraints": [c.to_dict() for c in self.constraints], | |
| "agent_states": self.agent_states | |
| }, sort_keys=True, default=str) | |
| return hashlib.md5(data_str.encode()).hexdigest() | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert snapshot to dictionary for serialization.""" | |
| return { | |
| "snapshot_id": self.snapshot_id, | |
| "version": self.version, | |
| "timestamp": self.timestamp.isoformat(), | |
| "context_data": self.context_data, | |
| "constraints": [c.to_dict() for c in self.constraints], | |
| "agent_states": self.agent_states, | |
| "checksum": self.checksum | |
| } | |
| class ContextObserver(ABC): | |
| """Abstract base class for context observers.""" | |
| async def on_context_update(self, event: ContextEvent, context_data: Dict[str, Any]) -> None: | |
| """Handle context update events.""" | |
| pass | |
| async def on_constraint_added(self, constraint: Constraint) -> None: | |
| """Handle new constraint addition.""" | |
| pass | |
| async def on_constraint_removed(self, constraint_id: str) -> None: | |
| """Handle constraint removal.""" | |
| pass | |
| class TripContext: | |
| """ | |
| Distributed trip context that maintains shared state across all agents. | |
| This class implements sophisticated state management patterns including: | |
| - Versioned state snapshots for rollback | |
| - Event-driven updates with propagation | |
| - Constraint satisfaction and validation | |
| - Real-time updates with conflict resolution | |
| - Persistent state for session recovery | |
| """ | |
| def __init__(self, trip_id: str, initial_budget: Decimal = Decimal('2000.00')): | |
| self.trip_id = trip_id | |
| self.version = 1 | |
| self.created_at = datetime.now() | |
| self.last_updated = datetime.now() | |
| # Core context data | |
| self.context_data: Dict[str, Any] = { | |
| "trip_id": trip_id, | |
| "total_budget": float(initial_budget), | |
| "remaining_budget": float(initial_budget), | |
| "trip_type": "leisure", | |
| "priority": "experience_first", | |
| "origin": "", | |
| "destination": "", | |
| "travel_dates": {}, | |
| "passengers": 1, | |
| "guests": 1, | |
| "interests": [], | |
| "preferences": {}, | |
| "special_requirements": [] | |
| } | |
| # State management | |
| self.constraints: Dict[str, Constraint] = {} | |
| self.agent_states: Dict[str, Dict[str, Any]] = {} | |
| self.snapshots: deque = deque(maxlen=50) # Keep last 50 snapshots | |
| self.event_history: deque = deque(maxlen=1000) # Keep last 1000 events | |
| # Observers and propagation | |
| self.observers: Dict[str, ContextObserver] = {} | |
| self.propagation_rules: Dict[str, List[str]] = {} | |
| self.dependency_graph: Dict[str, Set[str]] = defaultdict(set) | |
| # Real-time updates | |
| self.update_queue: asyncio.Queue = asyncio.Queue() | |
| self.update_lock = asyncio.Lock() | |
| self.pending_updates: Set[str] = set() | |
| # Persistence | |
| self.persistence_enabled = True | |
| self.auto_save_interval = 30 # seconds | |
| self._persistence_timer: Optional[asyncio.Task] = None | |
| # Create initial snapshot | |
| self._create_snapshot("initial_state") | |
| # Start background tasks | |
| self._start_background_tasks() | |
| def _start_background_tasks(self): | |
| """Start background tasks for state management.""" | |
| if not hasattr(self, '_background_tasks'): | |
| self._background_tasks = set() | |
| # Only start background tasks if we're in an async context | |
| try: | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| # Start update processor | |
| task = asyncio.create_task(self._process_updates()) | |
| self._background_tasks.add(task) | |
| task.add_done_callback(self._background_tasks.discard) | |
| # Start auto-save if persistence enabled | |
| if self.persistence_enabled: | |
| task = asyncio.create_task(self._auto_save_loop()) | |
| self._background_tasks.add(task) | |
| task.add_done_callback(self._background_tasks.discard) | |
| else: | |
| print("No running event loop - background tasks will be started when async context is available") | |
| except RuntimeError: | |
| print("No event loop available - background tasks will be started when async context is available") | |
| async def _process_updates(self): | |
| """Process queued context updates.""" | |
| while True: | |
| try: | |
| event = await self.update_queue.get() | |
| await self._handle_event(event) | |
| self.update_queue.task_done() | |
| except Exception as e: | |
| print(f"Error processing update: {e}") | |
| await asyncio.sleep(1) | |
| async def _auto_save_loop(self): | |
| """Auto-save context state at regular intervals.""" | |
| while True: | |
| try: | |
| await asyncio.sleep(self.auto_save_interval) | |
| await self.save_state() | |
| except Exception as e: | |
| print(f"Error in auto-save: {e}") | |
| async def _handle_event(self, event: ContextEvent): | |
| """Handle a context event with proper locking and validation.""" | |
| async with self.update_lock: | |
| try: | |
| # Validate event | |
| if not self._validate_event(event): | |
| print(f"Invalid event: {event.event_id}") | |
| return | |
| # Apply event | |
| await self._apply_event(event) | |
| # Propagate to observers | |
| await self._propagate_event(event) | |
| # Record in history | |
| self.event_history.append(event) | |
| # Create snapshot if significant change | |
| if self._is_significant_change(event): | |
| self._create_snapshot(f"event_{event.event_id}") | |
| except Exception as e: | |
| print(f"Error handling event {event.event_id}: {e}") | |
| def _validate_event(self, event: ContextEvent) -> bool: | |
| """Validate that an event can be applied.""" | |
| # Check if event is not too old | |
| if datetime.now() - event.timestamp > timedelta(hours=1): | |
| return False | |
| # Check if event conflicts with existing constraints | |
| if event.event_type == ContextEventType.CONSTRAINT_ADDED: | |
| return self._validate_constraint_addition(event.event_data) | |
| return True | |
| def _validate_constraint_addition(self, constraint_data: Dict[str, Any]) -> bool: | |
| """Validate that a new constraint doesn't conflict with existing ones.""" | |
| # Check for hard constraint conflicts | |
| existing_hard_constraints = [ | |
| c for c in self.constraints.values() | |
| if c.severity == ConstraintSeverity.HARD and c.is_active | |
| ] | |
| # Add validation logic here based on constraint type | |
| return True | |
| async def _apply_event(self, event: ContextEvent): | |
| """Apply an event to the context state.""" | |
| if event.event_type == ContextEventType.BUDGET_UPDATE: | |
| await self._apply_budget_update(event) | |
| elif event.event_type == ContextEventType.PREFERENCE_CHANGE: | |
| await self._apply_preference_change(event) | |
| elif event.event_type == ContextEventType.CONSTRAINT_ADDED: | |
| await self._apply_constraint_addition(event) | |
| elif event.event_type == ContextEventType.CONSTRAINT_REMOVED: | |
| await self._apply_constraint_removal(event) | |
| elif event.event_type == ContextEventType.AGENT_DECISION: | |
| await self._apply_agent_decision(event) | |
| elif event.event_type == ContextEventType.PRICE_CHANGE: | |
| await self._apply_price_change(event) | |
| elif event.event_type == ContextEventType.AVAILABILITY_CHANGE: | |
| await self._apply_availability_change(event) | |
| elif event.event_type == ContextEventType.ROLLBACK_REQUEST: | |
| await self._apply_rollback(event) | |
| self.last_updated = datetime.now() | |
| self.version += 1 | |
| async def _apply_budget_update(self, event: ContextEvent): | |
| """Apply budget update event.""" | |
| budget_data = event.event_data.get("budget_data", {}) | |
| if "total_budget" in budget_data: | |
| self.context_data["total_budget"] = budget_data["total_budget"] | |
| if "remaining_budget" in budget_data: | |
| self.context_data["remaining_budget"] = budget_data["remaining_budget"] | |
| # Update budget constraints | |
| await self._update_budget_constraints() | |
| async def _apply_preference_change(self, event: ContextEvent): | |
| """Apply preference change event.""" | |
| preference_data = event.event_data.get("preference_data", {}) | |
| for key, value in preference_data.items(): | |
| self.context_data["preferences"][key] = value | |
| # Update preference constraints | |
| await self._update_preference_constraints() | |
| async def _apply_constraint_addition(self, event: ContextEvent): | |
| """Apply constraint addition event.""" | |
| constraint_data = event.event_data.get("constraint_data", {}) | |
| constraint = Constraint.from_dict(constraint_data) | |
| self.constraints[constraint.constraint_id] = constraint | |
| # Update dependency graph | |
| for target_agent in constraint.target_agents: | |
| self.dependency_graph[constraint.source_agent].add(target_agent) | |
| async def _apply_constraint_removal(self, event: ContextEvent): | |
| """Apply constraint removal event.""" | |
| constraint_id = event.event_data.get("constraint_id") | |
| if constraint_id in self.constraints: | |
| constraint = self.constraints[constraint_id] | |
| constraint.is_active = False | |
| # Update dependency graph | |
| for target_agent in constraint.target_agents: | |
| self.dependency_graph[constraint.source_agent].discard(target_agent) | |
| async def _apply_agent_decision(self, event: ContextEvent): | |
| """Apply agent decision event.""" | |
| agent_id = event.source_agent | |
| decision_data = event.event_data.get("decision_data", {}) | |
| # Update agent state | |
| if agent_id not in self.agent_states: | |
| self.agent_states[agent_id] = {} | |
| self.agent_states[agent_id].update(decision_data) | |
| self.agent_states[agent_id]["last_decision"] = event.timestamp.isoformat() | |
| # Update context data based on decision | |
| await self._update_context_from_decision(agent_id, decision_data) | |
| async def _apply_price_change(self, event: ContextEvent): | |
| """Apply real-time price change event.""" | |
| price_data = event.event_data.get("price_data", {}) | |
| # Update relevant agent states | |
| for agent_id, agent_state in self.agent_states.items(): | |
| if "selected_options" in agent_state: | |
| for option in agent_state["selected_options"]: | |
| if option.get("id") == price_data.get("option_id"): | |
| option["price"] = price_data["new_price"] | |
| option["price_updated_at"] = event.timestamp.isoformat() | |
| # Trigger constraint re-evaluation | |
| await self._reevaluate_price_constraints(price_data) | |
| async def _apply_availability_change(self, event: ContextEvent): | |
| """Apply availability change event.""" | |
| availability_data = event.event_data.get("availability_data", {}) | |
| # Update relevant agent states | |
| for agent_id, agent_state in self.agent_states.items(): | |
| if "selected_options" in agent_state: | |
| for option in agent_state["selected_options"]: | |
| if option.get("id") == availability_data.get("option_id"): | |
| option["available"] = availability_data["available"] | |
| option["availability_updated_at"] = event.timestamp.isoformat() | |
| # Trigger constraint re-evaluation | |
| await self._reevaluate_availability_constraints(availability_data) | |
| async def _apply_rollback(self, event: ContextEvent): | |
| """Apply rollback to a previous state.""" | |
| target_version = event.event_data.get("target_version") | |
| if target_version: | |
| await self.rollback_to_version(target_version) | |
| async def _update_context_from_decision(self, agent_id: str, decision_data: Dict[str, Any]): | |
| """Update context data based on agent decision.""" | |
| if agent_id == "flight_agent": | |
| if "selected_flight" in decision_data: | |
| flight = decision_data["selected_flight"] | |
| self.context_data["selected_flight"] = flight.to_dict() if hasattr(flight, 'to_dict') else flight | |
| # Update travel dates if not set | |
| if not self.context_data.get("travel_dates"): | |
| self.context_data["travel_dates"] = { | |
| "departure": flight.departure_time.isoformat() if hasattr(flight, 'departure_time') else None, | |
| "return": flight.arrival_time.isoformat() if hasattr(flight, 'arrival_time') else None | |
| } | |
| elif agent_id == "hotel_agent": | |
| if "selected_hotel" in decision_data: | |
| hotel = decision_data["selected_hotel"] | |
| self.context_data["selected_hotel"] = hotel.to_dict() if hasattr(hotel, 'to_dict') else hotel | |
| elif agent_id == "poi_agent": | |
| if "selected_pois" in decision_data: | |
| pois = decision_data["selected_pois"] | |
| self.context_data["selected_pois"] = [ | |
| poi.to_dict() if hasattr(poi, 'to_dict') else poi | |
| for poi in pois | |
| ] | |
| async def _update_budget_constraints(self): | |
| """Update budget-related constraints.""" | |
| remaining_budget = self.context_data.get("remaining_budget", 0) | |
| # Update existing budget constraints | |
| for constraint in self.constraints.values(): | |
| if constraint.constraint_type == ConstraintType.BUDGET and constraint.is_active: | |
| if "max_budget" in constraint.constraint_data: | |
| constraint.constraint_data["max_budget"] = remaining_budget * 0.8 # 80% of remaining | |
| async def _update_preference_constraints(self): | |
| """Update preference-related constraints.""" | |
| preferences = self.context_data.get("preferences", {}) | |
| # Update existing preference constraints | |
| for constraint in self.constraints.values(): | |
| if constraint.constraint_type == ConstraintType.PREFERENCE and constraint.is_active: | |
| constraint.constraint_data.update(preferences) | |
| async def _reevaluate_price_constraints(self, price_data: Dict[str, Any]): | |
| """Re-evaluate constraints after price changes.""" | |
| affected_agents = set() | |
| # Find agents affected by price change | |
| for agent_id, agent_state in self.agent_states.items(): | |
| if "selected_options" in agent_state: | |
| for option in agent_state["selected_options"]: | |
| if option.get("id") == price_data.get("option_id"): | |
| affected_agents.add(agent_id) | |
| # Trigger constraint re-evaluation for affected agents | |
| for agent_id in affected_agents: | |
| await self._notify_constraint_change(agent_id, "price_change") | |
| async def _reevaluate_availability_constraints(self, availability_data: Dict[str, Any]): | |
| """Re-evaluate constraints after availability changes.""" | |
| affected_agents = set() | |
| # Find agents affected by availability change | |
| for agent_id, agent_state in self.agent_states.items(): | |
| if "selected_options" in agent_state: | |
| for option in agent_state["selected_options"]: | |
| if option.get("id") == availability_data.get("option_id"): | |
| affected_agents.add(agent_id) | |
| # Trigger constraint re-evaluation for affected agents | |
| for agent_id in affected_agents: | |
| await self._notify_constraint_change(agent_id, "availability_change") | |
| async def _notify_constraint_change(self, agent_id: str, change_type: str): | |
| """Notify an agent about constraint changes.""" | |
| if agent_id in self.observers: | |
| observer = self.observers[agent_id] | |
| event = ContextEvent( | |
| event_type=ContextEventType.CONSTRAINT_ADDED, | |
| source_agent="system", | |
| target_agents=[agent_id], | |
| event_data={"change_type": change_type} | |
| ) | |
| await observer.on_context_update(event, self.context_data) | |
| async def _propagate_event(self, event: ContextEvent): | |
| """Propagate event to relevant observers.""" | |
| # Determine which agents should receive the event | |
| target_agents = event.target_agents or [] | |
| # Add agents based on propagation rules | |
| if event.source_agent in self.propagation_rules: | |
| target_agents.extend(self.propagation_rules[event.source_agent]) | |
| # Add agents based on dependency graph | |
| if event.source_agent in self.dependency_graph: | |
| target_agents.extend(self.dependency_graph[event.source_agent]) | |
| # Remove duplicates | |
| target_agents = list(set(target_agents)) | |
| # Notify observers | |
| for agent_id in target_agents: | |
| if agent_id in self.observers: | |
| observer = self.observers[agent_id] | |
| try: | |
| await observer.on_context_update(event, self.context_data) | |
| except Exception as e: | |
| print(f"Error notifying observer {agent_id}: {e}") | |
| def _is_significant_change(self, event: ContextEvent) -> bool: | |
| """Determine if an event represents a significant change requiring a snapshot.""" | |
| significant_events = { | |
| ContextEventType.BUDGET_UPDATE, | |
| ContextEventType.PREFERENCE_CHANGE, | |
| ContextEventType.CONSTRAINT_ADDED, | |
| ContextEventType.CONSTRAINT_REMOVED, | |
| ContextEventType.AGENT_DECISION | |
| } | |
| return event.event_type in significant_events | |
| def _create_snapshot(self, reason: str): | |
| """Create a versioned snapshot of the current state.""" | |
| snapshot = ContextSnapshot( | |
| version=self.version, | |
| context_data=self.context_data.copy(), | |
| constraints=[c for c in self.constraints.values() if c.is_active], | |
| agent_states=self.agent_states.copy() | |
| ) | |
| self.snapshots.append(snapshot) | |
| # Trigger persistence if enabled | |
| if self.persistence_enabled: | |
| # Schedule persistence task without blocking | |
| try: | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| loop.create_task(self._persist_snapshot(snapshot)) | |
| else: | |
| # If no event loop is running, just log it | |
| print(f"Scheduling snapshot persistence for {snapshot.snapshot_id}") | |
| except RuntimeError: | |
| # No event loop available, just log it | |
| print(f"Scheduling snapshot persistence for {snapshot.snapshot_id}") | |
| async def _persist_snapshot(self, snapshot: ContextSnapshot): | |
| """Persist snapshot to storage.""" | |
| try: | |
| # In a real implementation, this would save to database/file | |
| # For now, we'll just log it | |
| print(f"Persisting snapshot {snapshot.snapshot_id} (version {snapshot.version})") | |
| except Exception as e: | |
| print(f"Error persisting snapshot: {e}") | |
| # Public API Methods | |
| async def update_context(self, event: ContextEvent) -> bool: | |
| """Update context with an event.""" | |
| try: | |
| await self.update_queue.put(event) | |
| return True | |
| except Exception as e: | |
| print(f"Error queuing update: {e}") | |
| return False | |
| async def add_constraint(self, constraint: Constraint) -> bool: | |
| """Add a new constraint to the context.""" | |
| event = ContextEvent( | |
| event_type=ContextEventType.CONSTRAINT_ADDED, | |
| source_agent=constraint.source_agent, | |
| target_agents=constraint.target_agents, | |
| event_data={"constraint_data": constraint.to_dict()}, | |
| priority=1 if constraint.severity == ConstraintSeverity.HARD else 3 | |
| ) | |
| return await self.update_context(event) | |
| async def remove_constraint(self, constraint_id: str, source_agent: str) -> bool: | |
| """Remove a constraint from the context.""" | |
| if constraint_id not in self.constraints: | |
| return False | |
| event = ContextEvent( | |
| event_type=ContextEventType.CONSTRAINT_REMOVED, | |
| source_agent=source_agent, | |
| event_data={"constraint_id": constraint_id} | |
| ) | |
| return await self.update_context(event) | |
| def get_constraints_for_agent(self, agent_id: str, constraint_type: Optional[ConstraintType] = None) -> List[Constraint]: | |
| """Get constraints relevant to a specific agent.""" | |
| relevant_constraints = [] | |
| for constraint in self.constraints.values(): | |
| if not constraint.is_active: | |
| continue | |
| if agent_id in constraint.target_agents or constraint.source_agent == agent_id: | |
| if constraint_type is None or constraint.constraint_type == constraint_type: | |
| relevant_constraints.append(constraint) | |
| return relevant_constraints | |
| def validate_agent_decision(self, agent_id: str, decision_data: Dict[str, Any]) -> Tuple[bool, List[str]]: | |
| """Validate an agent's decision against current constraints.""" | |
| violations = [] | |
| constraints = self.get_constraints_for_agent(agent_id) | |
| for constraint in constraints: | |
| if constraint.severity == ConstraintSeverity.HARD: | |
| if not self._check_constraint_violation(constraint, decision_data): | |
| violations.append(f"Hard constraint violated: {constraint.constraint_id}") | |
| return len(violations) == 0, violations | |
| def _check_constraint_violation(self, constraint: Constraint, decision_data: Dict[str, Any]) -> bool: | |
| """Check if decision violates a specific constraint.""" | |
| if constraint.constraint_type == ConstraintType.BUDGET: | |
| cost = decision_data.get("cost", 0) | |
| max_budget = constraint.constraint_data.get("max_budget", float('inf')) | |
| return cost <= max_budget | |
| elif constraint.constraint_type == ConstraintType.TIME: | |
| # Add time constraint validation logic | |
| return True | |
| elif constraint.constraint_type == ConstraintType.LOCATION: | |
| # Add location constraint validation logic | |
| return True | |
| return True | |
| async def rollback_to_version(self, target_version: int) -> bool: | |
| """Rollback context to a previous version.""" | |
| try: | |
| # Find target snapshot | |
| target_snapshot = None | |
| for snapshot in reversed(self.snapshots): | |
| if snapshot.version <= target_version: | |
| target_snapshot = snapshot | |
| break | |
| if not target_snapshot: | |
| return False | |
| # Restore state | |
| self.context_data = target_snapshot.context_data.copy() | |
| self.agent_states = target_snapshot.agent_states.copy() | |
| # Restore constraints | |
| self.constraints.clear() | |
| for constraint_data in target_snapshot.constraints: | |
| constraint = Constraint.from_dict(constraint_data.to_dict()) | |
| self.constraints[constraint.constraint_id] = constraint | |
| # Update version | |
| self.version = target_snapshot.version | |
| self.last_updated = datetime.now() | |
| # Create rollback event | |
| event = ContextEvent( | |
| event_type=ContextEventType.ROLLBACK_REQUEST, | |
| source_agent="system", | |
| event_data={"target_version": target_version, "rollback_reason": "user_request"} | |
| ) | |
| await self._propagate_event(event) | |
| return True | |
| except Exception as e: | |
| print(f"Error during rollback: {e}") | |
| return False | |
| def register_observer(self, agent_id: str, observer: ContextObserver): | |
| """Register an agent as a context observer.""" | |
| self.observers[agent_id] = observer | |
| def unregister_observer(self, agent_id: str): | |
| """Unregister an agent observer.""" | |
| if agent_id in self.observers: | |
| del self.observers[agent_id] | |
| def set_propagation_rule(self, source_agent: str, target_agents: List[str]): | |
| """Set propagation rules for context updates.""" | |
| self.propagation_rules[source_agent] = target_agents | |
| async def save_state(self) -> bool: | |
| """Save current state to persistent storage.""" | |
| try: | |
| if not self.persistence_enabled: | |
| return False | |
| # Create current snapshot | |
| snapshot = ContextSnapshot( | |
| version=self.version, | |
| context_data=self.context_data, | |
| constraints=list(self.constraints.values()), | |
| agent_states=self.agent_states | |
| ) | |
| # In a real implementation, save to database/file | |
| # For now, we'll just log it | |
| print(f"Saving state for trip {self.trip_id} (version {self.version})") | |
| return True | |
| except Exception as e: | |
| print(f"Error saving state: {e}") | |
| return False | |
| async def load_state(self, trip_id: str) -> bool: | |
| """Load state from persistent storage.""" | |
| try: | |
| # In a real implementation, load from database/file | |
| # For now, we'll just log it | |
| print(f"Loading state for trip {trip_id}") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading state: {e}") | |
| return False | |
| def get_context_summary(self) -> Dict[str, Any]: | |
| """Get a summary of the current context state.""" | |
| return { | |
| "trip_id": self.trip_id, | |
| "version": self.version, | |
| "created_at": self.created_at.isoformat(), | |
| "last_updated": self.last_updated.isoformat(), | |
| "total_budget": self.context_data.get("total_budget", 0), | |
| "remaining_budget": self.context_data.get("remaining_budget", 0), | |
| "active_constraints": len([c for c in self.constraints.values() if c.is_active]), | |
| "registered_agents": list(self.observers.keys()), | |
| "snapshots_count": len(self.snapshots), | |
| "events_count": len(self.event_history) | |
| } | |
| async def cleanup(self): | |
| """Cleanup resources and stop background tasks.""" | |
| # Cancel background tasks | |
| for task in self._background_tasks: | |
| task.cancel() | |
| # Wait for tasks to complete | |
| await asyncio.gather(*self._background_tasks, return_exceptions=True) | |
| # Save final state | |
| if self.persistence_enabled: | |
| await self.save_state() | |
| class DistributedStateManager: | |
| """ | |
| Manages multiple TripContext instances and provides coordination between them. | |
| This class implements patterns for: | |
| - Multi-trip coordination | |
| - Cross-trip constraint propagation | |
| - Global state synchronization | |
| - Resource management | |
| """ | |
| def __init__(self): | |
| self.trip_contexts: Dict[str, TripContext] = {} | |
| self.global_constraints: Dict[str, Constraint] = {} | |
| self.coordination_rules: Dict[str, List[str]] = {} | |
| self.global_observers: List[ContextObserver] = [] | |
| async def create_trip_context(self, trip_id: str, initial_budget: Decimal = Decimal('2000.00')) -> TripContext: | |
| """Create a new trip context.""" | |
| context = TripContext(trip_id, initial_budget) | |
| self.trip_contexts[trip_id] = context | |
| # Apply global constraints | |
| for constraint in self.global_constraints.values(): | |
| await context.add_constraint(constraint) | |
| return context | |
| async def remove_trip_context(self, trip_id: str) -> bool: | |
| """Remove a trip context and cleanup resources.""" | |
| if trip_id not in self.trip_contexts: | |
| return False | |
| context = self.trip_contexts[trip_id] | |
| await context.cleanup() | |
| del self.trip_contexts[trip_id] | |
| return True | |
| def get_trip_context(self, trip_id: str) -> Optional[TripContext]: | |
| """Get a trip context by ID.""" | |
| return self.trip_contexts.get(trip_id) | |
| async def add_global_constraint(self, constraint: Constraint) -> bool: | |
| """Add a constraint that applies to all trips.""" | |
| self.global_constraints[constraint.constraint_id] = constraint | |
| # Apply to all existing contexts | |
| for context in self.trip_contexts.values(): | |
| await context.add_constraint(constraint) | |
| return True | |
| def set_coordination_rule(self, source_trip: str, target_trips: List[str]): | |
| """Set coordination rules between trips.""" | |
| self.coordination_rules[source_trip] = target_trips | |
| async def synchronize_contexts(self, trip_id: str, event: ContextEvent): | |
| """Synchronize context updates across coordinated trips.""" | |
| if trip_id in self.coordination_rules: | |
| target_trips = self.coordination_rules[trip_id] | |
| for target_trip_id in target_trips: | |
| if target_trip_id in self.trip_contexts: | |
| target_context = self.trip_contexts[target_trip_id] | |
| await target_context.update_context(event) | |
| def get_global_summary(self) -> Dict[str, Any]: | |
| """Get a summary of all trip contexts.""" | |
| return { | |
| "active_trips": len(self.trip_contexts), | |
| "global_constraints": len(self.global_constraints), | |
| "coordination_rules": len(self.coordination_rules), | |
| "trip_summaries": { | |
| trip_id: context.get_context_summary() | |
| for trip_id, context in self.trip_contexts.items() | |
| } | |
| } | |