Spaces:
Sleeping
Sleeping
File size: 2,970 Bytes
ce7d977 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | from typing import Optional, Dict, Any, List, Literal
from dataclasses import dataclass, field
import time, random, math
from pydantic import BaseModel, Field
from openenv.core.env_server import Action, Observation
# ---------- Types ----------
DisruptionType = Literal["delay", "cancel"]
ActionType = Literal["rebook_connection", "keep_connection", "cancel_cab_rebook", "request_refund"]
# ---------- Models ----------
class ObservationFeatures(BaseModel):
buffer_min: float
flight_lock_in_min: float
cab_lock_in_min: float
cab_pickup_in_min: float
cab_mismatch_min: float # pickup - arrival - slack
connection_resolved: bool
cab_resolved: bool
next_task: str # e.g., "FIX_CONNECTION", "FIX_CAB", "DONE"
class AltFlightFeatures(BaseModel):
alt_buffer_min: float
alt_arrival_delay_min: float
alt_lock_in_min: float
class AltFlight(BaseModel):
id: str
depart_ts: int
arrive_ts: int
price_delta: float = 0.0
reliability: float = 0.8
lock_ts: Optional[int] = None
features: Optional[AltFlightFeatures] = None
class CabPolicy(BaseModel):
cancel_fee: float = 0.0
rebook_fee: float = 0.0
class DisruptionAction(Action):
type: ActionType
alt_flight_id: Optional[str] = None
new_cab_pickup_ts: Optional[int] = None
class DisruptionObservation(Observation):
now_ts: int
at_airport: str
orig_conn_depart_ts: int
orig_conn_arrive_ts: int
disruption_type: DisruptionType
delay_minutes: int = 0
cancelled: bool = False
alt_flights: List[AltFlight] = Field(default_factory=list)
cab_pickup_ts: int
cab_policy: CabPolicy
cab_lock_ts: int
flight_lock_ts: int
features: Optional[ObservationFeatures] = None
# ---------- Internal State ----------
@dataclass
class FlightOption:
id: str
depart_ts: int
arrive_ts: int
price_delta: float = 0.0
reliability: float = 0.8
lock_ts: Optional[int] = None
@dataclass
class Cab:
pickup_ts: int
lock_ts: int
cancel_fee: float = 0.0
rebook_fee: float = 0.0
active: bool = True
penalties_paid: float = 0.0
@dataclass
class DisruptionState:
seed: int
rng: random.Random
now_ts: int
at_airport: str
disruption_type: DisruptionType
orig_conn_depart_ts: int
orig_conn_arrive_ts: int
flight_lock_ts: int
planned_gate_ready_ts: int
baseline_arrive_ts: int
# Fields with defaults below
delay_mean_min: float = 0.0
delay_std_min: float = 0.0
delay_realized_min: Optional[int] = None
cancelled: bool = False
alt_flights: List[FlightOption] = field(default_factory=list)
cab: Cab = None
min_connection_buffer_min: int = 45
changes: int = 0
reversals: int = 0
unresolved_critical: int = 0
cost_incurred: float = 0.0
done: bool = False
success: bool = False
connection_resolved: bool = False
cab_resolved: bool = False
fail_reason: Optional[str] = None |