"""Flight search endpoint.""" from __future__ import annotations from datetime import date, timedelta from fastapi import APIRouter, HTTPException from ..config import ( BEST_FLIGHTS_COUNT, BEST_WEIGHT_DURATION, BEST_WEIGHT_PRICE, BEST_WEIGHT_STOPS, MAX_RESULTS, ROUND_TRIP_SAME_AIRLINE_DISCOUNT, ) from ..data_loader import get_route_graph from ..flight_generator import generate_flights_for_route from ..hub_detector import compute_hub_scores from ..models import FlightOffer, SearchRequest, SearchResponse, SortBy, TripType from ..route_finder import find_routes from ..seed_utils import make_seed router = APIRouter(prefix="/api", tags=["search"]) # Module-level hub cache _hub_iatas: list[str] | None = None def _get_hubs() -> list[str]: global _hub_iatas if _hub_iatas is None: graph = get_route_graph() _hub_iatas = compute_hub_scores(graph) return _hub_iatas def _apply_filters(flights: list[FlightOffer], req: SearchRequest) -> list[FlightOffer]: f = req.filters result = flights if f.max_stops is not None: result = [fl for fl in result if fl.stops <= f.max_stops] if f.max_price is not None: result = [fl for fl in result if fl.price_usd <= f.max_price] if f.max_duration_minutes is not None: result = [fl for fl in result if fl.total_duration_minutes <= f.max_duration_minutes] if f.airlines: airline_set = set(f.airlines) result = [ fl for fl in result if any(seg.airline_code in airline_set for seg in fl.segments) ] if f.departure_time_min: h, m = map(int, f.departure_time_min.split(":")) min_minutes = h * 60 + m result = [ fl for fl in result if fl.departure.hour * 60 + fl.departure.minute >= min_minutes ] if f.departure_time_max: h, m = map(int, f.departure_time_max.split(":")) max_minutes = h * 60 + m result = [ fl for fl in result if fl.departure.hour * 60 + fl.departure.minute <= max_minutes ] return result def _sort_flights(flights: list[FlightOffer], sort_by: SortBy) -> list[FlightOffer]: if sort_by == SortBy.cheapest: return sorted(flights, key=lambda f: f.price_usd) elif sort_by == SortBy.fastest: return sorted(flights, key=lambda f: f.total_duration_minutes) else: # best: balance of price and duration if not flights: return flights max_price = max(f.price_usd for f in flights) or 1 max_dur = max(f.total_duration_minutes for f in flights) or 1 return sorted( flights, key=lambda f: (f.price_usd / max_price) * 0.6 + (f.total_duration_minutes / max_dur) * 0.4, ) def _tag_best_flights(flights: list[FlightOffer], count: int = BEST_FLIGHTS_COUNT) -> None: """Tag top flights as is_best using a composite score. "Best" flights balance price, duration, and number of stops — like Google Flights' "Top departing flights" section. The composite score normalizes each factor to [0,1] then weights them: score = W_price * (price/max_price) + W_duration * (duration/max_duration) + W_stops * (stops/max_stops) Lower score = better. Top `count` flights are tagged is_best=True. """ if len(flights) < 2: for f in flights: f.is_best = True return max_price = max(f.price_usd for f in flights) or 1 max_dur = max(f.total_duration_minutes for f in flights) or 1 max_stops = max(f.stops for f in flights) or 1 scored = [] for f in flights: score = ( BEST_WEIGHT_PRICE * (f.price_usd / max_price) + BEST_WEIGHT_DURATION * (f.total_duration_minutes / max_dur) + BEST_WEIGHT_STOPS * (f.stops / max_stops) ) scored.append((score, f)) scored.sort(key=lambda x: x[0]) for _, f in scored[:count]: f.is_best = True @router.post("/search", response_model=SearchResponse) async def search_flights(req: SearchRequest): graph = get_route_graph() hub_iatas = _get_hubs() if not req.legs: raise HTTPException(status_code=400, detail="At least one leg required") # Validate airports for leg in req.legs: if leg.origin.upper() not in graph.airports: raise HTTPException(status_code=404, detail=f"Airport {leg.origin} not found") if leg.destination.upper() not in graph.airports: raise HTTPException(status_code=404, detail=f"Airport {leg.destination} not found") # Validate booking window: flights can only be booked up to 12 months in advance max_date = date.today() + timedelta(days=365) for leg in req.legs: if leg.date > max_date: raise HTTPException( status_code=400, detail="Flights can only be booked up to 12 months in advance", ) # Validate round-trip date order: return date must be >= departure date if req.trip_type == TripType.round_trip and len(req.legs) >= 2: if req.legs[1].date < req.legs[0].date: raise HTTPException( status_code=400, detail="Return date must be on or after the departure date", ) # Generate outbound flights outbound_leg = req.legs[0] origin = outbound_leg.origin.upper() destination = outbound_leg.destination.upper() max_stops = req.filters.max_stops route_plans = find_routes(graph, origin, destination, hub_iatas, max_stops=max_stops) outbound_flights: list[FlightOffer] = [] for plan in route_plans: flights = generate_flights_for_route( graph, plan, outbound_leg.date, req.cabin_class, hub_iatas ) outbound_flights.extend(flights) outbound_flights = _apply_filters(outbound_flights, req) outbound_flights = _sort_flights(outbound_flights, req.sort_by) outbound_flights = outbound_flights[:MAX_RESULTS] _tag_best_flights(outbound_flights) # Generate return flights if round trip return_flights: list[FlightOffer] = [] if req.trip_type.value == "round_trip" and len(req.legs) >= 2: return_leg = req.legs[1] ret_origin = return_leg.origin.upper() ret_dest = return_leg.destination.upper() ret_plans = find_routes(graph, ret_origin, ret_dest, hub_iatas, max_stops=max_stops) for plan in ret_plans: flights = generate_flights_for_route( graph, plan, return_leg.date, req.cabin_class, hub_iatas ) return_flights.extend(flights) # Round-trip same-airline discount is now applied client-side based on # the user's selected outbound flight (not blanket across all outbound # airlines). The discount rate is returned in same_airline_discount. return_flights = _apply_filters(return_flights, req) return_flights = _sort_flights(return_flights, req.sort_by) return_flights = return_flights[:MAX_RESULTS] _tag_best_flights(return_flights) search_id = str(make_seed(origin, destination, outbound_leg.date.isoformat())) return SearchResponse( outbound_flights=outbound_flights, return_flights=return_flights, search_id=search_id, origin=origin, destination=destination, same_airline_discount=ROUND_TRIP_SAME_AIRLINE_DISCOUNT if return_flights else 1.0, )