numb3r33 commited on
Commit
ce7d977
·
verified ·
1 Parent(s): 351f3fc

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install git and uv
6
+ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
7
+ RUN pip install uv
8
+
9
+ # Copy your environment code
10
+ COPY . /app
11
+
12
+ # Install dependencies
13
+ RUN uv pip install --system fastapi uvicorn pydantic git+https://github.com/meta-pytorch/OpenEnv.git
14
+
15
+ # Expose port
16
+ EXPOSE 7860
17
+
18
+ # Run the server
19
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,12 +1,114 @@
1
  ---
2
- title: Disruption Env
3
- emoji: 📉
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 6.3.0
8
- app_file: app.py
9
  pinned: false
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Disruption Recovery RL Environment
3
+ emoji: ✈️🚨
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: docker
7
+ sdk_version: "1.0"
 
8
  pinned: false
9
+ tags:
10
+ - reinforcement-learning
11
+ - llm-agents
12
+ - decision-making
13
+ - simulation
14
+ - openenv
15
+ - travel
16
+ - operations-research
17
+ - planning
18
+ - ai-agents
19
  ---
20
 
21
+ # Disruption Recovery RL Environment
22
+
23
+ An OpenEnv-compatible reinforcement learning environment for training AI agents to handle travel disruptions. When flights are delayed or cancelled, the agent must rebook connections and adjust ground transportation to minimize cost, delay, and traveler stress.
24
+
25
+ ## Environment Overview
26
+
27
+ This environment simulates real-world travel disruption scenarios where an agent must:
28
+ - Detect when a flight delay/cancellation breaks downstream bookings
29
+ - Evaluate alternative flight options with time and cost tradeoffs
30
+ - Rebook connections before lock windows expire
31
+ - Adjust cab/transfer timing to match new arrival times
32
+ - Balance competing objectives: minimize delay, cost, and number of changes
33
+
34
+ ## State Space
35
+
36
+ The agent observes:
37
+ - **Disruption type**: delay or cancellation
38
+ - **Connection status**: buffer time, feasibility, lock-in windows
39
+ - **Alternative flights**: departure/arrival times, costs, availability, buffers
40
+ - **Cab status**: pickup time, mismatch with arrival, cancellation policy
41
+ - **Derived features**: `connection_resolved`, `cab_resolved`, `next_task`
42
+
43
+ ## Action Space
44
+
45
+ Four actions available:
46
+ 1. **`rebook_connection`** — switch to an alternative flight
47
+ 2. **`keep_connection`** — maintain current booking
48
+ 3. **`cancel_cab_rebook`** — adjust ground transfer timing
49
+ 4. **`request_refund`** — cancel and request refund
50
+
51
+ ## Reward Structure
52
+
53
+ **Terminal rewards:**
54
+ - Success: +100 (reach destination with all bookings resolved)
55
+ - Failure: -200 (missed connection, locked out, or max steps)
56
+
57
+ **Shaped rewards** (within success):
58
+ - Time penalty: -0.5 per minute of additional delay
59
+ - Cost penalty: -1.0 per currency unit spent
60
+ - Change penalty: -15 per major booking change
61
+ - Stress penalty: -25 for unresolved critical items near lock time
62
+
63
+ Uses potential-based reward shaping on search space reduction for dense learning signal.
64
+
65
+ ## Usage
66
+
67
+ ### Start the server
68
+
69
+ The environment runs as a FastAPI server on port 7860:
70
+
71
+ `uvicorn server.app:app --host 0.0.0.0 --port 7860`
72
+
73
+ ### Interact via HTTP
74
+
75
+ **Reset:**
76
+ curl -X POST http://localhost:7860/reset
77
+
78
+ **Step:**
79
+ curl -X POST http://localhost:7860/step
80
+ -H "Content-Type: application/json"
81
+ -d '{"action": {"type": "rebook_connection", "alt_flight_id": "ALT_0"}}'
82
+
83
+ ### Python Client Example
84
+
85
+ ```py
86
+ import httpx from disruption_env.models import DisruptionAction
87
+
88
+ client = httpx.Client(base_url="http://localhost:7860")
89
+ obs = client.post("/reset").json()
90
+
91
+ action = DisruptionAction(type="rebook_connection", alt_flight_id="ALT_0")
92
+ result = client.post("/step", json={"action": action.dict()}).json()
93
+
94
+ print(f"Reward: {result['reward']}, Done: {result['done']}")
95
+ ```
96
+
97
+ ## Design Features
98
+
99
+ - **Phase-aware observations**: Clearly signals whether connection or cab needs attention
100
+ - **Feasibility labeling**: Marks invalid options (locked, insufficient buffer) to prevent impossible actions
101
+ - **Temporal constraints**: Lock windows create urgency and force timely decisions
102
+ - **Multi-objective optimization**: Trade off time, cost, and operational complexity
103
+
104
+ ## Competition Context
105
+
106
+ Built for the **OpenEnv Student Challenge** to demonstrate:
107
+ - Real-world applicability (travel operations)
108
+ - Novel domain (traveler-centric disruption recovery vs airline ops)
109
+ - Rich decision space (temporal constraints, multi-step planning, competing objectives)
110
+ - Production relevance (applicable to travel booking platforms)
111
+
112
+ ## License
113
+
114
+ BSD-3-Clause (following OpenEnv conventions)
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (182 Bytes). View file
 
