wanderlust.ai / src /wanderlust_ai /core /distributed_state.py
BlakeL's picture
Upload 115 files
3f9f85b verified
"""
Distributed State Management for AI Travel Agents
This module implements sophisticated distributed state management patterns for AI systems,
enabling agents to share context, maintain consistency, and handle real-time updates
like experienced travel planners working together.
Key Features:
1. Shared TripContext with versioning and rollback
2. Context propagation with dependency tracking
3. Constraint satisfaction and validation
4. Real-time updates with conflict resolution
5. Persistent state for session recovery
6. Event-driven architecture for scalability
"""
import asyncio
import json
import uuid
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from enum import Enum
from typing import Dict, List, Optional, Any, Set, Tuple, Union, Callable
from dataclasses import dataclass, field, asdict
from decimal import Decimal
import threading
from collections import defaultdict, deque
import pickle
import hashlib
from pydantic import BaseModel, Field, validator
from ..models.flight_models import FlightOption
from ..models.hotel_models import HotelOption
from ..models.poi_models import POI
class ContextEventType(str, Enum):
"""Types of context events that can trigger updates."""
BUDGET_UPDATE = "budget_update"
PREFERENCE_CHANGE = "preference_change"
CONSTRAINT_ADDED = "constraint_added"
CONSTRAINT_REMOVED = "constraint_removed"
AGENT_DECISION = "agent_decision"
AVAILABILITY_CHANGE = "availability_change"
PRICE_CHANGE = "price_change"
PLAN_CHANGE = "plan_change"
ROLLBACK_REQUEST = "rollback_request"
class ConstraintType(str, Enum):
"""Types of constraints that can be enforced."""
BUDGET = "budget"
TIME = "time"
LOCATION = "location"
AVAILABILITY = "availability"
PREFERENCE = "preference"
DEPENDENCY = "dependency"
QUALITY = "quality"
FLEXIBILITY = "flexibility"
class ConstraintSeverity(str, Enum):
"""Severity levels for constraints."""
HARD = "hard" # Must be satisfied
SOFT = "soft" # Preferred but not required
WARNING = "warning" # Advisory only
@dataclass
class Constraint:
"""Represents a constraint in the trip planning context."""
constraint_id: str
constraint_type: ConstraintType
severity: ConstraintSeverity
source_agent: str
target_agents: List[str]
constraint_data: Dict[str, Any]
created_at: datetime = field(default_factory=datetime.now)
expires_at: Optional[datetime] = None
version: int = 1
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
"""Convert constraint to dictionary for serialization."""
return {
"constraint_id": self.constraint_id,
"constraint_type": self.constraint_type.value,
"severity": self.severity.value,
"source_agent": self.source_agent,
"target_agents": self.target_agents,
"constraint_data": self.constraint_data,
"created_at": self.created_at.isoformat(),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"version": self.version,
"is_active": self.is_active
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Constraint':
"""Create constraint from dictionary."""
return cls(
constraint_id=data["constraint_id"],
constraint_type=ConstraintType(data["constraint_type"]),
severity=ConstraintSeverity(data["severity"]),
source_agent=data["source_agent"],
target_agents=data["target_agents"],
constraint_data=data["constraint_data"],
created_at=datetime.fromisoformat(data["created_at"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
version=data["version"],
is_active=data["is_active"]
)
@dataclass
class ContextEvent:
"""Represents an event that triggers context updates."""
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
event_type: ContextEventType = ContextEventType.PLAN_CHANGE
source_agent: str = "system"
target_agents: List[str] = field(default_factory=list)
event_data: Dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.now)
priority: int = 1 # 1 = highest, 5 = lowest
requires_acknowledgment: bool = False
correlation_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert event to dictionary for serialization."""
return {
"event_id": self.event_id,
"event_type": self.event_type.value,
"source_agent": self.source_agent,
"target_agents": self.target_agents,
"event_data": self.event_data,
"timestamp": self.timestamp.isoformat(),
"priority": self.priority,
"requires_acknowledgment": self.requires_acknowledgment,
"correlation_id": self.correlation_id
}
@dataclass
class ContextSnapshot:
"""Represents a versioned snapshot of the trip context."""
snapshot_id: str = field(default_factory=lambda: str(uuid.uuid4()))
version: int = 1
timestamp: datetime = field(default_factory=datetime.now)
context_data: Dict[str, Any] = field(default_factory=dict)
constraints: List[Constraint] = field(default_factory=list)
agent_states: Dict[str, Dict[str, Any]] = field(default_factory=dict)
checksum: str = ""
def __post_init__(self):
"""Calculate checksum after initialization."""
if not self.checksum:
self.checksum = self._calculate_checksum()
def _calculate_checksum(self) -> str:
"""Calculate checksum for data integrity verification."""
data_str = json.dumps({
"version": self.version,
"context_data": self.context_data,
"constraints": [c.to_dict() for c in self.constraints],
"agent_states": self.agent_states
}, sort_keys=True, default=str)
return hashlib.md5(data_str.encode()).hexdigest()
def to_dict(self) -> Dict[str, Any]:
"""Convert snapshot to dictionary for serialization."""
return {
"snapshot_id": self.snapshot_id,
"version": self.version,
"timestamp": self.timestamp.isoformat(),
"context_data": self.context_data,
"constraints": [c.to_dict() for c in self.constraints],
"agent_states": self.agent_states,
"checksum": self.checksum
}
class ContextObserver(ABC):
"""Abstract base class for context observers."""
@abstractmethod
async def on_context_update(self, event: ContextEvent, context_data: Dict[str, Any]) -> None:
"""Handle context update events."""
pass
@abstractmethod
async def on_constraint_added(self, constraint: Constraint) -> None:
"""Handle new constraint addition."""
pass
@abstractmethod
async def on_constraint_removed(self, constraint_id: str) -> None:
"""Handle constraint removal."""
pass
class TripContext:
"""
Distributed trip context that maintains shared state across all agents.
This class implements sophisticated state management patterns including:
- Versioned state snapshots for rollback
- Event-driven updates with propagation
- Constraint satisfaction and validation
- Real-time updates with conflict resolution
- Persistent state for session recovery
"""
def __init__(self, trip_id: str, initial_budget: Decimal = Decimal('2000.00')):
self.trip_id = trip_id
self.version = 1
self.created_at = datetime.now()
self.last_updated = datetime.now()
# Core context data
self.context_data: Dict[str, Any] = {
"trip_id": trip_id,
"total_budget": float(initial_budget),
"remaining_budget": float(initial_budget),
"trip_type": "leisure",
"priority": "experience_first",
"origin": "",
"destination": "",
"travel_dates": {},
"passengers": 1,
"guests": 1,
"interests": [],
"preferences": {},
"special_requirements": []
}
# State management
self.constraints: Dict[str, Constraint] = {}
self.agent_states: Dict[str, Dict[str, Any]] = {}
self.snapshots: deque = deque(maxlen=50) # Keep last 50 snapshots
self.event_history: deque = deque(maxlen=1000) # Keep last 1000 events
# Observers and propagation
self.observers: Dict[str, ContextObserver] = {}
self.propagation_rules: Dict[str, List[str]] = {}
self.dependency_graph: Dict[str, Set[str]] = defaultdict(set)
# Real-time updates
self.update_queue: asyncio.Queue = asyncio.Queue()
self.update_lock = asyncio.Lock()
self.pending_updates: Set[str] = set()
# Persistence
self.persistence_enabled = True
self.auto_save_interval = 30 # seconds
self._persistence_timer: Optional[asyncio.Task] = None
# Create initial snapshot
self._create_snapshot("initial_state")
# Start background tasks
self._start_background_tasks()
def _start_background_tasks(self):
"""Start background tasks for state management."""
if not hasattr(self, '_background_tasks'):
self._background_tasks = set()
# Only start background tasks if we're in an async context
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Start update processor
task = asyncio.create_task(self._process_updates())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# Start auto-save if persistence enabled
if self.persistence_enabled:
task = asyncio.create_task(self._auto_save_loop())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
else:
print("No running event loop - background tasks will be started when async context is available")
except RuntimeError:
print("No event loop available - background tasks will be started when async context is available")
async def _process_updates(self):
"""Process queued context updates."""
while True:
try:
event = await self.update_queue.get()
await self._handle_event(event)
self.update_queue.task_done()
except Exception as e:
print(f"Error processing update: {e}")
await asyncio.sleep(1)
async def _auto_save_loop(self):
"""Auto-save context state at regular intervals."""
while True:
try:
await asyncio.sleep(self.auto_save_interval)
await self.save_state()
except Exception as e:
print(f"Error in auto-save: {e}")
async def _handle_event(self, event: ContextEvent):
"""Handle a context event with proper locking and validation."""
async with self.update_lock:
try:
# Validate event
if not self._validate_event(event):
print(f"Invalid event: {event.event_id}")
return
# Apply event
await self._apply_event(event)
# Propagate to observers
await self._propagate_event(event)
# Record in history
self.event_history.append(event)
# Create snapshot if significant change
if self._is_significant_change(event):
self._create_snapshot(f"event_{event.event_id}")
except Exception as e:
print(f"Error handling event {event.event_id}: {e}")
def _validate_event(self, event: ContextEvent) -> bool:
"""Validate that an event can be applied."""
# Check if event is not too old
if datetime.now() - event.timestamp > timedelta(hours=1):
return False
# Check if event conflicts with existing constraints
if event.event_type == ContextEventType.CONSTRAINT_ADDED:
return self._validate_constraint_addition(event.event_data)
return True
def _validate_constraint_addition(self, constraint_data: Dict[str, Any]) -> bool:
"""Validate that a new constraint doesn't conflict with existing ones."""
# Check for hard constraint conflicts
existing_hard_constraints = [
c for c in self.constraints.values()
if c.severity == ConstraintSeverity.HARD and c.is_active
]
# Add validation logic here based on constraint type
return True
async def _apply_event(self, event: ContextEvent):
"""Apply an event to the context state."""
if event.event_type == ContextEventType.BUDGET_UPDATE:
await self._apply_budget_update(event)
elif event.event_type == ContextEventType.PREFERENCE_CHANGE:
await self._apply_preference_change(event)
elif event.event_type == ContextEventType.CONSTRAINT_ADDED:
await self._apply_constraint_addition(event)
elif event.event_type == ContextEventType.CONSTRAINT_REMOVED:
await self._apply_constraint_removal(event)
elif event.event_type == ContextEventType.AGENT_DECISION:
await self._apply_agent_decision(event)
elif event.event_type == ContextEventType.PRICE_CHANGE:
await self._apply_price_change(event)
elif event.event_type == ContextEventType.AVAILABILITY_CHANGE:
await self._apply_availability_change(event)
elif event.event_type == ContextEventType.ROLLBACK_REQUEST:
await self._apply_rollback(event)
self.last_updated = datetime.now()
self.version += 1
async def _apply_budget_update(self, event: ContextEvent):
"""Apply budget update event."""
budget_data = event.event_data.get("budget_data", {})
if "total_budget" in budget_data:
self.context_data["total_budget"] = budget_data["total_budget"]
if "remaining_budget" in budget_data:
self.context_data["remaining_budget"] = budget_data["remaining_budget"]
# Update budget constraints
await self._update_budget_constraints()
async def _apply_preference_change(self, event: ContextEvent):
"""Apply preference change event."""
preference_data = event.event_data.get("preference_data", {})
for key, value in preference_data.items():
self.context_data["preferences"][key] = value
# Update preference constraints
await self._update_preference_constraints()
async def _apply_constraint_addition(self, event: ContextEvent):
"""Apply constraint addition event."""
constraint_data = event.event_data.get("constraint_data", {})
constraint = Constraint.from_dict(constraint_data)
self.constraints[constraint.constraint_id] = constraint
# Update dependency graph
for target_agent in constraint.target_agents:
self.dependency_graph[constraint.source_agent].add(target_agent)
async def _apply_constraint_removal(self, event: ContextEvent):
"""Apply constraint removal event."""
constraint_id = event.event_data.get("constraint_id")
if constraint_id in self.constraints:
constraint = self.constraints[constraint_id]
constraint.is_active = False
# Update dependency graph
for target_agent in constraint.target_agents:
self.dependency_graph[constraint.source_agent].discard(target_agent)
async def _apply_agent_decision(self, event: ContextEvent):
"""Apply agent decision event."""
agent_id = event.source_agent
decision_data = event.event_data.get("decision_data", {})
# Update agent state
if agent_id not in self.agent_states:
self.agent_states[agent_id] = {}
self.agent_states[agent_id].update(decision_data)
self.agent_states[agent_id]["last_decision"] = event.timestamp.isoformat()
# Update context data based on decision
await self._update_context_from_decision(agent_id, decision_data)
async def _apply_price_change(self, event: ContextEvent):
"""Apply real-time price change event."""
price_data = event.event_data.get("price_data", {})
# Update relevant agent states
for agent_id, agent_state in self.agent_states.items():
if "selected_options" in agent_state:
for option in agent_state["selected_options"]:
if option.get("id") == price_data.get("option_id"):
option["price"] = price_data["new_price"]
option["price_updated_at"] = event.timestamp.isoformat()
# Trigger constraint re-evaluation
await self._reevaluate_price_constraints(price_data)
async def _apply_availability_change(self, event: ContextEvent):
"""Apply availability change event."""
availability_data = event.event_data.get("availability_data", {})
# Update relevant agent states
for agent_id, agent_state in self.agent_states.items():
if "selected_options" in agent_state:
for option in agent_state["selected_options"]:
if option.get("id") == availability_data.get("option_id"):
option["available"] = availability_data["available"]
option["availability_updated_at"] = event.timestamp.isoformat()
# Trigger constraint re-evaluation
await self._reevaluate_availability_constraints(availability_data)
async def _apply_rollback(self, event: ContextEvent):
"""Apply rollback to a previous state."""
target_version = event.event_data.get("target_version")
if target_version:
await self.rollback_to_version(target_version)
async def _update_context_from_decision(self, agent_id: str, decision_data: Dict[str, Any]):
"""Update context data based on agent decision."""
if agent_id == "flight_agent":
if "selected_flight" in decision_data:
flight = decision_data["selected_flight"]
self.context_data["selected_flight"] = flight.to_dict() if hasattr(flight, 'to_dict') else flight
# Update travel dates if not set
if not self.context_data.get("travel_dates"):
self.context_data["travel_dates"] = {
"departure": flight.departure_time.isoformat() if hasattr(flight, 'departure_time') else None,
"return": flight.arrival_time.isoformat() if hasattr(flight, 'arrival_time') else None
}
elif agent_id == "hotel_agent":
if "selected_hotel" in decision_data:
hotel = decision_data["selected_hotel"]
self.context_data["selected_hotel"] = hotel.to_dict() if hasattr(hotel, 'to_dict') else hotel
elif agent_id == "poi_agent":
if "selected_pois" in decision_data:
pois = decision_data["selected_pois"]
self.context_data["selected_pois"] = [
poi.to_dict() if hasattr(poi, 'to_dict') else poi
for poi in pois
]
async def _update_budget_constraints(self):
"""Update budget-related constraints."""
remaining_budget = self.context_data.get("remaining_budget", 0)
# Update existing budget constraints
for constraint in self.constraints.values():
if constraint.constraint_type == ConstraintType.BUDGET and constraint.is_active:
if "max_budget" in constraint.constraint_data:
constraint.constraint_data["max_budget"] = remaining_budget * 0.8 # 80% of remaining
async def _update_preference_constraints(self):
"""Update preference-related constraints."""
preferences = self.context_data.get("preferences", {})
# Update existing preference constraints
for constraint in self.constraints.values():
if constraint.constraint_type == ConstraintType.PREFERENCE and constraint.is_active:
constraint.constraint_data.update(preferences)
async def _reevaluate_price_constraints(self, price_data: Dict[str, Any]):
"""Re-evaluate constraints after price changes."""
affected_agents = set()
# Find agents affected by price change
for agent_id, agent_state in self.agent_states.items():
if "selected_options" in agent_state:
for option in agent_state["selected_options"]:
if option.get("id") == price_data.get("option_id"):
affected_agents.add(agent_id)
# Trigger constraint re-evaluation for affected agents
for agent_id in affected_agents:
await self._notify_constraint_change(agent_id, "price_change")
async def _reevaluate_availability_constraints(self, availability_data: Dict[str, Any]):
"""Re-evaluate constraints after availability changes."""
affected_agents = set()
# Find agents affected by availability change
for agent_id, agent_state in self.agent_states.items():
if "selected_options" in agent_state:
for option in agent_state["selected_options"]:
if option.get("id") == availability_data.get("option_id"):
affected_agents.add(agent_id)
# Trigger constraint re-evaluation for affected agents
for agent_id in affected_agents:
await self._notify_constraint_change(agent_id, "availability_change")
async def _notify_constraint_change(self, agent_id: str, change_type: str):
"""Notify an agent about constraint changes."""
if agent_id in self.observers:
observer = self.observers[agent_id]
event = ContextEvent(
event_type=ContextEventType.CONSTRAINT_ADDED,
source_agent="system",
target_agents=[agent_id],
event_data={"change_type": change_type}
)
await observer.on_context_update(event, self.context_data)
async def _propagate_event(self, event: ContextEvent):
"""Propagate event to relevant observers."""
# Determine which agents should receive the event
target_agents = event.target_agents or []
# Add agents based on propagation rules
if event.source_agent in self.propagation_rules:
target_agents.extend(self.propagation_rules[event.source_agent])
# Add agents based on dependency graph
if event.source_agent in self.dependency_graph:
target_agents.extend(self.dependency_graph[event.source_agent])
# Remove duplicates
target_agents = list(set(target_agents))
# Notify observers
for agent_id in target_agents:
if agent_id in self.observers:
observer = self.observers[agent_id]
try:
await observer.on_context_update(event, self.context_data)
except Exception as e:
print(f"Error notifying observer {agent_id}: {e}")
def _is_significant_change(self, event: ContextEvent) -> bool:
"""Determine if an event represents a significant change requiring a snapshot."""
significant_events = {
ContextEventType.BUDGET_UPDATE,
ContextEventType.PREFERENCE_CHANGE,
ContextEventType.CONSTRAINT_ADDED,
ContextEventType.CONSTRAINT_REMOVED,
ContextEventType.AGENT_DECISION
}
return event.event_type in significant_events
def _create_snapshot(self, reason: str):
"""Create a versioned snapshot of the current state."""
snapshot = ContextSnapshot(
version=self.version,
context_data=self.context_data.copy(),
constraints=[c for c in self.constraints.values() if c.is_active],
agent_states=self.agent_states.copy()
)
self.snapshots.append(snapshot)
# Trigger persistence if enabled
if self.persistence_enabled:
# Schedule persistence task without blocking
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self._persist_snapshot(snapshot))
else:
# If no event loop is running, just log it
print(f"Scheduling snapshot persistence for {snapshot.snapshot_id}")
except RuntimeError:
# No event loop available, just log it
print(f"Scheduling snapshot persistence for {snapshot.snapshot_id}")
async def _persist_snapshot(self, snapshot: ContextSnapshot):
"""Persist snapshot to storage."""
try:
# In a real implementation, this would save to database/file
# For now, we'll just log it
print(f"Persisting snapshot {snapshot.snapshot_id} (version {snapshot.version})")
except Exception as e:
print(f"Error persisting snapshot: {e}")
# Public API Methods
async def update_context(self, event: ContextEvent) -> bool:
"""Update context with an event."""
try:
await self.update_queue.put(event)
return True
except Exception as e:
print(f"Error queuing update: {e}")
return False
async def add_constraint(self, constraint: Constraint) -> bool:
"""Add a new constraint to the context."""
event = ContextEvent(
event_type=ContextEventType.CONSTRAINT_ADDED,
source_agent=constraint.source_agent,
target_agents=constraint.target_agents,
event_data={"constraint_data": constraint.to_dict()},
priority=1 if constraint.severity == ConstraintSeverity.HARD else 3
)
return await self.update_context(event)
async def remove_constraint(self, constraint_id: str, source_agent: str) -> bool:
"""Remove a constraint from the context."""
if constraint_id not in self.constraints:
return False
event = ContextEvent(
event_type=ContextEventType.CONSTRAINT_REMOVED,
source_agent=source_agent,
event_data={"constraint_id": constraint_id}
)
return await self.update_context(event)
def get_constraints_for_agent(self, agent_id: str, constraint_type: Optional[ConstraintType] = None) -> List[Constraint]:
"""Get constraints relevant to a specific agent."""
relevant_constraints = []
for constraint in self.constraints.values():
if not constraint.is_active:
continue
if agent_id in constraint.target_agents or constraint.source_agent == agent_id:
if constraint_type is None or constraint.constraint_type == constraint_type:
relevant_constraints.append(constraint)
return relevant_constraints
def validate_agent_decision(self, agent_id: str, decision_data: Dict[str, Any]) -> Tuple[bool, List[str]]:
"""Validate an agent's decision against current constraints."""
violations = []
constraints = self.get_constraints_for_agent(agent_id)
for constraint in constraints:
if constraint.severity == ConstraintSeverity.HARD:
if not self._check_constraint_violation(constraint, decision_data):
violations.append(f"Hard constraint violated: {constraint.constraint_id}")
return len(violations) == 0, violations
def _check_constraint_violation(self, constraint: Constraint, decision_data: Dict[str, Any]) -> bool:
"""Check if decision violates a specific constraint."""
if constraint.constraint_type == ConstraintType.BUDGET:
cost = decision_data.get("cost", 0)
max_budget = constraint.constraint_data.get("max_budget", float('inf'))
return cost <= max_budget
elif constraint.constraint_type == ConstraintType.TIME:
# Add time constraint validation logic
return True
elif constraint.constraint_type == ConstraintType.LOCATION:
# Add location constraint validation logic
return True
return True
async def rollback_to_version(self, target_version: int) -> bool:
"""Rollback context to a previous version."""
try:
# Find target snapshot
target_snapshot = None
for snapshot in reversed(self.snapshots):
if snapshot.version <= target_version:
target_snapshot = snapshot
break
if not target_snapshot:
return False
# Restore state
self.context_data = target_snapshot.context_data.copy()
self.agent_states = target_snapshot.agent_states.copy()
# Restore constraints
self.constraints.clear()
for constraint_data in target_snapshot.constraints:
constraint = Constraint.from_dict(constraint_data.to_dict())
self.constraints[constraint.constraint_id] = constraint
# Update version
self.version = target_snapshot.version
self.last_updated = datetime.now()
# Create rollback event
event = ContextEvent(
event_type=ContextEventType.ROLLBACK_REQUEST,
source_agent="system",
event_data={"target_version": target_version, "rollback_reason": "user_request"}
)
await self._propagate_event(event)
return True
except Exception as e:
print(f"Error during rollback: {e}")
return False
def register_observer(self, agent_id: str, observer: ContextObserver):
"""Register an agent as a context observer."""
self.observers[agent_id] = observer
def unregister_observer(self, agent_id: str):
"""Unregister an agent observer."""
if agent_id in self.observers:
del self.observers[agent_id]
def set_propagation_rule(self, source_agent: str, target_agents: List[str]):
"""Set propagation rules for context updates."""
self.propagation_rules[source_agent] = target_agents
async def save_state(self) -> bool:
"""Save current state to persistent storage."""
try:
if not self.persistence_enabled:
return False
# Create current snapshot
snapshot = ContextSnapshot(
version=self.version,
context_data=self.context_data,
constraints=list(self.constraints.values()),
agent_states=self.agent_states
)
# In a real implementation, save to database/file
# For now, we'll just log it
print(f"Saving state for trip {self.trip_id} (version {self.version})")
return True
except Exception as e:
print(f"Error saving state: {e}")
return False
async def load_state(self, trip_id: str) -> bool:
"""Load state from persistent storage."""
try:
# In a real implementation, load from database/file
# For now, we'll just log it
print(f"Loading state for trip {trip_id}")
return True
except Exception as e:
print(f"Error loading state: {e}")
return False
def get_context_summary(self) -> Dict[str, Any]:
"""Get a summary of the current context state."""
return {
"trip_id": self.trip_id,
"version": self.version,
"created_at": self.created_at.isoformat(),
"last_updated": self.last_updated.isoformat(),
"total_budget": self.context_data.get("total_budget", 0),
"remaining_budget": self.context_data.get("remaining_budget", 0),
"active_constraints": len([c for c in self.constraints.values() if c.is_active]),
"registered_agents": list(self.observers.keys()),
"snapshots_count": len(self.snapshots),
"events_count": len(self.event_history)
}
async def cleanup(self):
"""Cleanup resources and stop background tasks."""
# Cancel background tasks
for task in self._background_tasks:
task.cancel()
# Wait for tasks to complete
await asyncio.gather(*self._background_tasks, return_exceptions=True)
# Save final state
if self.persistence_enabled:
await self.save_state()
class DistributedStateManager:
"""
Manages multiple TripContext instances and provides coordination between them.
This class implements patterns for:
- Multi-trip coordination
- Cross-trip constraint propagation
- Global state synchronization
- Resource management
"""
def __init__(self):
self.trip_contexts: Dict[str, TripContext] = {}
self.global_constraints: Dict[str, Constraint] = {}
self.coordination_rules: Dict[str, List[str]] = {}
self.global_observers: List[ContextObserver] = []
async def create_trip_context(self, trip_id: str, initial_budget: Decimal = Decimal('2000.00')) -> TripContext:
"""Create a new trip context."""
context = TripContext(trip_id, initial_budget)
self.trip_contexts[trip_id] = context
# Apply global constraints
for constraint in self.global_constraints.values():
await context.add_constraint(constraint)
return context
async def remove_trip_context(self, trip_id: str) -> bool:
"""Remove a trip context and cleanup resources."""
if trip_id not in self.trip_contexts:
return False
context = self.trip_contexts[trip_id]
await context.cleanup()
del self.trip_contexts[trip_id]
return True
def get_trip_context(self, trip_id: str) -> Optional[TripContext]:
"""Get a trip context by ID."""
return self.trip_contexts.get(trip_id)
async def add_global_constraint(self, constraint: Constraint) -> bool:
"""Add a constraint that applies to all trips."""
self.global_constraints[constraint.constraint_id] = constraint
# Apply to all existing contexts
for context in self.trip_contexts.values():
await context.add_constraint(constraint)
return True
def set_coordination_rule(self, source_trip: str, target_trips: List[str]):
"""Set coordination rules between trips."""
self.coordination_rules[source_trip] = target_trips
async def synchronize_contexts(self, trip_id: str, event: ContextEvent):
"""Synchronize context updates across coordinated trips."""
if trip_id in self.coordination_rules:
target_trips = self.coordination_rules[trip_id]
for target_trip_id in target_trips:
if target_trip_id in self.trip_contexts:
target_context = self.trip_contexts[target_trip_id]
await target_context.update_context(event)
def get_global_summary(self) -> Dict[str, Any]:
"""Get a summary of all trip contexts."""
return {
"active_trips": len(self.trip_contexts),
"global_constraints": len(self.global_constraints),
"coordination_rules": len(self.coordination_rules),
"trip_summaries": {
trip_id: context.get_context_summary()
for trip_id, context in self.trip_contexts.items()
}
}