from typing import Optional, Dict, Any, List, Literal from dataclasses import dataclass, field import time, random, math from pydantic import BaseModel, Field from openenv.core.env_server import Action, Observation # ---------- Types ---------- DisruptionType = Literal["delay", "cancel"] ActionType = Literal["rebook_connection", "keep_connection", "cancel_cab_rebook", "request_refund"] # ---------- Models ---------- class ObservationFeatures(BaseModel): buffer_min: float flight_lock_in_min: float cab_lock_in_min: float cab_pickup_in_min: float cab_mismatch_min: float # pickup - arrival - slack connection_resolved: bool cab_resolved: bool next_task: str # e.g., "FIX_CONNECTION", "FIX_CAB", "DONE" class AltFlightFeatures(BaseModel): alt_buffer_min: float alt_arrival_delay_min: float alt_lock_in_min: float class AltFlight(BaseModel): id: str depart_ts: int arrive_ts: int price_delta: float = 0.0 reliability: float = 0.8 lock_ts: Optional[int] = None features: Optional[AltFlightFeatures] = None class CabPolicy(BaseModel): cancel_fee: float = 0.0 rebook_fee: float = 0.0 class DisruptionAction(Action): type: ActionType alt_flight_id: Optional[str] = None new_cab_pickup_ts: Optional[int] = None class DisruptionObservation(Observation): now_ts: int at_airport: str orig_conn_depart_ts: int orig_conn_arrive_ts: int disruption_type: DisruptionType delay_minutes: int = 0 cancelled: bool = False alt_flights: List[AltFlight] = Field(default_factory=list) cab_pickup_ts: int cab_policy: CabPolicy cab_lock_ts: int flight_lock_ts: int features: Optional[ObservationFeatures] = None # ---------- Internal State ---------- @dataclass class FlightOption: id: str depart_ts: int arrive_ts: int price_delta: float = 0.0 reliability: float = 0.8 lock_ts: Optional[int] = None @dataclass class Cab: pickup_ts: int lock_ts: int cancel_fee: float = 0.0 rebook_fee: float = 0.0 active: bool = True penalties_paid: float = 0.0 @dataclass class DisruptionState: seed: int rng: random.Random now_ts: int at_airport: str disruption_type: DisruptionType orig_conn_depart_ts: int orig_conn_arrive_ts: int flight_lock_ts: int planned_gate_ready_ts: int baseline_arrive_ts: int # Fields with defaults below delay_mean_min: float = 0.0 delay_std_min: float = 0.0 delay_realized_min: Optional[int] = None cancelled: bool = False alt_flights: List[FlightOption] = field(default_factory=list) cab: Cab = None min_connection_buffer_min: int = 45 changes: int = 0 reversals: int = 0 unresolved_critical: int = 0 cost_incurred: float = 0.0 done: bool = False success: bool = False connection_resolved: bool = False cab_resolved: bool = False fail_reason: Optional[str] = None