Spaces:
Sleeping
Sleeping
| from typing import Optional, Dict, Any, List, Literal | |
| from dataclasses import dataclass, field | |
| import time, random, math | |
| from pydantic import BaseModel, Field | |
| from openenv.core.env_server import Action, Observation | |
| # ---------- Types ---------- | |
| DisruptionType = Literal["delay", "cancel"] | |
| ActionType = Literal["rebook_connection", "keep_connection", "cancel_cab_rebook", "request_refund"] | |
| # ---------- Models ---------- | |
| class ObservationFeatures(BaseModel): | |
| buffer_min: float | |
| flight_lock_in_min: float | |
| cab_lock_in_min: float | |
| cab_pickup_in_min: float | |
| cab_mismatch_min: float # pickup - arrival - slack | |
| connection_resolved: bool | |
| cab_resolved: bool | |
| next_task: str # e.g., "FIX_CONNECTION", "FIX_CAB", "DONE" | |
| class AltFlightFeatures(BaseModel): | |
| alt_buffer_min: float | |
| alt_arrival_delay_min: float | |
| alt_lock_in_min: float | |
| class AltFlight(BaseModel): | |
| id: str | |
| depart_ts: int | |
| arrive_ts: int | |
| price_delta: float = 0.0 | |
| reliability: float = 0.8 | |
| lock_ts: Optional[int] = None | |
| features: Optional[AltFlightFeatures] = None | |
| class CabPolicy(BaseModel): | |
| cancel_fee: float = 0.0 | |
| rebook_fee: float = 0.0 | |
| class DisruptionAction(Action): | |
| type: ActionType | |
| alt_flight_id: Optional[str] = None | |
| new_cab_pickup_ts: Optional[int] = None | |
| class DisruptionObservation(Observation): | |
| now_ts: int | |
| at_airport: str | |
| orig_conn_depart_ts: int | |
| orig_conn_arrive_ts: int | |
| disruption_type: DisruptionType | |
| delay_minutes: int = 0 | |
| cancelled: bool = False | |
| alt_flights: List[AltFlight] = Field(default_factory=list) | |
| cab_pickup_ts: int | |
| cab_policy: CabPolicy | |
| cab_lock_ts: int | |
| flight_lock_ts: int | |
| features: Optional[ObservationFeatures] = None | |
| # ---------- Internal State ---------- | |
| class FlightOption: | |
| id: str | |
| depart_ts: int | |
| arrive_ts: int | |
| price_delta: float = 0.0 | |
| reliability: float = 0.8 | |
| lock_ts: Optional[int] = None | |
| class Cab: | |
| pickup_ts: int | |
| lock_ts: int | |
| cancel_fee: float = 0.0 | |
| rebook_fee: float = 0.0 | |
| active: bool = True | |
| penalties_paid: float = 0.0 | |
| class DisruptionState: | |
| seed: int | |
| rng: random.Random | |
| now_ts: int | |
| at_airport: str | |
| disruption_type: DisruptionType | |
| orig_conn_depart_ts: int | |
| orig_conn_arrive_ts: int | |
| flight_lock_ts: int | |
| planned_gate_ready_ts: int | |
| baseline_arrive_ts: int | |
| # Fields with defaults below | |
| delay_mean_min: float = 0.0 | |
| delay_std_min: float = 0.0 | |
| delay_realized_min: Optional[int] = None | |
| cancelled: bool = False | |
| alt_flights: List[FlightOption] = field(default_factory=list) | |
| cab: Cab = None | |
| min_connection_buffer_min: int = 45 | |
| changes: int = 0 | |
| reversals: int = 0 | |
| unresolved_critical: int = 0 | |
| cost_incurred: float = 0.0 | |
| done: bool = False | |
| success: bool = False | |
| connection_resolved: bool = False | |
| cab_resolved: bool = False | |
| fail_reason: Optional[str] = None |