wanderlust.ai / src /wanderlust_ai /agents /flight_agent.py
BlakeL's picture
Upload 115 files
3f9f85b verified
"""
Flight Agent - Specialized agent for flight search and booking.
This agent demonstrates inheritance from BaseAgent and shows how
specialized agents can extend the base functionality.
"""
import asyncio
from datetime import datetime, timezone
from typing import Dict, List, Any, Optional
from decimal import Decimal
from ..core.base_agent import BaseAgent
from ..models.flight_models import FlightOption
from ..utils.config import get_settings
class FlightAgent(BaseAgent):
"""
Specialized agent for flight search and booking.
This agent demonstrates inheritance by:
- Extending BaseAgent functionality
- Adding flight-specific methods
- Implementing abstract methods
- Adding specialized error handling
- Using flight-specific data models
"""
def __init__(self, api_key: Optional[str] = None):
"""
Initialize the Flight Agent.
Args:
api_key: API key for flight service (optional, will use settings if not provided)
"""
# Call parent constructor
super().__init__(
name="FlightAgent",
description="AI agent for flight search and booking",
base_url="https://api.flights.example.com", # Replace with real API
timeout=30,
max_retries=3,
retry_delay=1.0
)
# Flight-specific configuration
self.settings = get_settings()
self.api_key = api_key or getattr(self.settings, 'flight_api_key', None)
if not self.api_key:
self.logger.warning("No flight API key configured")
# Flight-specific state
self._search_cache: Dict[str, List[FlightOption]] = {}
self._booking_cache: Dict[str, Dict[str, Any]] = {}
# Flight-specific statistics
self._flight_stats = {
"searches_performed": 0,
"flights_found": 0,
"bookings_made": 0,
"cache_hits": 0,
"cache_misses": 0
}
async def _initialize_agent(self) -> None:
"""
Initialize flight-specific resources.
This method is called by the parent class during initialization.
"""
self.logger.info("Initializing Flight Agent...")
# Validate API key
if not self.api_key:
raise ValueError("Flight API key is required")
# Test API connectivity
try:
await self._test_api_connectivity()
self.logger.info("Flight API connectivity test passed")
except Exception as e:
self.logger.error(f"Flight API connectivity test failed: {e}")
raise ValueError(f"API connectivity test failed: {e}", self.name)
# Initialize flight-specific resources
self._search_cache.clear()
self._booking_cache.clear()
self.logger.info("Flight Agent initialized successfully")
async def _cleanup_agent(self) -> None:
"""
Cleanup flight-specific resources.
This method is called by the parent class during cleanup.
"""
self.logger.info("Cleaning up Flight Agent...")
# Save cache to persistent storage (if needed)
await self._save_cache()
# Clear caches
self._search_cache.clear()
self._booking_cache.clear()
self.logger.info("Flight Agent cleaned up successfully")
async def _test_api_connectivity(self) -> None:
"""Test connectivity to the flight API."""
try:
# Make a simple test request
response = await self.get("/health")
if response.status_code != 200:
raise RuntimeError(f"API health check failed: {response.status_code}", self.name)
except Exception as e:
raise RuntimeError(f"API connectivity test failed: {e}", self.name)
async def _save_cache(self) -> None:
"""Save cache to persistent storage."""
# In a real implementation, you might save to Redis or database
self.logger.debug("Saving flight search cache...")
# Implementation would go here
async def search_flights(
self,
departure_city: str,
arrival_city: str,
departure_date: datetime,
return_date: Optional[datetime] = None,
passengers: int = 1,
class_preference: str = "Economy"
) -> List[FlightOption]:
"""
Search for flights between two cities.
Args:
departure_city: Departure city
arrival_city: Arrival city
departure_date: Departure date
return_date: Return date (optional for one-way)
passengers: Number of passengers
class_preference: Travel class preference
Returns:
List of available flights
Raises:
APIError: If search fails
ValidationError: If input validation fails
"""
# Validate inputs
if not departure_city or not arrival_city:
raise ValueError("Departure and arrival cities are required", self.name)
if departure_date < datetime.now(timezone.utc):
raise ValueError("Departure date cannot be in the past", self.name)
if return_date and return_date <= departure_date:
raise ValueError("Return date must be after departure date", self.name)
# Create cache key
cache_key = f"{departure_city}_{arrival_city}_{departure_date.date()}_{passengers}_{class_preference}"
# Check cache first
if cache_key in self._search_cache:
self._flight_stats["cache_hits"] += 1
self.logger.info(f"Cache hit for flight search: {cache_key}")
return self._search_cache[cache_key]
self._flight_stats["cache_misses"] += 1
self._flight_stats["searches_performed"] += 1
try:
# Prepare search parameters
search_params = {
"departure_city": departure_city,
"arrival_city": arrival_city,
"departure_date": departure_date.isoformat(),
"passengers": passengers,
"class": class_preference
}
if return_date:
search_params["return_date"] = return_date.isoformat()
# Make API request
self.logger.info(f"Searching flights: {departure_city}{arrival_city}")
response = await self.post("/search", json=search_params)
# Parse response
flight_data = response.json()
flights = []
for flight_info in flight_data.get("flights", []):
try:
# Convert API response to FlightOption model
flight = FlightOption(
airline=flight_info["airline"],
flight_number=flight_info["flight_number"],
departure_city=flight_info["departure_city"],
arrival_city=flight_info["arrival_city"],
departure_time=flight_info["departure_time"],
arrival_time=flight_info["arrival_time"],
price=Decimal(str(flight_info["price"])),
duration_minutes=flight_info["duration_minutes"],
stops=flight_info["stops"]
)
flights.append(flight)
except Exception as e:
self.logger.warning(f"Failed to parse flight data: {e}")
continue
# Cache results
self._search_cache[cache_key] = flights
self._flight_stats["flights_found"] += len(flights)
self.logger.info(f"Found {len(flights)} flights")
return flights
except Exception as e:
self.logger.error(f"Flight search failed: {e}")
raise RuntimeError(f"Flight search failed: {e}", self.name)
async def get_flight_details(self, flight_id: str) -> Optional[FlightOption]:
"""
Get detailed information about a specific flight.
Args:
flight_id: Unique flight identifier
Returns:
Flight details or None if not found
"""
try:
response = await self.get(f"/flights/{flight_id}")
if response.status_code == 404:
return None
flight_data = response.json()
return FlightOption(
airline=flight_data["airline"],
flight_number=flight_data["flight_number"],
departure_city=flight_data["departure_city"],
arrival_city=flight_data["arrival_city"],
departure_time=flight_data["departure_time"],
arrival_time=flight_data["arrival_time"],
price=Decimal(str(flight_data["price"])),
duration_minutes=flight_data["duration_minutes"],
stops=flight_data["stops"]
)
except Exception as e:
self.logger.error(f"Failed to get flight details: {e}")
raise RuntimeError(f"Failed to get flight details: {e}", self.name)
async def book_flight(
self,
flight_id: str,
passenger_info: Dict[str, Any],
payment_info: Dict[str, Any]
) -> Dict[str, Any]:
"""
Book a flight.
Args:
flight_id: Flight to book
passenger_info: Passenger information
payment_info: Payment information
Returns:
Booking confirmation
"""
try:
booking_data = {
"flight_id": flight_id,
"passenger_info": passenger_info,
"payment_info": payment_info
}
self.logger.info(f"Booking flight: {flight_id}")
response = await self.post("/bookings", json=booking_data)
booking_confirmation = response.json()
# Cache booking
self._booking_cache[booking_confirmation["booking_id"]] = booking_confirmation
self._flight_stats["bookings_made"] += 1
self.logger.info(f"Flight booked successfully: {booking_confirmation['booking_id']}")
return booking_confirmation
except Exception as e:
self.logger.error(f"Flight booking failed: {e}")
raise RuntimeError(f"Flight booking failed: {e}", self.name)
async def process_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a flight-related request.
This method implements the abstract method from BaseAgent.
Args:
request_data: Request data containing action and parameters
Returns:
Response data
"""
action = request_data.get("action")
try:
if action == "search_flights":
flights = await self.search_flights(**request_data.get("params", {}))
return {
"status": "success",
"action": action,
"data": [flight.to_dict() for flight in flights],
"count": len(flights)
}
elif action == "get_flight_details":
flight = await self.get_flight_details(request_data.get("flight_id"))
if flight:
return {
"status": "success",
"action": action,
"data": flight.to_dict()
}
else:
return {
"status": "error",
"action": action,
"error": "Flight not found"
}
elif action == "book_flight":
booking = await self.book_flight(**request_data.get("params", {}))
return {
"status": "success",
"action": action,
"data": booking
}
else:
return {
"status": "error",
"action": action,
"error": f"Unknown action: {action}"
}
except Exception as e:
self.logger.error(f"Request processing failed: {e}")
return {
"status": "error",
"action": action,
"error": str(e)
}
def get_flight_stats(self) -> Dict[str, Any]:
"""Get flight-specific statistics."""
base_stats = self.get_stats()
flight_stats = self._flight_stats.copy()
# Combine base and flight-specific stats
return {
**base_stats,
"flight_stats": flight_stats,
"cache_size": len(self._search_cache),
"booking_cache_size": len(self._booking_cache)
}
# =============================================================================
# EXAMPLE USAGE
# =============================================================================
async def demonstrate_flight_agent():
"""Demonstrate the Flight Agent functionality."""
print("✈️ Flight Agent Demonstration")
print("=" * 50)
# Create flight agent
async with FlightAgent() as agent:
print(f"✅ Flight Agent created: {agent}")
print(f" Name: {agent.name}")
print(f" Description: {agent.description}")
print(f" Initialized: {agent._is_initialized}")
# Test flight search (this would fail with real API, but shows the structure)
try:
flights = await agent.search_flights(
departure_city="New York",
arrival_city="Los Angeles",
departure_date=datetime.now(timezone.utc),
passengers=2
)
print(f" Found {len(flights)} flights")
except Exception as e:
print(f" Flight search failed (expected): {e}")
# Test request processing
request_data = {
"action": "search_flights",
"params": {
"departure_city": "New York",
"arrival_city": "Los Angeles",
"departure_date": datetime.now(timezone.utc),
"passengers": 1
}
}
response = await agent.process_request(request_data)
print(f" Request response: {response['status']}")
# Get statistics
stats = agent.get_flight_stats()
print(f" Flight stats: {stats['flight_stats']}")
print("✅ Flight Agent demonstration completed")
if __name__ == "__main__":
asyncio.run(demonstrate_flight_agent())