team_22 / eval /eval_mlp_baseline.py
Antigravity Agent
Deploy Neuro-Flyt 3D Training
6083286
"""
Evaluate a trained PPO MLP baseline on the DroneWindEnv environment.
This script loads a saved PPO model and runs evaluation episodes,
printing statistics about average reward and episode length.
"""
import os
import sys
import argparse
import numpy as np
from stable_baselines3 import PPO
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.drone_env import DroneWindEnv
def main():
"""Main evaluation function."""
parser = argparse.ArgumentParser(description="Evaluate PPO agent on DroneWindEnv")
parser.add_argument(
"--model-path",
type=str,
default="models/mlp_baseline.zip",
help="Path to the trained model (default: models/mlp_baseline.zip)"
)
parser.add_argument(
"--episodes",
type=int,
default=10,
help="Number of evaluation episodes (default: 10)"
)
parser.add_argument(
"--render",
action="store_true",
help="Print environment state to console during evaluation"
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for evaluation (default: None)"
)
args = parser.parse_args()
print("=" * 60)
print("Evaluating PPO Agent on DroneWindEnv")
print("=" * 60)
print(f"Model path: {args.model_path}")
print(f"Number of episodes: {args.episodes}")
print("=" * 60)
# Check if model file exists
if not os.path.exists(args.model_path):
print(f"\nError: Model file not found at {args.model_path}")
print("Please train a model first using:")
print(" python train/train_mlp_ppo.py")
return
# Create environment
print("\nCreating environment...")
env = DroneWindEnv()
# Load the model
print(f"Loading model from {args.model_path}...")
try:
model = PPO.load(args.model_path, env=env)
print("Model loaded successfully!")
except Exception as e:
print(f"\nError loading model: {e}")
return
# Run evaluation episodes
print(f"\nRunning {args.episodes} evaluation episodes...")
print("-" * 60)
rewards = []
episode_lengths = []
for episode in range(args.episodes):
obs, info = env.reset(seed=args.seed)
done = False
truncated = False
total_reward = 0.0
step_count = 0
if args.render:
print(f"\nEpisode {episode + 1}:")
env.render()
while not (done or truncated):
# Get action from the model (deterministic)
action, _ = model.predict(obs, deterministic=True)
# Step the environment
obs, reward, done, truncated, info = env.step(action)
total_reward += reward
step_count += 1
if args.render:
env.render()
rewards.append(total_reward)
episode_lengths.append(step_count)
status = "terminated" if done else "truncated"
print(f"Episode {episode + 1}: Reward = {total_reward:.2f}, "
f"Length = {step_count} steps ({status})")
# Print statistics
print("\n" + "=" * 60)
print("Evaluation Results")
print("=" * 60)
print(f"Average reward: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}")
print(f"Average episode length: {np.mean(episode_lengths):.1f} ± {np.std(episode_lengths):.1f}")
print(f"Min reward: {np.min(rewards):.2f}")
print(f"Max reward: {np.max(rewards):.2f}")
print(f"Min episode length: {np.min(episode_lengths)}")
print(f"Max episode length: {np.max(episode_lengths)}")
print("=" * 60)
# Print per-episode rewards
print("\nPer-episode rewards:")
for i, reward in enumerate(rewards, 1):
print(f" Episode {i}: {reward:.2f}")
# Optional: Try to plot if matplotlib is available
try:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
# Plot 1: Episode rewards
plt.subplot(1, 2, 1)
plt.plot(range(1, len(rewards) + 1), rewards, 'o-', linewidth=2, markersize=6)
plt.axhline(y=np.mean(rewards), color='r', linestyle='--', label=f'Mean: {np.mean(rewards):.2f}')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Episode Rewards')
plt.grid(True, alpha=0.3)
plt.legend()
# Plot 2: Episode lengths
plt.subplot(1, 2, 2)
plt.plot(range(1, len(episode_lengths) + 1), episode_lengths, 's-',
linewidth=2, markersize=6, color='green')
plt.axhline(y=np.mean(episode_lengths), color='r', linestyle='--',
label=f'Mean: {np.mean(episode_lengths):.1f}')
plt.xlabel('Episode')
plt.ylabel('Episode Length')
plt.title('Episode Lengths')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig('eval_results.png', dpi=150, bbox_inches='tight')
print("\n✓ Evaluation plots saved to eval_results.png")
print(" (Close the plot window to continue)")
plt.show(block=False)
plt.pause(2) # Show for 2 seconds
plt.close()
except ImportError:
# Matplotlib not available, skip plotting
pass
except Exception as e:
print(f"\nNote: Could not generate plots: {e}")
if __name__ == "__main__":
main()