| """ |
| KAPS Gym Environment |
| ==================== |
| Wraps KAPSSimulation as a Gymnasium environment for RL training. |
| |
| Observation space: TAB positions, velocities, cable tensions, threats |
| Action space: Control surface deflections, release commands |
| Rewards: Intercepts, survival, formation quality |
| """ |
|
|
| import gymnasium as gym |
| from gymnasium import spaces |
| import numpy as np |
| from typing import Dict, Tuple, Optional, Any |
|
|
| import sys |
| import os |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from ..main import KAPSSimulation |
| from ..ai.defensive_matrix import ThreatType |
|
|
|
|
| class KAPSEnv(gym.Env): |
| """ |
| Gymnasium environment for training RL agents on KAPS. |
| |
| The agent controls: |
| - Formation spread/mode |
| - Individual TAB control surface commands |
| - Release decisions (which TAB to slingshot) |
| |
| The agent observes: |
| - Mother drone state (position, velocity, orientation) |
| - All TAB states (positions, velocities, attached status) |
| - Cable tensions |
| - Threat positions and velocities (up to N threats) |
| - Defensive bubble status |
| """ |
| |
| metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} |
| |
| def __init__(self, |
| render_mode: Optional[str] = None, |
| max_threats: int = 8, |
| episode_steps: int = 3000, |
| threat_spawn_rate: float = 0.01): |
| super().__init__() |
| |
| self.render_mode = render_mode |
| self.max_threats = max_threats |
| self.episode_steps = episode_steps |
| self.threat_spawn_rate = threat_spawn_rate |
| |
| |
| self.sim: Optional[KAPSSimulation] = None |
| self.step_count = 0 |
| self.total_reward = 0 |
| |
| |
| self.stats = { |
| 'threats_spawned': 0, |
| 'threats_intercepted': 0, |
| 'tabs_released': 0, |
| 'damage_taken': 0 |
| } |
| |
| |
| |
| |
| |
| |
| obs_dim = 9 + 32 + 7 * max_threats |
| |
| self.observation_space = spaces.Box( |
| low=-np.inf, |
| high=np.inf, |
| shape=(obs_dim,), |
| dtype=np.float32 |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| self.action_space = spaces.Box( |
| low=-1.0, |
| high=1.0, |
| shape=(13,), |
| dtype=np.float32 |
| ) |
| |
| |
| self.reward_weights = { |
| 'intercept': 100.0, |
| 'damage': -200.0, |
| 'cable_snap': -50.0, |
| 'tab_lost': -25.0, |
| 'formation_bonus': 0.1, |
| 'survival': 1.0, |
| 'threat_proximity': -0.01 |
| } |
| |
| def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]: |
| """Reset environment to initial state""" |
| super().reset(seed=seed) |
| |
| |
| self.sim = KAPSSimulation() |
| self.step_count = 0 |
| self.total_reward = 0 |
| self.stats = { |
| 'threats_spawned': 0, |
| 'threats_intercepted': 0, |
| 'tabs_released': 0, |
| 'damage_taken': 0 |
| } |
| |
| |
| obs = self._get_observation() |
| info = self._get_info() |
| |
| return obs, info |
| |
| def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]: |
| """Execute one environment step""" |
| assert self.sim is not None, "Must call reset() before step()" |
| |
| |
| self._apply_action(action) |
| |
| |
| if self.np_random.random() < self.threat_spawn_rate: |
| self._spawn_random_threat() |
| |
| |
| telemetry = self.sim.step() |
| self.step_count += 1 |
| |
| |
| reward = self._compute_reward(telemetry) |
| self.total_reward += reward |
| |
| |
| terminated = self._check_terminated(telemetry) |
| truncated = self.step_count >= self.episode_steps |
| |
| |
| obs = self._get_observation() |
| info = self._get_info() |
| |
| return obs, reward, terminated, truncated, info |
| |
| def _get_observation(self) -> np.ndarray: |
| """Build observation vector""" |
| obs = [] |
| |
| |
| md = self.sim.mother_drone |
| obs.extend(md.position / 1000.0) |
| obs.extend(md.velocity / 100.0) |
| obs.extend(md.orientation / np.pi) |
| |
| |
| for tab_id in ["UP", "DOWN", "LEFT", "RIGHT"]: |
| tab = self.sim.tab_array.tabs[tab_id] |
| obs.extend(tab.position / 1000.0) |
| obs.extend(tab.velocity / 100.0) |
| obs.append(1.0 if tab.is_attached else 0.0) |
| |
| |
| tension = self.sim.tether_array.get_tension(tab_id) if hasattr(self.sim.tether_array, 'get_tension') else 0 |
| obs.append(tension / 10000.0) |
| |
| |
| threats = self.sim.defensive_ai.get_active_threats() if hasattr(self.sim.defensive_ai, 'get_active_threats') else [] |
| for i in range(self.max_threats): |
| if i < len(threats): |
| t = threats[i] |
| obs.extend(t.get('position', np.zeros(3)) / 1000.0) |
| obs.extend(t.get('velocity', np.zeros(3)) / 100.0) |
| obs.append(1.0) |
| else: |
| obs.extend([0.0] * 7) |
| |
| return np.array(obs, dtype=np.float32) |
| |
| def _apply_action(self, action: np.ndarray): |
| """Apply agent's action to simulation""" |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| tab_ids = ["UP", "DOWN", "LEFT", "RIGHT"] |
| for i, tab_id in enumerate(tab_ids): |
| if tab_id in self.sim.tab_array.tabs: |
| tab = self.sim.tab_array.tabs[tab_id] |
| if tab.is_attached: |
| |
| elevator = float(action[1 + i]) * tab.config.elevator_max |
| rudder = float(action[5 + i]) * tab.config.rudder_max |
| tab.set_control_targets(elevator=elevator, rudder=rudder) |
| |
| |
| for i, tab_id in enumerate(tab_ids): |
| if action[9 + i] > 0.8: |
| if tab_id in self.sim.tab_array.tabs: |
| tab = self.sim.tab_array.tabs[tab_id] |
| if tab.is_attached: |
| |
| threat_dir = self._get_closest_threat_direction() |
| if threat_dir is not None: |
| tab.execute_release(tab.velocity + threat_dir * 50) |
| self.stats['tabs_released'] += 1 |
| |
| def _get_closest_threat_direction(self) -> Optional[np.ndarray]: |
| """Get direction to closest threat""" |
| threats = self.sim.defensive_ai.get_active_threats() if hasattr(self.sim.defensive_ai, 'get_active_threats') else [] |
| if not threats: |
| return None |
| |
| md_pos = self.sim.mother_drone.position |
| closest = None |
| min_dist = float('inf') |
| |
| for t in threats: |
| t_pos = t.get('position', np.zeros(3)) |
| dist = np.linalg.norm(t_pos - md_pos) |
| if dist < min_dist: |
| min_dist = dist |
| closest = t_pos |
| |
| if closest is None: |
| return None |
| |
| direction = closest - md_pos |
| return direction / (np.linalg.norm(direction) + 1e-8) |
| |
| def _spawn_random_threat(self): |
| """Spawn a random threat""" |
| md_pos = self.sim.mother_drone.position |
| |
| |
| theta = self.np_random.uniform(0, 2 * np.pi) |
| phi = self.np_random.uniform(-np.pi/4, np.pi/4) |
| dist = self.np_random.uniform(300, 500) |
| |
| offset = np.array([ |
| dist * np.cos(phi) * np.cos(theta), |
| dist * np.cos(phi) * np.sin(theta), |
| dist * np.sin(phi) |
| ]) |
| |
| threat_pos = md_pos + offset |
| threat_vel = -offset / np.linalg.norm(offset) * self.np_random.uniform(100, 200) |
| |
| self.sim.inject_threat( |
| position=threat_pos, |
| velocity=threat_vel, |
| threat_type=ThreatType.MISSILE_IR |
| ) |
| self.stats['threats_spawned'] += 1 |
| |
| def _compute_reward(self, telemetry: Dict) -> float: |
| """Compute reward for this step""" |
| reward = 0.0 |
| |
| |
| reward += self.reward_weights['survival'] |
| |
| |
| formation = telemetry.get('formation', {}) |
| quality = formation.get('quality', 0.5) |
| reward += self.reward_weights['formation_bonus'] * quality |
| |
| |
| |
| defense = telemetry.get('defense', {}) |
| alert_level = defense.get('alert_level', 'GREEN') |
| |
| if alert_level == 'RED': |
| reward += self.reward_weights['threat_proximity'] |
| |
| |
| |
| |
| return reward |
| |
| def _check_terminated(self, telemetry: Dict) -> bool: |
| """Check if episode should terminate""" |
| |
| |
| return False |
| |
| def _get_info(self) -> Dict: |
| """Get auxiliary information""" |
| return { |
| 'step': self.step_count, |
| 'total_reward': self.total_reward, |
| 'tabs_attached': self.sim.tab_array.count_attached() if self.sim else 0, |
| **self.stats |
| } |
| |
| def render(self): |
| """Render the environment""" |
| if self.render_mode == "human": |
| |
| pass |
| elif self.render_mode == "rgb_array": |
| |
| pass |
| |
| def close(self): |
| """Clean up""" |
| self.sim = None |
|
|
|
|
| |
| gym.register( |
| id='KAPS-v0', |
| entry_point='src.training.kaps_env:KAPSEnv', |
| max_episode_steps=3000, |
| ) |
|
|