File size: 2,970 Bytes
ce7d977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Optional, Dict, Any, List, Literal
from dataclasses import dataclass, field
import time, random, math
from pydantic import BaseModel, Field
from openenv.core.env_server import Action, Observation

# ---------- Types ----------
DisruptionType = Literal["delay", "cancel"]
ActionType = Literal["rebook_connection", "keep_connection", "cancel_cab_rebook", "request_refund"]

# ---------- Models ----------
class ObservationFeatures(BaseModel):
    buffer_min: float
    flight_lock_in_min: float
    cab_lock_in_min: float
    cab_pickup_in_min: float
    cab_mismatch_min: float  # pickup - arrival - slack
    connection_resolved: bool
    cab_resolved: bool
    next_task: str  # e.g., "FIX_CONNECTION", "FIX_CAB", "DONE"

class AltFlightFeatures(BaseModel):
    alt_buffer_min: float
    alt_arrival_delay_min: float
    alt_lock_in_min: float

class AltFlight(BaseModel):
    id: str
    depart_ts: int
    arrive_ts: int
    price_delta: float = 0.0
    reliability: float = 0.8
    lock_ts: Optional[int] = None
    features: Optional[AltFlightFeatures] = None

class CabPolicy(BaseModel):
    cancel_fee: float = 0.0
    rebook_fee: float = 0.0

class DisruptionAction(Action):
    type: ActionType
    alt_flight_id: Optional[str] = None
    new_cab_pickup_ts: Optional[int] = None

class DisruptionObservation(Observation):
    now_ts: int
    at_airport: str
    orig_conn_depart_ts: int
    orig_conn_arrive_ts: int
    disruption_type: DisruptionType
    delay_minutes: int = 0
    cancelled: bool = False
    alt_flights: List[AltFlight] = Field(default_factory=list)
    cab_pickup_ts: int
    cab_policy: CabPolicy
    cab_lock_ts: int
    flight_lock_ts: int
    features: Optional[ObservationFeatures] = None

# ---------- Internal State ----------
@dataclass
class FlightOption:
    id: str
    depart_ts: int
    arrive_ts: int
    price_delta: float = 0.0
    reliability: float = 0.8
    lock_ts: Optional[int] = None

@dataclass
class Cab:
    pickup_ts: int
    lock_ts: int
    cancel_fee: float = 0.0
    rebook_fee: float = 0.0
    active: bool = True
    penalties_paid: float = 0.0

@dataclass
class DisruptionState:
    seed: int
    rng: random.Random
    now_ts: int
    at_airport: str
    disruption_type: DisruptionType
    orig_conn_depart_ts: int
    orig_conn_arrive_ts: int
    flight_lock_ts: int
    planned_gate_ready_ts: int
    baseline_arrive_ts: int
    # Fields with defaults below
    delay_mean_min: float = 0.0
    delay_std_min: float = 0.0
    delay_realized_min: Optional[int] = None
    cancelled: bool = False
    alt_flights: List[FlightOption] = field(default_factory=list)
    cab: Cab = None
    min_connection_buffer_min: int = 45
    changes: int = 0
    reversals: int = 0
    unresolved_critical: int = 0
    cost_incurred: float = 0.0
    done: bool = False
    success: bool = False
    connection_resolved: bool = False
    cab_resolved: bool = False
    fail_reason: Optional[str] = None