"""Generate concrete flights for a route + date.""" from __future__ import annotations import random from datetime import date, datetime, timedelta, timezone from zoneinfo import ZoneInfo from .config import ( AIRCRAFT_BY_DISTANCE, CONNECTING_BASE_DISCOUNT, CURFEW_FACTORS, DEPARTURE_WEIGHT_JITTER, DEPARTURE_WEIGHTS, EMISSIONS_KG_PER_KM, HUB_FLIGHT_BONUS, LEGROOM_RANGES, MAX_FLIGHTS_MULTI_CARRIER, MAX_FLIGHTS_SINGLE_CARRIER, MAX_LAYOVER_MINUTES, MIN_CONNECTION_MINUTES, MIN_FLIGHTS_PER_DAY, MIN_LAYOVER_MINUTES, OVERNIGHT_EXEMPT_ROUTES, POPULARITY_LATEST_EXTENSION_MAX, POPULARITY_TIME_LIMITS, POWER_AIRCRAFT, PREMIUM_CLASS_AMENITY_BOOST, SAME_AIRLINE_CONNECTION_DISCOUNT, SHORT_CONNECTION_MINUTES, VIDEO_AIRCRAFT, WIFI_AIRCRAFT, ) from .data_loader import Route, RouteGraph from .models import CabinClass, FlightOffer, FlightSegment, LayoverInfo from .price_engine import compute_price from .route_finder import RoutePlan from .seed_utils import seeded_random def _pick_aircraft(distance_km: int, rng: random.Random) -> str: for max_dist, aircraft_list in AIRCRAFT_BY_DISTANCE: if distance_km <= max_dist: return rng.choice(aircraft_list) return "777-300ER" def _make_flight_number(carrier_iata: str, rng: random.Random) -> str: return f"{carrier_iata}{rng.randint(100, 9999)}" def _generate_amenities( aircraft: str, cabin_class: CabinClass, distance_km: int, rng: random.Random, ) -> dict: """Generate seat amenities based on aircraft, cabin class, and distance.""" cls = cabin_class.value lo, hi = LEGROOM_RANGES.get(cls, (29, 32)) legroom = rng.randint(lo, hi) is_premium = cls in ("business", "first") boost = PREMIUM_CLASS_AMENITY_BOOST if is_premium else 0.0 # WiFi wifi_prob = min(1.0, WIFI_AIRCRAFT.get(aircraft, 0.3) + boost) has_wifi = rng.random() < wifi_prob wifi_type = None if has_wifi: # Free WiFi more common on premium classes and newer aircraft free_prob = 0.7 if is_premium else (0.3 if aircraft.startswith(("787", "A35")) else 0.12) wifi_type = "Free Wi-Fi" if rng.random() < free_prob else "Wi-Fi for a fee" # Power & USB power_prob = min(1.0, POWER_AIRCRAFT.get(aircraft, 0.2) + boost) has_power = rng.random() < power_prob # USB slightly more common than AC power has_usb = has_power or rng.random() < min(1.0, power_prob + 0.15) # Video/IFE video_prob = min(1.0, VIDEO_AIRCRAFT.get(aircraft, 0.05) + boost) # Long-haul (>4000km) gets a boost if distance_km > 4000: video_prob = min(1.0, video_prob + 0.15) has_video = rng.random() < video_prob video_type = None if has_video: if aircraft in ("A380", "777-300ER", "787-9", "787-10", "A350-900", "A350-1000"): video_type = "On-demand video" elif distance_km > 3000: video_type = "On-demand video" if rng.random() < 0.7 else "Seatback screen" else: video_type = "Live TV" if rng.random() < 0.4 else "Seatback screen" return { "legroom_inches": legroom, "has_wifi": has_wifi, "wifi_type": wifi_type, "has_power": has_power, "has_usb": has_usb, "has_video": has_video, "video_type": video_type, } def _classify_route_type(origin_continent: str, dest_continent: str) -> str: """Classify a route for departure time distribution selection. Returns a key into DEPARTURE_WEIGHTS. """ if origin_continent == dest_continent: return "domestic" o, d = origin_continent, dest_continent if o == "NA" and d == "EU": return "na_to_eu" if o == "EU" and d == "NA": return "eu_to_na" if o == "NA" and d == "AS": return "na_to_asia" if o == "AS" and d == "NA": return "asia_to_na" if o == "EU" and d == "AS": return "eu_to_asia" if o == "AS" and d == "EU": return "asia_to_eu" return "default" def _build_departure_weights( route_type: str, num_carriers: int, origin_route_count: int, rng: random.Random, ) -> list[float]: """Build hourly departure weights adjusted for popularity, curfew, and jitter. Returns a 24-element list of non-negative floats (one weight per hour 0–23). Overnight-exempt routes (e.g. asia_to_eu with 23:00–02:30 departures) skip the popularity window and curfew so their late-night pattern is preserved intact. """ base = list(DEPARTURE_WEIGHTS.get(route_type, DEPARTURE_WEIGHTS["default"])) weights = [float(w) for w in base] is_overnight = route_type in OVERNIGHT_EXEMPT_ROUTES if not is_overnight: # 1. Popularity-based time window: zero out hours outside the range earliest, latest = 7, 21 # defaults for min_carriers, eh, lh in POPULARITY_TIME_LIMITS: if num_carriers >= min_carriers: earliest, latest = eh, lh break # Random extension up to 1.5 h on the latest hour if num_carriers >= 2: latest = min(23, latest + rng.randint(0, POPULARITY_LATEST_EXTENSION_MAX) // 60) for h in range(0, earliest): weights[h] = 0.0 for h in range(latest + 1, 24): weights[h] = 0.0 # 2. Airport curfew: reduce midnight–5 AM based on airport size curfew_factor = 0.01 for min_routes, factor in CURFEW_FACTORS: if origin_route_count >= min_routes: curfew_factor = factor break for h in range(0, 6): weights[h] *= curfew_factor # 3. Per-route jitter: ±DEPARTURE_WEIGHT_JITTER on each hour for h in range(24): if weights[h] > 0: jitter = 1.0 - DEPARTURE_WEIGHT_JITTER + 2 * DEPARTURE_WEIGHT_JITTER * rng.random() weights[h] *= jitter return weights def _sample_departure_minutes( weights: list[float], num_flights: int, rng: random.Random, ) -> list[int]: """Sample N departure times (minutes since midnight) from hourly weights. Each sampled departure gets a random minute offset within its chosen hour. Returns sorted list. """ total = sum(weights) if total <= 0: # Fallback: uniform 7 AM – 9 PM return sorted(rng.randint(7 * 60, 21 * 60) for _ in range(num_flights)) # Sample hours via weighted random selection hours = rng.choices(range(24), weights=weights, k=num_flights) # Convert to minutes with random offset within the hour minutes = [h * 60 + rng.randint(0, 59) for h in hours] return sorted(minutes) def _carrier_weights_at_airport( graph: RouteGraph, origin: str, leg_carriers: list[dict], ) -> list[float]: """Weight carriers by their presence at the origin airport. Carriers that operate more routes from this airport (hub carriers) get higher selection weight, so they appear on more flights. """ # Count how many routes each carrier operates from origin carrier_route_counts: dict[str, int] = {} for route in graph.airports[origin].routes: for carrier in route.carriers: iata = carrier["iata"] carrier_route_counts[iata] = carrier_route_counts.get(iata, 0) + 1 weights = [] for carrier in leg_carriers: count = carrier_route_counts.get(carrier["iata"], 1) # Square-root scaling: hubs get more weight but don't completely dominate weights.append(count ** 0.5) return weights def _get_timezone(graph: RouteGraph, iata: str) -> ZoneInfo: airport = graph.airports.get(iata) if airport and airport.timezone: try: return ZoneInfo(airport.timezone) except KeyError: pass return ZoneInfo("UTC") def _leg_type(graph: RouteGraph, origin_iata: str, dest_iata: str) -> str: """Return 'domestic' if both airports are in the same country, else 'international'.""" o = graph.airports.get(origin_iata) d = graph.airports.get(dest_iata) if o and d and o.country_code == d.country_code: return "domestic" return "international" def _min_connection_minutes( graph: RouteGraph, arriving_origin: str, hub: str, departing_dest: str, ) -> int: """Return the minimum connection time at *hub* given the arriving and departing legs.""" arr_type = _leg_type(graph, arriving_origin, hub) dep_type = _leg_type(graph, hub, departing_dest) return MIN_CONNECTION_MINUTES.get((arr_type, dep_type), MIN_LAYOVER_MINUTES) def generate_flights_for_route( graph: RouteGraph, route_plan: RoutePlan, departure_date: date, cabin_class: CabinClass, hub_iatas: list[str], ) -> list[FlightOffer]: """Generate concrete flight offers for a route plan on a given date.""" origin = route_plan.waypoints[0] destination = route_plan.waypoints[-1] # Seed based on route + date for determinism seed_key = f"{origin}-{destination}-{departure_date.isoformat()}-{cabin_class.value}" rng = seeded_random(seed_key, *route_plan.waypoints) if route_plan.stops == 0: return _generate_direct_flights(graph, route_plan, departure_date, cabin_class, rng) else: return _generate_connecting_flights(graph, route_plan, departure_date, cabin_class, rng) def _generate_direct_flights( graph: RouteGraph, route_plan: RoutePlan, departure_date: date, cabin_class: CabinClass, rng: random.Random, ) -> list[FlightOffer]: """Generate multiple direct flight options for a single-leg route.""" leg = route_plan.legs[0] origin = route_plan.waypoints[0] destination = route_plan.waypoints[1] origin_airport = graph.airports[origin] dest_airport = graph.airports[destination] origin_route_count = len(origin_airport.routes) # Number of flights based on carrier count num_carriers = len(leg.carriers) if num_carriers == 1: num_flights = rng.randint(MIN_FLIGHTS_PER_DAY, MAX_FLIGHTS_SINGLE_CARRIER) elif num_carriers <= 3: num_flights = rng.randint(3, 8) else: num_flights = rng.randint(8, MAX_FLIGHTS_MULTI_CARRIER) # Hub airport bonus: major airports generate more flights for min_routes, bonus_min, bonus_max in HUB_FLIGHT_BONUS: if origin_route_count >= min_routes: num_flights += rng.randint(bonus_min, bonus_max) break # Realistic departure time distribution based on route type route_type = _classify_route_type(origin_airport.continent, dest_airport.continent) dep_weights = _build_departure_weights(route_type, num_carriers, origin_route_count, rng) departure_hours = _sample_departure_minutes(dep_weights, num_flights, rng) origin_tz = _get_timezone(graph, origin) dest_tz = _get_timezone(graph, destination) # Pre-compute carrier weights for hub-airline bias carrier_wts = _carrier_weights_at_airport(graph, origin, leg.carriers) flights = [] for dep_minutes in departure_hours: carrier = rng.choices(leg.carriers, weights=carrier_wts, k=1)[0] dep_hour = dep_minutes // 60 dep_min = dep_minutes % 60 departure_dt = datetime( departure_date.year, departure_date.month, departure_date.day, dep_hour, dep_min, tzinfo=origin_tz, ) # Calculate arrival arrival_dt = departure_dt + timedelta(minutes=leg.duration_min) arrival_dt = arrival_dt.astimezone(dest_tz) price = compute_price( distance_km=leg.distance_km, cabin_class=cabin_class.value, departure_date=departure_date, departure_hour=dep_hour, num_carriers=num_carriers, dest_continent=dest_airport.continent, rng=rng, ) flight_id = f"{origin}{destination}{departure_date.isoformat()}{dep_minutes}{carrier['iata']}" aircraft = _pick_aircraft(leg.distance_km, rng) amenities = _generate_amenities(aircraft, cabin_class, leg.distance_km, rng) # Overnight: compare local calendar dates in their respective timezones. # Python .date() on a tz-aware datetime returns the local date in that tz. is_overnight = arrival_dt.date() != departure_dt.date() segment = FlightSegment( airline_code=carrier["iata"], airline_name=carrier["name"], flight_number=_make_flight_number(carrier["iata"], rng), aircraft=aircraft, origin=origin, origin_city=origin_airport.city_name, destination=destination, destination_city=dest_airport.city_name, departure=departure_dt, arrival=arrival_dt, duration_minutes=leg.duration_min, is_overnight=is_overnight, **amenities, ) emissions = int(leg.distance_km * EMISSIONS_KG_PER_KM.get(cabin_class.value, 0.09)) flights.append(FlightOffer( id=flight_id, segments=[segment], total_duration_minutes=leg.duration_min, stops=0, price_usd=price, cabin_class=cabin_class, origin=origin, destination=destination, departure=departure_dt, arrival=arrival_dt, emissions_kg=emissions, )) return flights def _generate_connecting_flights( graph: RouteGraph, route_plan: RoutePlan, departure_date: date, cabin_class: CabinClass, rng: random.Random, ) -> list[FlightOffer]: """Generate connecting flight options (1-stop or 2-stop).""" origin = route_plan.waypoints[0] destination = route_plan.waypoints[-1] origin_airport = graph.airports[origin] dest_airport = graph.airports[destination] origin_route_count = len(origin_airport.routes) # Realistic departure times for the first leg first_leg = route_plan.legs[0] num_first_carriers = len(first_leg.carriers) if first_leg.carriers else 1 route_type = _classify_route_type(origin_airport.continent, dest_airport.continent) dep_weights = _build_departure_weights(route_type, num_first_carriers, origin_route_count, rng) # Generate 2-5 options per connecting route num_options = rng.randint(2, 5) dep_times = _sample_departure_minutes(dep_weights, num_options, rng) flights = [] for option_idx in range(num_options): departure_minutes = dep_times[option_idx] segments = [] layover_infos: list[LayoverInfo] = [] current_time = datetime( departure_date.year, departure_date.month, departure_date.day, departure_minutes // 60, departure_minutes % 60, tzinfo=_get_timezone(graph, origin), ) total_price = 0.0 total_duration = 0 valid = True for i, leg in enumerate(route_plan.legs): leg_origin = route_plan.waypoints[i] leg_dest = route_plan.waypoints[i + 1] origin_tz = _get_timezone(graph, leg_origin) dest_tz = _get_timezone(graph, leg_dest) origin_ap = graph.airports[leg_origin] dest_ap = graph.airports[leg_dest] if not leg.carriers: valid = False break leg_cwts = _carrier_weights_at_airport(graph, leg_origin, leg.carriers) carrier = rng.choices(leg.carriers, weights=leg_cwts, k=1)[0] departure_dt = current_time.astimezone(origin_tz) arrival_dt = departure_dt + timedelta(minutes=leg.duration_min) arrival_dt = arrival_dt.astimezone(dest_tz) # Overnight: compare local calendar dates in their respective timezones is_overnight = arrival_dt.date() != departure_dt.date() # Per-leg price with connection discount # (see price_engine.py and config.py for full pricing docs) leg_price = compute_price( distance_km=leg.distance_km, cabin_class=cabin_class.value, departure_date=departure_date, departure_hour=departure_dt.hour, num_carriers=len(leg.carriers), dest_continent=dest_ap.continent, rng=rng, is_connection=True, ) leg_price *= CONNECTING_BASE_DISCOUNT total_price += leg_price aircraft = _pick_aircraft(leg.distance_km, rng) amenities = _generate_amenities(aircraft, cabin_class, leg.distance_km, rng) segments.append(FlightSegment( airline_code=carrier["iata"], airline_name=carrier["name"], flight_number=_make_flight_number(carrier["iata"], rng), aircraft=aircraft, origin=leg_origin, origin_city=origin_ap.city_name, destination=leg_dest, destination_city=dest_ap.city_name, departure=departure_dt, arrival=arrival_dt, duration_minutes=leg.duration_min, is_overnight=is_overnight, **amenities, )) # Add layover time for next leg if i < len(route_plan.legs) - 1: next_dest = route_plan.waypoints[i + 2] min_conn = _min_connection_minutes(graph, leg_origin, leg_dest, next_dest) layover = rng.randint(min_conn, MAX_LAYOVER_MINUTES) current_time = arrival_dt + timedelta(minutes=layover) total_duration += leg.duration_min + layover # Layover overnight: both times at the hub airport (same tz). # arrival_dt is already in dest_tz (= hub tz). # current_time inherits that tz after timedelta addition. layover_overnight = current_time.date() != arrival_dt.date() layover_infos.append(LayoverInfo( duration_minutes=layover, airport=leg_dest, airport_city=dest_ap.city_name, is_overnight=layover_overnight, is_short_connection=layover < SHORT_CONNECTION_MINUTES, )) # Check if layover pushes to next day too far if (current_time - datetime( departure_date.year, departure_date.month, departure_date.day, tzinfo=origin_tz )).days > 1: valid = False break else: total_duration += leg.duration_min if not valid: continue # Same-airline connection discount: if all segments are on the same # carrier, apply an additional discount (bundled itinerary fare). carrier_codes = {seg.airline_code for seg in segments} if len(carrier_codes) == 1: total_price *= SAME_AIRLINE_CONNECTION_DISCOUNT total_price = round(total_price, 0) first_departure = segments[0].departure last_arrival = segments[-1].arrival flight_id = ( f"{origin}{destination}{departure_date.isoformat()}" f"{departure_minutes}{'-'.join(route_plan.waypoints)}{option_idx}" ) total_distance = sum(leg.distance_km for leg in route_plan.legs) emissions = int(total_distance * EMISSIONS_KG_PER_KM.get(cabin_class.value, 0.09)) flights.append(FlightOffer( id=flight_id, segments=segments, layovers=layover_infos, total_duration_minutes=total_duration, stops=route_plan.stops, price_usd=total_price, cabin_class=cabin_class, origin=origin, destination=destination, departure=first_departure, arrival=last_arrival, emissions_kg=emissions, )) return flights