""" Drone Forest Navigation Environment. A quadrotor drone navigates through a forest of columns (trees) to reach a target. The RL policy commands velocity (forward/left/up/turn) while a built-in PD flight controller handles low-level motor mixing. """ import base64 import io import os import sys from pathlib import Path from typing import Any, Dict, List, Optional from uuid import uuid4 # Configure MuJoCo rendering backend before importing mujoco if "MUJOCO_GL" not in os.environ and sys.platform != "darwin": os.environ.setdefault("MUJOCO_GL", "egl") import numpy as np try: from openenv.core.env_server.interfaces import Environment from ..models import DMControlAction, DMControlObservation, DMControlState except ImportError: from openenv.core.env_server.interfaces import Environment try: import sys as _sys from pathlib import Path as _Path _parent = str(_Path(__file__).parent.parent) if _parent not in _sys.path: _sys.path.insert(0, _parent) from models import DMControlAction, DMControlObservation, DMControlState except ImportError: try: from dm_control_env.models import ( DMControlAction, DMControlObservation, DMControlState, ) except ImportError: from envs.dm_control_env.models import ( DMControlAction, DMControlObservation, DMControlState, ) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- NUM_TREES = 25 ARENA_HALF = 10.0 # arena is 20x20 m MAX_ALTITUDE = 8.0 MIN_ALTITUDE = 0.1 TARGET_RADIUS = 0.5 # success if within this distance TREE_MIN_SPACING = 1.5 # min distance between tree centres SPAWN_CLEAR_RADIUS = 2.0 # keep trees away from spawn TARGET_MIN_DIST = 5.0 # target at least this far from spawn MAX_STEPS = 1000 PHYSICS_DT = 0.002 CONTROL_DT = 0.02 # 50 Hz control # Velocity limits MAX_XY_VEL = 3.0 # m/s MAX_Z_VEL = 2.0 # m/s MAX_YAW_RATE = 2.0 # rad/s # Flight-controller PD gains KP_VEL = 4.0 KD_VEL = 1.5 KP_ATT = 8.0 KD_ATT = 2.0 # Drone physical parameters DRONE_MASS = 0.48 # total mass (body 0.4 + arms 0.08) close to XML GRAVITY = 9.81 HOVER_THRUST = DRONE_MASS * GRAVITY / 4.0 # per-motor hover ARM_LENGTH = 0.14 # distance from CoM to rotor XML_PATH = str(Path(__file__).parent / "drone_forest.xml") class DroneForestEnvironment(Environment): """Drone navigates a randomised forest of columns to reach a target.""" SUPPORTS_CONCURRENT_SESSIONS = True def __init__( self, render_height: Optional[int] = None, render_width: Optional[int] = None, **kwargs, ): self._model = None self._data = None self._render_height = render_height or int( os.environ.get("DMCONTROL_RENDER_HEIGHT", "480") ) self._render_width = render_width or int( os.environ.get("DMCONTROL_RENDER_WIDTH", "640") ) self._include_pixels = False self._step_count = 0 self._prev_dist = None self._tree_positions: List[np.ndarray] = [] self._target_pos = np.zeros(3) self._done = False self._rng = np.random.RandomState() self._state = DMControlState( episode_id=str(uuid4()), step_count=0, domain_name="drone_forest", task_name="navigate", ) # ------------------------------------------------------------------ # Model loading # ------------------------------------------------------------------ def _ensure_model(self): """Load MuJoCo model if not loaded yet.""" if self._model is not None: return import mujoco self._model = mujoco.MjModel.from_xml_path(XML_PATH) self._data = mujoco.MjData(self._model) # Precompute body / geom ids self._drone_body_id = mujoco.mj_name2id( self._model, mujoco.mjtObj.mjOBJ_BODY, "drone" ) self._target_body_id = mujoco.mj_name2id( self._model, mujoco.mjtObj.mjOBJ_BODY, "target" ) self._tree_body_ids = [ mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_BODY, f"tree_{i}") for i in range(NUM_TREES) ] self._trunk_geom_ids = [ mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_GEOM, f"trunk_{i}") for i in range(NUM_TREES) ] self._drone_body_geom_id = mujoco.mj_name2id( self._model, mujoco.mjtObj.mjOBJ_GEOM, "drone_body" ) self._ground_geom_id = mujoco.mj_name2id( self._model, mujoco.mjtObj.mjOBJ_GEOM, "ground" ) # Set state metadata self._state.action_spec = { "shape": [4], "dtype": "float64", "minimum": [-1.0, -1.0, -1.0, -1.0], "maximum": [1.0, 1.0, 1.0, 1.0], "name": "velocity_command", } self._state.observation_spec = { "position": {"shape": [3], "dtype": "float64"}, "velocity": {"shape": [3], "dtype": "float64"}, "orientation": {"shape": [3], "dtype": "float64"}, "angular_velocity": {"shape": [3], "dtype": "float64"}, "target_relative": {"shape": [3], "dtype": "float64"}, "obstacle_distances": {"shape": [8], "dtype": "float64"}, } self._state.physics_timestep = PHYSICS_DT self._state.control_timestep = CONTROL_DT # ------------------------------------------------------------------ # Forest randomisation # ------------------------------------------------------------------ def _randomise_forest(self): """Place trees and target using rejection sampling.""" import mujoco positions = [] attempts = 0 while len(positions) < NUM_TREES and attempts < 5000: x = self._rng.uniform(-ARENA_HALF + 1, ARENA_HALF - 1) y = self._rng.uniform(-ARENA_HALF + 1, ARENA_HALF - 1) # Keep clear of spawn if np.sqrt(x ** 2 + y ** 2) < SPAWN_CLEAR_RADIUS: attempts += 1 continue # Min spacing from existing trees ok = True for p in positions: if np.sqrt((x - p[0]) ** 2 + (y - p[1]) ** 2) < TREE_MIN_SPACING: ok = False break if ok: positions.append(np.array([x, y])) attempts += 1 # Pad with far-away positions if we didn't get enough while len(positions) < NUM_TREES: positions.append(np.array([100.0, 100.0])) self._tree_positions = positions # Set tree body positions in the model for i, pos in enumerate(positions): body_id = self._tree_body_ids[i] self._model.body_pos[body_id] = [pos[0], pos[1], 0.0] # Place target: at least TARGET_MIN_DIST from origin, away from trees for _ in range(1000): angle = self._rng.uniform(0, 2 * np.pi) dist = self._rng.uniform(TARGET_MIN_DIST, ARENA_HALF - 2) tx, ty = dist * np.cos(angle), dist * np.sin(angle) tz = self._rng.uniform(1.0, 3.0) # Check clearance from trees clear = True for p in positions[:NUM_TREES]: if np.sqrt((tx - p[0]) ** 2 + (ty - p[1]) ** 2) < 1.5: clear = False break if clear: break self._target_pos = np.array([tx, ty, tz]) self._model.body_pos[self._target_body_id] = self._target_pos.copy() # Recompute derived quantities after changing body positions mujoco.mj_forward(self._model, self._data) # ------------------------------------------------------------------ # Flight controller # ------------------------------------------------------------------ def _flight_controller(self, cmd: np.ndarray) -> np.ndarray: """ Convert velocity commands [vx, vy, vz, yaw_rate] in [-1,1] to 4 motor thrusts. """ # Scale commands vx_cmd = cmd[0] * MAX_XY_VEL vy_cmd = cmd[1] * MAX_XY_VEL vz_cmd = cmd[2] * MAX_Z_VEL yaw_rate_cmd = cmd[3] * MAX_YAW_RATE # Current state pos = self._data.qpos[:3].copy() quat = self._data.qpos[3:7].copy() # w, x, y, z vel = self._data.qvel[:3].copy() ang_vel = self._data.qvel[3:6].copy() # Extract yaw from quaternion roll, pitch, yaw = self._quat_to_euler(quat) # Rotate desired world-frame velocity into body XY cos_yaw, sin_yaw = np.cos(yaw), np.sin(yaw) # World-frame desired velocity vx_world = vx_cmd * cos_yaw - vy_cmd * sin_yaw vy_world = vx_cmd * sin_yaw + vy_cmd * cos_yaw # Velocity error in world frame vx_err = vx_world - vel[0] vy_err = vy_world - vel[1] vz_err = vz_cmd - vel[2] # Desired roll/pitch from XY velocity error (small angle approx) desired_pitch = np.clip(KP_VEL * vx_err, -0.5, 0.5) desired_roll = np.clip(-KP_VEL * vy_err, -0.5, 0.5) # Attitude PD roll_err = desired_roll - roll pitch_err = desired_pitch - pitch yaw_rate_err = yaw_rate_cmd - ang_vel[2] torque_roll = KP_ATT * roll_err - KD_ATT * ang_vel[0] torque_pitch = KP_ATT * pitch_err - KD_ATT * ang_vel[1] torque_yaw = KP_ATT * yaw_rate_err # Collective thrust: hover + vertical velocity correction thrust = DRONE_MASS * GRAVITY + KP_VEL * vz_err * DRONE_MASS # Quadrotor mixer: convert thrust + torques to 4 motor thrusts # Layout: FR(+x,-y), FL(+x,+y), BR(-x,-y), BL(-x,+y) L = ARM_LENGTH t_fr = thrust / 4.0 + torque_pitch / (4.0 * L) - torque_roll / (4.0 * L) - torque_yaw / 4.0 t_fl = thrust / 4.0 + torque_pitch / (4.0 * L) + torque_roll / (4.0 * L) + torque_yaw / 4.0 t_br = thrust / 4.0 - torque_pitch / (4.0 * L) - torque_roll / (4.0 * L) + torque_yaw / 4.0 t_bl = thrust / 4.0 - torque_pitch / (4.0 * L) + torque_roll / (4.0 * L) - torque_yaw / 4.0 # Clamp to actuator range [0, 3] motors = np.clip([t_fr, t_fl, t_br, t_bl], 0.0, 3.0) return motors @staticmethod def _quat_to_euler(quat: np.ndarray): """Convert quaternion [w, x, y, z] to Euler angles [roll, pitch, yaw].""" w, x, y, z = quat # Roll (x-axis rotation) sinr = 2.0 * (w * x + y * z) cosr = 1.0 - 2.0 * (x * x + y * y) roll = np.arctan2(sinr, cosr) # Pitch (y-axis rotation) sinp = 2.0 * (w * y - z * x) sinp = np.clip(sinp, -1.0, 1.0) pitch = np.arcsin(sinp) # Yaw (z-axis rotation) siny = 2.0 * (w * z + x * y) cosy = 1.0 - 2.0 * (y * y + z * z) yaw = np.arctan2(siny, cosy) return roll, pitch, yaw # ------------------------------------------------------------------ # Observations # ------------------------------------------------------------------ def _get_obs(self) -> Dict[str, List[float]]: pos = self._data.qpos[:3].copy() vel = self._data.qvel[:3].copy() quat = self._data.qpos[3:7].copy() ang_vel = self._data.qvel[3:6].copy() roll, pitch, yaw = self._quat_to_euler(quat) target_rel = self._target_pos - pos # 8 nearest obstacle distances (XY plane, from drone position) dists = [] for tp in self._tree_positions: dx = tp[0] - pos[0] dy = tp[1] - pos[1] dists.append(np.sqrt(dx ** 2 + dy ** 2)) dists.sort() obstacle_distances = dists[:8] # Pad if fewer than 8 while len(obstacle_distances) < 8: obstacle_distances.append(50.0) return { "position": pos.tolist(), "velocity": vel.tolist(), "orientation": [float(roll), float(pitch), float(yaw)], "angular_velocity": ang_vel.tolist(), "target_relative": target_rel.tolist(), "obstacle_distances": obstacle_distances, } # ------------------------------------------------------------------ # Collision detection # ------------------------------------------------------------------ def _check_collisions(self) -> bool: """Return True if drone collides with any tree trunk or ground.""" import mujoco for i in range(self._data.ncon): contact = self._data.contact[i] g1, g2 = contact.geom1, contact.geom2 pair = {g1, g2} if self._drone_body_geom_id not in pair: continue other = (pair - {self._drone_body_geom_id}).pop() if other == self._ground_geom_id or other in self._trunk_geom_ids: return True return False # ------------------------------------------------------------------ # Reward # ------------------------------------------------------------------ def _compute_reward(self, pos: np.ndarray) -> float: dist = np.linalg.norm(self._target_pos - pos) reward = 0.0 # Shaping: reward for getting closer if self._prev_dist is not None: reward += 1.0 * (self._prev_dist - dist) self._prev_dist = dist # Time pressure reward -= 0.01 return float(reward) # ------------------------------------------------------------------ # Termination # ------------------------------------------------------------------ def _check_termination(self, pos: np.ndarray): """Returns (done, bonus_reward).""" dist = np.linalg.norm(self._target_pos - pos) # Success if dist < TARGET_RADIUS: return True, 100.0 # Collision if self._check_collisions(): return True, -50.0 # Out of bounds if (abs(pos[0]) > ARENA_HALF or abs(pos[1]) > ARENA_HALF or pos[2] > MAX_ALTITUDE or pos[2] < MIN_ALTITUDE): return True, -10.0 # Max steps if self._step_count >= MAX_STEPS: return True, 0.0 return False, 0.0 # ------------------------------------------------------------------ # Core interface # ------------------------------------------------------------------ def reset( self, domain_name: Optional[str] = None, task_name: Optional[str] = None, seed: Optional[int] = None, render: bool = False, **kwargs, ) -> DMControlObservation: import mujoco self._ensure_model() self._include_pixels = render if seed is not None: self._rng = np.random.RandomState(seed) # Reset data to defaults mujoco.mj_resetData(self._model, self._data) # Randomise forest layout self._randomise_forest() # Place drone at origin, altitude 1.5 self._data.qpos[:3] = [0.0, 0.0, 1.5] self._data.qpos[3:7] = [1.0, 0.0, 0.0, 0.0] # identity quaternion self._data.qvel[:] = 0.0 mujoco.mj_forward(self._model, self._data) self._step_count = 0 pos = self._data.qpos[:3].copy() self._prev_dist = float(np.linalg.norm(self._target_pos - pos)) self._done = False self._state = DMControlState( episode_id=str(uuid4()), step_count=0, domain_name="drone_forest", task_name="navigate", action_spec=self._state.action_spec, observation_spec=self._state.observation_spec, physics_timestep=PHYSICS_DT, control_timestep=CONTROL_DT, ) obs = self._get_obs() pixels = self._render_pixels() if render else None return DMControlObservation( observations=obs, pixels=pixels, reward=0.0, done=False, ) def step( self, action: DMControlAction, render: bool = False, **kwargs, ) -> DMControlObservation: import mujoco if self._model is None or self._data is None: raise RuntimeError("Environment not initialized. Call reset() first.") if self._done: raise RuntimeError("Episode is done. Call reset() to start a new episode.") # Clip action to [-1, 1] cmd = np.clip(np.array(action.values[:4], dtype=np.float64), -1.0, 1.0) # Run flight controller to get motor thrusts motors = self._flight_controller(cmd) # Set actuator controls self._data.ctrl[:4] = motors # Step physics for one control timestep (multiple physics substeps) n_substeps = int(CONTROL_DT / PHYSICS_DT) for _ in range(n_substeps): mujoco.mj_step(self._model, self._data) self._step_count += 1 self._state.step_count = self._step_count pos = self._data.qpos[:3].copy() # Compute reward and check termination reward = self._compute_reward(pos) done, bonus = self._check_termination(pos) reward += bonus self._done = done obs = self._get_obs() pixels = self._render_pixels() if (render or self._include_pixels) else None return DMControlObservation( observations=obs, pixels=pixels, reward=float(reward), done=done, ) async def reset_async(self, **kwargs) -> DMControlObservation: if sys.platform == "darwin": return self.reset(**kwargs) else: import asyncio return await asyncio.to_thread(self.reset, **kwargs) async def step_async(self, action: DMControlAction, render: bool = False, **kwargs) -> DMControlObservation: if sys.platform == "darwin": return self.step(action, render=render, **kwargs) else: import asyncio return await asyncio.to_thread(self.step, action, render=render, **kwargs) # ------------------------------------------------------------------ # Rendering # ------------------------------------------------------------------ def _render_pixels(self) -> Optional[str]: try: import mujoco renderer = mujoco.Renderer(self._model, height=self._render_height, width=self._render_width) renderer.update_scene(self._data, camera="tracking") frame = renderer.render() renderer.close() from PIL import Image img = Image.fromarray(frame) buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("utf-8") except Exception: return None @property def state(self) -> DMControlState: return self._state def close(self) -> None: self._model = None self._data = None def __del__(self): try: self.close() except Exception: pass