flight-rebooking / environment.py
dhnkhr's picture
Production-ready: Clean code with Groq API integration, LoRA model support, and FastAPI app
9753ee2
"""
Flight Rebooking Environment Engine
===================================
Real-world simulation of airline disruption recovery where an agent must
rebook stranded passengers under strict business constraints.
OpenEnv interface:
- reset() -> Observation
- step(Action) -> tuple[Observation, Reward, bool, dict]
- state() -> EnvState
"""
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
class PriorityTier(str, Enum):
PLATINUM = "Platinum"
GOLD = "Gold"
SILVER = "Silver"
STANDARD = "Standard"
class CabinClass(str, Enum):
BUSINESS = "Business"
ECONOMY = "Economy"
class PassengerStatus(str, Enum):
PENDING = "pending"
REBOOKED = "rebooked"
DOWNGRADED = "downgraded"
HOTEL_BOOKED = "hotel_booked"
PARTNER_REBOOKED = "partner_rebooked"
NO_SOLUTION = "no_solution"
class ActionType(str, Enum):
REBOOK_PASSENGER = "rebook_passenger"
OFFER_DOWNGRADE = "offer_downgrade"
BOOK_HOTEL = "book_hotel"
REBOOK_ON_PARTNER = "rebook_on_partner"
MARK_NO_SOLUTION = "mark_no_solution"
FINALIZE = "finalize"
class Passenger(BaseModel):
"""A stranded passenger awaiting re-accommodation."""
id: str
name: str
priority_tier: PriorityTier
original_flight: str
cabin_class: CabinClass
connection_deadline_hrs: Optional[float] = None
status: PassengerStatus = PassengerStatus.PENDING
assigned_flight: Optional[str] = None
class Flight(BaseModel):
"""A candidate replacement flight."""
id: str
destination: str
departure_hrs: float
economy_seats: int
business_seats: int
is_partner: bool = False
class Action(BaseModel):
"""Action model consumed by step()."""
action_type: ActionType
passenger_id: Optional[str] = None
flight_id: Optional[str] = None
class Reward(BaseModel):
"""Typed reward payload in the [0.0, 1.0] range."""
value: float = Field(ge=0.0, le=1.0)
components: Dict[str, float] = Field(default_factory=dict)
notes: List[str] = Field(default_factory=list)
class Observation(BaseModel):
"""Agent-visible state after each transition."""
pending_passengers: List[Dict[str, Any]]
available_flights: List[Dict[str, Any]]
budget_remaining: float
budget_spent: float
processed_count: int
total_passengers: int
invalid_actions: int
step_count: int
class EnvState(BaseModel):
"""Full simulator state for graders and debugging."""
passengers: List[Passenger] = Field(default_factory=list)
flights: List[Flight] = Field(default_factory=list)
budget_spent: float = 0.0
max_budget: float = 0.0
actions_taken: List[Dict[str, Any]] = Field(default_factory=list)
invalid_actions: int = 0
finalized: bool = False
step_count: int = 0
ACTION_COSTS = {
ActionType.REBOOK_PASSENGER: 0.0,
ActionType.OFFER_DOWNGRADE: 500.0,
ActionType.BOOK_HOTEL: 250.0,
ActionType.REBOOK_ON_PARTNER: 800.0,
ActionType.MARK_NO_SOLUTION: 0.0,
ActionType.FINALIZE: 0.0,
}
PRIORITY_WEIGHTS = {
PriorityTier.PLATINUM: 4,
PriorityTier.GOLD: 3,
PriorityTier.SILVER: 2,
PriorityTier.STANDARD: 1,
}
OUTCOME_QUALITY = {
PassengerStatus.REBOOKED: 1.00,
PassengerStatus.PARTNER_REBOOKED: 0.85,
PassengerStatus.DOWNGRADED: 0.65,
PassengerStatus.HOTEL_BOOKED: 0.45,
PassengerStatus.NO_SOLUTION: 0.05,
}
class FlightRebookingEnv:
"""OpenEnv-compatible flight rebooking simulator."""
def __init__(self, task_data: dict):
self.task_data = task_data
self._state: Optional[EnvState] = None
self._step_count = 0
self._max_steps = int(task_data.get("max_steps", 80))
def reset(self) -> Observation:
passengers = [Passenger(**p) for p in self.task_data["passengers"]]
flights = [Flight(**f) for f in self.task_data["flights"]]
self._state = EnvState(
passengers=passengers,
flights=flights,
budget_spent=0.0,
max_budget=self.task_data["max_budget"],
actions_taken=[],
invalid_actions=0,
finalized=False,
step_count=0,
)
self._step_count = 0
return self._get_observation()
def state(self) -> EnvState:
if self._state is None:
raise RuntimeError("Environment is not initialized. Call reset() first.")
return self._state
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
if self._state is None:
raise RuntimeError("Environment is not initialized. Call reset() first.")
self._step_count += 1
self._state.step_count = self._step_count
info: Dict[str, Any] = {}
if self._state.finalized:
reward = Reward(value=0.01, components={"terminal": 1.0}, notes=["episode_already_finalized"])
return self._get_observation(), reward, True, {"warning": "Episode already finalized."}
if self._step_count > self._max_steps:
self._state.finalized = True
reward = Reward(
value=0.01,
components={
"progress": self._completion_ratio(),
"budget_efficiency": self._budget_efficiency(),
"max_step_exceeded": 1.0,
},
notes=["forced_finalize_max_steps"],
)
self._record_action(action, reward, success=False, done=True, info={"warning": "Max steps reached."})
return self._get_observation(), reward, True, {"warning": "Max steps reached, forcing finalize."}
if action.action_type == ActionType.FINALIZE:
reward = self._build_finalize_reward()
self._state.finalized = True
unresolved = [p.id for p in self._state.passengers if p.status == PassengerStatus.PENDING]
if unresolved:
info["unresolved_passengers"] = unresolved
self._record_action(action, reward, success=(len(unresolved) == 0), done=True, info=info)
return self._get_observation(), reward, True, info
passenger = self._find_passenger(action.passenger_id)
if passenger is None:
reward = self._invalid_reward("passenger_not_found")
info["error"] = f"Passenger not found: {action.passenger_id}"
self._record_action(action, reward, success=False, done=False, info=info)
return self._get_observation(), reward, False, info
if passenger.status != PassengerStatus.PENDING:
reward = self._invalid_reward("passenger_already_processed")
info["error"] = f"Passenger {action.passenger_id} already processed ({passenger.status.value})."
self._record_action(action, reward, success=False, done=False, info=info)
return self._get_observation(), reward, False, info
priority_inversion = self._has_higher_priority_pending(passenger)
handler = {
ActionType.REBOOK_PASSENGER: self._handle_rebook,
ActionType.OFFER_DOWNGRADE: self._handle_downgrade,
ActionType.BOOK_HOTEL: self._handle_hotel,
ActionType.REBOOK_ON_PARTNER: self._handle_partner,
ActionType.MARK_NO_SOLUTION: self._handle_no_solution,
}[action.action_type]
success, action_info = handler(passenger, action)
info.update(action_info)
if not success:
reward = self._invalid_reward(info.get("error", "invalid_action"))
self._record_action(action, reward, success=False, done=False, info=info)
return self._get_observation(), reward, False, info
repeat_penalty = self._repeat_failure_penalty(action)
reward = self._build_resolution_reward(
passenger=passenger,
flight=self._find_flight(passenger.assigned_flight),
action_cost=ACTION_COSTS[action.action_type],
priority_inversion=priority_inversion,
repeat_penalty=repeat_penalty,
)
done = all(p.status != PassengerStatus.PENDING for p in self._state.passengers)
if done:
self._state.finalized = True
reward = self._add_terminal_bonus(reward)
info["auto_finalized"] = True
self._record_action(action, reward, success=True, done=done, info=info)
return self._get_observation(), reward, done, info
def _handle_rebook(self, passenger: Passenger, action: Action) -> Tuple[bool, Dict[str, Any]]:
flight = self._find_flight(action.flight_id)
if flight is None:
return False, {"error": f"Flight not found: {action.flight_id}"}
if flight.is_partner:
return False, {"error": "Use rebook_on_partner for partner flights."}
ok, msg = self._consume_seat(flight, passenger.cabin_class)
if not ok:
return False, {"error": msg}
passenger.status = PassengerStatus.REBOOKED
passenger.assigned_flight = flight.id
return True, {"resolved_status": passenger.status.value}
def _handle_downgrade(self, passenger: Passenger, action: Action) -> Tuple[bool, Dict[str, Any]]:
if passenger.cabin_class != CabinClass.BUSINESS:
return False, {"error": "Can only downgrade Business passengers."}
cost = ACTION_COSTS[ActionType.OFFER_DOWNGRADE]
if not self._spend(cost):
return False, {"error": f"Insufficient budget. Need ${cost:.0f}, have ${self._budget_remaining():.0f}."}
flight = self._find_flight(action.flight_id)
if flight is None:
self._refund(cost)
return False, {"error": f"Flight not found: {action.flight_id}"}
ok, msg = self._consume_seat(flight, CabinClass.ECONOMY)
if not ok:
self._refund(cost)
return False, {"error": msg}
passenger.status = PassengerStatus.DOWNGRADED
passenger.assigned_flight = flight.id
return True, {"resolved_status": passenger.status.value}
def _handle_hotel(self, passenger: Passenger, action: Action) -> Tuple[bool, Dict[str, Any]]:
cost = ACTION_COSTS[ActionType.BOOK_HOTEL]
if not self._spend(cost):
return False, {"error": f"Insufficient budget. Need ${cost:.0f}, have ${self._budget_remaining():.0f}."}
passenger.status = PassengerStatus.HOTEL_BOOKED
passenger.assigned_flight = None
return True, {"resolved_status": passenger.status.value}
def _handle_partner(self, passenger: Passenger, action: Action) -> Tuple[bool, Dict[str, Any]]:
cost = ACTION_COSTS[ActionType.REBOOK_ON_PARTNER]
if not self._spend(cost):
return False, {"error": f"Insufficient budget. Need ${cost:.0f}, have ${self._budget_remaining():.0f}."}
flight = self._find_flight(action.flight_id)
if flight is None:
self._refund(cost)
return False, {"error": f"Flight not found: {action.flight_id}"}
if not flight.is_partner:
self._refund(cost)
return False, {"error": f"Flight {action.flight_id} is not a partner flight."}
ok, msg = self._consume_seat(flight, passenger.cabin_class)
if not ok:
self._refund(cost)
return False, {"error": msg}
passenger.status = PassengerStatus.PARTNER_REBOOKED
passenger.assigned_flight = flight.id
return True, {"resolved_status": passenger.status.value}
def _handle_no_solution(self, passenger: Passenger, action: Action) -> Tuple[bool, Dict[str, Any]]:
passenger.status = PassengerStatus.NO_SOLUTION
passenger.assigned_flight = None
return True, {"resolved_status": passenger.status.value}
def _invalid_reward(self, reason: str) -> Reward:
self._state.invalid_actions += 1
penalty = min(0.08 * self._state.invalid_actions, 0.5)
return Reward(
value=max(0.01, 0.05 - penalty),
components={
"progress": self._completion_ratio(),
"budget_efficiency": self._budget_efficiency(),
"invalid_action_penalty": penalty,
},
notes=[reason, "invalid_action"],
)
def _build_resolution_reward(
self,
passenger: Passenger,
flight: Optional[Flight],
action_cost: float,
priority_inversion: bool,
repeat_penalty: float,
) -> Reward:
progress = self._completion_ratio()
outcome_quality = OUTCOME_QUALITY.get(passenger.status, 0.0)
priority_score = PRIORITY_WEIGHTS[passenger.priority_tier] / 4.0
deadline_score = self._deadline_score(passenger, flight)
budget_efficiency = self._budget_efficiency()
penalty = 0.0
notes: List[str] = []
if priority_inversion:
penalty += 0.15
notes.append("priority_inversion")
if repeat_penalty > 0:
penalty += repeat_penalty
notes.append("repeated_failed_action_pattern")
if passenger.status == PassengerStatus.NO_SOLUTION:
penalty += 0.2
notes.append("no_solution_penalty")
if action_cost > 0:
# Costly actions are valid but receive a mild regularization penalty.
penalty += min(action_cost / max(self._state.max_budget, 1.0), 0.15)
base = (
(0.30 * outcome_quality)
+ (0.25 * progress)
+ (0.15 * priority_score)
+ (0.15 * deadline_score)
+ (0.15 * budget_efficiency)
)
value = self._clamp(base - penalty)
return Reward(
value=value,
components={
"progress": progress,
"outcome_quality": outcome_quality,
"priority_score": priority_score,
"deadline_score": deadline_score,
"budget_efficiency": budget_efficiency,
"penalty": penalty,
},
notes=notes,
)
def _build_finalize_reward(self) -> Reward:
pending_count = sum(1 for p in self._state.passengers if p.status == PassengerStatus.PENDING)
total = max(len(self._state.passengers), 1)
completion = self._completion_ratio()
budget_efficiency = self._budget_efficiency()
if pending_count == 0:
value = self._clamp((0.85 * completion) + (0.15 * budget_efficiency))
notes = ["clean_finalize"]
else:
unresolved_penalty = pending_count / total
value = self._clamp(0.20 * completion - 0.30 * unresolved_penalty)
notes = ["early_finalize_penalty"]
return Reward(
value=value,
components={
"completion": completion,
"budget_efficiency": budget_efficiency,
"pending_ratio": pending_count / total,
},
notes=notes,
)
def _add_terminal_bonus(self, reward: Reward) -> Reward:
bonus = 0.1 * max(0.0, 1.0 - (self._state.invalid_actions * 0.05))
merged = dict(reward.components)
merged["terminal_bonus"] = bonus
return Reward(
value=self._clamp(reward.value + bonus),
components=merged,
notes=reward.notes + ["all_passengers_processed"],
)
def _deadline_score(self, passenger: Passenger, flight: Optional[Flight]) -> float:
if passenger.connection_deadline_hrs is None:
return 1.0
if flight is None:
return 0.0
if flight.departure_hrs <= passenger.connection_deadline_hrs:
return 1.0
return 0.2
def _has_higher_priority_pending(self, passenger: Passenger) -> bool:
current_weight = PRIORITY_WEIGHTS[passenger.priority_tier]
for other in self._state.passengers:
if other.id == passenger.id or other.status != PassengerStatus.PENDING:
continue
other_weight = PRIORITY_WEIGHTS[other.priority_tier]
if other_weight > current_weight:
return True
if (
other_weight == current_weight
and other.connection_deadline_hrs is not None
and passenger.connection_deadline_hrs is not None
and other.connection_deadline_hrs < passenger.connection_deadline_hrs
):
return True
if (
other_weight == current_weight
and other.connection_deadline_hrs is not None
and passenger.connection_deadline_hrs is None
):
return True
return False
def _repeat_failure_penalty(self, action: Action) -> float:
if len(self._state.actions_taken) < 2:
return 0.0
signature = self._signature(action)
recent = self._state.actions_taken[-2:]
repeated_failures = all(
(not item.get("success", True)) and tuple(item.get("signature", ())) == signature
for item in recent
)
return 0.1 if repeated_failures else 0.0
def _completion_ratio(self) -> float:
total = max(len(self._state.passengers), 1)
processed = sum(1 for p in self._state.passengers if p.status != PassengerStatus.PENDING)
return processed / total
def _budget_efficiency(self) -> float:
if self._state.max_budget <= 0:
return 1.0
return self._clamp(1.0 - (self._state.budget_spent / self._state.max_budget))
def _spend(self, cost: float) -> bool:
if (self._state.budget_spent + cost) > self._state.max_budget:
return False
self._state.budget_spent += cost
return True
def _refund(self, cost: float) -> None:
self._state.budget_spent = max(0.0, self._state.budget_spent - cost)
def _find_passenger(self, passenger_id: Optional[str]) -> Optional[Passenger]:
if passenger_id is None:
return None
for passenger in self._state.passengers:
if passenger.id == passenger_id:
return passenger
return None
def _find_flight(self, flight_id: Optional[str]) -> Optional[Flight]:
if flight_id is None:
return None
for flight in self._state.flights:
if flight.id == flight_id:
return flight
return None
def _consume_seat(self, flight: Flight, cabin: CabinClass) -> Tuple[bool, str]:
if cabin == CabinClass.BUSINESS:
if flight.business_seats <= 0:
return False, f"No Business seats on {flight.id}."
flight.business_seats -= 1
return True, ""
if flight.economy_seats <= 0:
return False, f"No Economy seats on {flight.id}."
flight.economy_seats -= 1
return True, ""
def _budget_remaining(self) -> float:
return self._state.max_budget - self._state.budget_spent
def _signature(self, action: Action) -> Tuple[str, Optional[str], Optional[str]]:
return action.action_type.value, action.passenger_id, action.flight_id
def _record_action(self, action: Action, reward: Reward, success: bool, done: bool, info: Dict[str, Any]) -> None:
self._state.actions_taken.append(
{
"step": self._step_count,
"signature": self._signature(action),
"action": action.model_dump(mode="json"),
"reward": reward.model_dump(mode="json"),
"success": success,
"done": done,
"info": info,
}
)
def _clamp(self, value: float) -> float:
return max(0.01, min(0.99, value))
def _get_observation(self) -> Observation:
pending_passengers: List[Dict[str, Any]] = []
for passenger in self._state.passengers:
if passenger.status != PassengerStatus.PENDING:
continue
pending_passengers.append(
{
"id": passenger.id,
"name": passenger.name,
"priority_tier": passenger.priority_tier.value,
"original_flight": passenger.original_flight,
"cabin_class": passenger.cabin_class.value,
"connection_deadline_hrs": passenger.connection_deadline_hrs,
}
)
pending_passengers.sort(
key=lambda p: (
-PRIORITY_WEIGHTS[PriorityTier(p["priority_tier"])],
p["connection_deadline_hrs"] if p["connection_deadline_hrs"] is not None else 1e9,
)
)
available_flights: List[Dict[str, Any]] = []
for flight in self._state.flights:
available_flights.append(
{
"id": flight.id,
"destination": flight.destination,
"departure_hrs": flight.departure_hrs,
"economy_seats": flight.economy_seats,
"business_seats": flight.business_seats,
"is_partner": flight.is_partner,
}
)
processed = sum(1 for p in self._state.passengers if p.status != PassengerStatus.PENDING)
return Observation(
pending_passengers=pending_passengers,
available_flights=available_flights,
budget_remaining=self._budget_remaining(),
budget_spent=self._state.budget_spent,
processed_count=processed,
total_passengers=len(self._state.passengers),
invalid_actions=self._state.invalid_actions,
step_count=self._step_count,
)