disruption-env / models.py
numb3r33's picture
Upload folder using huggingface_hub
ce7d977 verified
from typing import Optional, Dict, Any, List, Literal
from dataclasses import dataclass, field
import time, random, math
from pydantic import BaseModel, Field
from openenv.core.env_server import Action, Observation
# ---------- Types ----------
DisruptionType = Literal["delay", "cancel"]
ActionType = Literal["rebook_connection", "keep_connection", "cancel_cab_rebook", "request_refund"]
# ---------- Models ----------
class ObservationFeatures(BaseModel):
buffer_min: float
flight_lock_in_min: float
cab_lock_in_min: float
cab_pickup_in_min: float
cab_mismatch_min: float # pickup - arrival - slack
connection_resolved: bool
cab_resolved: bool
next_task: str # e.g., "FIX_CONNECTION", "FIX_CAB", "DONE"
class AltFlightFeatures(BaseModel):
alt_buffer_min: float
alt_arrival_delay_min: float
alt_lock_in_min: float
class AltFlight(BaseModel):
id: str
depart_ts: int
arrive_ts: int
price_delta: float = 0.0
reliability: float = 0.8
lock_ts: Optional[int] = None
features: Optional[AltFlightFeatures] = None
class CabPolicy(BaseModel):
cancel_fee: float = 0.0
rebook_fee: float = 0.0
class DisruptionAction(Action):
type: ActionType
alt_flight_id: Optional[str] = None
new_cab_pickup_ts: Optional[int] = None
class DisruptionObservation(Observation):
now_ts: int
at_airport: str
orig_conn_depart_ts: int
orig_conn_arrive_ts: int
disruption_type: DisruptionType
delay_minutes: int = 0
cancelled: bool = False
alt_flights: List[AltFlight] = Field(default_factory=list)
cab_pickup_ts: int
cab_policy: CabPolicy
cab_lock_ts: int
flight_lock_ts: int
features: Optional[ObservationFeatures] = None
# ---------- Internal State ----------
@dataclass
class FlightOption:
id: str
depart_ts: int
arrive_ts: int
price_delta: float = 0.0
reliability: float = 0.8
lock_ts: Optional[int] = None
@dataclass
class Cab:
pickup_ts: int
lock_ts: int
cancel_fee: float = 0.0
rebook_fee: float = 0.0
active: bool = True
penalties_paid: float = 0.0
@dataclass
class DisruptionState:
seed: int
rng: random.Random
now_ts: int
at_airport: str
disruption_type: DisruptionType
orig_conn_depart_ts: int
orig_conn_arrive_ts: int
flight_lock_ts: int
planned_gate_ready_ts: int
baseline_arrive_ts: int
# Fields with defaults below
delay_mean_min: float = 0.0
delay_std_min: float = 0.0
delay_realized_min: Optional[int] = None
cancelled: bool = False
alt_flights: List[FlightOption] = field(default_factory=list)
cab: Cab = None
min_connection_buffer_min: int = 45
changes: int = 0
reversals: int = 0
unresolved_critical: int = 0
cost_incurred: float = 0.0
done: bool = False
success: bool = False
connection_resolved: bool = False
cab_resolved: bool = False
fail_reason: Optional[str] = None