team_22 / demo_interactive.py
Antigravity Agent
Deploy Neuro-Flyt 3D Training
6083286
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from env.drone_3d import Drone3DEnv
from models.liquid_ppo import make_liquid_ppo
from stable_baselines3 import PPO
def run_interactive_demo():
print("Initializing Interactive Dashboard...")
env = Drone3DEnv(render_mode="human", wind_scale=5.0, wind_speed=2.0)
model_path = "liquid_ppo_drone_final.zip"
if os.path.exists(model_path):
print(f"Loading trained model from {model_path}...")
model = PPO.load(model_path, env=env)
else:
print("No trained model found. Using untrained Liquid Brain.")
model = make_liquid_ppo(env, verbose=1)
obs, info = env.reset()
# Setup Dashboard
plt.ion()
fig = plt.figure(figsize=(14, 8))
gs = GridSpec(2, 2, width_ratios=[2, 1])
# 3D View (Left, spanning both rows)
ax_3d = fig.add_subplot(gs[:, 0], projection='3d')
# Altitude Plot (Top Right)
ax_alt = fig.add_subplot(gs[0, 1])
ax_alt.set_title("Altitude (Z)")
ax_alt.set_ylim(0, 15)
ax_alt.set_xlim(0, 100)
line_alt, = ax_alt.plot([], [], 'b-')
# Wind Speed Plot (Bottom Right)
ax_wind = fig.add_subplot(gs[1, 1])
ax_wind.set_title("Wind Magnitude")
ax_wind.set_ylim(0, 10)
ax_wind.set_xlim(0, 100)
line_wind, = ax_wind.plot([], [], 'r-')
# Data Buffers
history_len = 100
alt_history = [10.0] * history_len
wind_history = [0.0] * history_len
print("\n=== DASHBOARD LIVE ===")
print("Close the window to exit.")
try:
step = 0
while True:
# Predict & Step
action, _ = model.predict(obs, deterministic=True)
obs, reward, term, trunc, info = env.step(action)
# Update Data
pos = obs[0:3]
wind = info.get("wind", np.zeros(3))
wind_mag = np.linalg.norm(wind)
alt_history.append(pos[2])
alt_history.pop(0)
wind_history.append(wind_mag)
wind_history.pop(0)
# --- Render 3D View ---
ax_3d.clear()
ax_3d.set_xlim(-20, 20)
ax_3d.set_ylim(-20, 20)
ax_3d.set_zlim(0, 20)
ax_3d.set_xlabel('X')
ax_3d.set_ylabel('Y')
ax_3d.set_zlabel('Z')
ax_3d.set_title(f'Neuro-Flyt 3D | Step: {step}')
# Drone
ax_3d.scatter(pos[0], pos[1], pos[2], c='blue', s=100, label='Drone')
# Wind Vector
ax_3d.quiver(pos[0], pos[1], pos[2], wind[0], wind[1], wind[2], length=1.0, color='red', label='Wind')
# Target
target = info.get("target", np.array([0, 0, 10.0]))
ax_3d.scatter(target[0], target[1], target[2], c='green', marker='x', s=100, label='Target')
ax_3d.legend(loc='upper left')
# --- Render Stats ---
line_alt.set_ydata(alt_history)
line_alt.set_xdata(range(history_len))
line_wind.set_ydata(wind_history)
line_wind.set_xdata(range(history_len))
# Stats Text
stats = f"Alt: {pos[2]:.2f}m\nWind: {wind_mag:.2f} N\nDrift: {np.linalg.norm(pos[:2]):.2f}m"
ax_3d.text2D(0.05, 0.95, stats, transform=ax_3d.transAxes, fontsize=12, bbox=dict(facecolor='white', alpha=0.7))
plt.draw()
plt.pause(0.01)
if term or trunc:
obs, info = env.reset()
step += 1
# Check if window is closed
if not plt.fignum_exists(fig.number):
break
except KeyboardInterrupt:
print("Interrupted.")
except Exception as e:
print(f"Error: {e}")
finally:
plt.close()
env.close()
if __name__ == "__main__":
run_interactive_demo()