Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import time | |
| import os | |
| import matplotlib.pyplot as plt | |
| from matplotlib.gridspec import GridSpec | |
| from env.drone_3d import Drone3DEnv | |
| from models.liquid_ppo import make_liquid_ppo | |
| from stable_baselines3 import PPO | |
| def run_interactive_demo(): | |
| print("Initializing Interactive Dashboard...") | |
| env = Drone3DEnv(render_mode="human", wind_scale=5.0, wind_speed=2.0) | |
| model_path = "liquid_ppo_drone_final.zip" | |
| if os.path.exists(model_path): | |
| print(f"Loading trained model from {model_path}...") | |
| model = PPO.load(model_path, env=env) | |
| else: | |
| print("No trained model found. Using untrained Liquid Brain.") | |
| model = make_liquid_ppo(env, verbose=1) | |
| obs, info = env.reset() | |
| # Setup Dashboard | |
| plt.ion() | |
| fig = plt.figure(figsize=(14, 8)) | |
| gs = GridSpec(2, 2, width_ratios=[2, 1]) | |
| # 3D View (Left, spanning both rows) | |
| ax_3d = fig.add_subplot(gs[:, 0], projection='3d') | |
| # Altitude Plot (Top Right) | |
| ax_alt = fig.add_subplot(gs[0, 1]) | |
| ax_alt.set_title("Altitude (Z)") | |
| ax_alt.set_ylim(0, 15) | |
| ax_alt.set_xlim(0, 100) | |
| line_alt, = ax_alt.plot([], [], 'b-') | |
| # Wind Speed Plot (Bottom Right) | |
| ax_wind = fig.add_subplot(gs[1, 1]) | |
| ax_wind.set_title("Wind Magnitude") | |
| ax_wind.set_ylim(0, 10) | |
| ax_wind.set_xlim(0, 100) | |
| line_wind, = ax_wind.plot([], [], 'r-') | |
| # Data Buffers | |
| history_len = 100 | |
| alt_history = [10.0] * history_len | |
| wind_history = [0.0] * history_len | |
| print("\n=== DASHBOARD LIVE ===") | |
| print("Close the window to exit.") | |
| try: | |
| step = 0 | |
| while True: | |
| # Predict & Step | |
| action, _ = model.predict(obs, deterministic=True) | |
| obs, reward, term, trunc, info = env.step(action) | |
| # Update Data | |
| pos = obs[0:3] | |
| wind = info.get("wind", np.zeros(3)) | |
| wind_mag = np.linalg.norm(wind) | |
| alt_history.append(pos[2]) | |
| alt_history.pop(0) | |
| wind_history.append(wind_mag) | |
| wind_history.pop(0) | |
| # --- Render 3D View --- | |
| ax_3d.clear() | |
| ax_3d.set_xlim(-20, 20) | |
| ax_3d.set_ylim(-20, 20) | |
| ax_3d.set_zlim(0, 20) | |
| ax_3d.set_xlabel('X') | |
| ax_3d.set_ylabel('Y') | |
| ax_3d.set_zlabel('Z') | |
| ax_3d.set_title(f'Neuro-Flyt 3D | Step: {step}') | |
| # Drone | |
| ax_3d.scatter(pos[0], pos[1], pos[2], c='blue', s=100, label='Drone') | |
| # Wind Vector | |
| ax_3d.quiver(pos[0], pos[1], pos[2], wind[0], wind[1], wind[2], length=1.0, color='red', label='Wind') | |
| # Target | |
| target = info.get("target", np.array([0, 0, 10.0])) | |
| ax_3d.scatter(target[0], target[1], target[2], c='green', marker='x', s=100, label='Target') | |
| ax_3d.legend(loc='upper left') | |
| # --- Render Stats --- | |
| line_alt.set_ydata(alt_history) | |
| line_alt.set_xdata(range(history_len)) | |
| line_wind.set_ydata(wind_history) | |
| line_wind.set_xdata(range(history_len)) | |
| # Stats Text | |
| stats = f"Alt: {pos[2]:.2f}m\nWind: {wind_mag:.2f} N\nDrift: {np.linalg.norm(pos[:2]):.2f}m" | |
| ax_3d.text2D(0.05, 0.95, stats, transform=ax_3d.transAxes, fontsize=12, bbox=dict(facecolor='white', alpha=0.7)) | |
| plt.draw() | |
| plt.pause(0.01) | |
| if term or trunc: | |
| obs, info = env.reset() | |
| step += 1 | |
| # Check if window is closed | |
| if not plt.fignum_exists(fig.number): | |
| break | |
| except KeyboardInterrupt: | |
| print("Interrupted.") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| finally: | |
| plt.close() | |
| env.close() | |
| if __name__ == "__main__": | |
| run_interactive_demo() | |