atin5551 commited on
Commit
cb70a7d
·
1 Parent(s): 7e43148

Deploy Varaha OpenEnv Docker Space

Browse files
openenv_wrapper/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Varaha OpenEnv package — public API re-exports."""
2
+
3
+ from openenv_wrapper.models import VarahaAction, VarahaObservation, VarahaState
4
+ from openenv_wrapper.varaha_environment import VarahaEnvironment
5
+ from openenv_wrapper.client import VarahaEnvClient
6
+
7
+ __all__ = [
8
+ "VarahaAction",
9
+ "VarahaObservation",
10
+ "VarahaState",
11
+ "VarahaEnvironment",
12
+ "VarahaEnvClient",
13
+ ]
openenv_wrapper/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (514 Bytes). View file
 
openenv_wrapper/__pycache__/client.cpython-313.pyc ADDED
Binary file (1.97 kB). View file
 
openenv_wrapper/__pycache__/models.cpython-313.pyc ADDED
Binary file (4.21 kB). View file
 
openenv_wrapper/__pycache__/varaha_environment.cpython-313.pyc ADDED
Binary file (6.77 kB). View file
 
openenv_wrapper/client.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WebSocket client for the Varaha OpenEnv server."""
2
+
3
+ from typing import Any, Dict
4
+
5
+ from openenv.core.env_client import EnvClient
6
+ from openenv.core.client_types import StepResult
7
+
8
+ from openenv_wrapper.models import VarahaAction, VarahaObservation, VarahaState
9
+
10
+
11
+ class VarahaEnvClient(EnvClient[VarahaAction, VarahaObservation, VarahaState]):
12
+ """Typed client that speaks to a running Varaha OpenEnv server."""
13
+
14
+ def _step_payload(self, action: VarahaAction) -> Dict[str, Any]:
15
+ return action.model_dump(exclude={"metadata"})
16
+
17
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[VarahaObservation]:
18
+ obs_data = payload.get("observation", payload.get("data", payload))
19
+ obs = VarahaObservation(**obs_data)
20
+ return StepResult(
21
+ observation=obs,
22
+ reward=payload.get("reward", obs.reward),
23
+ done=payload.get("done", obs.done),
24
+ )
25
+
26
+ def _parse_state(self, payload: Dict[str, Any]) -> VarahaState:
27
+ return VarahaState(**payload)
openenv_wrapper/models.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for the Varaha OpenEnv environment."""
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from pydantic import Field
6
+ from openenv.core.env_server.types import Action, Observation, State
7
+
8
+
9
+ class VarahaAction(Action):
10
+ """Drone acceleration command with automatic delivery/recharge."""
11
+
12
+ ax: float = Field(0.0, description="Desired acceleration along x-axis (m/s^2)")
13
+ ay: float = Field(0.0, description="Desired acceleration along y-axis (m/s^2)")
14
+ az: float = Field(0.0, description="Desired acceleration along z-axis (m/s^2)")
15
+ deliver: bool = Field(True, description="Attempt delivery when near a target")
16
+ recharge: bool = Field(True, description="Attempt recharge when near base station")
17
+ tool_call: str = Field(
18
+ "",
19
+ description="Optional tool call: request_intel[:target_id] | battery_forecast | mission_report",
20
+ )
21
+
22
+
23
+ class VarahaObservation(Observation):
24
+ """Full observation returned after each step/reset."""
25
+
26
+ drone_position: Dict[str, float] = Field(
27
+ default_factory=dict, description="Drone {x, y, z} in local metres"
28
+ )
29
+ drone_velocity: Dict[str, float] = Field(
30
+ default_factory=dict, description="Drone velocity {x, y, z} in m/s"
31
+ )
32
+ battery: float = Field(0.0, description="Remaining battery units")
33
+ carrying_payload: bool = Field(True, description="Whether the drone still carries payload")
34
+ alive: bool = Field(True, description="Whether the drone is still operational")
35
+ targets: List[Dict[str, Any]] = Field(
36
+ default_factory=list,
37
+ description="Per-target relative position, urgency, delivered status",
38
+ )
39
+ hazards: List[Dict[str, Any]] = Field(
40
+ default_factory=list,
41
+ description="Per-hazard relative position, current height, severity",
42
+ )
43
+ step_num: int = Field(0, description="Current step number in the episode")
44
+ max_steps: int = Field(2000, description="Maximum allowed steps")
45
+ reward_breakdown: Dict[str, float] = Field(
46
+ default_factory=dict, description="Itemised reward components from the last step"
47
+ )
48
+ mission: Dict[str, Any] = Field(
49
+ default_factory=dict,
50
+ description="Instruction-mode progress, next instruction, and violation counters",
51
+ )
52
+ last_tool_result: Dict[str, Any] = Field(
53
+ default_factory=dict,
54
+ description="Result payload from the most recent tool call",
55
+ )
56
+ success: bool = Field(False, description="Whether the mission is successfully completed")
57
+ trace: Optional[Dict[str, Any]] = Field(
58
+ None, description="Full episode trace (only populated on the final step)"
59
+ )
60
+
61
+
62
+ class VarahaState(State):
63
+ """Internal environment state exposed via the state property."""
64
+
65
+ cumulative_reward: float = Field(0.0, description="Total accumulated reward")
66
+ deliveries_completed: int = Field(0, description="Number of targets delivered so far")
67
+ total_targets: int = Field(0, description="Total number of targets in the episode")
68
+ battery: float = Field(0.0, description="Current battery level")
69
+ success: bool = Field(False, description="Whether the mission is complete")
openenv_wrapper/server/__init__.py ADDED
File without changes
openenv_wrapper/server/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (145 Bytes). View file
 
openenv_wrapper/server/__pycache__/app.cpython-313.pyc ADDED
Binary file (786 Bytes). View file
 
openenv_wrapper/server/app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application for the Varaha OpenEnv environment."""
2
+
3
+ import sys
4
+ import os
5
+
6
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
7
+
8
+ from openenv.core.env_server import create_app
9
+
10
+ from openenv_wrapper.models import VarahaAction, VarahaObservation
11
+ from openenv_wrapper.varaha_environment import VarahaEnvironment
12
+
13
+ app = create_app(
14
+ VarahaEnvironment,
15
+ VarahaAction,
16
+ VarahaObservation,
17
+ env_name="varaha",
18
+ )
openenv_wrapper/varaha_environment.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv-compatible Varaha wildfire drone environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ import os
7
+ import uuid
8
+ from typing import Any, Callable, Optional
9
+
10
+ from openenv.core.env_server.interfaces import Environment
11
+ from openenv.core.env_server.types import EnvironmentMetadata
12
+
13
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
14
+
15
+ from varaha_env import VarahaConfig, VarahaEnv, build_random_world
16
+ from openenv_wrapper.models import VarahaAction, VarahaObservation, VarahaState
17
+
18
+
19
+ class VarahaEnvironment(Environment[VarahaAction, VarahaObservation, VarahaState]):
20
+ """Wildfire logistics drone environment wrapped for OpenEnv.
21
+
22
+ Each episode the drone must deliver supplies to responder zones near
23
+ wildfire hazards, then return to base. Supports domain-randomised
24
+ worlds when ``world_fn`` is provided.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ config: Optional[VarahaConfig] = None,
30
+ world_fn: Optional[Callable[..., None]] = None,
31
+ ) -> None:
32
+ super().__init__()
33
+ self._config = config or VarahaConfig()
34
+ self._world_fn = world_fn
35
+ self._env = VarahaEnv(config=self._config, world_fn=self._world_fn)
36
+ self._episode_id = str(uuid.uuid4())
37
+ self._last_info: dict[str, Any] = {}
38
+
39
+ # ------------------------------------------------------------------
40
+ # OpenEnv abstract interface
41
+ # ------------------------------------------------------------------
42
+
43
+ def reset(
44
+ self,
45
+ seed: Optional[int] = None,
46
+ episode_id: Optional[str] = None,
47
+ **kwargs: Any,
48
+ ) -> VarahaObservation:
49
+ self._episode_id = episode_id or str(uuid.uuid4())
50
+ obs_dict = self._env.reset(seed=seed)
51
+ self._last_info = {}
52
+ return self._build_observation(obs_dict, reward=0.0, done=False)
53
+
54
+ def step(
55
+ self,
56
+ action: VarahaAction,
57
+ timeout_s: Optional[float] = None,
58
+ **kwargs: Any,
59
+ ) -> VarahaObservation:
60
+ action_dict = {
61
+ "ax": action.ax,
62
+ "ay": action.ay,
63
+ "az": action.az,
64
+ "deliver": action.deliver,
65
+ "recharge": action.recharge,
66
+ "tool_call": action.tool_call,
67
+ }
68
+ obs_dict, reward, done, info = self._env.step(action_dict)
69
+ self._last_info = info
70
+ return self._build_observation(obs_dict, reward=reward, done=done, info=info)
71
+
72
+ @property
73
+ def state(self) -> VarahaState:
74
+ delivered = sum(1 for t in self._env.targets if t.delivered)
75
+ return VarahaState(
76
+ episode_id=self._episode_id,
77
+ step_count=self._env.step_count,
78
+ cumulative_reward=round(self._env.cumulative_reward, 4),
79
+ deliveries_completed=delivered,
80
+ total_targets=len(self._env.targets),
81
+ battery=round(self._env.drone.battery, 4),
82
+ success=self._env._is_success(),
83
+ )
84
+
85
+ # ------------------------------------------------------------------
86
+ # Optional overrides
87
+ # ------------------------------------------------------------------
88
+
89
+ def get_metadata(self) -> EnvironmentMetadata:
90
+ return EnvironmentMetadata(
91
+ name="Varaha Wildfire Logistics",
92
+ description=(
93
+ "A 3D drone delivery environment where an agent must navigate "
94
+ "wildfire hazards and obstacles to deliver supplies to responder "
95
+ "zones, then return to base."
96
+ ),
97
+ version="1.0.0",
98
+ author="Varaha Team",
99
+ )
100
+
101
+ def close(self) -> None:
102
+ pass
103
+
104
+ # ------------------------------------------------------------------
105
+ # Helpers
106
+ # ------------------------------------------------------------------
107
+
108
+ def _build_observation(
109
+ self,
110
+ obs_dict: dict[str, Any],
111
+ *,
112
+ reward: float,
113
+ done: bool,
114
+ info: dict[str, Any] | None = None,
115
+ ) -> VarahaObservation:
116
+ info = info or {}
117
+ trace = self._env.get_trace() if done else None
118
+ return VarahaObservation(
119
+ done=done,
120
+ reward=round(reward, 4),
121
+ metadata={"info": info},
122
+ drone_position=obs_dict["drone_position"],
123
+ drone_velocity=obs_dict["drone_velocity"],
124
+ battery=obs_dict["battery"],
125
+ carrying_payload=obs_dict["carrying_payload"],
126
+ alive=obs_dict["alive"],
127
+ targets=obs_dict["targets"],
128
+ hazards=obs_dict.get("hazards", []),
129
+ mission=obs_dict.get("mission", {}),
130
+ last_tool_result=obs_dict.get("last_tool_result", {}),
131
+ step_num=obs_dict["step"],
132
+ max_steps=obs_dict["max_steps"],
133
+ reward_breakdown=info.get("reward_breakdown", {}),
134
+ success=self._env._is_success(),
135
+ trace=trace,
136
+ )
sim_types.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Varaha simulation types — core data structures for the wildfire logistics environment."""
2
+
3
+ import math
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+
8
+ # ---------------------------------------------------------------------------
9
+ # Vec3
10
+ # ---------------------------------------------------------------------------
11
+
12
+ @dataclass
13
+ class Vec3:
14
+ """Lightweight 3-component vector with basic arithmetic helpers."""
15
+
16
+ x: float = 0.0
17
+ y: float = 0.0
18
+ z: float = 0.0
19
+
20
+ # --- arithmetic ---
21
+
22
+ def __add__(self, other: "Vec3") -> "Vec3":
23
+ return Vec3(self.x + other.x, self.y + other.y, self.z + other.z)
24
+
25
+ def __sub__(self, other: "Vec3") -> "Vec3":
26
+ return Vec3(self.x - other.x, self.y - other.y, self.z - other.z)
27
+
28
+ def scale(self, s: float) -> "Vec3":
29
+ return Vec3(self.x * s, self.y * s, self.z * s)
30
+
31
+ # --- magnitude ---
32
+
33
+ def norm(self) -> float:
34
+ return math.sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2)
35
+
36
+ def normalized(self) -> "Vec3":
37
+ n = self.norm()
38
+ if n < 1e-9:
39
+ return Vec3(0.0, 0.0, 0.0)
40
+ return self.scale(1.0 / n)
41
+
42
+ def clamp_magnitude(self, max_mag: float) -> "Vec3":
43
+ n = self.norm()
44
+ if n > max_mag and n > 1e-9:
45
+ return self.scale(max_mag / n)
46
+ return Vec3(self.x, self.y, self.z)
47
+
48
+ # --- distance ---
49
+
50
+ def distance_to(self, other: "Vec3") -> float:
51
+ return (self - other).norm()
52
+
53
+ def horizontal_distance_to(self, other: "Vec3") -> float:
54
+ dx = self.x - other.x
55
+ dy = self.y - other.y
56
+ return math.sqrt(dx * dx + dy * dy)
57
+
58
+ # --- serialization ---
59
+
60
+ def to_dict(self) -> dict[str, float]:
61
+ return {"x": round(self.x, 4), "y": round(self.y, 4), "z": round(self.z, 4)}
62
+
63
+ def __repr__(self) -> str:
64
+ return f"Vec3({self.x:.2f}, {self.y:.2f}, {self.z:.2f})"
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # Drone
69
+ # ---------------------------------------------------------------------------
70
+
71
+ @dataclass
72
+ class DroneState:
73
+ """Full kinematic + status state of the drone."""
74
+
75
+ position: Vec3 = field(default_factory=Vec3)
76
+ velocity: Vec3 = field(default_factory=Vec3)
77
+ battery: float = 100.0
78
+ carrying_payload: bool = True
79
+ alive: bool = True
80
+
81
+ def to_dict(self) -> dict[str, Any]:
82
+ return {
83
+ "position": self.position.to_dict(),
84
+ "velocity": self.velocity.to_dict(),
85
+ "battery": round(self.battery, 4),
86
+ "carrying_payload": self.carrying_payload,
87
+ "alive": self.alive,
88
+ }
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # World entities
93
+ # ---------------------------------------------------------------------------
94
+
95
+ @dataclass
96
+ class BaseStation:
97
+ """Home base where the drone launches, lands, and recharges."""
98
+
99
+ position: Vec3 = field(default_factory=Vec3)
100
+ recharge_radius: float = 20.0
101
+
102
+ def to_dict(self) -> dict[str, Any]:
103
+ return {
104
+ "position": self.position.to_dict(),
105
+ "recharge_radius": self.recharge_radius,
106
+ }
107
+
108
+
109
+ @dataclass
110
+ class DeliveryTarget:
111
+ """A responder zone requiring supply delivery."""
112
+
113
+ id: str = ""
114
+ position: Vec3 = field(default_factory=Vec3)
115
+ urgency: float = 0.5
116
+ delivered: bool = False
117
+ delivery_radius: float = 15.0
118
+
119
+ def to_dict(self) -> dict[str, Any]:
120
+ return {
121
+ "id": self.id,
122
+ "position": self.position.to_dict(),
123
+ "urgency": round(self.urgency, 4),
124
+ "delivered": self.delivered,
125
+ "delivery_radius": self.delivery_radius,
126
+ }
127
+
128
+
129
+ @dataclass
130
+ class HazardRegion:
131
+ """Wildfire danger zone modeled as a ground-level dome.
132
+
133
+ The hazard has a horizontal radius and a height. Danger is zero
134
+ above ``height`` and outside ``radius``, allowing drones to fly
135
+ over fires at sufficient altitude. Within the dome, danger scales
136
+ with proximity to the center both horizontally and vertically.
137
+
138
+ ``growth_rate`` controls per-step height increase (metres/step),
139
+ simulating fire growth over an episode.
140
+ """
141
+
142
+ id: str = ""
143
+ center: Vec3 = field(default_factory=Vec3)
144
+ radius: float = 50.0
145
+ severity: float = 0.5
146
+ height: float = 80.0
147
+ growth_rate: float = 0.0
148
+ _current_height: float = field(default=0.0, init=False, repr=False)
149
+
150
+ def __post_init__(self):
151
+ self._current_height = self.height
152
+
153
+ def reset(self):
154
+ """Reset dynamic state for a new episode."""
155
+ self._current_height = self.height
156
+
157
+ def tick(self):
158
+ """Advance one timestep — grow the fire."""
159
+ if self.growth_rate > 0:
160
+ self._current_height += self.growth_rate
161
+
162
+ def contains(self, pos: Vec3) -> bool:
163
+ horiz = ((pos.x - self.center.x) ** 2 + (pos.y - self.center.y) ** 2) ** 0.5
164
+ alt = pos.z - self.center.z
165
+ return horiz <= self.radius and 0 <= alt < self._current_height
166
+
167
+ def danger_factor(self, pos: Vec3) -> float:
168
+ """0 outside the dome, scales up toward the ground-level center."""
169
+ horiz = ((pos.x - self.center.x) ** 2 + (pos.y - self.center.y) ** 2) ** 0.5
170
+ if horiz >= self.radius:
171
+ return 0.0
172
+ alt = pos.z - self.center.z
173
+ if alt >= self._current_height or alt < 0:
174
+ return 0.0
175
+ horiz_factor = 1.0 - horiz / self.radius
176
+ vert_factor = 1.0 - alt / self._current_height
177
+ return self.severity * horiz_factor * vert_factor
178
+
179
+ def to_dict(self) -> dict[str, Any]:
180
+ return {
181
+ "id": self.id,
182
+ "center": self.center.to_dict(),
183
+ "radius": self.radius,
184
+ "severity": self.severity,
185
+ "height": self.height,
186
+ "current_height": round(self._current_height, 2),
187
+ "growth_rate": self.growth_rate,
188
+ }
189
+
190
+
191
+ @dataclass
192
+ class ObstacleVolume:
193
+ """Axis-aligned 3D box that the drone must not enter."""
194
+
195
+ id: str = ""
196
+ min_corner: Vec3 = field(default_factory=Vec3)
197
+ max_corner: Vec3 = field(default_factory=Vec3)
198
+ kind: str = "building"
199
+
200
+ def contains(self, pos: Vec3) -> bool:
201
+ return (
202
+ self.min_corner.x <= pos.x <= self.max_corner.x
203
+ and self.min_corner.y <= pos.y <= self.max_corner.y
204
+ and self.min_corner.z <= pos.z <= self.max_corner.z
205
+ )
206
+
207
+ @property
208
+ def center(self) -> Vec3:
209
+ return Vec3(
210
+ (self.min_corner.x + self.max_corner.x) / 2,
211
+ (self.min_corner.y + self.max_corner.y) / 2,
212
+ (self.min_corner.z + self.max_corner.z) / 2,
213
+ )
214
+
215
+ @property
216
+ def half_size(self) -> Vec3:
217
+ return Vec3(
218
+ (self.max_corner.x - self.min_corner.x) / 2,
219
+ (self.max_corner.y - self.min_corner.y) / 2,
220
+ (self.max_corner.z - self.min_corner.z) / 2,
221
+ )
222
+
223
+ @property
224
+ def height(self) -> float:
225
+ return self.max_corner.z
226
+
227
+ def nearest_surface_dist(self, pos: Vec3) -> float:
228
+ """Signed distance to the nearest surface (negative = inside)."""
229
+ cx, cy = self.center.x, self.center.y
230
+ hx, hy = self.half_size.x, self.half_size.y
231
+ dx = max(abs(pos.x - cx) - hx, 0.0)
232
+ dy = max(abs(pos.y - cy) - hy, 0.0)
233
+ dz_below = max(self.min_corner.z - pos.z, 0.0)
234
+ dz_above = max(pos.z - self.max_corner.z, 0.0)
235
+ return math.sqrt(dx * dx + dy * dy + (dz_below + dz_above) ** 2)
236
+
237
+ def to_dict(self) -> dict[str, Any]:
238
+ return {
239
+ "id": self.id,
240
+ "min_corner": self.min_corner.to_dict(),
241
+ "max_corner": self.max_corner.to_dict(),
242
+ "kind": self.kind,
243
+ }
244
+
245
+
246
+ @dataclass
247
+ class CylindricalObstacle:
248
+ """Vertical cylinder obstacle — trees, poles, pillars, tanks."""
249
+
250
+ id: str = ""
251
+ center: Vec3 = field(default_factory=Vec3)
252
+ radius: float = 10.0
253
+ height: float = 50.0
254
+ kind: str = "tree"
255
+
256
+ def contains(self, pos: Vec3) -> bool:
257
+ dx = pos.x - self.center.x
258
+ dy = pos.y - self.center.y
259
+ horiz_dist = math.sqrt(dx * dx + dy * dy)
260
+ return horiz_dist <= self.radius and 0 <= pos.z <= self.height
261
+
262
+ def nearest_surface_dist(self, pos: Vec3) -> float:
263
+ dx = pos.x - self.center.x
264
+ dy = pos.y - self.center.y
265
+ horiz_dist = math.sqrt(dx * dx + dy * dy)
266
+ radial_gap = max(horiz_dist - self.radius, 0.0)
267
+ vert_gap = max(pos.z - self.height, 0.0) if pos.z > self.height else max(-pos.z, 0.0)
268
+ return math.sqrt(radial_gap ** 2 + vert_gap ** 2)
269
+
270
+ def to_dict(self) -> dict[str, Any]:
271
+ return {
272
+ "id": self.id,
273
+ "center": self.center.to_dict(),
274
+ "radius": round(self.radius, 2),
275
+ "height": round(self.height, 2),
276
+ "kind": self.kind,
277
+ }
278
+
279
+
280
+ # ---------------------------------------------------------------------------
281
+ # Responder units — dynamic actors that alter mission conditions mid-episode
282
+ # ---------------------------------------------------------------------------
283
+
284
+ RESPONDER_STATUSES = ("stable", "urgent", "critical")
285
+ RESPONDER_STATUS_MAP = {"stable": 0.0, "urgent": 0.5, "critical": 1.0}
286
+
287
+ INTEL_TYPES = (
288
+ "none",
289
+ "blocked_north", "blocked_south", "blocked_east", "blocked_west",
290
+ "safe_north", "safe_south", "safe_east", "safe_west",
291
+ "fire_expanded", "fire_receded",
292
+ )
293
+
294
+ INTEL_DIRECTION_VECS = {
295
+ "none": (0.0, 0.0),
296
+ "blocked_north": (0.0, 1.0), "blocked_south": (0.0, -1.0),
297
+ "blocked_east": (1.0, 0.0), "blocked_west": (-1.0, 0.0),
298
+ "safe_north": (0.0, 1.0), "safe_south": (0.0, -1.0),
299
+ "safe_east": (1.0, 0.0), "safe_west": (-1.0, 0.0),
300
+ "fire_expanded": (0.0, 0.0), "fire_receded": (0.0, 0.0),
301
+ }
302
+
303
+
304
+ @dataclass
305
+ class ScheduledEvent:
306
+ """A future event a responder will trigger at a specific step."""
307
+ step: int = 0
308
+ event_type: str = ""
309
+ payload: dict[str, Any] = field(default_factory=dict)
310
+ fired: bool = False
311
+
312
+
313
+ @dataclass
314
+ class ResponderUnit:
315
+ """First responder on the ground linked to a delivery target.
316
+
317
+ Can dynamically alter mission conditions mid-episode:
318
+ 1. Update urgency of their linked target
319
+ 2. Relocate the drop-zone (move target position)
320
+ 3. Broadcast hazard intel (structured approach guidance)
321
+ """
322
+
323
+ id: str = ""
324
+ position: Vec3 = field(default_factory=Vec3)
325
+ linked_target_id: str = ""
326
+ status: str = "stable"
327
+ current_need: str = "supplies"
328
+ message: str = ""
329
+ can_update_dropzone: bool = False
330
+ active: bool = True
331
+
332
+ latest_intel: str = "none"
333
+ intel_severity: float = 0.0
334
+
335
+ scheduled_events: list[ScheduledEvent] = field(default_factory=list)
336
+
337
+ def status_code(self) -> float:
338
+ return RESPONDER_STATUS_MAP.get(self.status, 0.0)
339
+
340
+ def intel_direction(self) -> tuple[float, float]:
341
+ return INTEL_DIRECTION_VECS.get(self.latest_intel, (0.0, 0.0))
342
+
343
+ def to_dict(self) -> dict[str, Any]:
344
+ return {
345
+ "id": self.id,
346
+ "position": self.position.to_dict(),
347
+ "linked_target_id": self.linked_target_id,
348
+ "status": self.status,
349
+ "current_need": self.current_need,
350
+ "message": self.message,
351
+ "can_update_dropzone": self.can_update_dropzone,
352
+ "active": self.active,
353
+ "latest_intel": self.latest_intel,
354
+ "intel_severity": round(self.intel_severity, 4),
355
+ }
356
+
357
+
358
+ # ---------------------------------------------------------------------------
359
+ # Observation & step diagnostics
360
+ # ---------------------------------------------------------------------------
361
+
362
+ @dataclass
363
+ class VarahaObservation:
364
+ """Structured observation returned to the agent each step.
365
+
366
+ Kept as a dataclass for documentation; the env also offers a plain-dict
367
+ path via ``get_observation()`` for maximum serialisation flexibility.
368
+ """
369
+
370
+ drone_position: Vec3 = field(default_factory=Vec3)
371
+ drone_velocity: Vec3 = field(default_factory=Vec3)
372
+ battery: float = 100.0
373
+ carrying_payload: bool = True
374
+ alive: bool = True
375
+ targets: list[dict[str, Any]] = field(default_factory=list)
376
+ step: int = 0
377
+ max_steps: int = 500
378
+
379
+ def to_dict(self) -> dict[str, Any]:
380
+ return {
381
+ "drone_position": self.drone_position.to_dict(),
382
+ "drone_velocity": self.drone_velocity.to_dict(),
383
+ "battery": round(self.battery, 4),
384
+ "carrying_payload": self.carrying_payload,
385
+ "alive": self.alive,
386
+ "targets": self.targets,
387
+ "step": self.step,
388
+ "max_steps": self.max_steps,
389
+ }
390
+
391
+
392
+ @dataclass
393
+ class MissionInstruction:
394
+ """Single mission instruction used for long-horizon planning mode."""
395
+
396
+ id: str = ""
397
+ kind: str = ""
398
+ description: str = ""
399
+ target_id: str = ""
400
+ tool_name: str = ""
401
+ completed: bool = False
402
+ violated: bool = False
403
+
404
+ def to_dict(self) -> dict[str, Any]:
405
+ return {
406
+ "id": self.id,
407
+ "kind": self.kind,
408
+ "description": self.description,
409
+ "target_id": self.target_id,
410
+ "tool_name": self.tool_name,
411
+ "completed": self.completed,
412
+ "violated": self.violated,
413
+ }
414
+
415
+
416
+ @dataclass
417
+ class TracePoint:
418
+ """Single frame of the drone's recorded trajectory."""
419
+
420
+ step: int = 0
421
+ position: Vec3 = field(default_factory=Vec3)
422
+ velocity: Vec3 = field(default_factory=Vec3)
423
+ battery: float = 100.0
424
+ reward: float = 0.0
425
+ cumulative_reward: float = 0.0
426
+ events: list[str] = field(default_factory=list)
427
+ observation: dict[str, Any] = field(default_factory=dict)
428
+
429
+ def to_dict(self) -> dict[str, Any]:
430
+ return {
431
+ "step": self.step,
432
+ "position": self.position.to_dict(),
433
+ "velocity": self.velocity.to_dict(),
434
+ "battery": round(self.battery, 4),
435
+ "reward": round(self.reward, 4),
436
+ "cumulative_reward": round(self.cumulative_reward, 4),
437
+ "events": list(self.events),
438
+ "observation": self.observation,
439
+ }
440
+
441
+
442
+ @dataclass
443
+ class StepInfo:
444
+ """Per-step diagnostic info returned alongside the reward."""
445
+
446
+ collision: bool = False
447
+ delivered_target_ids: list[str] = field(default_factory=list)
448
+ in_hazard: bool = False
449
+ hazard_severity: float = 0.0
450
+ reached_base: bool = False
451
+ distance_traveled: float = 0.0
452
+ tool_call: str = ""
453
+ tool_result: dict[str, Any] = field(default_factory=dict)
454
+ instruction_completed: int = 0
455
+ instruction_total: int = 0
456
+ instruction_violations: int = 0
457
+ reward_breakdown: dict[str, float] = field(default_factory=dict)
458
+
459
+ def to_dict(self) -> dict[str, Any]:
460
+ return {
461
+ "collision": self.collision,
462
+ "delivered_target_ids": list(self.delivered_target_ids),
463
+ "in_hazard": self.in_hazard,
464
+ "hazard_severity": round(self.hazard_severity, 4),
465
+ "reached_base": self.reached_base,
466
+ "distance_traveled": round(self.distance_traveled, 4),
467
+ "tool_call": self.tool_call,
468
+ "tool_result": self.tool_result,
469
+ "instruction_completed": self.instruction_completed,
470
+ "instruction_total": self.instruction_total,
471
+ "instruction_violations": self.instruction_violations,
472
+ "reward_breakdown": {
473
+ k: round(v, 4) for k, v in self.reward_breakdown.items()
474
+ },
475
+ }
varaha_env.py ADDED
@@ -0,0 +1,1323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Varaha — wildfire logistics simulation environment.
2
+
3
+ A drone must deliver supplies to responder zones near wildfire hazards in
4
+ California-like terrain. The environment uses lightweight 3D kinematics with
5
+ local metre-based coordinates and an optional lat/lon conversion helper for
6
+ later Cesium visualisation.
7
+ """
8
+
9
+ import math
10
+ import random
11
+ from dataclasses import dataclass
12
+ from typing import Any, Optional
13
+
14
+ from sim_types import (
15
+ Vec3,
16
+ DroneState,
17
+ BaseStation,
18
+ DeliveryTarget,
19
+ HazardRegion,
20
+ ObstacleVolume,
21
+ CylindricalObstacle,
22
+ ResponderUnit,
23
+ ScheduledEvent,
24
+ RESPONDER_STATUSES,
25
+ INTEL_TYPES,
26
+ StepInfo,
27
+ TracePoint,
28
+ MissionInstruction,
29
+ )
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Configuration
34
+ # ---------------------------------------------------------------------------
35
+
36
+ @dataclass
37
+ class VarahaConfig:
38
+ """All tunable environment parameters live here."""
39
+
40
+ # World bounds (metres) — 5 km × 5 km operational area
41
+ world_x: float = 5000.0
42
+ world_y: float = 5000.0
43
+ world_z: float = 200.0
44
+
45
+ # Drone physics
46
+ battery_capacity: float = 300.0
47
+ max_speed: float = 25.0 # m/s
48
+ max_acceleration: float = 8.0 # m/s²
49
+ dt: float = 0.5 # seconds per step
50
+
51
+ # Episode
52
+ max_episode_steps: int = 2000
53
+
54
+ # Battery drain coefficients (tuned for 5 km scale)
55
+ drain_per_meter: float = 0.008
56
+ drain_elevation_factor: float = 0.02
57
+ drain_idle_per_step: float = 0.005
58
+ recharge_rate: float = 5.0 # battery units restored per recharge step
59
+
60
+ # Reward knobs
61
+ delivery_reward: float = 200.0
62
+ return_bonus: float = 100.0
63
+ step_penalty: float = 0.05
64
+ battery_cost_factor: float = 0.3
65
+ collision_penalty: float = 500.0
66
+ hazard_penalty: float = 5.0
67
+ failure_penalty: float = 200.0
68
+ distance_shaping_factor: float = 0.05
69
+ obstacle_proximity_penalty: float = 1.5
70
+ obstacle_proximity_radius: float = 80.0
71
+
72
+ # Long-horizon instruction mode (LLM-oriented)
73
+ instruction_mode: bool = False
74
+ instruction_count: int = 60
75
+ sparse_reward_mode: bool = False
76
+ instruction_completion_reward: float = 0.5
77
+ instruction_terminal_success_bonus: float = 2200.0
78
+ instruction_terminal_progress_bonus: float = 800.0
79
+ instruction_violation_penalty: float = 120.0
80
+ instruction_unfinished_penalty: float = 10.0
81
+ available_tools: tuple[str, ...] = (
82
+ "request_intel",
83
+ "battery_forecast",
84
+ "mission_report",
85
+ )
86
+
87
+ # California origin anchor (near Sacramento — wildfire-relevant)
88
+ origin_lat: float = 38.55
89
+ origin_lon: float = -121.47
90
+
91
+
92
+ # ---------------------------------------------------------------------------
93
+ # Random world generator for domain randomization
94
+ # ---------------------------------------------------------------------------
95
+
96
+ def build_random_world(env: "VarahaEnv") -> None:
97
+ """Legacy easy world gen — kept for backward compatibility."""
98
+ build_hardcore_world(env)
99
+
100
+
101
+ def _hdist(a: Vec3, b: Vec3) -> float:
102
+ return ((a.x - b.x) ** 2 + (a.y - b.y) ** 2) ** 0.5
103
+
104
+
105
+ def build_hardcore_world(env: "VarahaEnv", ultra_hard: bool = False) -> None:
106
+ """Generate an extremely challenging randomized world for serious RL training.
107
+
108
+ Features template-based obstacle placement (urban grid, dense forest,
109
+ corridor maze, river valley, fortress, mixed), cylindrical obstacles,
110
+ responder units with dynamic events, and adversarial target placement.
111
+
112
+ When ultra_hard=True: denser obstacles, more hazards, more targets, longer episodes.
113
+ """
114
+ cfg = env.cfg
115
+ rng = random
116
+
117
+ wx, wy, wz = cfg.world_x, cfg.world_y, cfg.world_z
118
+ margin = 200.0
119
+
120
+ def _rpos(z_lo=10.0, z_hi=60.0):
121
+ return Vec3(rng.uniform(margin, wx - margin),
122
+ rng.uniform(margin, wy - margin),
123
+ rng.uniform(z_lo, z_hi))
124
+
125
+ def _rpos_ground():
126
+ return Vec3(rng.uniform(margin, wx - margin),
127
+ rng.uniform(margin, wy - margin), 0.0)
128
+
129
+ # --- Base station ---
130
+ base_pos = Vec3(rng.uniform(100, wx - 100), rng.uniform(100, wy - 100), 0.0)
131
+ env.base = BaseStation(position=base_pos, recharge_radius=rng.uniform(60, 100))
132
+
133
+ # --- Targets (2-5 normal, 3-6 ultra) ---
134
+ if ultra_hard:
135
+ n_targets = rng.choices([3, 4, 5, 6], weights=[0.15, 0.35, 0.35, 0.15])[0]
136
+ else:
137
+ n_targets = rng.choices([2, 3, 4, 5], weights=[0.15, 0.40, 0.30, 0.15])[0]
138
+ targets = []
139
+ for i in range(n_targets):
140
+ for _ in range(120):
141
+ pos = _rpos(z_lo=5.0, z_hi=60.0)
142
+ if _hdist(pos, base_pos) < 500:
143
+ continue
144
+ if all(_hdist(pos, t.position) > 400 for t in targets):
145
+ break
146
+ targets.append(DeliveryTarget(
147
+ id=f"T{i+1}", position=pos,
148
+ urgency=rng.uniform(0.3, 1.0),
149
+ delivery_radius=rng.uniform(70.0, 130.0),
150
+ ))
151
+ env.targets = targets
152
+
153
+ # --- Hazards (3-8 normal, 5-10 ultra) with wild variety ---
154
+ if ultra_hard:
155
+ n_hazards = rng.choices([5, 6, 7, 8, 9, 10], weights=[0.10, 0.20, 0.25, 0.25, 0.15, 0.05])[0]
156
+ else:
157
+ n_hazards = rng.choices([3, 4, 5, 6, 7, 8], weights=[0.10, 0.20, 0.25, 0.25, 0.15, 0.05])[0]
158
+ hazards = []
159
+ for i in range(n_hazards):
160
+ center = _rpos_ground()
161
+ fire_type = rng.choice(["tiny_intense", "massive_low", "tall_mid", "standard"])
162
+ if fire_type == "tiny_intense":
163
+ r, sev, ht, gr = rng.uniform(80, 200), rng.uniform(0.9, 1.0), rng.uniform(140, 195), rng.uniform(0.012, 0.025)
164
+ elif fire_type == "massive_low":
165
+ r, sev, ht, gr = rng.uniform(500, 1000), rng.uniform(0.3, 0.5), rng.uniform(25, 50), rng.uniform(0.001, 0.004)
166
+ elif fire_type == "tall_mid":
167
+ r, sev, ht, gr = rng.uniform(250, 500), rng.uniform(0.7, 0.95), rng.uniform(100, 180), rng.uniform(0.008, 0.015)
168
+ else:
169
+ r, sev, ht, gr = rng.uniform(200, 600), rng.uniform(0.4, 0.9), rng.uniform(40, 120), rng.uniform(0.003, 0.012)
170
+ hazards.append(HazardRegion(id=f"H{i+1}", center=center,
171
+ radius=r, severity=sev, height=ht, growth_rate=gr))
172
+ env.hazards = hazards
173
+
174
+ # --- Obstacle templates ---
175
+ obstacles: list[ObstacleVolume] = []
176
+ cylinders: list[CylindricalObstacle] = []
177
+ oid = [0]
178
+
179
+ def _next_oid(prefix="O"):
180
+ oid[0] += 1
181
+ return f"{prefix}{oid[0]}"
182
+
183
+ def _add_box(cx, cy, w, h, zt, kind="building"):
184
+ obstacles.append(ObstacleVolume(
185
+ id=_next_oid(), kind=kind,
186
+ min_corner=Vec3(cx - w / 2, cy - h / 2, 0.0),
187
+ max_corner=Vec3(cx + w / 2, cy + h / 2, zt),
188
+ ))
189
+
190
+ def _add_cyl(cx, cy, radius, height, kind="tree"):
191
+ cylinders.append(CylindricalObstacle(
192
+ id=_next_oid("C"), kind=kind,
193
+ center=Vec3(cx, cy, 0.0), radius=radius, height=height,
194
+ ))
195
+
196
+ if ultra_hard:
197
+ template = rng.choices(["urban_grid", "dense_forest", "corridor_maze",
198
+ "river_valley", "fortress", "mixed"],
199
+ weights=[0.08, 0.12, 0.12, 0.10, 0.10, 0.48])[0]
200
+ else:
201
+ template = rng.choice(["urban_grid", "dense_forest", "corridor_maze",
202
+ "river_valley", "fortress", "mixed"])
203
+
204
+ # ---- URBAN GRID: rows and columns of buildings ----
205
+ if template == "urban_grid" or template == "mixed":
206
+ ox = rng.uniform(500, 1500)
207
+ oy = rng.uniform(500, 1500)
208
+ rows = rng.randint(2, 5) if ultra_hard else rng.randint(2, 4)
209
+ cols = rng.randint(3, 6) if ultra_hard else rng.randint(3, 5)
210
+ spacing = rng.uniform(300, 550) if ultra_hard else rng.uniform(350, 600)
211
+ for r in range(rows):
212
+ for c in range(cols):
213
+ bx = ox + c * spacing + rng.uniform(-80, 80)
214
+ by = oy + r * spacing + rng.uniform(-80, 80)
215
+ if bx < margin or bx > wx - margin or by < margin or by > wy - margin:
216
+ continue
217
+ bw = rng.uniform(80, 300)
218
+ bh = rng.uniform(80, 300)
219
+ bzt = rng.choice([rng.uniform(30, 60), rng.uniform(100, 195)])
220
+ _add_box(bx, by, bw, bh, bzt)
221
+ if rng.random() < (0.45 if ultra_hard else 0.3):
222
+ arm_dir = rng.choice(["east", "north"])
223
+ if arm_dir == "east":
224
+ _add_box(bx + bw / 2 + 40, by, 80, bh * 0.6, bzt * 0.9)
225
+ else:
226
+ _add_box(bx, by + bh / 2 + 40, bw * 0.6, 80, bzt * 0.9)
227
+
228
+ # ---- DENSE FOREST: many cylindrical trees ----
229
+ if template == "dense_forest" or template == "mixed":
230
+ forest_cx = rng.uniform(800, wx - 800)
231
+ forest_cy = rng.uniform(800, wy - 800)
232
+ n_trees = rng.randint(25, 60) if ultra_hard else rng.randint(15, 40)
233
+ for _ in range(n_trees):
234
+ tx = forest_cx + rng.gauss(0, 600)
235
+ ty = forest_cy + rng.gauss(0, 600)
236
+ tx = max(margin, min(wx - margin, tx))
237
+ ty = max(margin, min(wy - margin, ty))
238
+ tree_type = rng.choice(["pine", "oak", "palm", "dead"])
239
+ if tree_type == "pine":
240
+ _add_cyl(tx, ty, rng.uniform(8, 20), rng.uniform(40, 100), "tree_pine")
241
+ elif tree_type == "oak":
242
+ _add_cyl(tx, ty, rng.uniform(15, 40), rng.uniform(25, 60), "tree_oak")
243
+ elif tree_type == "palm":
244
+ _add_cyl(tx, ty, rng.uniform(5, 12), rng.uniform(30, 80), "tree_palm")
245
+ else:
246
+ _add_cyl(tx, ty, rng.uniform(10, 25), rng.uniform(20, 50), "tree_dead")
247
+
248
+ # ---- CORRIDOR MAZE: parallel walls with gaps ----
249
+ if template == "corridor_maze" or template == "mixed":
250
+ maze_ox = rng.uniform(400, wx / 2)
251
+ maze_oy = rng.uniform(400, wy / 2)
252
+ n_walls = rng.randint(6, 12) if ultra_hard else rng.randint(4, 8)
253
+ wall_dir = rng.choice(["horizontal", "vertical"])
254
+ spacing = rng.uniform(200, 500)
255
+ for w in range(n_walls):
256
+ wl = rng.uniform(400, 1500)
257
+ wt = rng.uniform(40, 80)
258
+ wzt = rng.uniform(100, 195)
259
+ if wall_dir == "horizontal":
260
+ wy_pos = maze_oy + w * spacing
261
+ if wy_pos > wy - margin:
262
+ continue
263
+ _add_box(maze_ox + wl / 2, wy_pos, wl, wt, wzt, "wall")
264
+ gap_x = maze_ox + rng.uniform(0.2, 0.8) * wl
265
+ _add_box(gap_x, wy_pos, rng.uniform(80, 200), wt, 0, "gap")
266
+ else:
267
+ wx_pos = maze_ox + w * spacing
268
+ if wx_pos > wx - margin:
269
+ continue
270
+ _add_box(wx_pos, maze_oy + wl / 2, wt, wl, wzt, "wall")
271
+
272
+ # ---- RIVER VALLEY: chain of low flat boxes + scattered trees ----
273
+ if template == "river_valley" or (template == "mixed" and rng.random() < (0.7 if ultra_hard else 0.5)):
274
+ river_start_x = rng.uniform(margin, wx / 3)
275
+ river_y = rng.uniform(wy * 0.3, wy * 0.7)
276
+ n_segs = rng.randint(10, 18) if ultra_hard else rng.randint(6, 12)
277
+ for seg in range(n_segs):
278
+ seg_x = river_start_x + seg * rng.uniform(200, 400)
279
+ seg_y = river_y + rng.gauss(0, 150)
280
+ if seg_x > wx - margin:
281
+ break
282
+ seg_y = max(margin, min(wy - margin, seg_y))
283
+ _add_box(seg_x, seg_y, rng.uniform(200, 400), rng.uniform(60, 150),
284
+ rng.uniform(3, 10), "river")
285
+ for _ in range(rng.randint(2, 6) if ultra_hard else rng.randint(1, 4)):
286
+ bank_offset = rng.choice([-1, 1]) * rng.uniform(100, 300)
287
+ _add_cyl(seg_x + rng.uniform(-100, 100),
288
+ seg_y + bank_offset,
289
+ rng.uniform(8, 20), rng.uniform(30, 80), "tree_bank")
290
+
291
+ # ---- FORTRESS: walls surrounding a target area ----
292
+ if template == "fortress" or (template == "mixed" and rng.random() < (0.6 if ultra_hard else 0.4)):
293
+ if targets:
294
+ fort_target = rng.choice(targets)
295
+ ftx, fty = fort_target.position.x, fort_target.position.y
296
+ wall_half = rng.uniform(250, 500)
297
+ wall_zt = rng.uniform(120, 190)
298
+ wall_thick = rng.uniform(50, 80)
299
+ _add_box(ftx, fty - wall_half, wall_half * 2, wall_thick, wall_zt, "fortress_wall")
300
+ _add_box(ftx, fty + wall_half, wall_half * 2, wall_thick, wall_zt, "fortress_wall")
301
+ _add_box(ftx - wall_half, fty, wall_thick, wall_half * 2, wall_zt, "fortress_wall")
302
+ _add_box(ftx + wall_half, fty, wall_thick, wall_half * 2, wall_zt, "fortress_wall")
303
+
304
+ # ---- Always scatter some light poles and random pillars ----
305
+ n_poles = rng.randint(6, 18) if ultra_hard else rng.randint(3, 10)
306
+ for _ in range(n_poles):
307
+ px = rng.uniform(margin, wx - margin)
308
+ py = rng.uniform(margin, wy - margin)
309
+ _add_cyl(px, py, rng.uniform(2, 6), rng.uniform(30, 80), "light_pole")
310
+
311
+ n_pillars = rng.randint(4, 12) if ultra_hard else rng.randint(2, 6)
312
+ for _ in range(n_pillars):
313
+ px = rng.uniform(margin, wx - margin)
314
+ py = rng.uniform(margin, wy - margin)
315
+ _add_cyl(px, py, rng.uniform(15, 50), rng.uniform(80, 195), "pillar")
316
+
317
+ obstacles = [o for o in obstacles if o.max_corner.z > 1.0]
318
+ env.obstacles = obstacles
319
+ env.cylinders = cylinders
320
+
321
+ # --- Responder units (1 per target, up to 5 in ultra) ---
322
+ responders = []
323
+ max_resp = 5 if ultra_hard else 4
324
+ for i, tgt in enumerate(targets[:max_resp]):
325
+ r = ResponderUnit(
326
+ id=f"R{i+1}",
327
+ position=Vec3(tgt.position.x + rng.uniform(-50, 50),
328
+ tgt.position.y + rng.uniform(-50, 50), 0.0),
329
+ linked_target_id=tgt.id,
330
+ status="stable",
331
+ current_need=rng.choice(["supplies", "medical", "evacuation", "water"]),
332
+ can_update_dropzone=rng.random() < 0.5,
333
+ active=True,
334
+ )
335
+ events = []
336
+
337
+ if rng.random() < 0.7:
338
+ events.append(ScheduledEvent(
339
+ step=rng.randint(100, 600),
340
+ event_type="urgency_update",
341
+ payload={"new_urgency": rng.uniform(0.5, 1.0)},
342
+ ))
343
+
344
+ if r.can_update_dropzone and rng.random() < 0.5:
345
+ events.append(ScheduledEvent(
346
+ step=rng.randint(200, 800),
347
+ event_type="dropzone_relocation",
348
+ payload={"dx": rng.uniform(-200, 200), "dy": rng.uniform(-200, 200)},
349
+ ))
350
+
351
+ if rng.random() < 0.6:
352
+ intel = rng.choice([
353
+ "blocked_north", "blocked_south", "blocked_east", "blocked_west",
354
+ "safe_north", "safe_south", "safe_east", "safe_west",
355
+ "fire_expanded", "fire_receded",
356
+ ])
357
+ events.append(ScheduledEvent(
358
+ step=rng.randint(50, 500),
359
+ event_type="hazard_intel",
360
+ payload={"intel": intel, "severity": rng.uniform(0.3, 1.0)},
361
+ ))
362
+
363
+ r.scheduled_events = events
364
+ responders.append(r)
365
+ env.responders = responders
366
+
367
+
368
+ def build_hardcore_world_v2(env: "VarahaEnv") -> None:
369
+ """Ultra-hard variant: denser obstacles, more hazards, more targets."""
370
+ build_hardcore_world(env, ultra_hard=True)
371
+
372
+
373
+ # ---------------------------------------------------------------------------
374
+ # Environment
375
+ # ---------------------------------------------------------------------------
376
+
377
+ class VarahaEnv:
378
+ """Core wildfire logistics simulation.
379
+
380
+ Action format (dict)::
381
+
382
+ {
383
+ "ax": float, # desired acceleration x (m/s²)
384
+ "ay": float, # desired acceleration y
385
+ "az": float, # desired acceleration z
386
+ "deliver": bool, # attempt delivery if near a target
387
+ "recharge": bool, # attempt recharge if near base
388
+ "tool_call": str, # optional: request_intel | battery_forecast | mission_report
389
+ }
390
+
391
+ Returns ``(obs_dict, reward, done, info_dict)`` per OpenAI-gym convention.
392
+ """
393
+
394
+ def __init__(self, config: Optional[VarahaConfig] = None,
395
+ world_fn: Optional[Any] = None) -> None:
396
+ self.cfg = config or VarahaConfig()
397
+ self._world_fn = world_fn
398
+
399
+ self.base: BaseStation
400
+ self.drone: DroneState
401
+ self.targets: list[DeliveryTarget] = []
402
+ self.hazards: list[HazardRegion] = []
403
+ self.obstacles: list[ObstacleVolume] = []
404
+ self.cylinders: list[CylindricalObstacle] = []
405
+ self.responders: list[ResponderUnit] = []
406
+
407
+ self.step_count: int = 0
408
+ self.cumulative_reward: float = 0.0
409
+ self.done: bool = False
410
+ self.trace: list[TracePoint] = []
411
+
412
+ self._prev_nearest_dist: float = 0.0
413
+ self._hazard_base_heights: list[float] = []
414
+ self._hazard_base_severities: list[float] = []
415
+ self.instructions: list[MissionInstruction] = []
416
+ self._instruction_cursor: int = 0
417
+ self._instruction_violations: int = 0
418
+ self._tool_history: list[str] = []
419
+ self._last_tool_result: dict[str, Any] = {}
420
+ self._instruction_progress_reward: float = 0.0
421
+
422
+ self._rebuild_world()
423
+
424
+ def _rebuild_world(self):
425
+ if self._world_fn is not None:
426
+ self._world_fn(self)
427
+ else:
428
+ self._build_demo_world()
429
+ self._hazard_base_heights = [h.height for h in self.hazards]
430
+ self._hazard_base_severities = [h.severity for h in self.hazards]
431
+
432
+ # ------------------------------------------------------------------
433
+ # World setup
434
+ # ------------------------------------------------------------------
435
+
436
+ def _build_demo_world(self) -> None:
437
+ """Hardcoded 5 km demo scenario.
438
+
439
+ Layout (top-down, +x → east, +y → north, 5 km × 5 km)::
440
+
441
+ T3 (1000,4200)
442
+ ·
443
+ H2 (900,3200) O2 [500-1500, 2600-3000]
444
+ ·
445
+ · T2 (4100,2900) ← inside H1 fringe
446
+ · H1 (3800,2600)
447
+ ·
448
+ · O1 [2200-2800, 1000-2200]
449
+ ·
450
+ · T1 (1800,600)
451
+ ·
452
+ Base (250,250)
453
+
454
+ - T2 sits inside the fringe of hazard H1 → brief hazard exposure required
455
+ - T3 is behind obstacle O2 and near hazard H2
456
+ - O1 blocks direct mid-map routing from T1 to T2
457
+ - Drone can fly over obstacles if altitude > obstacle height
458
+ - Total route ≈ 12 km, battery budget ≈ 300 units
459
+ """
460
+ self.base = BaseStation(position=Vec3(250.0, 250.0, 0.0), recharge_radius=80.0)
461
+
462
+ self.targets = [
463
+ DeliveryTarget(
464
+ id="T1", position=Vec3(1800.0, 600.0, 30.0),
465
+ urgency=0.6, delivery_radius=80.0,
466
+ ),
467
+ DeliveryTarget(
468
+ id="T2", position=Vec3(4100.0, 2900.0, 50.0),
469
+ urgency=1.0, delivery_radius=120.0,
470
+ ),
471
+ DeliveryTarget(
472
+ id="T3", position=Vec3(1000.0, 4200.0, 20.0),
473
+ urgency=0.8, delivery_radius=100.0,
474
+ ),
475
+ ]
476
+
477
+ self.hazards = [
478
+ HazardRegion(
479
+ id="H1", center=Vec3(3800.0, 2600.0, 0.0),
480
+ radius=500.0, severity=0.9,
481
+ height=70.0, growth_rate=0.005,
482
+ ),
483
+ HazardRegion(
484
+ id="H2", center=Vec3(900.0, 3200.0, 0.0),
485
+ radius=400.0, severity=0.7,
486
+ height=55.0, growth_rate=0.008,
487
+ ),
488
+ ]
489
+
490
+ self.obstacles = [
491
+ ObstacleVolume(
492
+ id="O1",
493
+ min_corner=Vec3(2200.0, 1000.0, 0.0),
494
+ max_corner=Vec3(2800.0, 2200.0, 120.0),
495
+ ),
496
+ ObstacleVolume(
497
+ id="O2",
498
+ min_corner=Vec3(500.0, 2600.0, 0.0),
499
+ max_corner=Vec3(1500.0, 3000.0, 90.0),
500
+ ),
501
+ ]
502
+
503
+ # ------------------------------------------------------------------
504
+ # Core API
505
+ # ------------------------------------------------------------------
506
+
507
+ def reset(self, seed: Optional[int] = None) -> dict[str, Any]:
508
+ """Reset the environment and return the initial observation."""
509
+ if seed is not None:
510
+ random.seed(seed)
511
+
512
+ if self._world_fn is not None:
513
+ self._rebuild_world()
514
+
515
+ self.drone = DroneState(
516
+ position=Vec3(self.base.position.x, self.base.position.y, 0.0),
517
+ velocity=Vec3(0.0, 0.0, 0.0),
518
+ battery=self.cfg.battery_capacity,
519
+ carrying_payload=True,
520
+ alive=True,
521
+ )
522
+
523
+ for t in self.targets:
524
+ t.delivered = False
525
+
526
+ for i, h in enumerate(self.hazards):
527
+ h.height = self._hazard_base_heights[i] * random.uniform(0.85, 1.15)
528
+ h.severity = max(0.3, min(1.0, self._hazard_base_severities[i] + random.uniform(-0.1, 0.1)))
529
+ h.reset()
530
+
531
+ for r in self.responders:
532
+ r.active = True
533
+ r.status = "stable"
534
+ r.latest_intel = "none"
535
+ r.intel_severity = 0.0
536
+ r.message = ""
537
+ for ev in r.scheduled_events:
538
+ ev.fired = False
539
+
540
+ self._target_base_positions = {
541
+ t.id: Vec3(t.position.x, t.position.y, t.position.z)
542
+ for t in self.targets
543
+ }
544
+ self._build_instruction_program()
545
+ self._instruction_progress_reward = 0.0
546
+ self._last_tool_result = {}
547
+ self._tool_history = []
548
+
549
+ self.step_count = 0
550
+ self.cumulative_reward = 0.0
551
+ self.done = False
552
+ self.trace = []
553
+ self._prev_nearest_dist = self._nearest_target_dist()
554
+
555
+ obs = self.get_observation()
556
+
557
+ self.trace.append(TracePoint(
558
+ step=0,
559
+ position=Vec3(self.drone.position.x, self.drone.position.y, self.drone.position.z),
560
+ velocity=Vec3(0.0, 0.0, 0.0),
561
+ battery=self.drone.battery,
562
+ reward=0.0,
563
+ cumulative_reward=0.0,
564
+ events=["reset"],
565
+ observation=obs,
566
+ ))
567
+
568
+ return obs
569
+
570
+ def step(self, action: dict[str, Any]) -> tuple[dict, float, bool, dict]:
571
+ """Advance the simulation by one timestep.
572
+
573
+ Returns ``(observation, reward, done, info)``.
574
+ """
575
+ if self.done:
576
+ return self.get_observation(), 0.0, True, StepInfo().to_dict()
577
+
578
+ self.step_count += 1
579
+
580
+ # --- parse & clamp acceleration ---
581
+ accel = Vec3(
582
+ float(action.get("ax", 0.0)),
583
+ float(action.get("ay", 0.0)),
584
+ float(action.get("az", 0.0)),
585
+ ).clamp_magnitude(self.cfg.max_acceleration)
586
+
587
+ # --- kinematics (Euler integration) ---
588
+ self.drone.velocity = (
589
+ self.drone.velocity + accel.scale(self.cfg.dt)
590
+ ).clamp_magnitude(self.cfg.max_speed)
591
+
592
+ old_pos = Vec3(self.drone.position.x, self.drone.position.y, self.drone.position.z)
593
+ self.drone.position = self.drone.position + self.drone.velocity.scale(self.cfg.dt)
594
+
595
+ # clamp to world bounds
596
+ self.drone.position.x = max(0.0, min(self.cfg.world_x, self.drone.position.x))
597
+ self.drone.position.y = max(0.0, min(self.cfg.world_y, self.drone.position.y))
598
+ self.drone.position.z = max(0.0, min(self.cfg.world_z, self.drone.position.z))
599
+
600
+ dist_traveled = old_pos.distance_to(self.drone.position)
601
+ elevation_change = abs(self.drone.position.z - old_pos.z)
602
+
603
+ # --- battery ---
604
+ drain = self._compute_battery_drain(dist_traveled, elevation_change)
605
+ self.drone.battery -= drain
606
+
607
+ # --- advance dynamic hazards ---
608
+ for h in self.hazards:
609
+ h.tick()
610
+
611
+ # --- advance responder events ---
612
+ self._tick_responders()
613
+
614
+ # --- world interactions ---
615
+ collision = self._check_collisions()
616
+ in_hazard, hazard_sev = self._check_hazards()
617
+
618
+ tool_call = ""
619
+ tool_result: dict[str, Any] = {}
620
+ raw_tool_call = action.get("tool_call")
621
+ if raw_tool_call is not None and str(raw_tool_call).strip():
622
+ tool_call, tool_result = self._execute_tool_call(str(raw_tool_call).strip())
623
+
624
+ prev_instruction_cursor = self._instruction_cursor
625
+ delivered_ids: list[str] = []
626
+ if action.get("deliver", False):
627
+ delivered_ids = self._deliver_targets()
628
+
629
+ reached_base = (
630
+ ((self.drone.position.x - self.base.position.x) ** 2
631
+ + (self.drone.position.y - self.base.position.y) ** 2) ** 0.5
632
+ <= self.base.recharge_radius
633
+ )
634
+ if action.get("recharge", False) and reached_base:
635
+ self.drone.battery = min(
636
+ self.cfg.battery_capacity,
637
+ self.drone.battery + self.cfg.recharge_rate,
638
+ )
639
+
640
+ self._update_instruction_progress(
641
+ delivered_ids=delivered_ids,
642
+ reached_base=reached_base,
643
+ tool_call=tool_call,
644
+ )
645
+ completed_now = max(0, self._instruction_cursor - prev_instruction_cursor)
646
+
647
+ if self._all_delivered():
648
+ self.drone.carrying_payload = False
649
+
650
+ # --- reward ---
651
+ info = StepInfo(
652
+ collision=collision,
653
+ delivered_target_ids=delivered_ids,
654
+ in_hazard=in_hazard,
655
+ hazard_severity=hazard_sev,
656
+ reached_base=reached_base,
657
+ distance_traveled=dist_traveled,
658
+ tool_call=tool_call,
659
+ tool_result=tool_result,
660
+ instruction_completed=self._instruction_cursor,
661
+ instruction_total=len(self.instructions),
662
+ instruction_violations=self._instruction_violations,
663
+ )
664
+ reward, breakdown = self._compute_reward(info)
665
+ info.reward_breakdown = breakdown
666
+ self.cumulative_reward += reward
667
+
668
+ # --- termination ---
669
+ if collision:
670
+ self.drone.alive = False
671
+ self.done = True
672
+ elif self.drone.battery <= 0.0:
673
+ self.drone.battery = 0.0
674
+ self.drone.alive = False
675
+ self.done = True
676
+ elif self._is_success():
677
+ self.done = True
678
+ elif self.step_count >= self.cfg.max_episode_steps:
679
+ self.done = True
680
+
681
+ # record trace
682
+ events: list[str] = []
683
+ for tid in delivered_ids:
684
+ events.append(f"delivered_{tid}")
685
+ if collision:
686
+ events.append("collision")
687
+ if in_hazard:
688
+ events.append(f"hazard_{hazard_sev:.2f}")
689
+ if self.drone.battery <= 0.0 and not collision:
690
+ events.append("battery_dead")
691
+ if self._is_success():
692
+ events.append("success")
693
+ if tool_call:
694
+ events.append(f"tool_{tool_call}")
695
+ if completed_now > 0:
696
+ events.append(f"instruction+{completed_now}")
697
+
698
+ obs = self.get_observation()
699
+
700
+ self.trace.append(TracePoint(
701
+ step=self.step_count,
702
+ position=Vec3(self.drone.position.x, self.drone.position.y, self.drone.position.z),
703
+ velocity=Vec3(self.drone.velocity.x, self.drone.velocity.y, self.drone.velocity.z),
704
+ battery=self.drone.battery,
705
+ reward=reward,
706
+ cumulative_reward=self.cumulative_reward,
707
+ events=events,
708
+ observation=obs,
709
+ ))
710
+
711
+ return obs, reward, self.done, info.to_dict()
712
+
713
+ # ------------------------------------------------------------------
714
+ # Observation / render
715
+ # ------------------------------------------------------------------
716
+
717
+ def get_observation(self) -> dict[str, Any]:
718
+ """Compact, RL-friendly observation dict."""
719
+ dp = self.drone.position
720
+
721
+ targets_obs = []
722
+ for t in self.targets:
723
+ rel = t.position - dp
724
+ targets_obs.append({
725
+ "id": t.id,
726
+ "relative_position": rel.to_dict(),
727
+ "urgency": t.urgency,
728
+ "delivered": t.delivered,
729
+ })
730
+
731
+ hazards_obs = []
732
+ for h in self.hazards:
733
+ rel = h.center - dp
734
+ hazards_obs.append({
735
+ "id": h.id,
736
+ "relative_position": rel.to_dict(),
737
+ "current_height": h._current_height,
738
+ "severity": h.severity,
739
+ })
740
+
741
+ obstacles_obs = []
742
+ for obs in self.obstacles:
743
+ c = obs.center
744
+ hs = obs.half_size
745
+ rel = c - dp
746
+ dist = dp.horizontal_distance_to(c)
747
+ obstacles_obs.append({
748
+ "type": "box",
749
+ "relative_position": rel.to_dict(),
750
+ "height": obs.height,
751
+ "size_x": hs.x * 2,
752
+ "size_y": hs.y * 2,
753
+ "distance": dist,
754
+ "kind": obs.kind,
755
+ })
756
+ for cyl in self.cylinders:
757
+ rel = cyl.center - dp
758
+ dist = dp.horizontal_distance_to(cyl.center)
759
+ obstacles_obs.append({
760
+ "type": "cylinder",
761
+ "relative_position": rel.to_dict(),
762
+ "height": cyl.height,
763
+ "size_x": cyl.radius * 2,
764
+ "size_y": cyl.radius * 2,
765
+ "distance": dist,
766
+ "kind": cyl.kind,
767
+ })
768
+ obstacles_obs.sort(key=lambda o: o["distance"])
769
+
770
+ responders_obs = []
771
+ for r in self.responders:
772
+ if not r.active:
773
+ continue
774
+ rel = r.position - dp
775
+ intel_dir = r.intel_direction()
776
+ responders_obs.append({
777
+ "id": r.id,
778
+ "relative_position": rel.to_dict(),
779
+ "linked_target_id": r.linked_target_id,
780
+ "status": r.status,
781
+ "status_code": r.status_code(),
782
+ "latest_intel": r.latest_intel,
783
+ "intel_direction": {"x": intel_dir[0], "y": intel_dir[1]},
784
+ "intel_severity": r.intel_severity,
785
+ })
786
+
787
+ mission_obs = self._instruction_snapshot()
788
+ return {
789
+ "drone_position": dp.to_dict(),
790
+ "drone_velocity": self.drone.velocity.to_dict(),
791
+ "battery": round(self.drone.battery, 4),
792
+ "carrying_payload": self.drone.carrying_payload,
793
+ "alive": self.drone.alive,
794
+ "targets": targets_obs,
795
+ "hazards": hazards_obs,
796
+ "obstacles": obstacles_obs,
797
+ "responders": responders_obs,
798
+ "mission": mission_obs,
799
+ "last_tool_result": self._last_tool_result,
800
+ "step": self.step_count,
801
+ "max_steps": self.cfg.max_episode_steps,
802
+ }
803
+
804
+ def render_state(self) -> dict[str, Any]:
805
+ """Rich state dict for future Cesium / frontend rendering."""
806
+ return {
807
+ "base_station": self.base.to_dict(),
808
+ "drone": self.drone.to_dict(),
809
+ "targets": [t.to_dict() for t in self.targets],
810
+ "hazards": [h.to_dict() for h in self.hazards],
811
+ "obstacles": [o.to_dict() for o in self.obstacles],
812
+ "cylinders": [c.to_dict() for c in self.cylinders],
813
+ "responders": [r.to_dict() for r in self.responders],
814
+ "mission": self._instruction_snapshot(include_full=True),
815
+ "tool_history": list(self._tool_history),
816
+ "step": self.step_count,
817
+ "max_steps": self.cfg.max_episode_steps,
818
+ "cumulative_reward": round(self.cumulative_reward, 4),
819
+ "done": self.done,
820
+ }
821
+
822
+ def get_trace(self) -> dict[str, Any]:
823
+ """Full episode trace for replay / visualisation."""
824
+ return {
825
+ "world": {
826
+ "bounds": {"x": self.cfg.world_x, "y": self.cfg.world_y, "z": self.cfg.world_z},
827
+ "base_station": self.base.to_dict(),
828
+ "targets": [t.to_dict() for t in self.targets],
829
+ "hazards": [h.to_dict() for h in self.hazards],
830
+ "obstacles": [o.to_dict() for o in self.obstacles],
831
+ "cylinders": [c.to_dict() for c in self.cylinders],
832
+ "responders": [r.to_dict() for r in self.responders],
833
+ "mission": self._instruction_snapshot(include_full=True),
834
+ },
835
+ "trace": [tp.to_dict() for tp in self.trace],
836
+ "summary": {
837
+ "total_steps": self.step_count,
838
+ "cumulative_reward": round(self.cumulative_reward, 4),
839
+ "delivered": [t.id for t in self.targets if t.delivered],
840
+ "alive": self.drone.alive,
841
+ "final_battery": round(self.drone.battery, 4),
842
+ "success": self._is_success(),
843
+ "instruction_completed": self._instruction_cursor,
844
+ "instruction_total": len(self.instructions),
845
+ "instruction_violations": self._instruction_violations,
846
+ "tool_calls": list(self._tool_history),
847
+ },
848
+ }
849
+
850
+ # ------------------------------------------------------------------
851
+ # Long-horizon instruction mode
852
+ # ------------------------------------------------------------------
853
+
854
+ def _build_instruction_program(self) -> None:
855
+ self.instructions = []
856
+ self._instruction_cursor = 0
857
+ self._instruction_violations = 0
858
+
859
+ if not self.cfg.instruction_mode or not self.targets:
860
+ return
861
+
862
+ ordered_targets = sorted(self.targets, key=lambda t: (-t.urgency, t.id))
863
+ target_count = len(ordered_targets)
864
+ desired_len = self.cfg.instruction_count if self.cfg.instruction_count > 0 else (target_count * 3 + 1)
865
+ desired_len = max(desired_len, target_count * 2 + 1)
866
+
867
+ instructions: list[MissionInstruction] = []
868
+ inst_idx = 1
869
+ cycle = 0
870
+ while len(instructions) < max(desired_len - 1, 1):
871
+ for tgt in ordered_targets:
872
+ if len(instructions) >= max(desired_len - 1, 1):
873
+ break
874
+ instructions.append(
875
+ MissionInstruction(
876
+ id=f"I{inst_idx}",
877
+ kind="deliver_target",
878
+ description=f"Cycle {cycle + 1}: deliver to {tgt.id} in order.",
879
+ target_id=tgt.id,
880
+ )
881
+ )
882
+ inst_idx += 1
883
+ if len(instructions) >= max(desired_len - 1, 1):
884
+ break
885
+ tool = "request_intel" if (cycle % 2 == 0) else "battery_forecast"
886
+ instructions.append(
887
+ MissionInstruction(
888
+ id=f"I{inst_idx}",
889
+ kind="tool_call",
890
+ description=f"Call {tool} after servicing {tgt.id}.",
891
+ target_id=tgt.id,
892
+ tool_name=tool,
893
+ )
894
+ )
895
+ inst_idx += 1
896
+ cycle += 1
897
+
898
+ instructions.append(
899
+ MissionInstruction(
900
+ id=f"I{inst_idx}",
901
+ kind="return_base",
902
+ description="Return to base only after all deliveries are completed.",
903
+ )
904
+ )
905
+ self.instructions = instructions
906
+
907
+ def _current_instruction(self) -> Optional[MissionInstruction]:
908
+ if self._instruction_cursor >= len(self.instructions):
909
+ return None
910
+ return self.instructions[self._instruction_cursor]
911
+
912
+ def _instruction_snapshot(self, include_full: bool = False) -> dict[str, Any]:
913
+ total = len(self.instructions)
914
+ completed = min(self._instruction_cursor, total)
915
+ next_instruction = self._current_instruction()
916
+ out: dict[str, Any] = {
917
+ "enabled": self.cfg.instruction_mode,
918
+ "total": total,
919
+ "completed": completed,
920
+ "remaining": max(total - completed, 0),
921
+ "progress": (completed / total) if total > 0 else 1.0,
922
+ "violations": self._instruction_violations,
923
+ "next_instruction": next_instruction.to_dict() if next_instruction else None,
924
+ }
925
+ if include_full:
926
+ out["instructions"] = [inst.to_dict() for inst in self.instructions]
927
+ return out
928
+
929
+ def _complete_current_instruction(self) -> None:
930
+ inst = self._current_instruction()
931
+ if inst is None:
932
+ return
933
+ inst.completed = True
934
+ self._instruction_cursor += 1
935
+ self._instruction_progress_reward += self.cfg.instruction_completion_reward
936
+
937
+ def _record_instruction_violation(self) -> None:
938
+ self._instruction_violations += 1
939
+ inst = self._current_instruction()
940
+ if inst is not None:
941
+ inst.violated = True
942
+
943
+ def _tool_matches_instruction(self, tool_call: str, inst: MissionInstruction) -> bool:
944
+ base, _, arg = tool_call.partition(":")
945
+ if base != inst.tool_name:
946
+ return False
947
+ if inst.target_id and arg and arg != inst.target_id:
948
+ return False
949
+ return True
950
+
951
+ def _update_instruction_progress(
952
+ self,
953
+ delivered_ids: list[str],
954
+ reached_base: bool,
955
+ tool_call: str,
956
+ ) -> None:
957
+ if not self.cfg.instruction_mode or not self.instructions:
958
+ return
959
+
960
+ inst = self._current_instruction()
961
+ if inst and inst.kind == "deliver_target":
962
+ for tid in delivered_ids:
963
+ if tid != inst.target_id:
964
+ self._record_instruction_violation()
965
+
966
+ while True:
967
+ inst = self._current_instruction()
968
+ if inst is None:
969
+ break
970
+
971
+ if inst.kind == "deliver_target":
972
+ if inst.target_id in delivered_ids:
973
+ self._complete_current_instruction()
974
+ continue
975
+ break
976
+
977
+ if inst.kind == "tool_call":
978
+ if not tool_call:
979
+ break
980
+ if self._tool_matches_instruction(tool_call, inst):
981
+ self._complete_current_instruction()
982
+ else:
983
+ self._record_instruction_violation()
984
+ break
985
+
986
+ if inst.kind == "return_base":
987
+ if reached_base and self._all_delivered():
988
+ self._complete_current_instruction()
989
+ break
990
+
991
+ break
992
+
993
+ def _execute_tool_call(self, tool_call: str) -> tuple[str, dict[str, Any]]:
994
+ raw = tool_call.strip().lower()
995
+ if not raw:
996
+ return "", {}
997
+
998
+ tool_name, _, arg = raw.partition(":")
999
+ normalized_call = f"{tool_name}:{arg}" if arg else tool_name
1000
+
1001
+ if tool_name not in self.cfg.available_tools:
1002
+ result = {"ok": False, "error": f"unsupported_tool:{tool_name}"}
1003
+ self._tool_history.append(normalized_call)
1004
+ self._last_tool_result = result
1005
+ return normalized_call, result
1006
+
1007
+ if tool_name == "request_intel":
1008
+ responder = None
1009
+ if arg:
1010
+ responder = next(
1011
+ (r for r in self.responders if r.active and r.linked_target_id.lower() == arg.lower()),
1012
+ None,
1013
+ )
1014
+ if responder is None:
1015
+ responder = next((r for r in self.responders if r.active), None)
1016
+ if responder is None:
1017
+ result = {"ok": True, "intel": "none", "message": "no_active_responders"}
1018
+ else:
1019
+ result = {
1020
+ "ok": True,
1021
+ "intel": responder.latest_intel,
1022
+ "intel_severity": round(responder.intel_severity, 3),
1023
+ "responder_id": responder.id,
1024
+ "target_id": responder.linked_target_id,
1025
+ "message": responder.message,
1026
+ }
1027
+ elif tool_name == "battery_forecast":
1028
+ burn = max(self.cfg.drain_per_meter, 1e-6)
1029
+ est_range = self.drone.battery / burn
1030
+ result = {
1031
+ "ok": True,
1032
+ "battery": round(self.drone.battery, 3),
1033
+ "estimated_range_m": round(est_range, 1),
1034
+ }
1035
+ else: # mission_report
1036
+ result = {
1037
+ "ok": True,
1038
+ "delivered": [t.id for t in self.targets if t.delivered],
1039
+ "remaining": [t.id for t in self.targets if not t.delivered],
1040
+ "instruction_progress": round(self._instruction_snapshot()["progress"], 3),
1041
+ "violations": self._instruction_violations,
1042
+ }
1043
+
1044
+ self._tool_history.append(normalized_call)
1045
+ self._last_tool_result = result
1046
+ return normalized_call, result
1047
+
1048
+ # ------------------------------------------------------------------
1049
+ # Coordinate conversion
1050
+ # ------------------------------------------------------------------
1051
+
1052
+ def local_to_latlon(self, vec: Vec3) -> tuple[float, float, float]:
1053
+ """Convert local (x, y, z) metres to (lat, lon, alt).
1054
+
1055
+ Uses a flat-earth approximation centred on ``cfg.origin_lat/lon``.
1056
+ Accurate enough for small areas (~tens of km) and Cesium plotting.
1057
+ """
1058
+ meters_per_deg_lat = 111_320.0
1059
+ meters_per_deg_lon = 111_320.0 * math.cos(math.radians(self.cfg.origin_lat))
1060
+
1061
+ lat = self.cfg.origin_lat + vec.y / meters_per_deg_lat
1062
+ lon = self.cfg.origin_lon + vec.x / meters_per_deg_lon
1063
+ alt = vec.z
1064
+ return (round(lat, 7), round(lon, 7), round(alt, 2))
1065
+
1066
+ # ------------------------------------------------------------------
1067
+ # Internal helpers
1068
+ # ------------------------------------------------------------------
1069
+
1070
+ def _compute_battery_drain(self, dist: float, elevation_change: float) -> float:
1071
+ return (
1072
+ dist * self.cfg.drain_per_meter
1073
+ + elevation_change * self.cfg.drain_elevation_factor
1074
+ + self.cfg.drain_idle_per_step
1075
+ )
1076
+
1077
+ def _check_collisions(self) -> bool:
1078
+ for obs in self.obstacles:
1079
+ if obs.contains(self.drone.position):
1080
+ return True
1081
+ for cyl in self.cylinders:
1082
+ if cyl.contains(self.drone.position):
1083
+ return True
1084
+ return False
1085
+
1086
+ def _check_hazards(self) -> tuple[bool, float]:
1087
+ max_sev = 0.0
1088
+ in_hazard = False
1089
+ for h in self.hazards:
1090
+ df = h.danger_factor(self.drone.position)
1091
+ if df > 0.0:
1092
+ in_hazard = True
1093
+ max_sev = max(max_sev, df)
1094
+ return in_hazard, max_sev
1095
+
1096
+ def _deliver_targets(self) -> list[str]:
1097
+ """Cylindrical delivery check — drone must be within horizontal radius
1098
+ and above the target (within a generous altitude window for drops)."""
1099
+ delivered: list[str] = []
1100
+ for t in self.targets:
1101
+ if t.delivered:
1102
+ continue
1103
+ dx = self.drone.position.x - t.position.x
1104
+ dy = self.drone.position.y - t.position.y
1105
+ horiz_dist = (dx * dx + dy * dy) ** 0.5
1106
+ alt_above = self.drone.position.z - t.position.z
1107
+ if horiz_dist <= t.delivery_radius and -10.0 <= alt_above <= t.delivery_radius * 2:
1108
+ t.delivered = True
1109
+ delivered.append(t.id)
1110
+ return delivered
1111
+
1112
+ def _all_delivered(self) -> bool:
1113
+ return all(t.delivered for t in self.targets)
1114
+
1115
+ def _is_success(self) -> bool:
1116
+ hdist = ((self.drone.position.x - self.base.position.x) ** 2
1117
+ + (self.drone.position.y - self.base.position.y) ** 2) ** 0.5
1118
+ return self._all_delivered() and hdist <= self.base.recharge_radius
1119
+
1120
+ def _nearest_target_dist(self) -> float:
1121
+ """Horizontal distance to closest undelivered target, or to base if all done."""
1122
+ dists = [
1123
+ ((self.drone.position.x - t.position.x) ** 2
1124
+ + (self.drone.position.y - t.position.y) ** 2) ** 0.5
1125
+ for t in self.targets
1126
+ if not t.delivered
1127
+ ]
1128
+ if not dists:
1129
+ return ((self.drone.position.x - self.base.position.x) ** 2
1130
+ + (self.drone.position.y - self.base.position.y) ** 2) ** 0.5
1131
+ return min(dists)
1132
+
1133
+ def _tick_responders(self) -> None:
1134
+ """Process scheduled responder events for the current step."""
1135
+ for r in self.responders:
1136
+ if not r.active:
1137
+ continue
1138
+ for ev in r.scheduled_events:
1139
+ if ev.fired or ev.step != self.step_count:
1140
+ continue
1141
+ ev.fired = True
1142
+ etype = ev.event_type
1143
+
1144
+ if etype == "urgency_update":
1145
+ tgt = self._find_target(r.linked_target_id)
1146
+ if tgt and not tgt.delivered:
1147
+ tgt.urgency = max(0.1, min(1.0, ev.payload.get("new_urgency", tgt.urgency)))
1148
+ r.status = "critical" if tgt.urgency >= 0.9 else "urgent" if tgt.urgency >= 0.6 else "stable"
1149
+ r.message = f"urgency->{tgt.urgency:.1f}"
1150
+
1151
+ elif etype == "dropzone_relocation":
1152
+ tgt = self._find_target(r.linked_target_id)
1153
+ if tgt and not tgt.delivered and r.can_update_dropzone:
1154
+ dx = ev.payload.get("dx", 0.0)
1155
+ dy = ev.payload.get("dy", 0.0)
1156
+ tgt.position.x = max(50, min(self.cfg.world_x - 50, tgt.position.x + dx))
1157
+ tgt.position.y = max(50, min(self.cfg.world_y - 50, tgt.position.y + dy))
1158
+ r.position = Vec3(tgt.position.x, tgt.position.y, 0.0)
1159
+ r.message = f"dropzone moved ({dx:+.0f},{dy:+.0f})"
1160
+ self._prev_nearest_dist = self._nearest_target_dist()
1161
+
1162
+ elif etype == "hazard_intel":
1163
+ r.latest_intel = ev.payload.get("intel", "none")
1164
+ r.intel_severity = ev.payload.get("severity", 0.5)
1165
+ r.message = f"intel: {r.latest_intel}"
1166
+
1167
+ def _find_target(self, tid: str) -> Optional[DeliveryTarget]:
1168
+ for t in self.targets:
1169
+ if t.id == tid:
1170
+ return t
1171
+ return None
1172
+
1173
+ def _obstacle_proximity_penalty(self) -> float:
1174
+ """Graduated penalty for flying close to any obstacle surface."""
1175
+ min_dist = float("inf")
1176
+ pos = self.drone.position
1177
+ for obs in self.obstacles:
1178
+ d = obs.nearest_surface_dist(pos)
1179
+ if d < min_dist:
1180
+ min_dist = d
1181
+ for cyl in self.cylinders:
1182
+ d = cyl.nearest_surface_dist(pos)
1183
+ if d < min_dist:
1184
+ min_dist = d
1185
+ if min_dist >= self.cfg.obstacle_proximity_radius:
1186
+ return 0.0
1187
+ factor = 1.0 - min_dist / self.cfg.obstacle_proximity_radius
1188
+ return self.cfg.obstacle_proximity_penalty * factor * factor
1189
+
1190
+ def _compute_reward(self, info: StepInfo) -> tuple[float, dict[str, float]]:
1191
+ if self.cfg.instruction_mode and self.cfg.sparse_reward_mode:
1192
+ return self._compute_sparse_instruction_reward(info)
1193
+
1194
+ bd: dict[str, float] = {}
1195
+ total = 0.0
1196
+
1197
+ # per-step cost of time
1198
+ bd["step_penalty"] = -self.cfg.step_penalty
1199
+ total += bd["step_penalty"]
1200
+
1201
+ # battery usage cost (proportional to energy spent)
1202
+ bd["battery_cost"] = -(
1203
+ info.distance_traveled * self.cfg.drain_per_meter * self.cfg.battery_cost_factor
1204
+ )
1205
+ total += bd["battery_cost"]
1206
+
1207
+ if self._instruction_progress_reward > 0.0:
1208
+ bd["instruction_progress"] = self._instruction_progress_reward
1209
+ total += bd["instruction_progress"]
1210
+ self._instruction_progress_reward = 0.0
1211
+
1212
+ # delivery rewards (scaled by urgency) + progress bonus
1213
+ for tid in info.delivered_target_ids:
1214
+ tgt = next(t for t in self.targets if t.id == tid)
1215
+ r = self.cfg.delivery_reward * (1.0 + tgt.urgency)
1216
+ bd[f"delivery_{tid}"] = r
1217
+ total += r
1218
+
1219
+ if info.delivered_target_ids:
1220
+ n_remaining = sum(1 for t in self.targets if not t.delivered)
1221
+ progress_bonus = 50.0 * (1.0 - n_remaining / len(self.targets))
1222
+ bd["progress_bonus"] = progress_bonus
1223
+ total += progress_bonus
1224
+
1225
+ # collision
1226
+ if info.collision:
1227
+ bd["collision"] = -self.cfg.collision_penalty
1228
+ total += bd["collision"]
1229
+
1230
+ # hazard exposure (severity-weighted)
1231
+ if info.in_hazard:
1232
+ bd["hazard"] = -self.cfg.hazard_penalty * info.hazard_severity
1233
+ total += bd["hazard"]
1234
+
1235
+ # safe return bonus
1236
+ if info.reached_base and self._all_delivered():
1237
+ bd["return_bonus"] = self.cfg.return_bonus
1238
+ total += bd["return_bonus"]
1239
+
1240
+ # distance shaping — nudge toward nearest undelivered target (or base)
1241
+ # Skip shaping on delivery steps to avoid a huge negative spike
1242
+ # when the nearest-target reference jumps to a farther target.
1243
+ # Double the factor when heading home after all deliveries.
1244
+ curr_dist = self._nearest_target_dist()
1245
+ if info.delivered_target_ids:
1246
+ bd["distance_shaping"] = 0.0
1247
+ self._prev_nearest_dist = curr_dist
1248
+ else:
1249
+ factor = self.cfg.distance_shaping_factor
1250
+ if self._all_delivered():
1251
+ factor *= 2.0
1252
+ shaping = (self._prev_nearest_dist - curr_dist) * factor
1253
+ bd["distance_shaping"] = shaping
1254
+ total += shaping
1255
+ self._prev_nearest_dist = curr_dist
1256
+
1257
+ # obstacle proximity (graduated — discourages flying close)
1258
+ prox = self._obstacle_proximity_penalty()
1259
+ if prox > 0:
1260
+ bd["obstacle_proximity"] = -prox
1261
+ total -= prox
1262
+
1263
+ # failure (battery depletion; collision already penalised above)
1264
+ if self.drone.battery <= 0.0 and not info.collision:
1265
+ bd["failure"] = -self.cfg.failure_penalty
1266
+ total += bd["failure"]
1267
+
1268
+ bd["total"] = total
1269
+ return total, bd
1270
+
1271
+ def _compute_sparse_instruction_reward(self, info: StepInfo) -> tuple[float, dict[str, float]]:
1272
+ bd: dict[str, float] = {}
1273
+ total = 0.0
1274
+
1275
+ # Keep shaping intentionally small in sparse mode.
1276
+ bd["step_penalty"] = -(self.cfg.step_penalty * 0.25)
1277
+ total += bd["step_penalty"]
1278
+
1279
+ if self._instruction_progress_reward > 0.0:
1280
+ bd["instruction_progress"] = self._instruction_progress_reward
1281
+ total += bd["instruction_progress"]
1282
+ self._instruction_progress_reward = 0.0
1283
+
1284
+ if info.in_hazard:
1285
+ bd["hazard"] = -(self.cfg.hazard_penalty * 0.2 * info.hazard_severity)
1286
+ total += bd["hazard"]
1287
+
1288
+ terminal = (
1289
+ info.collision
1290
+ or self.drone.battery <= 0.0
1291
+ or self._is_success()
1292
+ or self.step_count >= self.cfg.max_episode_steps
1293
+ )
1294
+ if terminal:
1295
+ total_instr = len(self.instructions)
1296
+ progress = (self._instruction_cursor / total_instr) if total_instr > 0 else 1.0
1297
+ bd["terminal_progress"] = self.cfg.instruction_terminal_progress_bonus * progress
1298
+ total += bd["terminal_progress"]
1299
+
1300
+ if self._is_success():
1301
+ bd["terminal_success"] = self.cfg.instruction_terminal_success_bonus
1302
+ total += bd["terminal_success"]
1303
+ else:
1304
+ bd["terminal_failure"] = -self.cfg.failure_penalty
1305
+ total += bd["terminal_failure"]
1306
+
1307
+ remaining = max(total_instr - self._instruction_cursor, 0)
1308
+ if remaining > 0:
1309
+ bd["unfinished_penalty"] = -remaining * self.cfg.instruction_unfinished_penalty
1310
+ total += bd["unfinished_penalty"]
1311
+
1312
+ if self._instruction_violations > 0:
1313
+ bd["instruction_violations"] = (
1314
+ -self._instruction_violations * self.cfg.instruction_violation_penalty
1315
+ )
1316
+ total += bd["instruction_violations"]
1317
+
1318
+ if info.collision:
1319
+ bd["collision"] = -self.cfg.collision_penalty
1320
+ total += bd["collision"]
1321
+
1322
+ bd["total"] = total
1323
+ return total, bd