team222 / env_3d.py
ylop's picture
Fix: Configure PyFlyt for headless Docker environment
963fbec verified
raw
history blame
10.4 kB
import numpy as np
import gymnasium as gym
from gymnasium import spaces
import pybullet as p
from PyFlyt.core import Aviary
from PyFlyt.core.drones import QuadX
class Drone3DEnv(gym.Env):
metadata = {"render_modes": ["human"]}
def __init__(self, render_mode=None):
super().__init__()
self.render_mode = render_mode
# Constants
self.BOUNDS = np.array([[-5, -5, 0], [5, 5, 10]]) # x_min, y_min, z_min, x_max, y_max, z_max
self.TARGET_BOUNDS = np.array([[-4, -4, 1], [4, 4, 9]])
self.MAX_STEPS = 1000
self.WIND_SCALE = 1.0
# Initialize PyFlyt Aviary
self.start_pos = np.array([[0.0, 0.0, 1.0]])
self.start_orn = np.array([[0.0, 0.0, 0.0]])
# Headless mode for Docker: Always use DIRECT mode (no rendering)
# This prevents URDF loading errors in headless environments
self.aviary = Aviary(
start_pos=self.start_pos,
start_orn=self.start_orn,
drone_type="quadx",
render=False, # Force headless mode
use_camera=False # Disable cameras
)
# Action Space: 4 motors (0..1)
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float32)
# Observation Space: 12 dim
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(12,), dtype=np.float32)
self.wind_vector = np.zeros(3)
self.target_pos = np.zeros(3)
self.step_count = 0
# Visual IDs
self.wind_arrow_ids = []
self.target_visual_id = None
self.bound_box_ids = []
self.path_points = []
self.text_ids = []
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.aviary.reset()
self.step_count = 0
self.path_points = []
# Clear text
for uid in self.text_ids:
p.removeUserDebugItem(uid)
self.text_ids = []
# Disable PyBullet GUI sidebars but enable mouse
if self.render_mode == "human":
p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)
p.configureDebugVisualizer(p.COV_ENABLE_MOUSE_PICKING, 1)
# Zoom out camera
p.resetDebugVisualizerCamera(
cameraDistance=12.0,
cameraYaw=45,
cameraPitch=-30,
cameraTargetPosition=[0, 0, 5]
)
# Random Target
self.target_pos = np.random.uniform(self.TARGET_BOUNDS[0], self.TARGET_BOUNDS[1])
# Reset Wind (Random Walk start)
self.wind_vector = np.random.uniform(-1, 1, 3) * self.WIND_SCALE
self.wind_vector[2] = 0
# Visuals
if self.render_mode == "human":
self._draw_bounds()
self._draw_target()
self._init_wind_field()
return self._get_obs(), {}
def step(self, action):
self.step_count += 1
# 1. Update Wind (Chaotic)
self.wind_vector += np.random.normal(0, 0.1, 3)
self.wind_vector = np.clip(self.wind_vector, -5, 5)
# 2. Update Target (Moving)
# Random walk for target
target_move = np.random.normal(0, 0.05, 3)
self.target_pos += target_move
self.target_pos = np.clip(self.target_pos, self.TARGET_BOUNDS[0], self.TARGET_BOUNDS[1])
# 3. Apply Wind Force
drone = self.aviary.drones[0]
p.applyExternalForce(
drone.Id,
-1, # Base link
forceObj=self.wind_vector,
posObj=[0, 0, 0],
flags=p.LINK_FRAME
)
# 4. Step Physics
motor_command = (action + 1.0) / 2.0
self.aviary.set_all_setpoints(motor_command.reshape(1, -1))
self.aviary.step()
# 5. Get State
state = self.aviary.state(0)
pos = state[-1]
orn = state[-3]
lin_vel = state[-2]
ang_vel = state[-4]
# 6. Compute Reward
reward = self._compute_reward(pos, orn, lin_vel, ang_vel)
# 7. Check Termination
terminated = False
# Out of bounds
if (pos < self.BOUNDS[0]).any() or (pos > self.BOUNDS[1]).any():
terminated = True
reward -= 100.0 # Massive penalty for leaving
# Crash (ground)
if pos[2] < 0.1:
terminated = True
reward -= 100.0
# Target Reached
dist = np.linalg.norm(pos - self.target_pos)
if dist < 0.5:
terminated = True
reward += 100.0
truncated = self.step_count >= self.MAX_STEPS
# 8. Update Visuals
if self.render_mode == "human":
self._update_wind_field()
self._draw_target() # Update target position
# Flight Path
self.path_points.append(pos)
if len(self.path_points) > 1:
p.addUserDebugLine(self.path_points[-2], self.path_points[-1], [0, 1, 0], 2, 0)
# Info Overlay
# Clear old text
for uid in self.text_ids:
p.removeUserDebugItem(uid)
self.text_ids = []
wind_speed = np.linalg.norm(self.wind_vector)
info_text = f"Wind: {wind_speed:.2f} m/s\nDrone: {pos.round(2)}\nTarget: {self.target_pos.round(2)}"
# Draw text near the drone or fixed on screen (PyBullet text is 3D mostly, use text3d)
# Or use addUserDebugText with textPosition
uid = p.addUserDebugText(info_text, [pos[0], pos[1], pos[2] + 0.5], [0, 0, 0], textSize=1.5)
self.text_ids.append(uid)
return self._get_obs(), reward, terminated, truncated, {}
def _get_obs(self):
state = self.aviary.state(0)
pos = state[-1]
orn = state[-3]
lin_vel = state[-2]
ang_vel = state[-4]
# Relative position to target
rel_pos = self.target_pos - pos
# [rel_x, rel_y, rel_z, roll, pitch, yaw, u, v, w, p, q, r]
obs = np.concatenate([rel_pos, orn, lin_vel, ang_vel])
return obs.astype(np.float32)
def _compute_reward(self, pos, orn, lin_vel, ang_vel):
# 1. Stability (Smoothness)
# Penalize high angular velocity (spinning/shaking)
r_smooth = -0.1 * np.linalg.norm(ang_vel)
# Penalize extreme angles (flipping)
deviation = np.linalg.norm(orn[0:2])
r_stability = -deviation
# 2. Target
dist = np.linalg.norm(pos - self.target_pos)
r_target = -dist * 0.5 # Stronger distance penalty
# 3. Boundary Safety (Proximity Penalty)
# Calculate distance to nearest wall
d_min = np.min([
pos[0] - self.BOUNDS[0][0],
self.BOUNDS[1][0] - pos[0],
pos[1] - self.BOUNDS[0][1],
self.BOUNDS[1][1] - pos[1],
pos[2] - self.BOUNDS[0][2],
self.BOUNDS[1][2] - pos[2]
])
# Exponential penalty as it gets closer than 1.0m
r_boundary = 0.0
if d_min < 1.0:
r_boundary = -np.exp(1.0 - d_min) # -1 at 1m, -e at 0m
return r_stability + r_smooth + r_target + r_boundary
def _init_wind_field(self):
# Clear old arrows
for uid, _ in self.wind_arrow_ids:
p.removeUserDebugItem(uid)
self.wind_arrow_ids = []
# Create grid of arrows
# Sparse grid: 3x3x3
xs = np.linspace(self.BOUNDS[0][0], self.BOUNDS[1][0], 4)
ys = np.linspace(self.BOUNDS[0][1], self.BOUNDS[1][1], 4)
zs = np.linspace(self.BOUNDS[0][2] + 1, self.BOUNDS[1][2] - 1, 3)
for x in xs:
for y in ys:
for z in zs:
start_pos = [x, y, z]
# Initial draw (will be updated)
# Use a placeholder end_pos
end_pos = [x+0.1, y, z]
uid = p.addUserDebugLine(start_pos, end_pos, [1, 1, 0], 2)
self.wind_arrow_ids.append((uid, start_pos))
def _update_wind_field(self):
# Update all arrows to point in wind direction
# Scale wind for visual length
scale = 0.5
wind_end_offset = self.wind_vector * scale
for uid, start_pos in self.wind_arrow_ids:
end_pos = [
start_pos[0] + wind_end_offset[0],
start_pos[1] + wind_end_offset[1],
start_pos[2] + wind_end_offset[2]
]
# Color based on intensity?
intensity = np.linalg.norm(self.wind_vector)
# Map 0-5 to Color (Green to Red)
t = np.clip(intensity / 5.0, 0, 1)
color = [t, 1.0 - t, 0]
p.addUserDebugLine(
start_pos,
end_pos,
lineColorRGB=color,
lineWidth=2,
replaceItemUniqueId=uid
)
def _draw_bounds(self):
min_x, min_y, min_z = self.BOUNDS[0]
max_x, max_y, max_z = self.BOUNDS[1]
corners = [
[min_x, min_y, min_z], [max_x, min_y, min_z],
[min_x, max_y, min_z], [max_x, max_y, min_z],
[min_x, min_y, max_z], [max_x, min_y, max_z],
[min_x, max_y, max_z], [max_x, max_y, max_z]
]
lines = [
(0, 1), (1, 3), (3, 2), (2, 0), # Bottom
(4, 5), (5, 7), (7, 6), (6, 4), # Top
(0, 4), (1, 5), (2, 6), (3, 7) # Vertical
]
for start, end in lines:
p.addUserDebugLine(corners[start], corners[end], [1, 0, 0], 4)
def _draw_target(self):
if self.target_visual_id is not None:
p.removeBody(self.target_visual_id)
visual_shape = p.createVisualShape(shapeType=p.GEOM_SPHERE, radius=0.3, rgbaColor=[1, 0, 1, 0.7])
self.target_visual_id = p.createMultiBody(baseVisualShapeIndex=visual_shape, basePosition=self.target_pos)
def close(self):
self.aviary.disconnect()