__pycache__/models.cpython-312.pyc ADDED
Binary file (5.25 kB). View file
 
models.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any, List, Literal
2
+ from dataclasses import dataclass, field
3
+ import time, random, math
4
+ from pydantic import BaseModel, Field
5
+ from openenv.core.env_server import Action, Observation
6
+
7
+ # ---------- Types ----------
8
+ DisruptionType = Literal["delay", "cancel"]
9
+ ActionType = Literal["rebook_connection", "keep_connection", "cancel_cab_rebook", "request_refund"]
10
+
11
+ # ---------- Models ----------
12
+ class ObservationFeatures(BaseModel):
13
+ buffer_min: float
14
+ flight_lock_in_min: float
15
+ cab_lock_in_min: float
16
+ cab_pickup_in_min: float
17
+ cab_mismatch_min: float # pickup - arrival - slack
18
+ connection_resolved: bool
19
+ cab_resolved: bool
20
+ next_task: str # e.g., "FIX_CONNECTION", "FIX_CAB", "DONE"
21
+
22
+ class AltFlightFeatures(BaseModel):
23
+ alt_buffer_min: float
24
+ alt_arrival_delay_min: float
25
+ alt_lock_in_min: float
26
+
27
+ class AltFlight(BaseModel):
28
+ id: str
29
+ depart_ts: int
30
+ arrive_ts: int
31
+ price_delta: float = 0.0
32
+ reliability: float = 0.8
33
+ lock_ts: Optional[int] = None
34
+ features: Optional[AltFlightFeatures] = None
35
+
36
+ class CabPolicy(BaseModel):
37
+ cancel_fee: float = 0.0
38
+ rebook_fee: float = 0.0
39
+
40
+ class DisruptionAction(Action):
41
+ type: ActionType
42
+ alt_flight_id: Optional[str] = None
43
+ new_cab_pickup_ts: Optional[int] = None
44
+
45
+ class DisruptionObservation(Observation):
46
+ now_ts: int
47
+ at_airport: str
48
+ orig_conn_depart_ts: int
49
+ orig_conn_arrive_ts: int
50
+ disruption_type: DisruptionType
51
+ delay_minutes: int = 0
52
+ cancelled: bool = False
53
+ alt_flights: List[AltFlight] = Field(default_factory=list)
54
+ cab_pickup_ts: int
55
+ cab_policy: CabPolicy
56
+ cab_lock_ts: int
57
+ flight_lock_ts: int
58
+ features: Optional[ObservationFeatures] = None
59
+
60
+ # ---------- Internal State ----------
61
+ @dataclass
62
+ class FlightOption:
63
+ id: str
64
+ depart_ts: int
65
+ arrive_ts: int
66
+ price_delta: float = 0.0
67
+ reliability: float = 0.8
68
+ lock_ts: Optional[int] = None
69
+
70
+ @dataclass
71
+ class Cab:
72
+ pickup_ts: int
73
+ lock_ts: int
74
+ cancel_fee: float = 0.0
75
+ rebook_fee: float = 0.0
76
+ active: bool = True
77
+ penalties_paid: float = 0.0
78
+
79
+ @dataclass
80
+ class DisruptionState:
81
+ seed: int
82
+ rng: random.Random
83
+ now_ts: int
84
+ at_airport: str
85
+ disruption_type: DisruptionType
86
+ orig_conn_depart_ts: int
87
+ orig_conn_arrive_ts: int
88
+ flight_lock_ts: int
89
+ planned_gate_ready_ts: int
90
+ baseline_arrive_ts: int
91
+ # Fields with defaults below
92
+ delay_mean_min: float = 0.0
93
+ delay_std_min: float = 0.0
94
+ delay_realized_min: Optional[int] = None
95
+ cancelled: bool = False
96
+ alt_flights: List[FlightOption] = field(default_factory=list)
97
+ cab: Cab = None
98
+ min_connection_buffer_min: int = 45
99
+ changes: int = 0
100
+ reversals: int = 0
101
+ unresolved_critical: int = 0
102
+ cost_incurred: float = 0.0
103
+ done: bool = False
104
+ success: bool = False
105
+ connection_resolved: bool = False
106
+ cab_resolved: bool = False
107
+ fail_reason: Optional[str] = None
server/__init__.py ADDED
File without changes
server/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (189 Bytes). View file
 
