Spaces:
Runtime error
Runtime error
| """Varaha — wildfire logistics simulation environment. | |
| A drone must deliver supplies to responder zones near wildfire hazards in | |
| California-like terrain. The environment uses lightweight 3D kinematics with | |
| local metre-based coordinates and an optional lat/lon conversion helper for | |
| later Cesium visualisation. | |
| """ | |
| import math | |
| import random | |
| from dataclasses import dataclass | |
| from typing import Any, Optional | |
| from sim_types import ( | |
| Vec3, | |
| DroneState, | |
| BaseStation, | |
| DeliveryTarget, | |
| HazardRegion, | |
| ObstacleVolume, | |
| CylindricalObstacle, | |
| ResponderUnit, | |
| ScheduledEvent, | |
| RESPONDER_STATUSES, | |
| INTEL_TYPES, | |
| StepInfo, | |
| TracePoint, | |
| MissionInstruction, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| class VarahaConfig: | |
| """All tunable environment parameters live here.""" | |
| # World bounds (metres) — 5 km × 5 km operational area | |
| world_x: float = 5000.0 | |
| world_y: float = 5000.0 | |
| world_z: float = 200.0 | |
| # Drone physics | |
| battery_capacity: float = 300.0 | |
| max_speed: float = 25.0 # m/s | |
| max_acceleration: float = 8.0 # m/s² | |
| dt: float = 0.5 # seconds per step | |
| # Episode | |
| max_episode_steps: int = 2000 | |
| # Battery drain coefficients (tuned for 5 km scale) | |
| drain_per_meter: float = 0.008 | |
| drain_elevation_factor: float = 0.02 | |
| drain_idle_per_step: float = 0.005 | |
| recharge_rate: float = 5.0 # battery units restored per recharge step | |
| # Reward knobs | |
| delivery_reward: float = 200.0 | |
| return_bonus: float = 100.0 | |
| step_penalty: float = 0.05 | |
| battery_cost_factor: float = 0.3 | |
| collision_penalty: float = 500.0 | |
| hazard_penalty: float = 5.0 | |
| failure_penalty: float = 200.0 | |
| distance_shaping_factor: float = 0.05 | |
| obstacle_proximity_penalty: float = 1.5 | |
| obstacle_proximity_radius: float = 80.0 | |
| # Long-horizon instruction mode (LLM-oriented) | |
| instruction_mode: bool = False | |
| instruction_count: int = 60 | |
| sparse_reward_mode: bool = False | |
| instruction_completion_reward: float = 0.5 | |
| instruction_terminal_success_bonus: float = 2200.0 | |
| instruction_terminal_progress_bonus: float = 800.0 | |
| instruction_violation_penalty: float = 120.0 | |
| instruction_unfinished_penalty: float = 10.0 | |
| available_tools: tuple[str, ...] = ( | |
| "request_intel", | |
| "battery_forecast", | |
| "mission_report", | |
| ) | |
| # California origin anchor (near Sacramento — wildfire-relevant) | |
| origin_lat: float = 38.55 | |
| origin_lon: float = -121.47 | |
| # --------------------------------------------------------------------------- | |
| # Random world generator for domain randomization | |
| # --------------------------------------------------------------------------- | |
| def build_random_world(env: "VarahaEnv") -> None: | |
| """Legacy easy world gen — kept for backward compatibility.""" | |
| build_hardcore_world(env) | |
| def _hdist(a: Vec3, b: Vec3) -> float: | |
| return ((a.x - b.x) ** 2 + (a.y - b.y) ** 2) ** 0.5 | |
| def build_hardcore_world(env: "VarahaEnv", ultra_hard: bool = False) -> None: | |
| """Generate an extremely challenging randomized world for serious RL training. | |
| Features template-based obstacle placement (urban grid, dense forest, | |
| corridor maze, river valley, fortress, mixed), cylindrical obstacles, | |
| responder units with dynamic events, and adversarial target placement. | |
| When ultra_hard=True: denser obstacles, more hazards, more targets, longer episodes. | |
| """ | |
| cfg = env.cfg | |
| rng = random | |
| wx, wy, wz = cfg.world_x, cfg.world_y, cfg.world_z | |
| margin = 200.0 | |
| def _rpos(z_lo=10.0, z_hi=60.0): | |
| return Vec3(rng.uniform(margin, wx - margin), | |
| rng.uniform(margin, wy - margin), | |
| rng.uniform(z_lo, z_hi)) | |
| def _rpos_ground(): | |
| return Vec3(rng.uniform(margin, wx - margin), | |
| rng.uniform(margin, wy - margin), 0.0) | |
| # --- Base station --- | |
| base_pos = Vec3(rng.uniform(100, wx - 100), rng.uniform(100, wy - 100), 0.0) | |
| env.base = BaseStation(position=base_pos, recharge_radius=rng.uniform(60, 100)) | |
| # --- Targets (2-5 normal, 3-6 ultra) --- | |
| if ultra_hard: | |
| n_targets = rng.choices([3, 4, 5, 6], weights=[0.15, 0.35, 0.35, 0.15])[0] | |
| else: | |
| n_targets = rng.choices([2, 3, 4, 5], weights=[0.15, 0.40, 0.30, 0.15])[0] | |
| targets = [] | |
| for i in range(n_targets): | |
| for _ in range(120): | |
| pos = _rpos(z_lo=5.0, z_hi=60.0) | |
| if _hdist(pos, base_pos) < 500: | |
| continue | |
| if all(_hdist(pos, t.position) > 400 for t in targets): | |
| break | |
| targets.append(DeliveryTarget( | |
| id=f"T{i+1}", position=pos, | |
| urgency=rng.uniform(0.3, 1.0), | |
| delivery_radius=rng.uniform(70.0, 130.0), | |
| )) | |
| env.targets = targets | |
| # --- Hazards (3-8 normal, 5-10 ultra) with wild variety --- | |
| if ultra_hard: | |
| n_hazards = rng.choices([5, 6, 7, 8, 9, 10], weights=[0.10, 0.20, 0.25, 0.25, 0.15, 0.05])[0] | |
| else: | |
| n_hazards = rng.choices([3, 4, 5, 6, 7, 8], weights=[0.10, 0.20, 0.25, 0.25, 0.15, 0.05])[0] | |
| hazards = [] | |
| for i in range(n_hazards): | |
| center = _rpos_ground() | |
| fire_type = rng.choice(["tiny_intense", "massive_low", "tall_mid", "standard"]) | |
| if fire_type == "tiny_intense": | |
| r, sev, ht, gr = rng.uniform(80, 200), rng.uniform(0.9, 1.0), rng.uniform(140, 195), rng.uniform(0.012, 0.025) | |
| elif fire_type == "massive_low": | |
| r, sev, ht, gr = rng.uniform(500, 1000), rng.uniform(0.3, 0.5), rng.uniform(25, 50), rng.uniform(0.001, 0.004) | |
| elif fire_type == "tall_mid": | |
| r, sev, ht, gr = rng.uniform(250, 500), rng.uniform(0.7, 0.95), rng.uniform(100, 180), rng.uniform(0.008, 0.015) | |
| else: | |
| r, sev, ht, gr = rng.uniform(200, 600), rng.uniform(0.4, 0.9), rng.uniform(40, 120), rng.uniform(0.003, 0.012) | |
| hazards.append(HazardRegion(id=f"H{i+1}", center=center, | |
| radius=r, severity=sev, height=ht, growth_rate=gr)) | |
| env.hazards = hazards | |
| # --- Obstacle templates --- | |
| obstacles: list[ObstacleVolume] = [] | |
| cylinders: list[CylindricalObstacle] = [] | |
| oid = [0] | |
| def _next_oid(prefix="O"): | |
| oid[0] += 1 | |
| return f"{prefix}{oid[0]}" | |
| def _add_box(cx, cy, w, h, zt, kind="building"): | |
| obstacles.append(ObstacleVolume( | |
| id=_next_oid(), kind=kind, | |
| min_corner=Vec3(cx - w / 2, cy - h / 2, 0.0), | |
| max_corner=Vec3(cx + w / 2, cy + h / 2, zt), | |
| )) | |
| def _add_cyl(cx, cy, radius, height, kind="tree"): | |
| cylinders.append(CylindricalObstacle( | |
| id=_next_oid("C"), kind=kind, | |
| center=Vec3(cx, cy, 0.0), radius=radius, height=height, | |
| )) | |
| if ultra_hard: | |
| template = rng.choices(["urban_grid", "dense_forest", "corridor_maze", | |
| "river_valley", "fortress", "mixed"], | |
| weights=[0.08, 0.12, 0.12, 0.10, 0.10, 0.48])[0] | |
| else: | |
| template = rng.choice(["urban_grid", "dense_forest", "corridor_maze", | |
| "river_valley", "fortress", "mixed"]) | |
| # ---- URBAN GRID: rows and columns of buildings ---- | |
| if template == "urban_grid" or template == "mixed": | |
| ox = rng.uniform(500, 1500) | |
| oy = rng.uniform(500, 1500) | |
| rows = rng.randint(2, 5) if ultra_hard else rng.randint(2, 4) | |
| cols = rng.randint(3, 6) if ultra_hard else rng.randint(3, 5) | |
| spacing = rng.uniform(300, 550) if ultra_hard else rng.uniform(350, 600) | |
| for r in range(rows): | |
| for c in range(cols): | |
| bx = ox + c * spacing + rng.uniform(-80, 80) | |
| by = oy + r * spacing + rng.uniform(-80, 80) | |
| if bx < margin or bx > wx - margin or by < margin or by > wy - margin: | |
| continue | |
| bw = rng.uniform(80, 300) | |
| bh = rng.uniform(80, 300) | |
| bzt = rng.choice([rng.uniform(30, 60), rng.uniform(100, 195)]) | |
| _add_box(bx, by, bw, bh, bzt) | |
| if rng.random() < (0.45 if ultra_hard else 0.3): | |
| arm_dir = rng.choice(["east", "north"]) | |
| if arm_dir == "east": | |
| _add_box(bx + bw / 2 + 40, by, 80, bh * 0.6, bzt * 0.9) | |
| else: | |
| _add_box(bx, by + bh / 2 + 40, bw * 0.6, 80, bzt * 0.9) | |
| # ---- DENSE FOREST: many cylindrical trees ---- | |
| if template == "dense_forest" or template == "mixed": | |
| forest_cx = rng.uniform(800, wx - 800) | |
| forest_cy = rng.uniform(800, wy - 800) | |
| n_trees = rng.randint(25, 60) if ultra_hard else rng.randint(15, 40) | |
| for _ in range(n_trees): | |
| tx = forest_cx + rng.gauss(0, 600) | |
| ty = forest_cy + rng.gauss(0, 600) | |
| tx = max(margin, min(wx - margin, tx)) | |
| ty = max(margin, min(wy - margin, ty)) | |
| tree_type = rng.choice(["pine", "oak", "palm", "dead"]) | |
| if tree_type == "pine": | |
| _add_cyl(tx, ty, rng.uniform(8, 20), rng.uniform(40, 100), "tree_pine") | |
| elif tree_type == "oak": | |
| _add_cyl(tx, ty, rng.uniform(15, 40), rng.uniform(25, 60), "tree_oak") | |
| elif tree_type == "palm": | |
| _add_cyl(tx, ty, rng.uniform(5, 12), rng.uniform(30, 80), "tree_palm") | |
| else: | |
| _add_cyl(tx, ty, rng.uniform(10, 25), rng.uniform(20, 50), "tree_dead") | |
| # ---- CORRIDOR MAZE: parallel walls with gaps ---- | |
| if template == "corridor_maze" or template == "mixed": | |
| maze_ox = rng.uniform(400, wx / 2) | |
| maze_oy = rng.uniform(400, wy / 2) | |
| n_walls = rng.randint(6, 12) if ultra_hard else rng.randint(4, 8) | |
| wall_dir = rng.choice(["horizontal", "vertical"]) | |
| spacing = rng.uniform(200, 500) | |
| for w in range(n_walls): | |
| wl = rng.uniform(400, 1500) | |
| wt = rng.uniform(40, 80) | |
| wzt = rng.uniform(100, 195) | |
| if wall_dir == "horizontal": | |
| wy_pos = maze_oy + w * spacing | |
| if wy_pos > wy - margin: | |
| continue | |
| _add_box(maze_ox + wl / 2, wy_pos, wl, wt, wzt, "wall") | |
| gap_x = maze_ox + rng.uniform(0.2, 0.8) * wl | |
| _add_box(gap_x, wy_pos, rng.uniform(80, 200), wt, 0, "gap") | |
| else: | |
| wx_pos = maze_ox + w * spacing | |
| if wx_pos > wx - margin: | |
| continue | |
| _add_box(wx_pos, maze_oy + wl / 2, wt, wl, wzt, "wall") | |
| # ---- RIVER VALLEY: chain of low flat boxes + scattered trees ---- | |
| if template == "river_valley" or (template == "mixed" and rng.random() < (0.7 if ultra_hard else 0.5)): | |
| river_start_x = rng.uniform(margin, wx / 3) | |
| river_y = rng.uniform(wy * 0.3, wy * 0.7) | |
| n_segs = rng.randint(10, 18) if ultra_hard else rng.randint(6, 12) | |
| for seg in range(n_segs): | |
| seg_x = river_start_x + seg * rng.uniform(200, 400) | |
| seg_y = river_y + rng.gauss(0, 150) | |
| if seg_x > wx - margin: | |
| break | |
| seg_y = max(margin, min(wy - margin, seg_y)) | |
| _add_box(seg_x, seg_y, rng.uniform(200, 400), rng.uniform(60, 150), | |
| rng.uniform(3, 10), "river") | |
| for _ in range(rng.randint(2, 6) if ultra_hard else rng.randint(1, 4)): | |
| bank_offset = rng.choice([-1, 1]) * rng.uniform(100, 300) | |
| _add_cyl(seg_x + rng.uniform(-100, 100), | |
| seg_y + bank_offset, | |
| rng.uniform(8, 20), rng.uniform(30, 80), "tree_bank") | |
| # ---- FORTRESS: walls surrounding a target area ---- | |
| if template == "fortress" or (template == "mixed" and rng.random() < (0.6 if ultra_hard else 0.4)): | |
| if targets: | |
| fort_target = rng.choice(targets) | |
| ftx, fty = fort_target.position.x, fort_target.position.y | |
| wall_half = rng.uniform(250, 500) | |
| wall_zt = rng.uniform(120, 190) | |
| wall_thick = rng.uniform(50, 80) | |
| _add_box(ftx, fty - wall_half, wall_half * 2, wall_thick, wall_zt, "fortress_wall") | |
| _add_box(ftx, fty + wall_half, wall_half * 2, wall_thick, wall_zt, "fortress_wall") | |
| _add_box(ftx - wall_half, fty, wall_thick, wall_half * 2, wall_zt, "fortress_wall") | |
| _add_box(ftx + wall_half, fty, wall_thick, wall_half * 2, wall_zt, "fortress_wall") | |
| # ---- Always scatter some light poles and random pillars ---- | |
| n_poles = rng.randint(6, 18) if ultra_hard else rng.randint(3, 10) | |
| for _ in range(n_poles): | |
| px = rng.uniform(margin, wx - margin) | |
| py = rng.uniform(margin, wy - margin) | |
| _add_cyl(px, py, rng.uniform(2, 6), rng.uniform(30, 80), "light_pole") | |
| n_pillars = rng.randint(4, 12) if ultra_hard else rng.randint(2, 6) | |
| for _ in range(n_pillars): | |
| px = rng.uniform(margin, wx - margin) | |
| py = rng.uniform(margin, wy - margin) | |
| _add_cyl(px, py, rng.uniform(15, 50), rng.uniform(80, 195), "pillar") | |
| obstacles = [o for o in obstacles if o.max_corner.z > 1.0] | |
| env.obstacles = obstacles | |
| env.cylinders = cylinders | |
| # --- Responder units (1 per target, up to 5 in ultra) --- | |
| responders = [] | |
| max_resp = 5 if ultra_hard else 4 | |
| for i, tgt in enumerate(targets[:max_resp]): | |
| r = ResponderUnit( | |
| id=f"R{i+1}", | |
| position=Vec3(tgt.position.x + rng.uniform(-50, 50), | |
| tgt.position.y + rng.uniform(-50, 50), 0.0), | |
| linked_target_id=tgt.id, | |
| status="stable", | |
| current_need=rng.choice(["supplies", "medical", "evacuation", "water"]), | |
| can_update_dropzone=rng.random() < 0.5, | |
| active=True, | |
| ) | |
| events = [] | |
| if rng.random() < 0.7: | |
| events.append(ScheduledEvent( | |
| step=rng.randint(100, 600), | |
| event_type="urgency_update", | |
| payload={"new_urgency": rng.uniform(0.5, 1.0)}, | |
| )) | |
| if r.can_update_dropzone and rng.random() < 0.5: | |
| events.append(ScheduledEvent( | |
| step=rng.randint(200, 800), | |
| event_type="dropzone_relocation", | |
| payload={"dx": rng.uniform(-200, 200), "dy": rng.uniform(-200, 200)}, | |
| )) | |
| if rng.random() < 0.6: | |
| intel = rng.choice([ | |
| "blocked_north", "blocked_south", "blocked_east", "blocked_west", | |
| "safe_north", "safe_south", "safe_east", "safe_west", | |
| "fire_expanded", "fire_receded", | |
| ]) | |
| events.append(ScheduledEvent( | |
| step=rng.randint(50, 500), | |
| event_type="hazard_intel", | |
| payload={"intel": intel, "severity": rng.uniform(0.3, 1.0)}, | |
| )) | |
| r.scheduled_events = events | |
| responders.append(r) | |
| env.responders = responders | |
| def build_hardcore_world_v2(env: "VarahaEnv") -> None: | |
| """Ultra-hard variant: denser obstacles, more hazards, more targets.""" | |
| build_hardcore_world(env, ultra_hard=True) | |
| # --------------------------------------------------------------------------- | |
| # Environment | |
| # --------------------------------------------------------------------------- | |
| class VarahaEnv: | |
| """Core wildfire logistics simulation. | |
| Action format (dict):: | |
| { | |
| "ax": float, # desired acceleration x (m/s²) | |
| "ay": float, # desired acceleration y | |
| "az": float, # desired acceleration z | |
| "deliver": bool, # attempt delivery if near a target | |
| "recharge": bool, # attempt recharge if near base | |
| "tool_call": str, # optional: request_intel | battery_forecast | mission_report | |
| } | |
| Returns ``(obs_dict, reward, done, info_dict)`` per OpenAI-gym convention. | |
| """ | |
| def __init__(self, config: Optional[VarahaConfig] = None, | |
| world_fn: Optional[Any] = None) -> None: | |
| self.cfg = config or VarahaConfig() | |
| self._world_fn = world_fn | |
| self.base: BaseStation | |
| self.drone: DroneState | |
| self.targets: list[DeliveryTarget] = [] | |
| self.hazards: list[HazardRegion] = [] | |
| self.obstacles: list[ObstacleVolume] = [] | |
| self.cylinders: list[CylindricalObstacle] = [] | |
| self.responders: list[ResponderUnit] = [] | |
| self.step_count: int = 0 | |
| self.cumulative_reward: float = 0.0 | |
| self.done: bool = False | |
| self.trace: list[TracePoint] = [] | |
| self._prev_nearest_dist: float = 0.0 | |
| self._hazard_base_heights: list[float] = [] | |
| self._hazard_base_severities: list[float] = [] | |
| self.instructions: list[MissionInstruction] = [] | |
| self._instruction_cursor: int = 0 | |
| self._instruction_violations: int = 0 | |
| self._tool_history: list[str] = [] | |
| self._last_tool_result: dict[str, Any] = {} | |
| self._instruction_progress_reward: float = 0.0 | |
| self._rebuild_world() | |
| def _rebuild_world(self): | |
| if self._world_fn is not None: | |
| self._world_fn(self) | |
| else: | |
| self._build_demo_world() | |
| self._hazard_base_heights = [h.height for h in self.hazards] | |
| self._hazard_base_severities = [h.severity for h in self.hazards] | |
| # ------------------------------------------------------------------ | |
| # World setup | |
| # ------------------------------------------------------------------ | |
| def _build_demo_world(self) -> None: | |
| """Hardcoded 5 km demo scenario. | |
| Layout (top-down, +x → east, +y → north, 5 km × 5 km):: | |
| T3 (1000,4200) | |
| · | |
| H2 (900,3200) O2 [500-1500, 2600-3000] | |
| · | |
| · T2 (4100,2900) ← inside H1 fringe | |
| · H1 (3800,2600) | |
| · | |
| · O1 [2200-2800, 1000-2200] | |
| · | |
| · T1 (1800,600) | |
| · | |
| Base (250,250) | |
| - T2 sits inside the fringe of hazard H1 → brief hazard exposure required | |
| - T3 is behind obstacle O2 and near hazard H2 | |
| - O1 blocks direct mid-map routing from T1 to T2 | |
| - Drone can fly over obstacles if altitude > obstacle height | |
| - Total route ≈ 12 km, battery budget ≈ 300 units | |
| """ | |
| self.base = BaseStation(position=Vec3(250.0, 250.0, 0.0), recharge_radius=80.0) | |
| self.targets = [ | |
| DeliveryTarget( | |
| id="T1", position=Vec3(1800.0, 600.0, 30.0), | |
| urgency=0.6, delivery_radius=80.0, | |
| ), | |
| DeliveryTarget( | |
| id="T2", position=Vec3(4100.0, 2900.0, 50.0), | |
| urgency=1.0, delivery_radius=120.0, | |
| ), | |
| DeliveryTarget( | |
| id="T3", position=Vec3(1000.0, 4200.0, 20.0), | |
| urgency=0.8, delivery_radius=100.0, | |
| ), | |
| ] | |
| self.hazards = [ | |
| HazardRegion( | |
| id="H1", center=Vec3(3800.0, 2600.0, 0.0), | |
| radius=500.0, severity=0.9, | |
| height=70.0, growth_rate=0.005, | |
| ), | |
| HazardRegion( | |
| id="H2", center=Vec3(900.0, 3200.0, 0.0), | |
| radius=400.0, severity=0.7, | |
| height=55.0, growth_rate=0.008, | |
| ), | |
| ] | |
| self.obstacles = [ | |
| ObstacleVolume( | |
| id="O1", | |
| min_corner=Vec3(2200.0, 1000.0, 0.0), | |
| max_corner=Vec3(2800.0, 2200.0, 120.0), | |
| ), | |
| ObstacleVolume( | |
| id="O2", | |
| min_corner=Vec3(500.0, 2600.0, 0.0), | |
| max_corner=Vec3(1500.0, 3000.0, 90.0), | |
| ), | |
| ] | |
| # ------------------------------------------------------------------ | |
| # Core API | |
| # ------------------------------------------------------------------ | |
| def reset(self, seed: Optional[int] = None) -> dict[str, Any]: | |
| """Reset the environment and return the initial observation.""" | |
| if seed is not None: | |
| random.seed(seed) | |
| if self._world_fn is not None: | |
| self._rebuild_world() | |
| self.drone = DroneState( | |
| position=Vec3(self.base.position.x, self.base.position.y, 0.0), | |
| velocity=Vec3(0.0, 0.0, 0.0), | |
| battery=self.cfg.battery_capacity, | |
| carrying_payload=True, | |
| alive=True, | |
| ) | |
| for t in self.targets: | |
| t.delivered = False | |
| for i, h in enumerate(self.hazards): | |
| h.height = self._hazard_base_heights[i] * random.uniform(0.85, 1.15) | |
| h.severity = max(0.3, min(1.0, self._hazard_base_severities[i] + random.uniform(-0.1, 0.1))) | |
| h.reset() | |
| for r in self.responders: | |
| r.active = True | |
| r.status = "stable" | |
| r.latest_intel = "none" | |
| r.intel_severity = 0.0 | |
| r.message = "" | |
| for ev in r.scheduled_events: | |
| ev.fired = False | |
| self._target_base_positions = { | |
| t.id: Vec3(t.position.x, t.position.y, t.position.z) | |
| for t in self.targets | |
| } | |
| self._build_instruction_program() | |
| self._instruction_progress_reward = 0.0 | |
| self._last_tool_result = {} | |
| self._tool_history = [] | |
| self.step_count = 0 | |
| self.cumulative_reward = 0.0 | |
| self.done = False | |
| self.trace = [] | |
| self._prev_nearest_dist = self._nearest_target_dist() | |
| obs = self.get_observation() | |
| self.trace.append(TracePoint( | |
| step=0, | |
| position=Vec3(self.drone.position.x, self.drone.position.y, self.drone.position.z), | |
| velocity=Vec3(0.0, 0.0, 0.0), | |
| battery=self.drone.battery, | |
| reward=0.0, | |
| cumulative_reward=0.0, | |
| events=["reset"], | |
| observation=obs, | |
| )) | |
| return obs | |
| def step(self, action: dict[str, Any]) -> tuple[dict, float, bool, dict]: | |
| """Advance the simulation by one timestep. | |
| Returns ``(observation, reward, done, info)``. | |
| """ | |
| if self.done: | |
| return self.get_observation(), 0.0, True, StepInfo().to_dict() | |
| self.step_count += 1 | |
| # --- parse & clamp acceleration --- | |
| accel = Vec3( | |
| float(action.get("ax", 0.0)), | |
| float(action.get("ay", 0.0)), | |
| float(action.get("az", 0.0)), | |
| ).clamp_magnitude(self.cfg.max_acceleration) | |
| # --- kinematics (Euler integration) --- | |
| self.drone.velocity = ( | |
| self.drone.velocity + accel.scale(self.cfg.dt) | |
| ).clamp_magnitude(self.cfg.max_speed) | |
| old_pos = Vec3(self.drone.position.x, self.drone.position.y, self.drone.position.z) | |
| self.drone.position = self.drone.position + self.drone.velocity.scale(self.cfg.dt) | |
| # clamp to world bounds | |
| self.drone.position.x = max(0.0, min(self.cfg.world_x, self.drone.position.x)) | |
| self.drone.position.y = max(0.0, min(self.cfg.world_y, self.drone.position.y)) | |
| self.drone.position.z = max(0.0, min(self.cfg.world_z, self.drone.position.z)) | |
| dist_traveled = old_pos.distance_to(self.drone.position) | |
| elevation_change = abs(self.drone.position.z - old_pos.z) | |
| # --- battery --- | |
| drain = self._compute_battery_drain(dist_traveled, elevation_change) | |
| self.drone.battery -= drain | |
| # --- advance dynamic hazards --- | |
| for h in self.hazards: | |
| h.tick() | |
| # --- advance responder events --- | |
| self._tick_responders() | |
| # --- world interactions --- | |
| collision = self._check_collisions() | |
| in_hazard, hazard_sev = self._check_hazards() | |
| tool_call = "" | |
| tool_result: dict[str, Any] = {} | |
| raw_tool_call = action.get("tool_call") | |
| if raw_tool_call is not None and str(raw_tool_call).strip(): | |
| tool_call, tool_result = self._execute_tool_call(str(raw_tool_call).strip()) | |
| prev_instruction_cursor = self._instruction_cursor | |
| delivered_ids: list[str] = [] | |
| if action.get("deliver", False): | |
| delivered_ids = self._deliver_targets() | |
| reached_base = ( | |
| ((self.drone.position.x - self.base.position.x) ** 2 | |
| + (self.drone.position.y - self.base.position.y) ** 2) ** 0.5 | |
| <= self.base.recharge_radius | |
| ) | |
| if action.get("recharge", False) and reached_base: | |
| self.drone.battery = min( | |
| self.cfg.battery_capacity, | |
| self.drone.battery + self.cfg.recharge_rate, | |
| ) | |
| self._update_instruction_progress( | |
| delivered_ids=delivered_ids, | |
| reached_base=reached_base, | |
| tool_call=tool_call, | |
| ) | |
| completed_now = max(0, self._instruction_cursor - prev_instruction_cursor) | |
| if self._all_delivered(): | |
| self.drone.carrying_payload = False | |
| # --- reward --- | |
| info = StepInfo( | |
| collision=collision, | |
| delivered_target_ids=delivered_ids, | |
| in_hazard=in_hazard, | |
| hazard_severity=hazard_sev, | |
| reached_base=reached_base, | |
| distance_traveled=dist_traveled, | |
| tool_call=tool_call, | |
| tool_result=tool_result, | |
| instruction_completed=self._instruction_cursor, | |
| instruction_total=len(self.instructions), | |
| instruction_violations=self._instruction_violations, | |
| ) | |
| reward, breakdown = self._compute_reward(info) | |
| info.reward_breakdown = breakdown | |
| self.cumulative_reward += reward | |
| # --- termination --- | |
| if collision: | |
| self.drone.alive = False | |
| self.done = True | |
| elif self.drone.battery <= 0.0: | |
| self.drone.battery = 0.0 | |
| self.drone.alive = False | |
| self.done = True | |
| elif self._is_success(): | |
| self.done = True | |
| elif self.step_count >= self.cfg.max_episode_steps: | |
| self.done = True | |
| # record trace | |
| events: list[str] = [] | |
| for tid in delivered_ids: | |
| events.append(f"delivered_{tid}") | |
| if collision: | |
| events.append("collision") | |
| if in_hazard: | |
| events.append(f"hazard_{hazard_sev:.2f}") | |
| if self.drone.battery <= 0.0 and not collision: | |
| events.append("battery_dead") | |
| if self._is_success(): | |
| events.append("success") | |
| if tool_call: | |
| events.append(f"tool_{tool_call}") | |
| if completed_now > 0: | |
| events.append(f"instruction+{completed_now}") | |
| obs = self.get_observation() | |
| self.trace.append(TracePoint( | |
| step=self.step_count, | |
| position=Vec3(self.drone.position.x, self.drone.position.y, self.drone.position.z), | |
| velocity=Vec3(self.drone.velocity.x, self.drone.velocity.y, self.drone.velocity.z), | |
| battery=self.drone.battery, | |
| reward=reward, | |
| cumulative_reward=self.cumulative_reward, | |
| events=events, | |
| observation=obs, | |
| )) | |
| return obs, reward, self.done, info.to_dict() | |
| # ------------------------------------------------------------------ | |
| # Observation / render | |
| # ------------------------------------------------------------------ | |
| def get_observation(self) -> dict[str, Any]: | |
| """Compact, RL-friendly observation dict.""" | |
| dp = self.drone.position | |
| targets_obs = [] | |
| for t in self.targets: | |
| rel = t.position - dp | |
| targets_obs.append({ | |
| "id": t.id, | |
| "relative_position": rel.to_dict(), | |
| "urgency": t.urgency, | |
| "delivered": t.delivered, | |
| }) | |
| hazards_obs = [] | |
| for h in self.hazards: | |
| rel = h.center - dp | |
| hazards_obs.append({ | |
| "id": h.id, | |
| "relative_position": rel.to_dict(), | |
| "current_height": h._current_height, | |
| "severity": h.severity, | |
| }) | |
| obstacles_obs = [] | |
| for obs in self.obstacles: | |
| c = obs.center | |
| hs = obs.half_size | |
| rel = c - dp | |
| dist = dp.horizontal_distance_to(c) | |
| obstacles_obs.append({ | |
| "type": "box", | |
| "relative_position": rel.to_dict(), | |
| "height": obs.height, | |
| "size_x": hs.x * 2, | |
| "size_y": hs.y * 2, | |
| "distance": dist, | |
| "kind": obs.kind, | |
| }) | |
| for cyl in self.cylinders: | |
| rel = cyl.center - dp | |
| dist = dp.horizontal_distance_to(cyl.center) | |
| obstacles_obs.append({ | |
| "type": "cylinder", | |
| "relative_position": rel.to_dict(), | |
| "height": cyl.height, | |
| "size_x": cyl.radius * 2, | |
| "size_y": cyl.radius * 2, | |
| "distance": dist, | |
| "kind": cyl.kind, | |
| }) | |
| obstacles_obs.sort(key=lambda o: o["distance"]) | |
| responders_obs = [] | |
| for r in self.responders: | |
| if not r.active: | |
| continue | |
| rel = r.position - dp | |
| intel_dir = r.intel_direction() | |
| responders_obs.append({ | |
| "id": r.id, | |
| "relative_position": rel.to_dict(), | |
| "linked_target_id": r.linked_target_id, | |
| "status": r.status, | |
| "status_code": r.status_code(), | |
| "latest_intel": r.latest_intel, | |
| "intel_direction": {"x": intel_dir[0], "y": intel_dir[1]}, | |
| "intel_severity": r.intel_severity, | |
| }) | |
| mission_obs = self._instruction_snapshot() | |
| return { | |
| "drone_position": dp.to_dict(), | |
| "drone_velocity": self.drone.velocity.to_dict(), | |
| "battery": round(self.drone.battery, 4), | |
| "carrying_payload": self.drone.carrying_payload, | |
| "alive": self.drone.alive, | |
| "targets": targets_obs, | |
| "hazards": hazards_obs, | |
| "obstacles": obstacles_obs, | |
| "responders": responders_obs, | |
| "mission": mission_obs, | |
| "last_tool_result": self._last_tool_result, | |
| "step": self.step_count, | |
| "max_steps": self.cfg.max_episode_steps, | |
| } | |
| def render_state(self) -> dict[str, Any]: | |
| """Rich state dict for future Cesium / frontend rendering.""" | |
| return { | |
| "base_station": self.base.to_dict(), | |
| "drone": self.drone.to_dict(), | |
| "targets": [t.to_dict() for t in self.targets], | |
| "hazards": [h.to_dict() for h in self.hazards], | |
| "obstacles": [o.to_dict() for o in self.obstacles], | |
| "cylinders": [c.to_dict() for c in self.cylinders], | |
| "responders": [r.to_dict() for r in self.responders], | |
| "mission": self._instruction_snapshot(include_full=True), | |
| "tool_history": list(self._tool_history), | |
| "step": self.step_count, | |
| "max_steps": self.cfg.max_episode_steps, | |
| "cumulative_reward": round(self.cumulative_reward, 4), | |
| "done": self.done, | |
| } | |
| def get_trace(self) -> dict[str, Any]: | |
| """Full episode trace for replay / visualisation.""" | |
| return { | |
| "world": { | |
| "bounds": {"x": self.cfg.world_x, "y": self.cfg.world_y, "z": self.cfg.world_z}, | |
| "base_station": self.base.to_dict(), | |
| "targets": [t.to_dict() for t in self.targets], | |
| "hazards": [h.to_dict() for h in self.hazards], | |
| "obstacles": [o.to_dict() for o in self.obstacles], | |
| "cylinders": [c.to_dict() for c in self.cylinders], | |
| "responders": [r.to_dict() for r in self.responders], | |
| "mission": self._instruction_snapshot(include_full=True), | |
| }, | |
| "trace": [tp.to_dict() for tp in self.trace], | |
| "summary": { | |
| "total_steps": self.step_count, | |
| "cumulative_reward": round(self.cumulative_reward, 4), | |
| "delivered": [t.id for t in self.targets if t.delivered], | |
| "alive": self.drone.alive, | |
| "final_battery": round(self.drone.battery, 4), | |
| "success": self._is_success(), | |
| "instruction_completed": self._instruction_cursor, | |
| "instruction_total": len(self.instructions), | |
| "instruction_violations": self._instruction_violations, | |
| "tool_calls": list(self._tool_history), | |
| }, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Long-horizon instruction mode | |
| # ------------------------------------------------------------------ | |
| def _build_instruction_program(self) -> None: | |
| self.instructions = [] | |
| self._instruction_cursor = 0 | |
| self._instruction_violations = 0 | |
| if not self.cfg.instruction_mode or not self.targets: | |
| return | |
| ordered_targets = sorted(self.targets, key=lambda t: (-t.urgency, t.id)) | |
| target_count = len(ordered_targets) | |
| desired_len = self.cfg.instruction_count if self.cfg.instruction_count > 0 else (target_count * 3 + 1) | |
| desired_len = max(desired_len, target_count * 2 + 1) | |
| instructions: list[MissionInstruction] = [] | |
| inst_idx = 1 | |
| cycle = 0 | |
| while len(instructions) < max(desired_len - 1, 1): | |
| for tgt in ordered_targets: | |
| if len(instructions) >= max(desired_len - 1, 1): | |
| break | |
| instructions.append( | |
| MissionInstruction( | |
| id=f"I{inst_idx}", | |
| kind="deliver_target", | |
| description=f"Cycle {cycle + 1}: deliver to {tgt.id} in order.", | |
| target_id=tgt.id, | |
| ) | |
| ) | |
| inst_idx += 1 | |
| if len(instructions) >= max(desired_len - 1, 1): | |
| break | |
| tool = "request_intel" if (cycle % 2 == 0) else "battery_forecast" | |
| instructions.append( | |
| MissionInstruction( | |
| id=f"I{inst_idx}", | |
| kind="tool_call", | |
| description=f"Call {tool} after servicing {tgt.id}.", | |
| target_id=tgt.id, | |
| tool_name=tool, | |
| ) | |
| ) | |
| inst_idx += 1 | |
| cycle += 1 | |
| instructions.append( | |
| MissionInstruction( | |
| id=f"I{inst_idx}", | |
| kind="return_base", | |
| description="Return to base only after all deliveries are completed.", | |
| ) | |
| ) | |
| self.instructions = instructions | |
| def _current_instruction(self) -> Optional[MissionInstruction]: | |
| if self._instruction_cursor >= len(self.instructions): | |
| return None | |
| return self.instructions[self._instruction_cursor] | |
| def _instruction_snapshot(self, include_full: bool = False) -> dict[str, Any]: | |
| total = len(self.instructions) | |
| completed = min(self._instruction_cursor, total) | |
| next_instruction = self._current_instruction() | |
| out: dict[str, Any] = { | |
| "enabled": self.cfg.instruction_mode, | |
| "total": total, | |
| "completed": completed, | |
| "remaining": max(total - completed, 0), | |
| "progress": (completed / total) if total > 0 else 1.0, | |
| "violations": self._instruction_violations, | |
| "next_instruction": next_instruction.to_dict() if next_instruction else None, | |
| } | |
| if include_full: | |
| out["instructions"] = [inst.to_dict() for inst in self.instructions] | |
| return out | |
| def _complete_current_instruction(self) -> None: | |
| inst = self._current_instruction() | |
| if inst is None: | |
| return | |
| inst.completed = True | |
| self._instruction_cursor += 1 | |
| self._instruction_progress_reward += self.cfg.instruction_completion_reward | |
| def _record_instruction_violation(self) -> None: | |
| self._instruction_violations += 1 | |
| inst = self._current_instruction() | |
| if inst is not None: | |
| inst.violated = True | |
| def _tool_matches_instruction(self, tool_call: str, inst: MissionInstruction) -> bool: | |
| base, _, arg = tool_call.partition(":") | |
| if base != inst.tool_name: | |
| return False | |
| if inst.target_id and arg and arg != inst.target_id: | |
| return False | |
| return True | |
| def _update_instruction_progress( | |
| self, | |
| delivered_ids: list[str], | |
| reached_base: bool, | |
| tool_call: str, | |
| ) -> None: | |
| if not self.cfg.instruction_mode or not self.instructions: | |
| return | |
| inst = self._current_instruction() | |
| if inst and inst.kind == "deliver_target": | |
| for tid in delivered_ids: | |
| if tid != inst.target_id: | |
| self._record_instruction_violation() | |
| while True: | |
| inst = self._current_instruction() | |
| if inst is None: | |
| break | |
| if inst.kind == "deliver_target": | |
| if inst.target_id in delivered_ids: | |
| self._complete_current_instruction() | |
| continue | |
| break | |
| if inst.kind == "tool_call": | |
| if not tool_call: | |
| break | |
| if self._tool_matches_instruction(tool_call, inst): | |
| self._complete_current_instruction() | |
| else: | |
| self._record_instruction_violation() | |
| break | |
| if inst.kind == "return_base": | |
| if reached_base and self._all_delivered(): | |
| self._complete_current_instruction() | |
| break | |
| break | |
| def _execute_tool_call(self, tool_call: str) -> tuple[str, dict[str, Any]]: | |
| raw = tool_call.strip().lower() | |
| if not raw: | |
| return "", {} | |
| tool_name, _, arg = raw.partition(":") | |
| normalized_call = f"{tool_name}:{arg}" if arg else tool_name | |
| if tool_name not in self.cfg.available_tools: | |
| result = {"ok": False, "error": f"unsupported_tool:{tool_name}"} | |
| self._tool_history.append(normalized_call) | |
| self._last_tool_result = result | |
| return normalized_call, result | |
| if tool_name == "request_intel": | |
| responder = None | |
| if arg: | |
| responder = next( | |
| (r for r in self.responders if r.active and r.linked_target_id.lower() == arg.lower()), | |
| None, | |
| ) | |
| if responder is None: | |
| responder = next((r for r in self.responders if r.active), None) | |
| if responder is None: | |
| result = {"ok": True, "intel": "none", "message": "no_active_responders"} | |
| else: | |
| result = { | |
| "ok": True, | |
| "intel": responder.latest_intel, | |
| "intel_severity": round(responder.intel_severity, 3), | |
| "responder_id": responder.id, | |
| "target_id": responder.linked_target_id, | |
| "message": responder.message, | |
| } | |
| elif tool_name == "battery_forecast": | |
| burn = max(self.cfg.drain_per_meter, 1e-6) | |
| est_range = self.drone.battery / burn | |
| result = { | |
| "ok": True, | |
| "battery": round(self.drone.battery, 3), | |
| "estimated_range_m": round(est_range, 1), | |
| } | |
| else: # mission_report | |
| result = { | |
| "ok": True, | |
| "delivered": [t.id for t in self.targets if t.delivered], | |
| "remaining": [t.id for t in self.targets if not t.delivered], | |
| "instruction_progress": round(self._instruction_snapshot()["progress"], 3), | |
| "violations": self._instruction_violations, | |
| } | |
| self._tool_history.append(normalized_call) | |
| self._last_tool_result = result | |
| return normalized_call, result | |
| # ------------------------------------------------------------------ | |
| # Coordinate conversion | |
| # ------------------------------------------------------------------ | |
| def local_to_latlon(self, vec: Vec3) -> tuple[float, float, float]: | |
| """Convert local (x, y, z) metres to (lat, lon, alt). | |
| Uses a flat-earth approximation centred on ``cfg.origin_lat/lon``. | |
| Accurate enough for small areas (~tens of km) and Cesium plotting. | |
| """ | |
| meters_per_deg_lat = 111_320.0 | |
| meters_per_deg_lon = 111_320.0 * math.cos(math.radians(self.cfg.origin_lat)) | |
| lat = self.cfg.origin_lat + vec.y / meters_per_deg_lat | |
| lon = self.cfg.origin_lon + vec.x / meters_per_deg_lon | |
| alt = vec.z | |
| return (round(lat, 7), round(lon, 7), round(alt, 2)) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _compute_battery_drain(self, dist: float, elevation_change: float) -> float: | |
| return ( | |
| dist * self.cfg.drain_per_meter | |
| + elevation_change * self.cfg.drain_elevation_factor | |
| + self.cfg.drain_idle_per_step | |
| ) | |
| def _check_collisions(self) -> bool: | |
| for obs in self.obstacles: | |
| if obs.contains(self.drone.position): | |
| return True | |
| for cyl in self.cylinders: | |
| if cyl.contains(self.drone.position): | |
| return True | |
| return False | |
| def _check_hazards(self) -> tuple[bool, float]: | |
| max_sev = 0.0 | |
| in_hazard = False | |
| for h in self.hazards: | |
| df = h.danger_factor(self.drone.position) | |
| if df > 0.0: | |
| in_hazard = True | |
| max_sev = max(max_sev, df) | |
| return in_hazard, max_sev | |
| def _deliver_targets(self) -> list[str]: | |
| """Cylindrical delivery check — drone must be within horizontal radius | |
| and above the target (within a generous altitude window for drops).""" | |
| delivered: list[str] = [] | |
| for t in self.targets: | |
| if t.delivered: | |
| continue | |
| dx = self.drone.position.x - t.position.x | |
| dy = self.drone.position.y - t.position.y | |
| horiz_dist = (dx * dx + dy * dy) ** 0.5 | |
| alt_above = self.drone.position.z - t.position.z | |
| if horiz_dist <= t.delivery_radius and -10.0 <= alt_above <= t.delivery_radius * 2: | |
| t.delivered = True | |
| delivered.append(t.id) | |
| return delivered | |
| def _all_delivered(self) -> bool: | |
| return all(t.delivered for t in self.targets) | |
| def _is_success(self) -> bool: | |
| hdist = ((self.drone.position.x - self.base.position.x) ** 2 | |
| + (self.drone.position.y - self.base.position.y) ** 2) ** 0.5 | |
| return self._all_delivered() and hdist <= self.base.recharge_radius | |
| def _nearest_target_dist(self) -> float: | |
| """Horizontal distance to closest undelivered target, or to base if all done.""" | |
| dists = [ | |
| ((self.drone.position.x - t.position.x) ** 2 | |
| + (self.drone.position.y - t.position.y) ** 2) ** 0.5 | |
| for t in self.targets | |
| if not t.delivered | |
| ] | |
| if not dists: | |
| return ((self.drone.position.x - self.base.position.x) ** 2 | |
| + (self.drone.position.y - self.base.position.y) ** 2) ** 0.5 | |
| return min(dists) | |
| def _tick_responders(self) -> None: | |
| """Process scheduled responder events for the current step.""" | |
| for r in self.responders: | |
| if not r.active: | |
| continue | |
| for ev in r.scheduled_events: | |
| if ev.fired or ev.step != self.step_count: | |
| continue | |
| ev.fired = True | |
| etype = ev.event_type | |
| if etype == "urgency_update": | |
| tgt = self._find_target(r.linked_target_id) | |
| if tgt and not tgt.delivered: | |
| tgt.urgency = max(0.1, min(1.0, ev.payload.get("new_urgency", tgt.urgency))) | |
| r.status = "critical" if tgt.urgency >= 0.9 else "urgent" if tgt.urgency >= 0.6 else "stable" | |
| r.message = f"urgency->{tgt.urgency:.1f}" | |
| elif etype == "dropzone_relocation": | |
| tgt = self._find_target(r.linked_target_id) | |
| if tgt and not tgt.delivered and r.can_update_dropzone: | |
| dx = ev.payload.get("dx", 0.0) | |
| dy = ev.payload.get("dy", 0.0) | |
| tgt.position.x = max(50, min(self.cfg.world_x - 50, tgt.position.x + dx)) | |
| tgt.position.y = max(50, min(self.cfg.world_y - 50, tgt.position.y + dy)) | |
| r.position = Vec3(tgt.position.x, tgt.position.y, 0.0) | |
| r.message = f"dropzone moved ({dx:+.0f},{dy:+.0f})" | |
| self._prev_nearest_dist = self._nearest_target_dist() | |
| elif etype == "hazard_intel": | |
| r.latest_intel = ev.payload.get("intel", "none") | |
| r.intel_severity = ev.payload.get("severity", 0.5) | |
| r.message = f"intel: {r.latest_intel}" | |
| def _find_target(self, tid: str) -> Optional[DeliveryTarget]: | |
| for t in self.targets: | |
| if t.id == tid: | |
| return t | |
| return None | |
| def _obstacle_proximity_penalty(self) -> float: | |
| """Graduated penalty for flying close to any obstacle surface.""" | |
| min_dist = float("inf") | |
| pos = self.drone.position | |
| for obs in self.obstacles: | |
| d = obs.nearest_surface_dist(pos) | |
| if d < min_dist: | |
| min_dist = d | |
| for cyl in self.cylinders: | |
| d = cyl.nearest_surface_dist(pos) | |
| if d < min_dist: | |
| min_dist = d | |
| if min_dist >= self.cfg.obstacle_proximity_radius: | |
| return 0.0 | |
| factor = 1.0 - min_dist / self.cfg.obstacle_proximity_radius | |
| return self.cfg.obstacle_proximity_penalty * factor * factor | |
| def _compute_reward(self, info: StepInfo) -> tuple[float, dict[str, float]]: | |
| if self.cfg.instruction_mode and self.cfg.sparse_reward_mode: | |
| return self._compute_sparse_instruction_reward(info) | |
| bd: dict[str, float] = {} | |
| total = 0.0 | |
| # per-step cost of time | |
| bd["step_penalty"] = -self.cfg.step_penalty | |
| total += bd["step_penalty"] | |
| # battery usage cost (proportional to energy spent) | |
| bd["battery_cost"] = -( | |
| info.distance_traveled * self.cfg.drain_per_meter * self.cfg.battery_cost_factor | |
| ) | |
| total += bd["battery_cost"] | |
| if self._instruction_progress_reward > 0.0: | |
| bd["instruction_progress"] = self._instruction_progress_reward | |
| total += bd["instruction_progress"] | |
| self._instruction_progress_reward = 0.0 | |
| # delivery rewards (scaled by urgency) + progress bonus | |
| for tid in info.delivered_target_ids: | |
| tgt = next(t for t in self.targets if t.id == tid) | |
| r = self.cfg.delivery_reward * (1.0 + tgt.urgency) | |
| bd[f"delivery_{tid}"] = r | |
| total += r | |
| if info.delivered_target_ids: | |
| n_remaining = sum(1 for t in self.targets if not t.delivered) | |
| progress_bonus = 50.0 * (1.0 - n_remaining / len(self.targets)) | |
| bd["progress_bonus"] = progress_bonus | |
| total += progress_bonus | |
| # collision | |
| if info.collision: | |
| bd["collision"] = -self.cfg.collision_penalty | |
| total += bd["collision"] | |
| # hazard exposure (severity-weighted) | |
| if info.in_hazard: | |
| bd["hazard"] = -self.cfg.hazard_penalty * info.hazard_severity | |
| total += bd["hazard"] | |
| # safe return bonus | |
| if info.reached_base and self._all_delivered(): | |
| bd["return_bonus"] = self.cfg.return_bonus | |
| total += bd["return_bonus"] | |
| # distance shaping — nudge toward nearest undelivered target (or base) | |
| # Skip shaping on delivery steps to avoid a huge negative spike | |
| # when the nearest-target reference jumps to a farther target. | |
| # Double the factor when heading home after all deliveries. | |
| curr_dist = self._nearest_target_dist() | |
| if info.delivered_target_ids: | |
| bd["distance_shaping"] = 0.0 | |
| self._prev_nearest_dist = curr_dist | |
| else: | |
| factor = self.cfg.distance_shaping_factor | |
| if self._all_delivered(): | |
| factor *= 2.0 | |
| shaping = (self._prev_nearest_dist - curr_dist) * factor | |
| bd["distance_shaping"] = shaping | |
| total += shaping | |
| self._prev_nearest_dist = curr_dist | |
| # obstacle proximity (graduated — discourages flying close) | |
| prox = self._obstacle_proximity_penalty() | |
| if prox > 0: | |
| bd["obstacle_proximity"] = -prox | |
| total -= prox | |
| # failure (battery depletion; collision already penalised above) | |
| if self.drone.battery <= 0.0 and not info.collision: | |
| bd["failure"] = -self.cfg.failure_penalty | |
| total += bd["failure"] | |
| bd["total"] = total | |
| return total, bd | |
| def _compute_sparse_instruction_reward(self, info: StepInfo) -> tuple[float, dict[str, float]]: | |
| bd: dict[str, float] = {} | |
| total = 0.0 | |
| # Keep shaping intentionally small in sparse mode. | |
| bd["step_penalty"] = -(self.cfg.step_penalty * 0.25) | |
| total += bd["step_penalty"] | |
| if self._instruction_progress_reward > 0.0: | |
| bd["instruction_progress"] = self._instruction_progress_reward | |
| total += bd["instruction_progress"] | |
| self._instruction_progress_reward = 0.0 | |
| if info.in_hazard: | |
| bd["hazard"] = -(self.cfg.hazard_penalty * 0.2 * info.hazard_severity) | |
| total += bd["hazard"] | |
| terminal = ( | |
| info.collision | |
| or self.drone.battery <= 0.0 | |
| or self._is_success() | |
| or self.step_count >= self.cfg.max_episode_steps | |
| ) | |
| if terminal: | |
| total_instr = len(self.instructions) | |
| progress = (self._instruction_cursor / total_instr) if total_instr > 0 else 1.0 | |
| bd["terminal_progress"] = self.cfg.instruction_terminal_progress_bonus * progress | |
| total += bd["terminal_progress"] | |
| if self._is_success(): | |
| bd["terminal_success"] = self.cfg.instruction_terminal_success_bonus | |
| total += bd["terminal_success"] | |
| else: | |
| bd["terminal_failure"] = -self.cfg.failure_penalty | |
| total += bd["terminal_failure"] | |
| remaining = max(total_instr - self._instruction_cursor, 0) | |
| if remaining > 0: | |
| bd["unfinished_penalty"] = -remaining * self.cfg.instruction_unfinished_penalty | |
| total += bd["unfinished_penalty"] | |
| if self._instruction_violations > 0: | |
| bd["instruction_violations"] = ( | |
| -self._instruction_violations * self.cfg.instruction_violation_penalty | |
| ) | |
| total += bd["instruction_violations"] | |
| if info.collision: | |
| bd["collision"] = -self.cfg.collision_penalty | |
| total += bd["collision"] | |
| bd["total"] = total | |
| return total, bd | |