"""Varaha — wildfire logistics simulation environment. A drone must deliver supplies to responder zones near wildfire hazards in California-like terrain. The environment uses lightweight 3D kinematics with local metre-based coordinates and an optional lat/lon conversion helper for later Cesium visualisation. """ import math import random from dataclasses import dataclass from typing import Any, Optional from sim_types import ( Vec3, DroneState, BaseStation, DeliveryTarget, HazardRegion, ObstacleVolume, CylindricalObstacle, ResponderUnit, ScheduledEvent, RESPONDER_STATUSES, INTEL_TYPES, StepInfo, TracePoint, MissionInstruction, ) # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- @dataclass class VarahaConfig: """All tunable environment parameters live here.""" # World bounds (metres) — 5 km × 5 km operational area world_x: float = 5000.0 world_y: float = 5000.0 world_z: float = 200.0 # Drone physics battery_capacity: float = 300.0 max_speed: float = 25.0 # m/s max_acceleration: float = 8.0 # m/s² dt: float = 0.5 # seconds per step # Episode max_episode_steps: int = 2000 # Battery drain coefficients (tuned for 5 km scale) drain_per_meter: float = 0.008 drain_elevation_factor: float = 0.02 drain_idle_per_step: float = 0.005 recharge_rate: float = 5.0 # battery units restored per recharge step # Reward knobs delivery_reward: float = 200.0 return_bonus: float = 100.0 step_penalty: float = 0.05 battery_cost_factor: float = 0.3 collision_penalty: float = 500.0 hazard_penalty: float = 5.0 failure_penalty: float = 200.0 distance_shaping_factor: float = 0.05 obstacle_proximity_penalty: float = 1.5 obstacle_proximity_radius: float = 80.0 # Long-horizon instruction mode (LLM-oriented) instruction_mode: bool = False instruction_count: int = 60 sparse_reward_mode: bool = False instruction_completion_reward: float = 0.5 instruction_terminal_success_bonus: float = 2200.0 instruction_terminal_progress_bonus: float = 800.0 instruction_violation_penalty: float = 120.0 instruction_unfinished_penalty: float = 10.0 available_tools: tuple[str, ...] = ( "request_intel", "battery_forecast", "mission_report", ) # California origin anchor (near Sacramento — wildfire-relevant) origin_lat: float = 38.55 origin_lon: float = -121.47 # --------------------------------------------------------------------------- # Random world generator for domain randomization # --------------------------------------------------------------------------- def build_random_world(env: "VarahaEnv") -> None: """Legacy easy world gen — kept for backward compatibility.""" build_hardcore_world(env) def _hdist(a: Vec3, b: Vec3) -> float: return ((a.x - b.x) ** 2 + (a.y - b.y) ** 2) ** 0.5 def build_hardcore_world(env: "VarahaEnv", ultra_hard: bool = False) -> None: """Generate an extremely challenging randomized world for serious RL training. Features template-based obstacle placement (urban grid, dense forest, corridor maze, river valley, fortress, mixed), cylindrical obstacles, responder units with dynamic events, and adversarial target placement. When ultra_hard=True: denser obstacles, more hazards, more targets, longer episodes. """ cfg = env.cfg rng = random wx, wy, wz = cfg.world_x, cfg.world_y, cfg.world_z margin = 200.0 def _rpos(z_lo=10.0, z_hi=60.0): return Vec3(rng.uniform(margin, wx - margin), rng.uniform(margin, wy - margin), rng.uniform(z_lo, z_hi)) def _rpos_ground(): return Vec3(rng.uniform(margin, wx - margin), rng.uniform(margin, wy - margin), 0.0) # --- Base station --- base_pos = Vec3(rng.uniform(100, wx - 100), rng.uniform(100, wy - 100), 0.0) env.base = BaseStation(position=base_pos, recharge_radius=rng.uniform(60, 100)) # --- Targets (2-5 normal, 3-6 ultra) --- if ultra_hard: n_targets = rng.choices([3, 4, 5, 6], weights=[0.15, 0.35, 0.35, 0.15])[0] else: n_targets = rng.choices([2, 3, 4, 5], weights=[0.15, 0.40, 0.30, 0.15])[0] targets = [] for i in range(n_targets): for _ in range(120): pos = _rpos(z_lo=5.0, z_hi=60.0) if _hdist(pos, base_pos) < 500: continue if all(_hdist(pos, t.position) > 400 for t in targets): break targets.append(DeliveryTarget( id=f"T{i+1}", position=pos, urgency=rng.uniform(0.3, 1.0), delivery_radius=rng.uniform(70.0, 130.0), )) env.targets = targets # --- Hazards (3-8 normal, 5-10 ultra) with wild variety --- if ultra_hard: n_hazards = rng.choices([5, 6, 7, 8, 9, 10], weights=[0.10, 0.20, 0.25, 0.25, 0.15, 0.05])[0] else: n_hazards = rng.choices([3, 4, 5, 6, 7, 8], weights=[0.10, 0.20, 0.25, 0.25, 0.15, 0.05])[0] hazards = [] for i in range(n_hazards): center = _rpos_ground() fire_type = rng.choice(["tiny_intense", "massive_low", "tall_mid", "standard"]) if fire_type == "tiny_intense": r, sev, ht, gr = rng.uniform(80, 200), rng.uniform(0.9, 1.0), rng.uniform(140, 195), rng.uniform(0.012, 0.025) elif fire_type == "massive_low": r, sev, ht, gr = rng.uniform(500, 1000), rng.uniform(0.3, 0.5), rng.uniform(25, 50), rng.uniform(0.001, 0.004) elif fire_type == "tall_mid": r, sev, ht, gr = rng.uniform(250, 500), rng.uniform(0.7, 0.95), rng.uniform(100, 180), rng.uniform(0.008, 0.015) else: r, sev, ht, gr = rng.uniform(200, 600), rng.uniform(0.4, 0.9), rng.uniform(40, 120), rng.uniform(0.003, 0.012) hazards.append(HazardRegion(id=f"H{i+1}", center=center, radius=r, severity=sev, height=ht, growth_rate=gr)) env.hazards = hazards # --- Obstacle templates --- obstacles: list[ObstacleVolume] = [] cylinders: list[CylindricalObstacle] = [] oid = [0] def _next_oid(prefix="O"): oid[0] += 1 return f"{prefix}{oid[0]}" def _add_box(cx, cy, w, h, zt, kind="building"): obstacles.append(ObstacleVolume( id=_next_oid(), kind=kind, min_corner=Vec3(cx - w / 2, cy - h / 2, 0.0), max_corner=Vec3(cx + w / 2, cy + h / 2, zt), )) def _add_cyl(cx, cy, radius, height, kind="tree"): cylinders.append(CylindricalObstacle( id=_next_oid("C"), kind=kind, center=Vec3(cx, cy, 0.0), radius=radius, height=height, )) if ultra_hard: template = rng.choices(["urban_grid", "dense_forest", "corridor_maze", "river_valley", "fortress", "mixed"], weights=[0.08, 0.12, 0.12, 0.10, 0.10, 0.48])[0] else: template = rng.choice(["urban_grid", "dense_forest", "corridor_maze", "river_valley", "fortress", "mixed"]) # ---- URBAN GRID: rows and columns of buildings ---- if template == "urban_grid" or template == "mixed": ox = rng.uniform(500, 1500) oy = rng.uniform(500, 1500) rows = rng.randint(2, 5) if ultra_hard else rng.randint(2, 4) cols = rng.randint(3, 6) if ultra_hard else rng.randint(3, 5) spacing = rng.uniform(300, 550) if ultra_hard else rng.uniform(350, 600) for r in range(rows): for c in range(cols): bx = ox + c * spacing + rng.uniform(-80, 80) by = oy + r * spacing + rng.uniform(-80, 80) if bx < margin or bx > wx - margin or by < margin or by > wy - margin: continue bw = rng.uniform(80, 300) bh = rng.uniform(80, 300) bzt = rng.choice([rng.uniform(30, 60), rng.uniform(100, 195)]) _add_box(bx, by, bw, bh, bzt) if rng.random() < (0.45 if ultra_hard else 0.3): arm_dir = rng.choice(["east", "north"]) if arm_dir == "east": _add_box(bx + bw / 2 + 40, by, 80, bh * 0.6, bzt * 0.9) else: _add_box(bx, by + bh / 2 + 40, bw * 0.6, 80, bzt * 0.9) # ---- DENSE FOREST: many cylindrical trees ---- if template == "dense_forest" or template == "mixed": forest_cx = rng.uniform(800, wx - 800) forest_cy = rng.uniform(800, wy - 800) n_trees = rng.randint(25, 60) if ultra_hard else rng.randint(15, 40) for _ in range(n_trees): tx = forest_cx + rng.gauss(0, 600) ty = forest_cy + rng.gauss(0, 600) tx = max(margin, min(wx - margin, tx)) ty = max(margin, min(wy - margin, ty)) tree_type = rng.choice(["pine", "oak", "palm", "dead"]) if tree_type == "pine": _add_cyl(tx, ty, rng.uniform(8, 20), rng.uniform(40, 100), "tree_pine") elif tree_type == "oak": _add_cyl(tx, ty, rng.uniform(15, 40), rng.uniform(25, 60), "tree_oak") elif tree_type == "palm": _add_cyl(tx, ty, rng.uniform(5, 12), rng.uniform(30, 80), "tree_palm") else: _add_cyl(tx, ty, rng.uniform(10, 25), rng.uniform(20, 50), "tree_dead") # ---- CORRIDOR MAZE: parallel walls with gaps ---- if template == "corridor_maze" or template == "mixed": maze_ox = rng.uniform(400, wx / 2) maze_oy = rng.uniform(400, wy / 2) n_walls = rng.randint(6, 12) if ultra_hard else rng.randint(4, 8) wall_dir = rng.choice(["horizontal", "vertical"]) spacing = rng.uniform(200, 500) for w in range(n_walls): wl = rng.uniform(400, 1500) wt = rng.uniform(40, 80) wzt = rng.uniform(100, 195) if wall_dir == "horizontal": wy_pos = maze_oy + w * spacing if wy_pos > wy - margin: continue _add_box(maze_ox + wl / 2, wy_pos, wl, wt, wzt, "wall") gap_x = maze_ox + rng.uniform(0.2, 0.8) * wl _add_box(gap_x, wy_pos, rng.uniform(80, 200), wt, 0, "gap") else: wx_pos = maze_ox + w * spacing if wx_pos > wx - margin: continue _add_box(wx_pos, maze_oy + wl / 2, wt, wl, wzt, "wall") # ---- RIVER VALLEY: chain of low flat boxes + scattered trees ---- if template == "river_valley" or (template == "mixed" and rng.random() < (0.7 if ultra_hard else 0.5)): river_start_x = rng.uniform(margin, wx / 3) river_y = rng.uniform(wy * 0.3, wy * 0.7) n_segs = rng.randint(10, 18) if ultra_hard else rng.randint(6, 12) for seg in range(n_segs): seg_x = river_start_x + seg * rng.uniform(200, 400) seg_y = river_y + rng.gauss(0, 150) if seg_x > wx - margin: break seg_y = max(margin, min(wy - margin, seg_y)) _add_box(seg_x, seg_y, rng.uniform(200, 400), rng.uniform(60, 150), rng.uniform(3, 10), "river") for _ in range(rng.randint(2, 6) if ultra_hard else rng.randint(1, 4)): bank_offset = rng.choice([-1, 1]) * rng.uniform(100, 300) _add_cyl(seg_x + rng.uniform(-100, 100), seg_y + bank_offset, rng.uniform(8, 20), rng.uniform(30, 80), "tree_bank") # ---- FORTRESS: walls surrounding a target area ---- if template == "fortress" or (template == "mixed" and rng.random() < (0.6 if ultra_hard else 0.4)): if targets: fort_target = rng.choice(targets) ftx, fty = fort_target.position.x, fort_target.position.y wall_half = rng.uniform(250, 500) wall_zt = rng.uniform(120, 190) wall_thick = rng.uniform(50, 80) _add_box(ftx, fty - wall_half, wall_half * 2, wall_thick, wall_zt, "fortress_wall") _add_box(ftx, fty + wall_half, wall_half * 2, wall_thick, wall_zt, "fortress_wall") _add_box(ftx - wall_half, fty, wall_thick, wall_half * 2, wall_zt, "fortress_wall") _add_box(ftx + wall_half, fty, wall_thick, wall_half * 2, wall_zt, "fortress_wall") # ---- Always scatter some light poles and random pillars ---- n_poles = rng.randint(6, 18) if ultra_hard else rng.randint(3, 10) for _ in range(n_poles): px = rng.uniform(margin, wx - margin) py = rng.uniform(margin, wy - margin) _add_cyl(px, py, rng.uniform(2, 6), rng.uniform(30, 80), "light_pole") n_pillars = rng.randint(4, 12) if ultra_hard else rng.randint(2, 6) for _ in range(n_pillars): px = rng.uniform(margin, wx - margin) py = rng.uniform(margin, wy - margin) _add_cyl(px, py, rng.uniform(15, 50), rng.uniform(80, 195), "pillar") obstacles = [o for o in obstacles if o.max_corner.z > 1.0] env.obstacles = obstacles env.cylinders = cylinders # --- Responder units (1 per target, up to 5 in ultra) --- responders = [] max_resp = 5 if ultra_hard else 4 for i, tgt in enumerate(targets[:max_resp]): r = ResponderUnit( id=f"R{i+1}", position=Vec3(tgt.position.x + rng.uniform(-50, 50), tgt.position.y + rng.uniform(-50, 50), 0.0), linked_target_id=tgt.id, status="stable", current_need=rng.choice(["supplies", "medical", "evacuation", "water"]), can_update_dropzone=rng.random() < 0.5, active=True, ) events = [] if rng.random() < 0.7: events.append(ScheduledEvent( step=rng.randint(100, 600), event_type="urgency_update", payload={"new_urgency": rng.uniform(0.5, 1.0)}, )) if r.can_update_dropzone and rng.random() < 0.5: events.append(ScheduledEvent( step=rng.randint(200, 800), event_type="dropzone_relocation", payload={"dx": rng.uniform(-200, 200), "dy": rng.uniform(-200, 200)}, )) if rng.random() < 0.6: intel = rng.choice([ "blocked_north", "blocked_south", "blocked_east", "blocked_west", "safe_north", "safe_south", "safe_east", "safe_west", "fire_expanded", "fire_receded", ]) events.append(ScheduledEvent( step=rng.randint(50, 500), event_type="hazard_intel", payload={"intel": intel, "severity": rng.uniform(0.3, 1.0)}, )) r.scheduled_events = events responders.append(r) env.responders = responders def build_hardcore_world_v2(env: "VarahaEnv") -> None: """Ultra-hard variant: denser obstacles, more hazards, more targets.""" build_hardcore_world(env, ultra_hard=True) # --------------------------------------------------------------------------- # Environment # --------------------------------------------------------------------------- class VarahaEnv: """Core wildfire logistics simulation. Action format (dict):: { "ax": float, # desired acceleration x (m/s²) "ay": float, # desired acceleration y "az": float, # desired acceleration z "deliver": bool, # attempt delivery if near a target "recharge": bool, # attempt recharge if near base "tool_call": str, # optional: request_intel | battery_forecast | mission_report } Returns ``(obs_dict, reward, done, info_dict)`` per OpenAI-gym convention. """ def __init__(self, config: Optional[VarahaConfig] = None, world_fn: Optional[Any] = None) -> None: self.cfg = config or VarahaConfig() self._world_fn = world_fn self.base: BaseStation self.drone: DroneState self.targets: list[DeliveryTarget] = [] self.hazards: list[HazardRegion] = [] self.obstacles: list[ObstacleVolume] = [] self.cylinders: list[CylindricalObstacle] = [] self.responders: list[ResponderUnit] = [] self.step_count: int = 0 self.cumulative_reward: float = 0.0 self.done: bool = False self.trace: list[TracePoint] = [] self._prev_nearest_dist: float = 0.0 self._hazard_base_heights: list[float] = [] self._hazard_base_severities: list[float] = [] self.instructions: list[MissionInstruction] = [] self._instruction_cursor: int = 0 self._instruction_violations: int = 0 self._tool_history: list[str] = [] self._last_tool_result: dict[str, Any] = {} self._instruction_progress_reward: float = 0.0 self._rebuild_world() def _rebuild_world(self): if self._world_fn is not None: self._world_fn(self) else: self._build_demo_world() self._hazard_base_heights = [h.height for h in self.hazards] self._hazard_base_severities = [h.severity for h in self.hazards] # ------------------------------------------------------------------ # World setup # ------------------------------------------------------------------ def _build_demo_world(self) -> None: """Hardcoded 5 km demo scenario. Layout (top-down, +x → east, +y → north, 5 km × 5 km):: T3 (1000,4200) · H2 (900,3200) O2 [500-1500, 2600-3000] · · T2 (4100,2900) ← inside H1 fringe · H1 (3800,2600) · · O1 [2200-2800, 1000-2200] · · T1 (1800,600) · Base (250,250) - T2 sits inside the fringe of hazard H1 → brief hazard exposure required - T3 is behind obstacle O2 and near hazard H2 - O1 blocks direct mid-map routing from T1 to T2 - Drone can fly over obstacles if altitude > obstacle height - Total route ≈ 12 km, battery budget ≈ 300 units """ self.base = BaseStation(position=Vec3(250.0, 250.0, 0.0), recharge_radius=80.0) self.targets = [ DeliveryTarget( id="T1", position=Vec3(1800.0, 600.0, 30.0), urgency=0.6, delivery_radius=80.0, ), DeliveryTarget( id="T2", position=Vec3(4100.0, 2900.0, 50.0), urgency=1.0, delivery_radius=120.0, ), DeliveryTarget( id="T3", position=Vec3(1000.0, 4200.0, 20.0), urgency=0.8, delivery_radius=100.0, ), ] self.hazards = [ HazardRegion( id="H1", center=Vec3(3800.0, 2600.0, 0.0), radius=500.0, severity=0.9, height=70.0, growth_rate=0.005, ), HazardRegion( id="H2", center=Vec3(900.0, 3200.0, 0.0), radius=400.0, severity=0.7, height=55.0, growth_rate=0.008, ), ] self.obstacles = [ ObstacleVolume( id="O1", min_corner=Vec3(2200.0, 1000.0, 0.0), max_corner=Vec3(2800.0, 2200.0, 120.0), ), ObstacleVolume( id="O2", min_corner=Vec3(500.0, 2600.0, 0.0), max_corner=Vec3(1500.0, 3000.0, 90.0), ), ] # ------------------------------------------------------------------ # Core API # ------------------------------------------------------------------ def reset(self, seed: Optional[int] = None) -> dict[str, Any]: """Reset the environment and return the initial observation.""" if seed is not None: random.seed(seed) if self._world_fn is not None: self._rebuild_world() self.drone = DroneState( position=Vec3(self.base.position.x, self.base.position.y, 0.0), velocity=Vec3(0.0, 0.0, 0.0), battery=self.cfg.battery_capacity, carrying_payload=True, alive=True, ) for t in self.targets: t.delivered = False for i, h in enumerate(self.hazards): h.height = self._hazard_base_heights[i] * random.uniform(0.85, 1.15) h.severity = max(0.3, min(1.0, self._hazard_base_severities[i] + random.uniform(-0.1, 0.1))) h.reset() for r in self.responders: r.active = True r.status = "stable" r.latest_intel = "none" r.intel_severity = 0.0 r.message = "" for ev in r.scheduled_events: ev.fired = False self._target_base_positions = { t.id: Vec3(t.position.x, t.position.y, t.position.z) for t in self.targets } self._build_instruction_program() self._instruction_progress_reward = 0.0 self._last_tool_result = {} self._tool_history = [] self.step_count = 0 self.cumulative_reward = 0.0 self.done = False self.trace = [] self._prev_nearest_dist = self._nearest_target_dist() obs = self.get_observation() self.trace.append(TracePoint( step=0, position=Vec3(self.drone.position.x, self.drone.position.y, self.drone.position.z), velocity=Vec3(0.0, 0.0, 0.0), battery=self.drone.battery, reward=0.0, cumulative_reward=0.0, events=["reset"], observation=obs, )) return obs def step(self, action: dict[str, Any]) -> tuple[dict, float, bool, dict]: """Advance the simulation by one timestep. Returns ``(observation, reward, done, info)``. """ if self.done: return self.get_observation(), 0.0, True, StepInfo().to_dict() self.step_count += 1 # --- parse & clamp acceleration --- accel = Vec3( float(action.get("ax", 0.0)), float(action.get("ay", 0.0)), float(action.get("az", 0.0)), ).clamp_magnitude(self.cfg.max_acceleration) # --- kinematics (Euler integration) --- self.drone.velocity = ( self.drone.velocity + accel.scale(self.cfg.dt) ).clamp_magnitude(self.cfg.max_speed) old_pos = Vec3(self.drone.position.x, self.drone.position.y, self.drone.position.z) self.drone.position = self.drone.position + self.drone.velocity.scale(self.cfg.dt) # clamp to world bounds self.drone.position.x = max(0.0, min(self.cfg.world_x, self.drone.position.x)) self.drone.position.y = max(0.0, min(self.cfg.world_y, self.drone.position.y)) self.drone.position.z = max(0.0, min(self.cfg.world_z, self.drone.position.z)) dist_traveled = old_pos.distance_to(self.drone.position) elevation_change = abs(self.drone.position.z - old_pos.z) # --- battery --- drain = self._compute_battery_drain(dist_traveled, elevation_change) self.drone.battery -= drain # --- advance dynamic hazards --- for h in self.hazards: h.tick() # --- advance responder events --- self._tick_responders() # --- world interactions --- collision = self._check_collisions() in_hazard, hazard_sev = self._check_hazards() tool_call = "" tool_result: dict[str, Any] = {} raw_tool_call = action.get("tool_call") if raw_tool_call is not None and str(raw_tool_call).strip(): tool_call, tool_result = self._execute_tool_call(str(raw_tool_call).strip()) prev_instruction_cursor = self._instruction_cursor delivered_ids: list[str] = [] if action.get("deliver", False): delivered_ids = self._deliver_targets() reached_base = ( ((self.drone.position.x - self.base.position.x) ** 2 + (self.drone.position.y - self.base.position.y) ** 2) ** 0.5 <= self.base.recharge_radius ) if action.get("recharge", False) and reached_base: self.drone.battery = min( self.cfg.battery_capacity, self.drone.battery + self.cfg.recharge_rate, ) self._update_instruction_progress( delivered_ids=delivered_ids, reached_base=reached_base, tool_call=tool_call, ) completed_now = max(0, self._instruction_cursor - prev_instruction_cursor) if self._all_delivered(): self.drone.carrying_payload = False # --- reward --- info = StepInfo( collision=collision, delivered_target_ids=delivered_ids, in_hazard=in_hazard, hazard_severity=hazard_sev, reached_base=reached_base, distance_traveled=dist_traveled, tool_call=tool_call, tool_result=tool_result, instruction_completed=self._instruction_cursor, instruction_total=len(self.instructions), instruction_violations=self._instruction_violations, ) reward, breakdown = self._compute_reward(info) info.reward_breakdown = breakdown self.cumulative_reward += reward # --- termination --- if collision: self.drone.alive = False self.done = True elif self.drone.battery <= 0.0: self.drone.battery = 0.0 self.drone.alive = False self.done = True elif self._is_success(): self.done = True elif self.step_count >= self.cfg.max_episode_steps: self.done = True # record trace events: list[str] = [] for tid in delivered_ids: events.append(f"delivered_{tid}") if collision: events.append("collision") if in_hazard: events.append(f"hazard_{hazard_sev:.2f}") if self.drone.battery <= 0.0 and not collision: events.append("battery_dead") if self._is_success(): events.append("success") if tool_call: events.append(f"tool_{tool_call}") if completed_now > 0: events.append(f"instruction+{completed_now}") obs = self.get_observation() self.trace.append(TracePoint( step=self.step_count, position=Vec3(self.drone.position.x, self.drone.position.y, self.drone.position.z), velocity=Vec3(self.drone.velocity.x, self.drone.velocity.y, self.drone.velocity.z), battery=self.drone.battery, reward=reward, cumulative_reward=self.cumulative_reward, events=events, observation=obs, )) return obs, reward, self.done, info.to_dict() # ------------------------------------------------------------------ # Observation / render # ------------------------------------------------------------------ def get_observation(self) -> dict[str, Any]: """Compact, RL-friendly observation dict.""" dp = self.drone.position targets_obs = [] for t in self.targets: rel = t.position - dp targets_obs.append({ "id": t.id, "relative_position": rel.to_dict(), "urgency": t.urgency, "delivered": t.delivered, }) hazards_obs = [] for h in self.hazards: rel = h.center - dp hazards_obs.append({ "id": h.id, "relative_position": rel.to_dict(), "current_height": h._current_height, "severity": h.severity, }) obstacles_obs = [] for obs in self.obstacles: c = obs.center hs = obs.half_size rel = c - dp dist = dp.horizontal_distance_to(c) obstacles_obs.append({ "type": "box", "relative_position": rel.to_dict(), "height": obs.height, "size_x": hs.x * 2, "size_y": hs.y * 2, "distance": dist, "kind": obs.kind, }) for cyl in self.cylinders: rel = cyl.center - dp dist = dp.horizontal_distance_to(cyl.center) obstacles_obs.append({ "type": "cylinder", "relative_position": rel.to_dict(), "height": cyl.height, "size_x": cyl.radius * 2, "size_y": cyl.radius * 2, "distance": dist, "kind": cyl.kind, }) obstacles_obs.sort(key=lambda o: o["distance"]) responders_obs = [] for r in self.responders: if not r.active: continue rel = r.position - dp intel_dir = r.intel_direction() responders_obs.append({ "id": r.id, "relative_position": rel.to_dict(), "linked_target_id": r.linked_target_id, "status": r.status, "status_code": r.status_code(), "latest_intel": r.latest_intel, "intel_direction": {"x": intel_dir[0], "y": intel_dir[1]}, "intel_severity": r.intel_severity, }) mission_obs = self._instruction_snapshot() return { "drone_position": dp.to_dict(), "drone_velocity": self.drone.velocity.to_dict(), "battery": round(self.drone.battery, 4), "carrying_payload": self.drone.carrying_payload, "alive": self.drone.alive, "targets": targets_obs, "hazards": hazards_obs, "obstacles": obstacles_obs, "responders": responders_obs, "mission": mission_obs, "last_tool_result": self._last_tool_result, "step": self.step_count, "max_steps": self.cfg.max_episode_steps, } def render_state(self) -> dict[str, Any]: """Rich state dict for future Cesium / frontend rendering.""" return { "base_station": self.base.to_dict(), "drone": self.drone.to_dict(), "targets": [t.to_dict() for t in self.targets], "hazards": [h.to_dict() for h in self.hazards], "obstacles": [o.to_dict() for o in self.obstacles], "cylinders": [c.to_dict() for c in self.cylinders], "responders": [r.to_dict() for r in self.responders], "mission": self._instruction_snapshot(include_full=True), "tool_history": list(self._tool_history), "step": self.step_count, "max_steps": self.cfg.max_episode_steps, "cumulative_reward": round(self.cumulative_reward, 4), "done": self.done, } def get_trace(self) -> dict[str, Any]: """Full episode trace for replay / visualisation.""" return { "world": { "bounds": {"x": self.cfg.world_x, "y": self.cfg.world_y, "z": self.cfg.world_z}, "base_station": self.base.to_dict(), "targets": [t.to_dict() for t in self.targets], "hazards": [h.to_dict() for h in self.hazards], "obstacles": [o.to_dict() for o in self.obstacles], "cylinders": [c.to_dict() for c in self.cylinders], "responders": [r.to_dict() for r in self.responders], "mission": self._instruction_snapshot(include_full=True), }, "trace": [tp.to_dict() for tp in self.trace], "summary": { "total_steps": self.step_count, "cumulative_reward": round(self.cumulative_reward, 4), "delivered": [t.id for t in self.targets if t.delivered], "alive": self.drone.alive, "final_battery": round(self.drone.battery, 4), "success": self._is_success(), "instruction_completed": self._instruction_cursor, "instruction_total": len(self.instructions), "instruction_violations": self._instruction_violations, "tool_calls": list(self._tool_history), }, } # ------------------------------------------------------------------ # Long-horizon instruction mode # ------------------------------------------------------------------ def _build_instruction_program(self) -> None: self.instructions = [] self._instruction_cursor = 0 self._instruction_violations = 0 if not self.cfg.instruction_mode or not self.targets: return ordered_targets = sorted(self.targets, key=lambda t: (-t.urgency, t.id)) target_count = len(ordered_targets) desired_len = self.cfg.instruction_count if self.cfg.instruction_count > 0 else (target_count * 3 + 1) desired_len = max(desired_len, target_count * 2 + 1) instructions: list[MissionInstruction] = [] inst_idx = 1 cycle = 0 while len(instructions) < max(desired_len - 1, 1): for tgt in ordered_targets: if len(instructions) >= max(desired_len - 1, 1): break instructions.append( MissionInstruction( id=f"I{inst_idx}", kind="deliver_target", description=f"Cycle {cycle + 1}: deliver to {tgt.id} in order.", target_id=tgt.id, ) ) inst_idx += 1 if len(instructions) >= max(desired_len - 1, 1): break tool = "request_intel" if (cycle % 2 == 0) else "battery_forecast" instructions.append( MissionInstruction( id=f"I{inst_idx}", kind="tool_call", description=f"Call {tool} after servicing {tgt.id}.", target_id=tgt.id, tool_name=tool, ) ) inst_idx += 1 cycle += 1 instructions.append( MissionInstruction( id=f"I{inst_idx}", kind="return_base", description="Return to base only after all deliveries are completed.", ) ) self.instructions = instructions def _current_instruction(self) -> Optional[MissionInstruction]: if self._instruction_cursor >= len(self.instructions): return None return self.instructions[self._instruction_cursor] def _instruction_snapshot(self, include_full: bool = False) -> dict[str, Any]: total = len(self.instructions) completed = min(self._instruction_cursor, total) next_instruction = self._current_instruction() out: dict[str, Any] = { "enabled": self.cfg.instruction_mode, "total": total, "completed": completed, "remaining": max(total - completed, 0), "progress": (completed / total) if total > 0 else 1.0, "violations": self._instruction_violations, "next_instruction": next_instruction.to_dict() if next_instruction else None, } if include_full: out["instructions"] = [inst.to_dict() for inst in self.instructions] return out def _complete_current_instruction(self) -> None: inst = self._current_instruction() if inst is None: return inst.completed = True self._instruction_cursor += 1 self._instruction_progress_reward += self.cfg.instruction_completion_reward def _record_instruction_violation(self) -> None: self._instruction_violations += 1 inst = self._current_instruction() if inst is not None: inst.violated = True def _tool_matches_instruction(self, tool_call: str, inst: MissionInstruction) -> bool: base, _, arg = tool_call.partition(":") if base != inst.tool_name: return False if inst.target_id and arg and arg != inst.target_id: return False return True def _update_instruction_progress( self, delivered_ids: list[str], reached_base: bool, tool_call: str, ) -> None: if not self.cfg.instruction_mode or not self.instructions: return inst = self._current_instruction() if inst and inst.kind == "deliver_target": for tid in delivered_ids: if tid != inst.target_id: self._record_instruction_violation() while True: inst = self._current_instruction() if inst is None: break if inst.kind == "deliver_target": if inst.target_id in delivered_ids: self._complete_current_instruction() continue break if inst.kind == "tool_call": if not tool_call: break if self._tool_matches_instruction(tool_call, inst): self._complete_current_instruction() else: self._record_instruction_violation() break if inst.kind == "return_base": if reached_base and self._all_delivered(): self._complete_current_instruction() break break def _execute_tool_call(self, tool_call: str) -> tuple[str, dict[str, Any]]: raw = tool_call.strip().lower() if not raw: return "", {} tool_name, _, arg = raw.partition(":") normalized_call = f"{tool_name}:{arg}" if arg else tool_name if tool_name not in self.cfg.available_tools: result = {"ok": False, "error": f"unsupported_tool:{tool_name}"} self._tool_history.append(normalized_call) self._last_tool_result = result return normalized_call, result if tool_name == "request_intel": responder = None if arg: responder = next( (r for r in self.responders if r.active and r.linked_target_id.lower() == arg.lower()), None, ) if responder is None: responder = next((r for r in self.responders if r.active), None) if responder is None: result = {"ok": True, "intel": "none", "message": "no_active_responders"} else: result = { "ok": True, "intel": responder.latest_intel, "intel_severity": round(responder.intel_severity, 3), "responder_id": responder.id, "target_id": responder.linked_target_id, "message": responder.message, } elif tool_name == "battery_forecast": burn = max(self.cfg.drain_per_meter, 1e-6) est_range = self.drone.battery / burn result = { "ok": True, "battery": round(self.drone.battery, 3), "estimated_range_m": round(est_range, 1), } else: # mission_report result = { "ok": True, "delivered": [t.id for t in self.targets if t.delivered], "remaining": [t.id for t in self.targets if not t.delivered], "instruction_progress": round(self._instruction_snapshot()["progress"], 3), "violations": self._instruction_violations, } self._tool_history.append(normalized_call) self._last_tool_result = result return normalized_call, result # ------------------------------------------------------------------ # Coordinate conversion # ------------------------------------------------------------------ def local_to_latlon(self, vec: Vec3) -> tuple[float, float, float]: """Convert local (x, y, z) metres to (lat, lon, alt). Uses a flat-earth approximation centred on ``cfg.origin_lat/lon``. Accurate enough for small areas (~tens of km) and Cesium plotting. """ meters_per_deg_lat = 111_320.0 meters_per_deg_lon = 111_320.0 * math.cos(math.radians(self.cfg.origin_lat)) lat = self.cfg.origin_lat + vec.y / meters_per_deg_lat lon = self.cfg.origin_lon + vec.x / meters_per_deg_lon alt = vec.z return (round(lat, 7), round(lon, 7), round(alt, 2)) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _compute_battery_drain(self, dist: float, elevation_change: float) -> float: return ( dist * self.cfg.drain_per_meter + elevation_change * self.cfg.drain_elevation_factor + self.cfg.drain_idle_per_step ) def _check_collisions(self) -> bool: for obs in self.obstacles: if obs.contains(self.drone.position): return True for cyl in self.cylinders: if cyl.contains(self.drone.position): return True return False def _check_hazards(self) -> tuple[bool, float]: max_sev = 0.0 in_hazard = False for h in self.hazards: df = h.danger_factor(self.drone.position) if df > 0.0: in_hazard = True max_sev = max(max_sev, df) return in_hazard, max_sev def _deliver_targets(self) -> list[str]: """Cylindrical delivery check — drone must be within horizontal radius and above the target (within a generous altitude window for drops).""" delivered: list[str] = [] for t in self.targets: if t.delivered: continue dx = self.drone.position.x - t.position.x dy = self.drone.position.y - t.position.y horiz_dist = (dx * dx + dy * dy) ** 0.5 alt_above = self.drone.position.z - t.position.z if horiz_dist <= t.delivery_radius and -10.0 <= alt_above <= t.delivery_radius * 2: t.delivered = True delivered.append(t.id) return delivered def _all_delivered(self) -> bool: return all(t.delivered for t in self.targets) def _is_success(self) -> bool: hdist = ((self.drone.position.x - self.base.position.x) ** 2 + (self.drone.position.y - self.base.position.y) ** 2) ** 0.5 return self._all_delivered() and hdist <= self.base.recharge_radius def _nearest_target_dist(self) -> float: """Horizontal distance to closest undelivered target, or to base if all done.""" dists = [ ((self.drone.position.x - t.position.x) ** 2 + (self.drone.position.y - t.position.y) ** 2) ** 0.5 for t in self.targets if not t.delivered ] if not dists: return ((self.drone.position.x - self.base.position.x) ** 2 + (self.drone.position.y - self.base.position.y) ** 2) ** 0.5 return min(dists) def _tick_responders(self) -> None: """Process scheduled responder events for the current step.""" for r in self.responders: if not r.active: continue for ev in r.scheduled_events: if ev.fired or ev.step != self.step_count: continue ev.fired = True etype = ev.event_type if etype == "urgency_update": tgt = self._find_target(r.linked_target_id) if tgt and not tgt.delivered: tgt.urgency = max(0.1, min(1.0, ev.payload.get("new_urgency", tgt.urgency))) r.status = "critical" if tgt.urgency >= 0.9 else "urgent" if tgt.urgency >= 0.6 else "stable" r.message = f"urgency->{tgt.urgency:.1f}" elif etype == "dropzone_relocation": tgt = self._find_target(r.linked_target_id) if tgt and not tgt.delivered and r.can_update_dropzone: dx = ev.payload.get("dx", 0.0) dy = ev.payload.get("dy", 0.0) tgt.position.x = max(50, min(self.cfg.world_x - 50, tgt.position.x + dx)) tgt.position.y = max(50, min(self.cfg.world_y - 50, tgt.position.y + dy)) r.position = Vec3(tgt.position.x, tgt.position.y, 0.0) r.message = f"dropzone moved ({dx:+.0f},{dy:+.0f})" self._prev_nearest_dist = self._nearest_target_dist() elif etype == "hazard_intel": r.latest_intel = ev.payload.get("intel", "none") r.intel_severity = ev.payload.get("severity", 0.5) r.message = f"intel: {r.latest_intel}" def _find_target(self, tid: str) -> Optional[DeliveryTarget]: for t in self.targets: if t.id == tid: return t return None def _obstacle_proximity_penalty(self) -> float: """Graduated penalty for flying close to any obstacle surface.""" min_dist = float("inf") pos = self.drone.position for obs in self.obstacles: d = obs.nearest_surface_dist(pos) if d < min_dist: min_dist = d for cyl in self.cylinders: d = cyl.nearest_surface_dist(pos) if d < min_dist: min_dist = d if min_dist >= self.cfg.obstacle_proximity_radius: return 0.0 factor = 1.0 - min_dist / self.cfg.obstacle_proximity_radius return self.cfg.obstacle_proximity_penalty * factor * factor def _compute_reward(self, info: StepInfo) -> tuple[float, dict[str, float]]: if self.cfg.instruction_mode and self.cfg.sparse_reward_mode: return self._compute_sparse_instruction_reward(info) bd: dict[str, float] = {} total = 0.0 # per-step cost of time bd["step_penalty"] = -self.cfg.step_penalty total += bd["step_penalty"] # battery usage cost (proportional to energy spent) bd["battery_cost"] = -( info.distance_traveled * self.cfg.drain_per_meter * self.cfg.battery_cost_factor ) total += bd["battery_cost"] if self._instruction_progress_reward > 0.0: bd["instruction_progress"] = self._instruction_progress_reward total += bd["instruction_progress"] self._instruction_progress_reward = 0.0 # delivery rewards (scaled by urgency) + progress bonus for tid in info.delivered_target_ids: tgt = next(t for t in self.targets if t.id == tid) r = self.cfg.delivery_reward * (1.0 + tgt.urgency) bd[f"delivery_{tid}"] = r total += r if info.delivered_target_ids: n_remaining = sum(1 for t in self.targets if not t.delivered) progress_bonus = 50.0 * (1.0 - n_remaining / len(self.targets)) bd["progress_bonus"] = progress_bonus total += progress_bonus # collision if info.collision: bd["collision"] = -self.cfg.collision_penalty total += bd["collision"] # hazard exposure (severity-weighted) if info.in_hazard: bd["hazard"] = -self.cfg.hazard_penalty * info.hazard_severity total += bd["hazard"] # safe return bonus if info.reached_base and self._all_delivered(): bd["return_bonus"] = self.cfg.return_bonus total += bd["return_bonus"] # distance shaping — nudge toward nearest undelivered target (or base) # Skip shaping on delivery steps to avoid a huge negative spike # when the nearest-target reference jumps to a farther target. # Double the factor when heading home after all deliveries. curr_dist = self._nearest_target_dist() if info.delivered_target_ids: bd["distance_shaping"] = 0.0 self._prev_nearest_dist = curr_dist else: factor = self.cfg.distance_shaping_factor if self._all_delivered(): factor *= 2.0 shaping = (self._prev_nearest_dist - curr_dist) * factor bd["distance_shaping"] = shaping total += shaping self._prev_nearest_dist = curr_dist # obstacle proximity (graduated — discourages flying close) prox = self._obstacle_proximity_penalty() if prox > 0: bd["obstacle_proximity"] = -prox total -= prox # failure (battery depletion; collision already penalised above) if self.drone.battery <= 0.0 and not info.collision: bd["failure"] = -self.cfg.failure_penalty total += bd["failure"] bd["total"] = total return total, bd def _compute_sparse_instruction_reward(self, info: StepInfo) -> tuple[float, dict[str, float]]: bd: dict[str, float] = {} total = 0.0 # Keep shaping intentionally small in sparse mode. bd["step_penalty"] = -(self.cfg.step_penalty * 0.25) total += bd["step_penalty"] if self._instruction_progress_reward > 0.0: bd["instruction_progress"] = self._instruction_progress_reward total += bd["instruction_progress"] self._instruction_progress_reward = 0.0 if info.in_hazard: bd["hazard"] = -(self.cfg.hazard_penalty * 0.2 * info.hazard_severity) total += bd["hazard"] terminal = ( info.collision or self.drone.battery <= 0.0 or self._is_success() or self.step_count >= self.cfg.max_episode_steps ) if terminal: total_instr = len(self.instructions) progress = (self._instruction_cursor / total_instr) if total_instr > 0 else 1.0 bd["terminal_progress"] = self.cfg.instruction_terminal_progress_bonus * progress total += bd["terminal_progress"] if self._is_success(): bd["terminal_success"] = self.cfg.instruction_terminal_success_bonus total += bd["terminal_success"] else: bd["terminal_failure"] = -self.cfg.failure_penalty total += bd["terminal_failure"] remaining = max(total_instr - self._instruction_cursor, 0) if remaining > 0: bd["unfinished_penalty"] = -remaining * self.cfg.instruction_unfinished_penalty total += bd["unfinished_penalty"] if self._instruction_violations > 0: bd["instruction_violations"] = ( -self._instruction_violations * self.cfg.instruction_violation_penalty ) total += bd["instruction_violations"] if info.collision: bd["collision"] = -self.cfg.collision_penalty total += bd["collision"] bd["total"] = total return total, bd