""" ML policy helpers for flight rebooking. This module provides: - deterministic expert policy used for dataset generation, - fixed-length feature extraction for supervised learning, - safe action construction from ranked action-type preferences. """ from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional from environment import ActionType, CabinClass, PriorityTier ACTION_TYPE_ORDER: List[str] = [ ActionType.REBOOK_PASSENGER.value, ActionType.OFFER_DOWNGRADE.value, ActionType.REBOOK_ON_PARTNER.value, ActionType.BOOK_HOTEL.value, ActionType.MARK_NO_SOLUTION.value, ActionType.FINALIZE.value, ] def _tier_weight(tier: str) -> int: return { PriorityTier.PLATINUM.value: 4, PriorityTier.GOLD.value: 3, PriorityTier.SILVER.value: 2, PriorityTier.STANDARD.value: 1, }.get(tier, 1) def _deadline_sort_value(deadline_hrs: Optional[float]) -> float: return float(deadline_hrs) if deadline_hrs is not None else 10**9 def _has_seat(flight: Dict[str, Any], cabin_class: str) -> bool: if cabin_class == CabinClass.BUSINESS.value: return int(flight["business_seats"]) > 0 return int(flight["economy_seats"]) > 0 def _sorted_pending_passengers(observation: Dict[str, Any]) -> List[Dict[str, Any]]: pending = list(observation.get("pending_passengers", [])) pending.sort( key=lambda p: ( -_tier_weight(str(p.get("priority_tier", ""))), _deadline_sort_value(p.get("connection_deadline_hrs")), ) ) return pending def _sorted_flights(observation: Dict[str, Any]) -> List[Dict[str, Any]]: flights = list(observation.get("available_flights", [])) flights.sort(key=lambda f: float(f.get("departure_hrs", 10**9))) return flights def heuristic_action(observation: Dict[str, Any]) -> Dict[str, Any]: pending = _sorted_pending_passengers(observation) if not pending: return {"action_type": ActionType.FINALIZE.value} passenger = pending[0] flights = _sorted_flights(observation) for flight in flights: if flight.get("is_partner", False): continue if _has_seat(flight, str(passenger["cabin_class"])): return { "action_type": ActionType.REBOOK_PASSENGER.value, "passenger_id": passenger["id"], "flight_id": flight["id"], } if passenger.get("cabin_class") == CabinClass.BUSINESS.value: for flight in flights: if flight.get("is_partner", False): continue if int(flight.get("economy_seats", 0)) > 0 and float(observation.get("budget_remaining", 0.0)) >= 500.0: return { "action_type": ActionType.OFFER_DOWNGRADE.value, "passenger_id": passenger["id"], "flight_id": flight["id"], } for flight in flights: if not flight.get("is_partner", False): continue if _has_seat(flight, str(passenger["cabin_class"])) and float(observation.get("budget_remaining", 0.0)) >= 800.0: return { "action_type": ActionType.REBOOK_ON_PARTNER.value, "passenger_id": passenger["id"], "flight_id": flight["id"], } if float(observation.get("budget_remaining", 0.0)) >= 250.0: return { "action_type": ActionType.BOOK_HOTEL.value, "passenger_id": passenger["id"], } return { "action_type": ActionType.MARK_NO_SOLUTION.value, "passenger_id": passenger["id"], } def observation_to_features( observation: Dict[str, Any], max_pending: int = 5, max_flights: int = 6, ) -> List[float]: pending = _sorted_pending_passengers(observation) flights = _sorted_flights(observation) processed_count = float(observation.get("processed_count", 0)) total_passengers = max(float(observation.get("total_passengers", 1)), 1.0) budget_remaining = max(float(observation.get("budget_remaining", 0.0)), 0.0) budget_spent = max(float(observation.get("budget_spent", 0.0)), 0.0) budget_total = max(budget_remaining + budget_spent, 1.0) features: List[float] = [] features.extend( [ min(len(pending), 20) / 20.0, min(len(flights), 20) / 20.0, budget_remaining / budget_total, budget_spent / budget_total, processed_count / total_passengers, min(float(observation.get("invalid_actions", 0)), 20.0) / 20.0, min(float(observation.get("step_count", 0)), 120.0) / 120.0, 1.0 if pending else 0.0, ] ) for passenger in pending[:max_pending]: deadline = passenger.get("connection_deadline_hrs") has_deadline = 1.0 if deadline is not None else 0.0 deadline_norm = (min(float(deadline), 12.0) / 12.0) if deadline is not None else 1.0 features.extend( [ _tier_weight(str(passenger.get("priority_tier", ""))) / 4.0, 1.0 if passenger.get("cabin_class") == CabinClass.BUSINESS.value else 0.0, has_deadline, deadline_norm, ] ) for _ in range(max_pending - len(pending[:max_pending])): features.extend([0.0, 0.0, 0.0, 0.0]) for flight in flights[:max_flights]: features.extend( [ 1.0 if flight.get("is_partner", False) else 0.0, min(float(flight.get("departure_hrs", 12.0)), 12.0) / 12.0, min(float(flight.get("economy_seats", 0.0)), 12.0) / 12.0, min(float(flight.get("business_seats", 0.0)), 6.0) / 6.0, ] ) for _ in range(max_flights - len(flights[:max_flights])): features.extend([0.0, 0.0, 0.0, 0.0]) same_econ = 0.0 same_bus = 0.0 partner_econ = 0.0 partner_bus = 0.0 for flight in flights: if flight.get("is_partner", False): partner_econ += float(flight.get("economy_seats", 0.0)) partner_bus += float(flight.get("business_seats", 0.0)) else: same_econ += float(flight.get("economy_seats", 0.0)) same_bus += float(flight.get("business_seats", 0.0)) features.extend( [ min(same_econ, 30.0) / 30.0, min(same_bus, 20.0) / 20.0, min(partner_econ, 30.0) / 30.0, min(partner_bus, 20.0) / 20.0, ] ) return features def build_feasible_action_for_type(observation: Dict[str, Any], action_type: str) -> Optional[Dict[str, Any]]: pending = _sorted_pending_passengers(observation) flights = _sorted_flights(observation) budget_remaining = float(observation.get("budget_remaining", 0.0)) if not pending: return {"action_type": ActionType.FINALIZE.value} if action_type == ActionType.FINALIZE.value: return None if action_type == ActionType.BOOK_HOTEL.value: if budget_remaining >= 250.0: return { "action_type": ActionType.BOOK_HOTEL.value, "passenger_id": pending[0]["id"], } return None if action_type == ActionType.MARK_NO_SOLUTION.value: return { "action_type": ActionType.MARK_NO_SOLUTION.value, "passenger_id": pending[0]["id"], } if action_type == ActionType.REBOOK_PASSENGER.value: for passenger in pending: for flight in flights: if flight.get("is_partner", False): continue if _has_seat(flight, str(passenger["cabin_class"])): return { "action_type": ActionType.REBOOK_PASSENGER.value, "passenger_id": passenger["id"], "flight_id": flight["id"], } return None if action_type == ActionType.OFFER_DOWNGRADE.value: if budget_remaining < 500.0: return None business_pending = [p for p in pending if p.get("cabin_class") == CabinClass.BUSINESS.value] for passenger in business_pending: for flight in flights: if flight.get("is_partner", False): continue if int(flight.get("economy_seats", 0)) > 0: return { "action_type": ActionType.OFFER_DOWNGRADE.value, "passenger_id": passenger["id"], "flight_id": flight["id"], } return None if action_type == ActionType.REBOOK_ON_PARTNER.value: if budget_remaining < 800.0: return None for passenger in pending: for flight in flights: if not flight.get("is_partner", False): continue if _has_seat(flight, str(passenger["cabin_class"])): return { "action_type": ActionType.REBOOK_ON_PARTNER.value, "passenger_id": passenger["id"], "flight_id": flight["id"], } return None return None def choose_action_from_ranked_types(observation: Dict[str, Any], ranked_types: Iterable[str]) -> Dict[str, Any]: for action_type in ranked_types: if not action_type: continue candidate = build_feasible_action_for_type(observation, str(action_type)) if candidate is not None: return candidate return heuristic_action(observation)