Spaces:
Runtime error
Runtime error
| """Flight search endpoint.""" | |
| from __future__ import annotations | |
| from datetime import date, timedelta | |
| from fastapi import APIRouter, HTTPException | |
| from ..config import ( | |
| BEST_FLIGHTS_COUNT, | |
| BEST_WEIGHT_DURATION, | |
| BEST_WEIGHT_PRICE, | |
| BEST_WEIGHT_STOPS, | |
| MAX_RESULTS, | |
| ROUND_TRIP_SAME_AIRLINE_DISCOUNT, | |
| ) | |
| from ..data_loader import get_route_graph | |
| from ..flight_generator import generate_flights_for_route | |
| from ..hub_detector import compute_hub_scores | |
| from ..models import FlightOffer, SearchRequest, SearchResponse, SortBy, TripType | |
| from ..route_finder import find_routes | |
| from ..seed_utils import make_seed | |
| router = APIRouter(prefix="/api", tags=["search"]) | |
| # Module-level hub cache | |
| _hub_iatas: list[str] | None = None | |
| def _get_hubs() -> list[str]: | |
| global _hub_iatas | |
| if _hub_iatas is None: | |
| graph = get_route_graph() | |
| _hub_iatas = compute_hub_scores(graph) | |
| return _hub_iatas | |
| def _apply_filters(flights: list[FlightOffer], req: SearchRequest) -> list[FlightOffer]: | |
| f = req.filters | |
| result = flights | |
| if f.max_stops is not None: | |
| result = [fl for fl in result if fl.stops <= f.max_stops] | |
| if f.max_price is not None: | |
| result = [fl for fl in result if fl.price_usd <= f.max_price] | |
| if f.max_duration_minutes is not None: | |
| result = [fl for fl in result if fl.total_duration_minutes <= f.max_duration_minutes] | |
| if f.airlines: | |
| airline_set = set(f.airlines) | |
| result = [ | |
| fl for fl in result | |
| if any(seg.airline_code in airline_set for seg in fl.segments) | |
| ] | |
| if f.departure_time_min: | |
| h, m = map(int, f.departure_time_min.split(":")) | |
| min_minutes = h * 60 + m | |
| result = [ | |
| fl for fl in result | |
| if fl.departure.hour * 60 + fl.departure.minute >= min_minutes | |
| ] | |
| if f.departure_time_max: | |
| h, m = map(int, f.departure_time_max.split(":")) | |
| max_minutes = h * 60 + m | |
| result = [ | |
| fl for fl in result | |
| if fl.departure.hour * 60 + fl.departure.minute <= max_minutes | |
| ] | |
| return result | |
| def _sort_flights(flights: list[FlightOffer], sort_by: SortBy) -> list[FlightOffer]: | |
| if sort_by == SortBy.cheapest: | |
| return sorted(flights, key=lambda f: f.price_usd) | |
| elif sort_by == SortBy.fastest: | |
| return sorted(flights, key=lambda f: f.total_duration_minutes) | |
| else: # best: balance of price and duration | |
| if not flights: | |
| return flights | |
| max_price = max(f.price_usd for f in flights) or 1 | |
| max_dur = max(f.total_duration_minutes for f in flights) or 1 | |
| return sorted( | |
| flights, | |
| key=lambda f: (f.price_usd / max_price) * 0.6 + (f.total_duration_minutes / max_dur) * 0.4, | |
| ) | |
| def _tag_best_flights(flights: list[FlightOffer], count: int = BEST_FLIGHTS_COUNT) -> None: | |
| """Tag top flights as is_best using a composite score. | |
| "Best" flights balance price, duration, and number of stops — like | |
| Google Flights' "Top departing flights" section. The composite score | |
| normalizes each factor to [0,1] then weights them: | |
| score = W_price * (price/max_price) | |
| + W_duration * (duration/max_duration) | |
| + W_stops * (stops/max_stops) | |
| Lower score = better. Top `count` flights are tagged is_best=True. | |
| """ | |
| if len(flights) < 2: | |
| for f in flights: | |
| f.is_best = True | |
| return | |
| max_price = max(f.price_usd for f in flights) or 1 | |
| max_dur = max(f.total_duration_minutes for f in flights) or 1 | |
| max_stops = max(f.stops for f in flights) or 1 | |
| scored = [] | |
| for f in flights: | |
| score = ( | |
| BEST_WEIGHT_PRICE * (f.price_usd / max_price) | |
| + BEST_WEIGHT_DURATION * (f.total_duration_minutes / max_dur) | |
| + BEST_WEIGHT_STOPS * (f.stops / max_stops) | |
| ) | |
| scored.append((score, f)) | |
| scored.sort(key=lambda x: x[0]) | |
| for _, f in scored[:count]: | |
| f.is_best = True | |
| async def search_flights(req: SearchRequest): | |
| graph = get_route_graph() | |
| hub_iatas = _get_hubs() | |
| if not req.legs: | |
| raise HTTPException(status_code=400, detail="At least one leg required") | |
| # Validate airports | |
| for leg in req.legs: | |
| if leg.origin.upper() not in graph.airports: | |
| raise HTTPException(status_code=404, detail=f"Airport {leg.origin} not found") | |
| if leg.destination.upper() not in graph.airports: | |
| raise HTTPException(status_code=404, detail=f"Airport {leg.destination} not found") | |
| # Validate booking window: flights can only be booked up to 12 months in advance | |
| max_date = date.today() + timedelta(days=365) | |
| for leg in req.legs: | |
| if leg.date > max_date: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Flights can only be booked up to 12 months in advance", | |
| ) | |
| # Validate round-trip date order: return date must be >= departure date | |
| if req.trip_type == TripType.round_trip and len(req.legs) >= 2: | |
| if req.legs[1].date < req.legs[0].date: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Return date must be on or after the departure date", | |
| ) | |
| # Generate outbound flights | |
| outbound_leg = req.legs[0] | |
| origin = outbound_leg.origin.upper() | |
| destination = outbound_leg.destination.upper() | |
| max_stops = req.filters.max_stops | |
| route_plans = find_routes(graph, origin, destination, hub_iatas, max_stops=max_stops) | |
| outbound_flights: list[FlightOffer] = [] | |
| for plan in route_plans: | |
| flights = generate_flights_for_route( | |
| graph, plan, outbound_leg.date, req.cabin_class, hub_iatas | |
| ) | |
| outbound_flights.extend(flights) | |
| outbound_flights = _apply_filters(outbound_flights, req) | |
| outbound_flights = _sort_flights(outbound_flights, req.sort_by) | |
| outbound_flights = outbound_flights[:MAX_RESULTS] | |
| _tag_best_flights(outbound_flights) | |
| # Generate return flights if round trip | |
| return_flights: list[FlightOffer] = [] | |
| if req.trip_type.value == "round_trip" and len(req.legs) >= 2: | |
| return_leg = req.legs[1] | |
| ret_origin = return_leg.origin.upper() | |
| ret_dest = return_leg.destination.upper() | |
| ret_plans = find_routes(graph, ret_origin, ret_dest, hub_iatas, max_stops=max_stops) | |
| for plan in ret_plans: | |
| flights = generate_flights_for_route( | |
| graph, plan, return_leg.date, req.cabin_class, hub_iatas | |
| ) | |
| return_flights.extend(flights) | |
| # Round-trip same-airline discount is now applied client-side based on | |
| # the user's selected outbound flight (not blanket across all outbound | |
| # airlines). The discount rate is returned in same_airline_discount. | |
| return_flights = _apply_filters(return_flights, req) | |
| return_flights = _sort_flights(return_flights, req.sort_by) | |
| return_flights = return_flights[:MAX_RESULTS] | |
| _tag_best_flights(return_flights) | |
| search_id = str(make_seed(origin, destination, outbound_leg.date.isoformat())) | |
| return SearchResponse( | |
| outbound_flights=outbound_flights, | |
| return_flights=return_flights, | |
| search_id=search_id, | |
| origin=origin, | |
| destination=destination, | |
| same_airline_discount=ROUND_TRIP_SAME_AIRLINE_DISCOUNT if return_flights else 1.0, | |
| ) | |