Spaces:
Sleeping
Sleeping
| """ | |
| Flight Agent - Specialized agent for flight search and booking. | |
| This agent demonstrates inheritance from BaseAgent and shows how | |
| specialized agents can extend the base functionality. | |
| """ | |
| import asyncio | |
| from datetime import datetime, timezone | |
| from typing import Dict, List, Any, Optional | |
| from decimal import Decimal | |
| from ..core.base_agent import BaseAgent | |
| from ..models.flight_models import FlightOption | |
| from ..utils.config import get_settings | |
| class FlightAgent(BaseAgent): | |
| """ | |
| Specialized agent for flight search and booking. | |
| This agent demonstrates inheritance by: | |
| - Extending BaseAgent functionality | |
| - Adding flight-specific methods | |
| - Implementing abstract methods | |
| - Adding specialized error handling | |
| - Using flight-specific data models | |
| """ | |
| def __init__(self, api_key: Optional[str] = None): | |
| """ | |
| Initialize the Flight Agent. | |
| Args: | |
| api_key: API key for flight service (optional, will use settings if not provided) | |
| """ | |
| # Call parent constructor | |
| super().__init__( | |
| name="FlightAgent", | |
| description="AI agent for flight search and booking", | |
| base_url="https://api.flights.example.com", # Replace with real API | |
| timeout=30, | |
| max_retries=3, | |
| retry_delay=1.0 | |
| ) | |
| # Flight-specific configuration | |
| self.settings = get_settings() | |
| self.api_key = api_key or getattr(self.settings, 'flight_api_key', None) | |
| if not self.api_key: | |
| self.logger.warning("No flight API key configured") | |
| # Flight-specific state | |
| self._search_cache: Dict[str, List[FlightOption]] = {} | |
| self._booking_cache: Dict[str, Dict[str, Any]] = {} | |
| # Flight-specific statistics | |
| self._flight_stats = { | |
| "searches_performed": 0, | |
| "flights_found": 0, | |
| "bookings_made": 0, | |
| "cache_hits": 0, | |
| "cache_misses": 0 | |
| } | |
| async def _initialize_agent(self) -> None: | |
| """ | |
| Initialize flight-specific resources. | |
| This method is called by the parent class during initialization. | |
| """ | |
| self.logger.info("Initializing Flight Agent...") | |
| # Validate API key | |
| if not self.api_key: | |
| raise ValueError("Flight API key is required") | |
| # Test API connectivity | |
| try: | |
| await self._test_api_connectivity() | |
| self.logger.info("Flight API connectivity test passed") | |
| except Exception as e: | |
| self.logger.error(f"Flight API connectivity test failed: {e}") | |
| raise ValueError(f"API connectivity test failed: {e}", self.name) | |
| # Initialize flight-specific resources | |
| self._search_cache.clear() | |
| self._booking_cache.clear() | |
| self.logger.info("Flight Agent initialized successfully") | |
| async def _cleanup_agent(self) -> None: | |
| """ | |
| Cleanup flight-specific resources. | |
| This method is called by the parent class during cleanup. | |
| """ | |
| self.logger.info("Cleaning up Flight Agent...") | |
| # Save cache to persistent storage (if needed) | |
| await self._save_cache() | |
| # Clear caches | |
| self._search_cache.clear() | |
| self._booking_cache.clear() | |
| self.logger.info("Flight Agent cleaned up successfully") | |
| async def _test_api_connectivity(self) -> None: | |
| """Test connectivity to the flight API.""" | |
| try: | |
| # Make a simple test request | |
| response = await self.get("/health") | |
| if response.status_code != 200: | |
| raise RuntimeError(f"API health check failed: {response.status_code}", self.name) | |
| except Exception as e: | |
| raise RuntimeError(f"API connectivity test failed: {e}", self.name) | |
| async def _save_cache(self) -> None: | |
| """Save cache to persistent storage.""" | |
| # In a real implementation, you might save to Redis or database | |
| self.logger.debug("Saving flight search cache...") | |
| # Implementation would go here | |
| async def search_flights( | |
| self, | |
| departure_city: str, | |
| arrival_city: str, | |
| departure_date: datetime, | |
| return_date: Optional[datetime] = None, | |
| passengers: int = 1, | |
| class_preference: str = "Economy" | |
| ) -> List[FlightOption]: | |
| """ | |
| Search for flights between two cities. | |
| Args: | |
| departure_city: Departure city | |
| arrival_city: Arrival city | |
| departure_date: Departure date | |
| return_date: Return date (optional for one-way) | |
| passengers: Number of passengers | |
| class_preference: Travel class preference | |
| Returns: | |
| List of available flights | |
| Raises: | |
| APIError: If search fails | |
| ValidationError: If input validation fails | |
| """ | |
| # Validate inputs | |
| if not departure_city or not arrival_city: | |
| raise ValueError("Departure and arrival cities are required", self.name) | |
| if departure_date < datetime.now(timezone.utc): | |
| raise ValueError("Departure date cannot be in the past", self.name) | |
| if return_date and return_date <= departure_date: | |
| raise ValueError("Return date must be after departure date", self.name) | |
| # Create cache key | |
| cache_key = f"{departure_city}_{arrival_city}_{departure_date.date()}_{passengers}_{class_preference}" | |
| # Check cache first | |
| if cache_key in self._search_cache: | |
| self._flight_stats["cache_hits"] += 1 | |
| self.logger.info(f"Cache hit for flight search: {cache_key}") | |
| return self._search_cache[cache_key] | |
| self._flight_stats["cache_misses"] += 1 | |
| self._flight_stats["searches_performed"] += 1 | |
| try: | |
| # Prepare search parameters | |
| search_params = { | |
| "departure_city": departure_city, | |
| "arrival_city": arrival_city, | |
| "departure_date": departure_date.isoformat(), | |
| "passengers": passengers, | |
| "class": class_preference | |
| } | |
| if return_date: | |
| search_params["return_date"] = return_date.isoformat() | |
| # Make API request | |
| self.logger.info(f"Searching flights: {departure_city} → {arrival_city}") | |
| response = await self.post("/search", json=search_params) | |
| # Parse response | |
| flight_data = response.json() | |
| flights = [] | |
| for flight_info in flight_data.get("flights", []): | |
| try: | |
| # Convert API response to FlightOption model | |
| flight = FlightOption( | |
| airline=flight_info["airline"], | |
| flight_number=flight_info["flight_number"], | |
| departure_city=flight_info["departure_city"], | |
| arrival_city=flight_info["arrival_city"], | |
| departure_time=flight_info["departure_time"], | |
| arrival_time=flight_info["arrival_time"], | |
| price=Decimal(str(flight_info["price"])), | |
| duration_minutes=flight_info["duration_minutes"], | |
| stops=flight_info["stops"] | |
| ) | |
| flights.append(flight) | |
| except Exception as e: | |
| self.logger.warning(f"Failed to parse flight data: {e}") | |
| continue | |
| # Cache results | |
| self._search_cache[cache_key] = flights | |
| self._flight_stats["flights_found"] += len(flights) | |
| self.logger.info(f"Found {len(flights)} flights") | |
| return flights | |
| except Exception as e: | |
| self.logger.error(f"Flight search failed: {e}") | |
| raise RuntimeError(f"Flight search failed: {e}", self.name) | |
| async def get_flight_details(self, flight_id: str) -> Optional[FlightOption]: | |
| """ | |
| Get detailed information about a specific flight. | |
| Args: | |
| flight_id: Unique flight identifier | |
| Returns: | |
| Flight details or None if not found | |
| """ | |
| try: | |
| response = await self.get(f"/flights/{flight_id}") | |
| if response.status_code == 404: | |
| return None | |
| flight_data = response.json() | |
| return FlightOption( | |
| airline=flight_data["airline"], | |
| flight_number=flight_data["flight_number"], | |
| departure_city=flight_data["departure_city"], | |
| arrival_city=flight_data["arrival_city"], | |
| departure_time=flight_data["departure_time"], | |
| arrival_time=flight_data["arrival_time"], | |
| price=Decimal(str(flight_data["price"])), | |
| duration_minutes=flight_data["duration_minutes"], | |
| stops=flight_data["stops"] | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"Failed to get flight details: {e}") | |
| raise RuntimeError(f"Failed to get flight details: {e}", self.name) | |
| async def book_flight( | |
| self, | |
| flight_id: str, | |
| passenger_info: Dict[str, Any], | |
| payment_info: Dict[str, Any] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Book a flight. | |
| Args: | |
| flight_id: Flight to book | |
| passenger_info: Passenger information | |
| payment_info: Payment information | |
| Returns: | |
| Booking confirmation | |
| """ | |
| try: | |
| booking_data = { | |
| "flight_id": flight_id, | |
| "passenger_info": passenger_info, | |
| "payment_info": payment_info | |
| } | |
| self.logger.info(f"Booking flight: {flight_id}") | |
| response = await self.post("/bookings", json=booking_data) | |
| booking_confirmation = response.json() | |
| # Cache booking | |
| self._booking_cache[booking_confirmation["booking_id"]] = booking_confirmation | |
| self._flight_stats["bookings_made"] += 1 | |
| self.logger.info(f"Flight booked successfully: {booking_confirmation['booking_id']}") | |
| return booking_confirmation | |
| except Exception as e: | |
| self.logger.error(f"Flight booking failed: {e}") | |
| raise RuntimeError(f"Flight booking failed: {e}", self.name) | |
| async def process_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Process a flight-related request. | |
| This method implements the abstract method from BaseAgent. | |
| Args: | |
| request_data: Request data containing action and parameters | |
| Returns: | |
| Response data | |
| """ | |
| action = request_data.get("action") | |
| try: | |
| if action == "search_flights": | |
| flights = await self.search_flights(**request_data.get("params", {})) | |
| return { | |
| "status": "success", | |
| "action": action, | |
| "data": [flight.to_dict() for flight in flights], | |
| "count": len(flights) | |
| } | |
| elif action == "get_flight_details": | |
| flight = await self.get_flight_details(request_data.get("flight_id")) | |
| if flight: | |
| return { | |
| "status": "success", | |
| "action": action, | |
| "data": flight.to_dict() | |
| } | |
| else: | |
| return { | |
| "status": "error", | |
| "action": action, | |
| "error": "Flight not found" | |
| } | |
| elif action == "book_flight": | |
| booking = await self.book_flight(**request_data.get("params", {})) | |
| return { | |
| "status": "success", | |
| "action": action, | |
| "data": booking | |
| } | |
| else: | |
| return { | |
| "status": "error", | |
| "action": action, | |
| "error": f"Unknown action: {action}" | |
| } | |
| except Exception as e: | |
| self.logger.error(f"Request processing failed: {e}") | |
| return { | |
| "status": "error", | |
| "action": action, | |
| "error": str(e) | |
| } | |
| def get_flight_stats(self) -> Dict[str, Any]: | |
| """Get flight-specific statistics.""" | |
| base_stats = self.get_stats() | |
| flight_stats = self._flight_stats.copy() | |
| # Combine base and flight-specific stats | |
| return { | |
| **base_stats, | |
| "flight_stats": flight_stats, | |
| "cache_size": len(self._search_cache), | |
| "booking_cache_size": len(self._booking_cache) | |
| } | |
| # ============================================================================= | |
| # EXAMPLE USAGE | |
| # ============================================================================= | |
| async def demonstrate_flight_agent(): | |
| """Demonstrate the Flight Agent functionality.""" | |
| print("✈️ Flight Agent Demonstration") | |
| print("=" * 50) | |
| # Create flight agent | |
| async with FlightAgent() as agent: | |
| print(f"✅ Flight Agent created: {agent}") | |
| print(f" Name: {agent.name}") | |
| print(f" Description: {agent.description}") | |
| print(f" Initialized: {agent._is_initialized}") | |
| # Test flight search (this would fail with real API, but shows the structure) | |
| try: | |
| flights = await agent.search_flights( | |
| departure_city="New York", | |
| arrival_city="Los Angeles", | |
| departure_date=datetime.now(timezone.utc), | |
| passengers=2 | |
| ) | |
| print(f" Found {len(flights)} flights") | |
| except Exception as e: | |
| print(f" Flight search failed (expected): {e}") | |
| # Test request processing | |
| request_data = { | |
| "action": "search_flights", | |
| "params": { | |
| "departure_city": "New York", | |
| "arrival_city": "Los Angeles", | |
| "departure_date": datetime.now(timezone.utc), | |
| "passengers": 1 | |
| } | |
| } | |
| response = await agent.process_request(request_data) | |
| print(f" Request response: {response['status']}") | |
| # Get statistics | |
| stats = agent.get_flight_stats() | |
| print(f" Flight stats: {stats['flight_stats']}") | |
| print("✅ Flight Agent demonstration completed") | |
| if __name__ == "__main__": | |
| asyncio.run(demonstrate_flight_agent()) | |