Flight-Search / backend /flight_generator.py
fyliu's picture
Add connection-type MCT, overnight/short-connection warnings
d6def07
"""Generate concrete flights for a route + date."""
from __future__ import annotations
import random
from datetime import date, datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from .config import (
AIRCRAFT_BY_DISTANCE,
CONNECTING_BASE_DISCOUNT,
CURFEW_FACTORS,
DEPARTURE_WEIGHT_JITTER,
DEPARTURE_WEIGHTS,
EMISSIONS_KG_PER_KM,
HUB_FLIGHT_BONUS,
LEGROOM_RANGES,
MAX_FLIGHTS_MULTI_CARRIER,
MAX_FLIGHTS_SINGLE_CARRIER,
MAX_LAYOVER_MINUTES,
MIN_CONNECTION_MINUTES,
MIN_FLIGHTS_PER_DAY,
MIN_LAYOVER_MINUTES,
OVERNIGHT_EXEMPT_ROUTES,
POPULARITY_LATEST_EXTENSION_MAX,
POPULARITY_TIME_LIMITS,
POWER_AIRCRAFT,
PREMIUM_CLASS_AMENITY_BOOST,
SAME_AIRLINE_CONNECTION_DISCOUNT,
SHORT_CONNECTION_MINUTES,
VIDEO_AIRCRAFT,
WIFI_AIRCRAFT,
)
from .data_loader import Route, RouteGraph
from .models import CabinClass, FlightOffer, FlightSegment, LayoverInfo
from .price_engine import compute_price
from .route_finder import RoutePlan
from .seed_utils import seeded_random
def _pick_aircraft(distance_km: int, rng: random.Random) -> str:
for max_dist, aircraft_list in AIRCRAFT_BY_DISTANCE:
if distance_km <= max_dist:
return rng.choice(aircraft_list)
return "777-300ER"
def _make_flight_number(carrier_iata: str, rng: random.Random) -> str:
return f"{carrier_iata}{rng.randint(100, 9999)}"
def _generate_amenities(
aircraft: str, cabin_class: CabinClass, distance_km: int, rng: random.Random,
) -> dict:
"""Generate seat amenities based on aircraft, cabin class, and distance."""
cls = cabin_class.value
lo, hi = LEGROOM_RANGES.get(cls, (29, 32))
legroom = rng.randint(lo, hi)
is_premium = cls in ("business", "first")
boost = PREMIUM_CLASS_AMENITY_BOOST if is_premium else 0.0
# WiFi
wifi_prob = min(1.0, WIFI_AIRCRAFT.get(aircraft, 0.3) + boost)
has_wifi = rng.random() < wifi_prob
wifi_type = None
if has_wifi:
# Free WiFi more common on premium classes and newer aircraft
free_prob = 0.7 if is_premium else (0.3 if aircraft.startswith(("787", "A35")) else 0.12)
wifi_type = "Free Wi-Fi" if rng.random() < free_prob else "Wi-Fi for a fee"
# Power & USB
power_prob = min(1.0, POWER_AIRCRAFT.get(aircraft, 0.2) + boost)
has_power = rng.random() < power_prob
# USB slightly more common than AC power
has_usb = has_power or rng.random() < min(1.0, power_prob + 0.15)
# Video/IFE
video_prob = min(1.0, VIDEO_AIRCRAFT.get(aircraft, 0.05) + boost)
# Long-haul (>4000km) gets a boost
if distance_km > 4000:
video_prob = min(1.0, video_prob + 0.15)
has_video = rng.random() < video_prob
video_type = None
if has_video:
if aircraft in ("A380", "777-300ER", "787-9", "787-10", "A350-900", "A350-1000"):
video_type = "On-demand video"
elif distance_km > 3000:
video_type = "On-demand video" if rng.random() < 0.7 else "Seatback screen"
else:
video_type = "Live TV" if rng.random() < 0.4 else "Seatback screen"
return {
"legroom_inches": legroom,
"has_wifi": has_wifi,
"wifi_type": wifi_type,
"has_power": has_power,
"has_usb": has_usb,
"has_video": has_video,
"video_type": video_type,
}
def _classify_route_type(origin_continent: str, dest_continent: str) -> str:
"""Classify a route for departure time distribution selection.
Returns a key into DEPARTURE_WEIGHTS.
"""
if origin_continent == dest_continent:
return "domestic"
o, d = origin_continent, dest_continent
if o == "NA" and d == "EU":
return "na_to_eu"
if o == "EU" and d == "NA":
return "eu_to_na"
if o == "NA" and d == "AS":
return "na_to_asia"
if o == "AS" and d == "NA":
return "asia_to_na"
if o == "EU" and d == "AS":
return "eu_to_asia"
if o == "AS" and d == "EU":
return "asia_to_eu"
return "default"
def _build_departure_weights(
route_type: str,
num_carriers: int,
origin_route_count: int,
rng: random.Random,
) -> list[float]:
"""Build hourly departure weights adjusted for popularity, curfew, and jitter.
Returns a 24-element list of non-negative floats (one weight per hour 0–23).
Overnight-exempt routes (e.g. asia_to_eu with 23:00–02:30 departures)
skip the popularity window and curfew so their late-night pattern is
preserved intact.
"""
base = list(DEPARTURE_WEIGHTS.get(route_type, DEPARTURE_WEIGHTS["default"]))
weights = [float(w) for w in base]
is_overnight = route_type in OVERNIGHT_EXEMPT_ROUTES
if not is_overnight:
# 1. Popularity-based time window: zero out hours outside the range
earliest, latest = 7, 21 # defaults
for min_carriers, eh, lh in POPULARITY_TIME_LIMITS:
if num_carriers >= min_carriers:
earliest, latest = eh, lh
break
# Random extension up to 1.5 h on the latest hour
if num_carriers >= 2:
latest = min(23, latest + rng.randint(0, POPULARITY_LATEST_EXTENSION_MAX) // 60)
for h in range(0, earliest):
weights[h] = 0.0
for h in range(latest + 1, 24):
weights[h] = 0.0
# 2. Airport curfew: reduce midnight–5 AM based on airport size
curfew_factor = 0.01
for min_routes, factor in CURFEW_FACTORS:
if origin_route_count >= min_routes:
curfew_factor = factor
break
for h in range(0, 6):
weights[h] *= curfew_factor
# 3. Per-route jitter: ±DEPARTURE_WEIGHT_JITTER on each hour
for h in range(24):
if weights[h] > 0:
jitter = 1.0 - DEPARTURE_WEIGHT_JITTER + 2 * DEPARTURE_WEIGHT_JITTER * rng.random()
weights[h] *= jitter
return weights
def _sample_departure_minutes(
weights: list[float], num_flights: int, rng: random.Random,
) -> list[int]:
"""Sample N departure times (minutes since midnight) from hourly weights.
Each sampled departure gets a random minute offset within its chosen hour.
Returns sorted list.
"""
total = sum(weights)
if total <= 0:
# Fallback: uniform 7 AM – 9 PM
return sorted(rng.randint(7 * 60, 21 * 60) for _ in range(num_flights))
# Sample hours via weighted random selection
hours = rng.choices(range(24), weights=weights, k=num_flights)
# Convert to minutes with random offset within the hour
minutes = [h * 60 + rng.randint(0, 59) for h in hours]
return sorted(minutes)
def _carrier_weights_at_airport(
graph: RouteGraph, origin: str, leg_carriers: list[dict],
) -> list[float]:
"""Weight carriers by their presence at the origin airport.
Carriers that operate more routes from this airport (hub carriers)
get higher selection weight, so they appear on more flights.
"""
# Count how many routes each carrier operates from origin
carrier_route_counts: dict[str, int] = {}
for route in graph.airports[origin].routes:
for carrier in route.carriers:
iata = carrier["iata"]
carrier_route_counts[iata] = carrier_route_counts.get(iata, 0) + 1
weights = []
for carrier in leg_carriers:
count = carrier_route_counts.get(carrier["iata"], 1)
# Square-root scaling: hubs get more weight but don't completely dominate
weights.append(count ** 0.5)
return weights
def _get_timezone(graph: RouteGraph, iata: str) -> ZoneInfo:
airport = graph.airports.get(iata)
if airport and airport.timezone:
try:
return ZoneInfo(airport.timezone)
except KeyError:
pass
return ZoneInfo("UTC")
def _leg_type(graph: RouteGraph, origin_iata: str, dest_iata: str) -> str:
"""Return 'domestic' if both airports are in the same country, else 'international'."""
o = graph.airports.get(origin_iata)
d = graph.airports.get(dest_iata)
if o and d and o.country_code == d.country_code:
return "domestic"
return "international"
def _min_connection_minutes(
graph: RouteGraph,
arriving_origin: str,
hub: str,
departing_dest: str,
) -> int:
"""Return the minimum connection time at *hub* given the arriving and departing legs."""
arr_type = _leg_type(graph, arriving_origin, hub)
dep_type = _leg_type(graph, hub, departing_dest)
return MIN_CONNECTION_MINUTES.get((arr_type, dep_type), MIN_LAYOVER_MINUTES)
def generate_flights_for_route(
graph: RouteGraph,
route_plan: RoutePlan,
departure_date: date,
cabin_class: CabinClass,
hub_iatas: list[str],
) -> list[FlightOffer]:
"""Generate concrete flight offers for a route plan on a given date."""
origin = route_plan.waypoints[0]
destination = route_plan.waypoints[-1]
# Seed based on route + date for determinism
seed_key = f"{origin}-{destination}-{departure_date.isoformat()}-{cabin_class.value}"
rng = seeded_random(seed_key, *route_plan.waypoints)
if route_plan.stops == 0:
return _generate_direct_flights(graph, route_plan, departure_date, cabin_class, rng)
else:
return _generate_connecting_flights(graph, route_plan, departure_date, cabin_class, rng)
def _generate_direct_flights(
graph: RouteGraph,
route_plan: RoutePlan,
departure_date: date,
cabin_class: CabinClass,
rng: random.Random,
) -> list[FlightOffer]:
"""Generate multiple direct flight options for a single-leg route."""
leg = route_plan.legs[0]
origin = route_plan.waypoints[0]
destination = route_plan.waypoints[1]
origin_airport = graph.airports[origin]
dest_airport = graph.airports[destination]
origin_route_count = len(origin_airport.routes)
# Number of flights based on carrier count
num_carriers = len(leg.carriers)
if num_carriers == 1:
num_flights = rng.randint(MIN_FLIGHTS_PER_DAY, MAX_FLIGHTS_SINGLE_CARRIER)
elif num_carriers <= 3:
num_flights = rng.randint(3, 8)
else:
num_flights = rng.randint(8, MAX_FLIGHTS_MULTI_CARRIER)
# Hub airport bonus: major airports generate more flights
for min_routes, bonus_min, bonus_max in HUB_FLIGHT_BONUS:
if origin_route_count >= min_routes:
num_flights += rng.randint(bonus_min, bonus_max)
break
# Realistic departure time distribution based on route type
route_type = _classify_route_type(origin_airport.continent, dest_airport.continent)
dep_weights = _build_departure_weights(route_type, num_carriers, origin_route_count, rng)
departure_hours = _sample_departure_minutes(dep_weights, num_flights, rng)
origin_tz = _get_timezone(graph, origin)
dest_tz = _get_timezone(graph, destination)
# Pre-compute carrier weights for hub-airline bias
carrier_wts = _carrier_weights_at_airport(graph, origin, leg.carriers)
flights = []
for dep_minutes in departure_hours:
carrier = rng.choices(leg.carriers, weights=carrier_wts, k=1)[0]
dep_hour = dep_minutes // 60
dep_min = dep_minutes % 60
departure_dt = datetime(
departure_date.year, departure_date.month, departure_date.day,
dep_hour, dep_min,
tzinfo=origin_tz,
)
# Calculate arrival
arrival_dt = departure_dt + timedelta(minutes=leg.duration_min)
arrival_dt = arrival_dt.astimezone(dest_tz)
price = compute_price(
distance_km=leg.distance_km,
cabin_class=cabin_class.value,
departure_date=departure_date,
departure_hour=dep_hour,
num_carriers=num_carriers,
dest_continent=dest_airport.continent,
rng=rng,
)
flight_id = f"{origin}{destination}{departure_date.isoformat()}{dep_minutes}{carrier['iata']}"
aircraft = _pick_aircraft(leg.distance_km, rng)
amenities = _generate_amenities(aircraft, cabin_class, leg.distance_km, rng)
# Overnight: compare local calendar dates in their respective timezones.
# Python .date() on a tz-aware datetime returns the local date in that tz.
is_overnight = arrival_dt.date() != departure_dt.date()
segment = FlightSegment(
airline_code=carrier["iata"],
airline_name=carrier["name"],
flight_number=_make_flight_number(carrier["iata"], rng),
aircraft=aircraft,
origin=origin,
origin_city=origin_airport.city_name,
destination=destination,
destination_city=dest_airport.city_name,
departure=departure_dt,
arrival=arrival_dt,
duration_minutes=leg.duration_min,
is_overnight=is_overnight,
**amenities,
)
emissions = int(leg.distance_km * EMISSIONS_KG_PER_KM.get(cabin_class.value, 0.09))
flights.append(FlightOffer(
id=flight_id,
segments=[segment],
total_duration_minutes=leg.duration_min,
stops=0,
price_usd=price,
cabin_class=cabin_class,
origin=origin,
destination=destination,
departure=departure_dt,
arrival=arrival_dt,
emissions_kg=emissions,
))
return flights
def _generate_connecting_flights(
graph: RouteGraph,
route_plan: RoutePlan,
departure_date: date,
cabin_class: CabinClass,
rng: random.Random,
) -> list[FlightOffer]:
"""Generate connecting flight options (1-stop or 2-stop)."""
origin = route_plan.waypoints[0]
destination = route_plan.waypoints[-1]
origin_airport = graph.airports[origin]
dest_airport = graph.airports[destination]
origin_route_count = len(origin_airport.routes)
# Realistic departure times for the first leg
first_leg = route_plan.legs[0]
num_first_carriers = len(first_leg.carriers) if first_leg.carriers else 1
route_type = _classify_route_type(origin_airport.continent, dest_airport.continent)
dep_weights = _build_departure_weights(route_type, num_first_carriers, origin_route_count, rng)
# Generate 2-5 options per connecting route
num_options = rng.randint(2, 5)
dep_times = _sample_departure_minutes(dep_weights, num_options, rng)
flights = []
for option_idx in range(num_options):
departure_minutes = dep_times[option_idx]
segments = []
layover_infos: list[LayoverInfo] = []
current_time = datetime(
departure_date.year, departure_date.month, departure_date.day,
departure_minutes // 60, departure_minutes % 60,
tzinfo=_get_timezone(graph, origin),
)
total_price = 0.0
total_duration = 0
valid = True
for i, leg in enumerate(route_plan.legs):
leg_origin = route_plan.waypoints[i]
leg_dest = route_plan.waypoints[i + 1]
origin_tz = _get_timezone(graph, leg_origin)
dest_tz = _get_timezone(graph, leg_dest)
origin_ap = graph.airports[leg_origin]
dest_ap = graph.airports[leg_dest]
if not leg.carriers:
valid = False
break
leg_cwts = _carrier_weights_at_airport(graph, leg_origin, leg.carriers)
carrier = rng.choices(leg.carriers, weights=leg_cwts, k=1)[0]
departure_dt = current_time.astimezone(origin_tz)
arrival_dt = departure_dt + timedelta(minutes=leg.duration_min)
arrival_dt = arrival_dt.astimezone(dest_tz)
# Overnight: compare local calendar dates in their respective timezones
is_overnight = arrival_dt.date() != departure_dt.date()
# Per-leg price with connection discount
# (see price_engine.py and config.py for full pricing docs)
leg_price = compute_price(
distance_km=leg.distance_km,
cabin_class=cabin_class.value,
departure_date=departure_date,
departure_hour=departure_dt.hour,
num_carriers=len(leg.carriers),
dest_continent=dest_ap.continent,
rng=rng,
is_connection=True,
)
leg_price *= CONNECTING_BASE_DISCOUNT
total_price += leg_price
aircraft = _pick_aircraft(leg.distance_km, rng)
amenities = _generate_amenities(aircraft, cabin_class, leg.distance_km, rng)
segments.append(FlightSegment(
airline_code=carrier["iata"],
airline_name=carrier["name"],
flight_number=_make_flight_number(carrier["iata"], rng),
aircraft=aircraft,
origin=leg_origin,
origin_city=origin_ap.city_name,
destination=leg_dest,
destination_city=dest_ap.city_name,
departure=departure_dt,
arrival=arrival_dt,
duration_minutes=leg.duration_min,
is_overnight=is_overnight,
**amenities,
))
# Add layover time for next leg
if i < len(route_plan.legs) - 1:
next_dest = route_plan.waypoints[i + 2]
min_conn = _min_connection_minutes(graph, leg_origin, leg_dest, next_dest)
layover = rng.randint(min_conn, MAX_LAYOVER_MINUTES)
current_time = arrival_dt + timedelta(minutes=layover)
total_duration += leg.duration_min + layover
# Layover overnight: both times at the hub airport (same tz).
# arrival_dt is already in dest_tz (= hub tz).
# current_time inherits that tz after timedelta addition.
layover_overnight = current_time.date() != arrival_dt.date()
layover_infos.append(LayoverInfo(
duration_minutes=layover,
airport=leg_dest,
airport_city=dest_ap.city_name,
is_overnight=layover_overnight,
is_short_connection=layover < SHORT_CONNECTION_MINUTES,
))
# Check if layover pushes to next day too far
if (current_time - datetime(
departure_date.year, departure_date.month, departure_date.day,
tzinfo=origin_tz
)).days > 1:
valid = False
break
else:
total_duration += leg.duration_min
if not valid:
continue
# Same-airline connection discount: if all segments are on the same
# carrier, apply an additional discount (bundled itinerary fare).
carrier_codes = {seg.airline_code for seg in segments}
if len(carrier_codes) == 1:
total_price *= SAME_AIRLINE_CONNECTION_DISCOUNT
total_price = round(total_price, 0)
first_departure = segments[0].departure
last_arrival = segments[-1].arrival
flight_id = (
f"{origin}{destination}{departure_date.isoformat()}"
f"{departure_minutes}{'-'.join(route_plan.waypoints)}{option_idx}"
)
total_distance = sum(leg.distance_km for leg in route_plan.legs)
emissions = int(total_distance * EMISSIONS_KG_PER_KM.get(cabin_class.value, 0.09))
flights.append(FlightOffer(
id=flight_id,
segments=segments,
layovers=layover_infos,
total_duration_minutes=total_duration,
stops=route_plan.stops,
price_usd=total_price,
cabin_class=cabin_class,
origin=origin,
destination=destination,
departure=first_departure,
arrival=last_arrival,
emissions_kg=emissions,
))
return flights