""" Visual demonstration of the drone environment using Pygame. This script loads a trained model and visualizes the drone navigating through the environment with wind forces. """ import os import sys import pygame import numpy as np from typing import Optional # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from env.drone_env import DroneWindEnv from stable_baselines3 import PPO # Pygame constants WINDOW_WIDTH = 800 WINDOW_HEIGHT = 600 FPS = 30 # Color definitions BLACK = (0, 0, 0) WHITE = (255, 255, 255) RED = (255, 0, 0) GREEN = (0, 255, 0) BLUE = (0, 0, 255) YELLOW = (255, 255, 0) CYAN = (0, 255, 255) MAGENTA = (255, 0, 255) GRAY = (128, 128, 128) DARK_GRAY = (64, 64, 64) ORANGE = (255, 165, 0) class DroneVisualizer: """Pygame-based visualizer for the drone environment.""" def __init__(self, env: DroneWindEnv, model: Optional[PPO] = None): """ Initialize the visualizer. Args: env: DroneWindEnv instance model: Optional trained PPO model (if None, uses random actions) """ self.env = env self.model = model # Initialize Pygame pygame.init() self.screen = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT)) pygame.display.set_caption("Drone RL - Visual Demonstration") self.clock = pygame.time.Clock() self.font = pygame.font.Font(None, 24) self.small_font = pygame.font.Font(None, 18) # World to screen scaling # Environment is [0, 1] x [0, 1], we'll use most of the screen self.world_margin = 50 self.world_width = WINDOW_WIDTH - 2 * self.world_margin self.world_height = WINDOW_HEIGHT - 2 * self.world_margin def world_to_screen(self, x: float, y: float) -> tuple[int, int]: """Convert world coordinates [0,1] to screen coordinates.""" screen_x = int(self.world_margin + x * self.world_width) # Flip y-axis (world y=0 is bottom, screen y=0 is top) screen_y = int(WINDOW_HEIGHT - self.world_margin - y * self.world_height) return screen_x, screen_y def draw_drone(self, x: float, y: float, vx: float, vy: float): """Draw the drone as a circle with velocity vector.""" screen_x, screen_y = self.world_to_screen(x, y) # Draw drone body (circle) drone_radius = 15 pygame.draw.circle(self.screen, CYAN, (screen_x, screen_y), drone_radius) pygame.draw.circle(self.screen, BLUE, (screen_x, screen_y), drone_radius, 2) # Draw velocity vector if abs(vx) > 0.01 or abs(vy) > 0.01: # Scale velocity for visualization scale = 30 end_x = screen_x + int(vx * scale) end_y = screen_y - int(vy * scale) # Flip y for screen pygame.draw.line(self.screen, YELLOW, (screen_x, screen_y), (end_x, end_y), 3) # Draw arrowhead if abs(vx) > 0.01 or abs(vy) > 0.01: angle = np.arctan2(-vy, vx) # Negative vy because screen y is flipped arrow_size = 8 arrow_x1 = end_x - arrow_size * np.cos(angle - np.pi / 6) arrow_y1 = end_y - arrow_size * np.sin(angle - np.pi / 6) arrow_x2 = end_x - arrow_size * np.cos(angle + np.pi / 6) arrow_y2 = end_y - arrow_size * np.sin(angle + np.pi / 6) pygame.draw.line(self.screen, YELLOW, (end_x, end_y), (int(arrow_x1), int(arrow_y1)), 2) pygame.draw.line(self.screen, YELLOW, (end_x, end_y), (int(arrow_x2), int(arrow_y2)), 2) def draw_wind(self, wind_x: float, wind_y: float): """Draw wind arrows indicating direction.""" # Draw fewer, clearer wind arrows grid_size = 6 for i in range(grid_size): for j in range(grid_size): x = (i + 0.5) / grid_size y = (j + 0.5) / grid_size screen_x, screen_y = self.world_to_screen(x, y) # Draw wind arrow if abs(wind_x) > 0.01 or abs(wind_y) > 0.01: scale = 25 end_x = screen_x + int(wind_x * scale) end_y = screen_y - int(wind_y * scale) # Flip y # Color based on wind strength wind_strength = abs(wind_x) + abs(wind_y) if wind_strength < 1.0: color = GREEN elif wind_strength < 1.5: color = YELLOW else: color = ORANGE # Draw arrow line pygame.draw.line(self.screen, color, (screen_x, screen_y), (end_x, end_y), 3) # Draw arrowhead if abs(wind_x) > 0.01 or abs(wind_y) > 0.01: angle = np.arctan2(-wind_y, wind_x) # Negative y because screen y is flipped arrow_size = 10 arrow_x1 = end_x - arrow_size * np.cos(angle - np.pi / 6) arrow_y1 = end_y - arrow_size * np.sin(angle - np.pi / 6) arrow_x2 = end_x - arrow_size * np.cos(angle + np.pi / 6) arrow_y2 = end_y - arrow_size * np.sin(angle + np.pi / 6) pygame.draw.polygon(self.screen, color, [ (end_x, end_y), (int(arrow_x1), int(arrow_y1)), (int(arrow_x2), int(arrow_y2)) ]) def draw_boundaries(self): """Draw the world boundaries.""" # Top boundary top_left = self.world_to_screen(0, 1) top_right = self.world_to_screen(1, 1) pygame.draw.line(self.screen, RED, top_left, top_right, 3) # Bottom boundary bot_left = self.world_to_screen(0, 0) bot_right = self.world_to_screen(1, 0) pygame.draw.line(self.screen, RED, bot_left, bot_right, 3) # Left boundary pygame.draw.line(self.screen, RED, top_left, bot_left, 3) # Right boundary pygame.draw.line(self.screen, RED, top_right, bot_right, 3) def draw_target_zone(self, target_spawned: bool = True): """Draw the target zone (box) that the drone needs to reach.""" from env.drone_env import TARGET_X_MIN, TARGET_X_MAX, TARGET_Y_MIN, TARGET_Y_MAX, TARGET_SPAWN_DELAY # Only draw if target has spawned if not target_spawned: return # Get screen coordinates for target zone corners top_left = self.world_to_screen(TARGET_X_MIN, TARGET_Y_MAX) top_right = self.world_to_screen(TARGET_X_MAX, TARGET_Y_MAX) bot_left = self.world_to_screen(TARGET_X_MIN, TARGET_Y_MIN) bot_right = self.world_to_screen(TARGET_X_MAX, TARGET_Y_MIN) # Draw target zone as a semi-transparent box # Create a surface for transparency target_surface = pygame.Surface((WINDOW_WIDTH, WINDOW_HEIGHT)) target_surface.set_alpha(100) # Semi-transparent # Draw filled rectangle rect = pygame.Rect( top_left[0], top_left[1], top_right[0] - top_left[0], bot_left[1] - top_left[1] ) pygame.draw.rect(target_surface, MAGENTA, rect) self.screen.blit(target_surface, (0, 0)) # Draw border pygame.draw.line(self.screen, MAGENTA, top_left, top_right, 3) pygame.draw.line(self.screen, MAGENTA, top_right, bot_right, 3) pygame.draw.line(self.screen, MAGENTA, bot_right, bot_left, 3) pygame.draw.line(self.screen, MAGENTA, bot_left, top_left, 3) # Draw label label_x = (top_left[0] + top_right[0]) // 2 label_y = (top_left[1] + bot_left[1]) // 2 text = self.small_font.render("TARGET", True, WHITE) text_rect = text.get_rect(center=(label_x, label_y)) self.screen.blit(text, text_rect) def draw_info(self, step: int, reward: float, action: Optional[int] = None, in_target: bool = False): """Draw information text.""" y_offset = 10 # Step count text = self.font.render(f"Step: {step}", True, WHITE) self.screen.blit(text, (10, y_offset)) y_offset += 30 # Reward text = self.font.render(f"Reward: {reward:.2f}", True, WHITE) self.screen.blit(text, (10, y_offset)) y_offset += 30 # In target zone status target_color = GREEN if in_target else GRAY target_text = "IN TARGET ZONE!" if in_target else "Not in target" text = self.font.render(target_text, True, target_color) self.screen.blit(text, (10, y_offset)) y_offset += 30 # Position text = self.small_font.render(f"Position: ({self.env.x:.2f}, {self.env.y:.2f})", True, WHITE) self.screen.blit(text, (10, y_offset)) y_offset += 25 # Velocity text = self.small_font.render(f"Velocity: ({self.env.vx:.2f}, {self.env.vy:.2f})", True, WHITE) self.screen.blit(text, (10, y_offset)) y_offset += 25 # Wind text = self.small_font.render(f"Wind: ({self.env.wind_x:.2f}, {self.env.wind_y:.2f})", True, GREEN) self.screen.blit(text, (10, y_offset)) y_offset += 25 # Action if action is not None: action_names = ["No thrust", "Up", "Down", "Left", "Right"] text = self.small_font.render(f"Action: {action_names[action]}", True, YELLOW) self.screen.blit(text, (10, y_offset)) y_offset += 25 # Model info if self.model is not None: text = self.small_font.render("Mode: AI Agent (Liquid NN)", True, CYAN) else: text = self.small_font.render("Mode: Random Actions", True, GRAY) self.screen.blit(text, (10, y_offset)) def run(self, max_steps: int = 500, speed: float = 1.0): """ Run the visualization. Args: max_steps: Maximum number of steps to run speed: Speed multiplier (1.0 = normal, higher = faster) """ obs, info = self.env.reset() done = False truncated = False step_count = 0 action = None running = True paused = False while running and step_count < max_steps: # Handle events for event in pygame.event.get(): if event.type == pygame.QUIT: running = False elif event.type == pygame.KEYDOWN: if event.key == pygame.K_SPACE: paused = not paused elif event.key == pygame.K_r: # Reset obs, info = self.env.reset() done = False truncated = False step_count = 0 elif event.key == pygame.K_ESCAPE: running = False if not paused and not done and not truncated: # Get action if self.model is not None: action, _ = self.model.predict(obs, deterministic=True) else: action = self.env.action_space.sample() # Step environment obs, reward, done, truncated, info = self.env.step(action) step_count += 1 in_target = info.get("in_target", False) target_spawned = info.get("target_spawned", False) # Draw everything self.screen.fill(BLACK) # Draw boundaries self.draw_boundaries() # Draw target zone (only if spawned) target_spawned_current = info.get("target_spawned", self.env.step_count >= 50) if not paused else False self.draw_target_zone(target_spawned=target_spawned_current) # Draw wind arrows self.draw_wind(self.env.wind_x, self.env.wind_y) # Draw drone self.draw_drone(self.env.x, self.env.y, self.env.vx, self.env.vy) # Get in_target from info if available, otherwise compute if not paused and 'in_target' in locals(): current_in_target = in_target else: from env.drone_env import TARGET_X_MIN, TARGET_X_MAX, TARGET_Y_MIN, TARGET_Y_MAX current_in_target = ( TARGET_X_MIN <= self.env.x <= TARGET_X_MAX and TARGET_Y_MIN <= self.env.y <= TARGET_Y_MAX ) # Draw info self.draw_info(step_count, reward if not paused else 0, action, current_in_target) # Draw pause indicator if paused: text = self.font.render("PAUSED (SPACE to resume)", True, YELLOW) text_rect = text.get_rect(center=(WINDOW_WIDTH // 2, 30)) self.screen.blit(text, text_rect) # Draw controls controls_y = WINDOW_HEIGHT - 80 controls = [ "SPACE: Pause/Resume", "R: Reset", "ESC: Quit" ] for i, control in enumerate(controls): text = self.small_font.render(control, True, GRAY) self.screen.blit(text, (10, controls_y + i * 20)) pygame.display.flip() # Control speed if not paused: self.clock.tick(FPS * speed) else: self.clock.tick(10) # Auto-reset on done/truncated if (done or truncated) and not paused: pygame.time.wait(1000) # Wait 1 second before reset obs, info = self.env.reset() done = False truncated = False step_count = 0 pygame.quit() def main(): """Main function to run the visualization.""" import argparse parser = argparse.ArgumentParser(description="Visualize drone environment") parser.add_argument( "--model-path", type=str, default="models/liquid_policy.zip", help="Path to trained model (default: models/liquid_policy.zip)" ) parser.add_argument( "--random", action="store_true", help="Use random actions instead of trained model" ) parser.add_argument( "--max-steps", type=int, default=500, help="Maximum steps per episode (default: 500)" ) parser.add_argument( "--speed", type=float, default=1.0, help="Animation speed multiplier (default: 1.0)" ) args = parser.parse_args() # Create environment env = DroneWindEnv() # Load model if specified model = None if not args.random: if os.path.exists(args.model_path): print(f"Loading model from {args.model_path}...") model = PPO.load(args.model_path, env=env) print("Model loaded successfully!") else: print(f"Model not found at {args.model_path}, using random actions") # Create visualizer visualizer = DroneVisualizer(env, model) # Run visualization print("\nStarting visualization...") print("Controls:") print(" SPACE: Pause/Resume") print(" R: Reset episode") print(" ESC: Quit") print() visualizer.run(max_steps=args.max_steps, speed=args.speed) if __name__ == "__main__": main()