server/__pycache__/app.cpython-312.pyc ADDED
Binary file (1.43 kB). View file
 
server/__pycache__/disruption_environment.cpython-312.pyc ADDED
Binary file (19.3 kB). View file
 
server/app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+
4
+ from disruption_env.server.disruption_environment import DisruptionEnv
5
+ from disruption_env.models import DisruptionAction
6
+
7
+ app = FastAPI()
8
+ env = DisruptionEnv(seed=0)
9
+
10
+ class StepRequest(BaseModel):
11
+ action: DisruptionAction
12
+
13
+ @app.post("/reset")
14
+ def reset():
15
+ obs = env.reset()
16
+ return obs.model_dump()
17
+
18
+ @app.post("/step")
19
+ def step(request: StepRequest):
20
+ obs, reward, done, info = env.step(request.action)
21
+ return {
22
+ "observation": obs.model_dump(),
23
+ "reward": reward,
24
+ "done": done,
25
+ "info": info
26
+ }
server/disruption_environment.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from disruption_env.models import *
2
+
3
+ class DisruptionEnv:
4
+ def __init__(self, seed: int = 0, max_steps: int = 6):
5
+ self.seed = seed
6
+ self.max_steps = max_steps
7
+ self._episode_idx = 0
8
+ self.state: Optional[DisruptionState] = None
9
+ self.steps = 0
10
+
11
+ def _update_resolution_status(self, s: DisruptionState):
12
+ """Update connection_resolved and cab_resolved flags."""
13
+ # Connection resolved if buffer sufficient and not cancelled
14
+ buffer = self._connection_buffer_min(s)
15
+ s.connection_resolved = (buffer >= s.min_connection_buffer_min and not s.cancelled)
16
+
17
+ # Cab resolved if mismatch within tolerance (0-30 min after arrival)
18
+ arrival_ts = self._final_arrival_ts(s)
19
+ cab_mismatch_min = (s.cab.pickup_ts - arrival_ts - 10*60) / 60.0
20
+ s.cab_resolved = (0 <= cab_mismatch_min <= 30) or not s.cab.active
21
+
22
+ def _compute_features(self, s: DisruptionState) -> ObservationFeatures:
23
+ now = s.now_ts
24
+ buffer_min = self._connection_buffer_min(s)
25
+ flight_lock_in_min = (s.flight_lock_ts - now) / 60.0
26
+ cab_lock_in_min = (s.cab.lock_ts - now) / 60.0
27
+
28
+ arrival_ts = self._final_arrival_ts(s)
29
+ cab_pickup_in_min = (s.cab.pickup_ts - now) / 60.0
30
+ cab_mismatch_min = (s.cab.pickup_ts - arrival_ts - 10*60) / 60.0 # 10 min slack
31
+
32
+ if not s.connection_resolved:
33
+ next_task = "FIX_CONNECTION"
34
+ elif not s.cab_resolved:
35
+ next_task = "FIX_CAB"
36
+ else:
37
+ next_task = "DONE"
38
+
39
+ return ObservationFeatures(
40
+ buffer_min=buffer_min,
41
+ flight_lock_in_min=flight_lock_in_min,
42
+ cab_lock_in_min=cab_lock_in_min,
43
+ cab_pickup_in_min=cab_pickup_in_min,
44
+ cab_mismatch_min=cab_mismatch_min,
45
+ connection_resolved=s.connection_resolved,
46
+ cab_resolved=s.cab_resolved,
47
+ next_task=next_task
48
+ )
49
+
50
+ def _compute_alt_features(self, s: DisruptionState, alt: FlightOption) -> AltFlightFeatures:
51
+ now = s.now_ts
52
+ # Buffer if we switch to this alt
53
+ gate_ready = self._compute_gate_ready_ts(s)
54
+ alt_buffer_min = (alt.depart_ts - gate_ready) / 60.0
55
+ alt_arrival_delay_min = (alt.arrive_ts - s.baseline_arrive_ts) / 60.0
56
+ alt_lock_in_min = ((alt.lock_ts - now) / 60.0) if alt.lock_ts else 999.0
57
+
58
+ return AltFlightFeatures(
59
+ alt_buffer_min=alt_buffer_min,
60
+ alt_arrival_delay_min=alt_arrival_delay_min,
61
+ alt_lock_in_min=alt_lock_in_min
62
+ )
63
+
64
+
65
+ # ---------- Timing Helpers ----------
66
+ def _compute_gate_ready_ts(self, s: DisruptionState) -> int:
67
+ t = s.planned_gate_ready_ts
68
+ if s.disruption_type == "cancel" and s.cancelled:
69
+ return t # cancellation handled elsewhere
70
+ if s.disruption_type == "delay":
71
+ t += int(round(s.delay_mean_min)) * 60
72
+ return t
73
+
74
+ def _connection_buffer_min(self, s: DisruptionState) -> float:
75
+ gate_ready = self._compute_gate_ready_ts(s)
76
+ return (s.orig_conn_depart_ts - gate_ready) / 60.0
77
+
78
+ def _final_arrival_ts(self, s: DisruptionState) -> int:
79
+ t = s.orig_conn_arrive_ts
80
+ if s.disruption_type == "delay":
81
+ t += int(round(s.delay_mean_min)) * 60
82
+ return t
83
+
84
+ def _delta_delay_min(self, s: DisruptionState) -> float:
85
+ return (self._final_arrival_ts(s) - s.baseline_arrive_ts) / 60.0
86
+
87
+ # ---------- Terminal Resolution ----------
88
+ def _resolve_terminal(self, s: DisruptionState):
89
+ print(f"DEBUG: Entering _resolve_terminal, unresolved_critical={s.unresolved_critical}")
90
+
91
+ # Cancellation unresolved past lock
92
+ if s.now_ts >= s.flight_lock_ts and s.disruption_type == "cancel" and s.cancelled:
93
+ s.done = True
94
+ s.success = False
95
+ s.fail_reason = "flight_cancelled_unresolved"
96
+ return
97
+
98
+ # Cancellation still active
99
+ if s.disruption_type == "cancel" and s.cancelled:
100
+ s.done = True
101
+ s.success = False
102
+ s.fail_reason = "cancelled_connection"
103
+ return
104
+
105
+ # Buffer check for delay
106
+ buf = self._connection_buffer_min(s)
107
+ if buf < s.min_connection_buffer_min:
108
+ s.done = True
109
+ s.success = False
110
+ s.fail_reason = "missed_connection"
111
+ return
112
+
113
+ # Cab feasibility
114
+ arrival_ts = self._final_arrival_ts(s)
115
+ if s.cab.active and s.cab.pickup_ts < arrival_ts + 10 * 60:
116
+ if s.now_ts >= s.cab.lock_ts:
117
+ s.done = True
118
+ s.success = False
119
+ s.fail_reason = "cab_missed_locked"
120
+ return
121
+ else:
122
+ s.unresolved_critical = 1
123
+
124
+ print(f"DEBUG: Before success check - unresolved={s.unresolved_critical}, cancelled={s.cancelled}")
125
+
126
+ # Success if no unresolved issues
127
+ if s.connection_resolved and s.cab_resolved and s.unresolved_critical == 0:
128
+ print("DEBUG: Setting success=True")
129
+ s.done = True
130
+ s.success = True
131
+ s.fail_reason = None
132
+
133
+ # ---------- Reward ----------
134
+ def _reward(self, s: DisruptionState) -> float:
135
+ if s.done and s.success:
136
+ dT = max(0.0, self._delta_delay_min(s))
137
+ return 100.0 - 0.5 * dT - 1.0 * s.cost_incurred - 15.0 * s.changes - 25.0 * s.unresolved_critical
138
+ if s.done and not s.success:
139
+ return -200.0
140
+ return -0.2 # step penalty
141
+
142
+ # ---------- Observation Projection ----------
143
+ def _obs_from_state(self, s: DisruptionState) -> DisruptionObservation:
144
+ delay_minutes = int(round(s.delay_mean_min)) if s.disruption_type == "delay" else 0
145
+
146
+ # Compute features for alt flights
147
+ alt = []
148
+ for f in s.alt_flights:
149
+ alt_features = self._compute_alt_features(s, f)
150
+ alt.append(AltFlight(
151
+ id=f.id, depart_ts=f.depart_ts, arrive_ts=f.arrive_ts,
152
+ price_delta=f.price_delta, reliability=f.reliability, lock_ts=f.lock_ts,
153
+ features=alt_features
154
+ ))
155
+
156
+ # Compute main features
157
+ features = self._compute_features(s)
158
+
159
+ return DisruptionObservation(
160
+ now_ts=s.now_ts,
161
+ at_airport=s.at_airport,
162
+ orig_conn_depart_ts=s.orig_conn_depart_ts,
163
+ orig_conn_arrive_ts=s.orig_conn_arrive_ts,
164
+ disruption_type=s.disruption_type,
165
+ delay_minutes=delay_minutes,
166
+ cancelled=s.cancelled,
167
+ alt_flights=alt,
168
+ cab_pickup_ts=s.cab.pickup_ts,
169
+ cab_policy=CabPolicy(cancel_fee=s.cab.cancel_fee, rebook_fee=s.cab.rebook_fee),
170
+ cab_lock_ts=s.cab.lock_ts,
171
+ flight_lock_ts=s.flight_lock_ts,
172
+ features=features
173
+ )
174
+
175
+ # ---------- Reset ----------
176
+ def reset(self, *, seed: Optional[int] = None) -> DisruptionObservation:
177
+ if seed is None:
178
+ seed = self.seed + self._episode_idx
179
+ rng = random.Random(seed)
180
+ self._episode_idx += 1
181
+ self.steps = 0
182
+
183
+ now_ts = int(time.time())
184
+ at_airport = "ORIGIN"
185
+
186
+ orig_conn_depart_ts = now_ts + 3 * 3600
187
+ orig_conn_arrive_ts = now_ts + 7 * 3600
188
+ baseline_arrive_ts = orig_conn_arrive_ts
189
+
190
+ # Gate ready with buffer
191
+ base_buffer_min = rng.choice([35, 50, 70, 90])
192
+ planned_gate_ready_ts = orig_conn_depart_ts - base_buffer_min * 60
193
+
194
+ disruption_type: DisruptionType = rng.choice(["delay", "cancel"])
195
+ cancelled = (disruption_type == "cancel")
196
+
197
+ if disruption_type == "delay":
198
+ delay_mean_min = rng.choice([15, 45, 90, 180])
199
+ delay_std_min = 0.0
200
+ else:
201
+ delay_mean_min = 0.0
202
+ delay_std_min = 0.0
203
+
204
+ flight_lock_ts = now_ts + rng.choice([20, 40, 90]) * 60
205
+ cab_lock_ts = now_ts + rng.choice([20, 60, 120]) * 60
206
+
207
+ cab = Cab(
208
+ pickup_ts=orig_conn_arrive_ts + 20 * 60,
209
+ lock_ts=cab_lock_ts,
210
+ cancel_fee=rng.choice([0.0, 10.0, 25.0]),
211
+ rebook_fee=rng.choice([0.0, 5.0, 15.0]),
212
+ active=True,
213
+ penalties_paid=0.0,
214
+ )
215
+
216
+ alt_flights: List[FlightOption] = []
217
+ for i in range(rng.randint(1, 4)):
218
+ depart = orig_conn_depart_ts + rng.choice([-60, 0, 60, 120, 240]) * 60
219
+ arrive = orig_conn_arrive_ts + rng.choice([0, 30, 60, 120, 240]) * 60
220
+ alt_flights.append(FlightOption(
221
+ id=f"ALT_{i}",
222
+ depart_ts=depart,
223
+ arrive_ts=arrive,
224
+ price_delta=rng.choice([0.0, 25.0, 60.0, 120.0]),
225
+ reliability=rng.choice([0.6, 0.75, 0.9]),
226
+ lock_ts=now_ts + rng.choice([15, 30, 60]) * 60
227
+ ))
228
+
229
+ self.state = DisruptionState(
230
+ seed=seed, rng=rng,
231
+ now_ts=now_ts, at_airport=at_airport,
232
+ disruption_type=disruption_type,
233
+ delay_mean_min=delay_mean_min, delay_std_min=delay_std_min,
234
+ delay_realized_min=None,
235
+ cancelled=cancelled,
236
+ orig_conn_depart_ts=orig_conn_depart_ts,
237
+ orig_conn_arrive_ts=orig_conn_arrive_ts,
238
+ flight_lock_ts=flight_lock_ts,
239
+ alt_flights=alt_flights,
240
+ cab=cab,
241
+ planned_gate_ready_ts=planned_gate_ready_ts,
242
+ baseline_arrive_ts=baseline_arrive_ts,
243
+ min_connection_buffer_min=45,
244
+ changes=0, reversals=0, unresolved_critical=0, cost_incurred=0.0,
245
+ done=False, success=False, fail_reason=None,
246
+ )
247
+ return self._obs_from_state(self.state)
248
+
249
+ # ---------- Step ----------
250
+ def step(self, action: DisruptionAction):
251
+ s = self.state
252
+ if s is None:
253
+ raise RuntimeError("Call reset() before step().")
254
+ if s.done:
255
+ return self._obs_from_state(s), 0.0, True, {"fail_reason": s.fail_reason, "success": s.success}
256
+
257
+ self.steps += 1
258
+ now = s.now_ts
259
+
260
+ # Apply action
261
+ if action.type == "rebook_connection":
262
+ if now >= s.flight_lock_ts:
263
+ s.done = True; s.success = False; s.fail_reason = "flight_locked"
264
+ elif not action.alt_flight_id:
265
+ s.done = True; s.success = False; s.fail_reason = "missing_alt_flight_id"
266
+ else:
267
+ chosen = next((f for f in s.alt_flights if f.id == action.alt_flight_id), None)
268
+ if chosen is None:
269
+ s.done = True; s.success = False; s.fail_reason = "unknown_alt_flight"
270
+ elif chosen.lock_ts is not None and now >= chosen.lock_ts:
271
+ s.done = True; s.success = False; s.fail_reason = "alt_flight_locked"
272
+ else:
273
+ s.changes += 1
274
+ s.cost_incurred += chosen.price_delta
275
+ s.orig_conn_depart_ts = chosen.depart_ts
276
+ s.orig_conn_arrive_ts = chosen.arrive_ts
277
+ s.cancelled = False # Resolve cancellation
278
+
279
+
280
+ elif action.type == "keep_connection":
281
+ pass # No change, feasibility checked in _resolve_terminal
282
+
283
+ elif action.type == "cancel_cab_rebook":
284
+ if now >= s.cab.lock_ts:
285
+ s.done = True; s.success = False; s.fail_reason = "cab_locked"
286
+ elif action.new_cab_pickup_ts is None:
287
+ s.done = True; s.success = False; s.fail_reason = "missing_new_cab_pickup_ts"
288
+ else:
289
+ s.changes += 1
290
+ s.cost_incurred += s.cab.rebook_fee
291
+ s.cab.pickup_ts = action.new_cab_pickup_ts
292
+ # Clear unresolved if cab timing is now valid
293
+ s.unresolved_critical = 0
294
+
295
+ elif action.type == "request_refund":
296
+ s.changes += 1
297
+
298
+ else:
299
+ s.done = True; s.success = False; s.fail_reason = "unknown_action"
300
+
301
+ # Advance time
302
+ s.now_ts += 5 * 60
303
+
304
+ # Check terminal conditions if not already failed
305
+ if not s.done:
306
+ self._update_resolution_status(s)
307
+ self._resolve_terminal(s)
308
+
309
+ # Max steps fallback
310
+ if not s.done and self.steps >= self.max_steps:
311
+ s.done = True
312
+ s.success = False
313
+ s.fail_reason = "max_steps"
314
+
315
+ reward = self._reward(s)
316
+ info = {
317
+ "success": s.success,
318
+ "fail_reason": s.fail_reason,
319
+ "cost_incurred": s.cost_incurred,
320
+ "changes": s.changes,
321
+ "delta_delay_min": self._delta_delay_min(s) if s.success else None
322
+ }
323
+ return self._obs_from_state(s), reward, s.done, info
324
+
325
+
326
+ def format_observation_for_llm(obs: DisruptionObservation) -> str:
327
+ f = obs.features
328
+ min_buffer = 45
329
+
330
+ # Phase-aware header
331
+ if f.connection_resolved and f.cab_resolved:
332
+ status_header = "=== STATUS: ALL RESOLVED ==="
333
+ elif f.connection_resolved:
334
+ status_header = "=== STATUS: CONNECTION RESOLVED, CAB NEEDS ATTENTION ==="
335
+ else:
336
+ status_header = "=== STATUS: CONNECTION UNRESOLVED ==="
337
+
338
+ # Build prompt based on phase
339
+ prompt = f"""{status_header}
340
+ Next task: {f.next_task}
341
+
342
+ === SITUATION ===
343
+ Disruption: {obs.disruption_type}
344
+ """
345
+
346
+ # Connection section
347
+ if f.connection_resolved:
348
+ prompt += f"Connection: RESOLVED (buffer {f.buffer_min:.0f} min)\n"
349
+ else:
350
+ feasible = f.buffer_min >= min_buffer and not obs.cancelled
351
+ prompt += f"Original plan feasible: {'YES' if feasible else 'NO'}\n"
352
+ if obs.cancelled:
353
+ prompt += "Hard constraint: keep_connection will FAIL (flight cancelled)\n"
354
+ elif f.buffer_min < min_buffer:
355
+ prompt += f"Hard constraint: keep_connection will FAIL (buffer {f.buffer_min:.0f} min < {min_buffer} min required)\n"
356
+ prompt += f"Required min buffer: {min_buffer} min\n"
357
+
358
+ # Urgency
359
+ alt_locks = [alt.features.alt_lock_in_min for alt in obs.alt_flights if alt.features]
360
+ min_alt_lock = min(alt_locks) if alt_locks else 999
361
+ prompt += f"""
362
+ === URGENCY ===
363
+ Next lock: {min(min_alt_lock, f.flight_lock_in_min, f.cab_lock_in_min):.0f} min
364
+ Flight lock: {f.flight_lock_in_min:.0f} min
365
+ Cab lock: {f.cab_lock_in_min:.0f} min
366
+ """
367
+
368
+ # Cab section (emphasize if that's the task)
369
+ cab_status = "OK" if f.cab_resolved else ("too early" if f.cab_mismatch_min < 0 else "too late")
370
+ prompt += f"""
371
+ === CAB STATUS ===
372
+ Cab pickup mismatch: {f.cab_mismatch_min:+.0f} min ({cab_status})
373
+ """
374
+ if not f.cab_resolved:
375
+ prompt += f"** ACTION REQUIRED: Rebook cab to align with arrival **\n"
376
+
377
+ # Options section (only if connection not resolved)
378
+ if not f.connection_resolved:
379
+ options_lines = []
380
+ for alt in obs.alt_flights:
381
+ af = alt.features
382
+ if af.alt_lock_in_min <= 0:
383
+ feasible_label = "INVALID"
384
+ reason = "locked"
385
+ elif af.alt_buffer_min < min_buffer:
386
+ feasible_label = "INVALID"
387
+ reason = f"buffer {af.alt_buffer_min:.0f} < {min_buffer}"
388
+ else:
389
+ feasible_label = "FEASIBLE"
390
+ reason = ""
391
+
392
+ reason_str = f" ({reason})" if reason else ""
393
+ options_lines.append(
394
+ f"{alt.id}: [{feasible_label}{reason_str}] buffer={af.alt_buffer_min:+.0f}, "
395
+ f"delay={af.alt_arrival_delay_min:+.0f}, cost={alt.price_delta:+.0f}, locks_in={af.alt_lock_in_min:.0f}"
396
+ )
397
+
398
+ prompt += f"""
399
+ === FLIGHT OPTIONS ===
400
+ {chr(10).join(options_lines)}
401
+ """
402
+
403
+ # Action instructions
404
+ prompt += """
405
+ === ACTION (strict) ===
406
+ Only choose FEASIBLE options!
407
+ Respond exactly:
408
+ rebook_connection | alt_flight_id=ALT_X
409
+ OR
410
+ cancel_cab_rebook | new_cab_pickup_in_min=X
411
+ OR
412
+ keep_connection |
413
+ OR
414
+ request_refund |
415
+ """
416
+ return prompt