Spaces:
Runtime error
Runtime error
File size: 5,570 Bytes
6083286 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
"""
Evaluate a trained PPO MLP baseline on the DroneWindEnv environment.
This script loads a saved PPO model and runs evaluation episodes,
printing statistics about average reward and episode length.
"""
import os
import sys
import argparse
import numpy as np
from stable_baselines3 import PPO
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.drone_env import DroneWindEnv
def main():
"""Main evaluation function."""
parser = argparse.ArgumentParser(description="Evaluate PPO agent on DroneWindEnv")
parser.add_argument(
"--model-path",
type=str,
default="models/mlp_baseline.zip",
help="Path to the trained model (default: models/mlp_baseline.zip)"
)
parser.add_argument(
"--episodes",
type=int,
default=10,
help="Number of evaluation episodes (default: 10)"
)
parser.add_argument(
"--render",
action="store_true",
help="Print environment state to console during evaluation"
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for evaluation (default: None)"
)
args = parser.parse_args()
print("=" * 60)
print("Evaluating PPO Agent on DroneWindEnv")
print("=" * 60)
print(f"Model path: {args.model_path}")
print(f"Number of episodes: {args.episodes}")
print("=" * 60)
# Check if model file exists
if not os.path.exists(args.model_path):
print(f"\nError: Model file not found at {args.model_path}")
print("Please train a model first using:")
print(" python train/train_mlp_ppo.py")
return
# Create environment
print("\nCreating environment...")
env = DroneWindEnv()
# Load the model
print(f"Loading model from {args.model_path}...")
try:
model = PPO.load(args.model_path, env=env)
print("Model loaded successfully!")
except Exception as e:
print(f"\nError loading model: {e}")
return
# Run evaluation episodes
print(f"\nRunning {args.episodes} evaluation episodes...")
print("-" * 60)
rewards = []
episode_lengths = []
for episode in range(args.episodes):
obs, info = env.reset(seed=args.seed)
done = False
truncated = False
total_reward = 0.0
step_count = 0
if args.render:
print(f"\nEpisode {episode + 1}:")
env.render()
while not (done or truncated):
# Get action from the model (deterministic)
action, _ = model.predict(obs, deterministic=True)
# Step the environment
obs, reward, done, truncated, info = env.step(action)
total_reward += reward
step_count += 1
if args.render:
env.render()
rewards.append(total_reward)
episode_lengths.append(step_count)
status = "terminated" if done else "truncated"
print(f"Episode {episode + 1}: Reward = {total_reward:.2f}, "
f"Length = {step_count} steps ({status})")
# Print statistics
print("\n" + "=" * 60)
print("Evaluation Results")
print("=" * 60)
print(f"Average reward: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}")
print(f"Average episode length: {np.mean(episode_lengths):.1f} ± {np.std(episode_lengths):.1f}")
print(f"Min reward: {np.min(rewards):.2f}")
print(f"Max reward: {np.max(rewards):.2f}")
print(f"Min episode length: {np.min(episode_lengths)}")
print(f"Max episode length: {np.max(episode_lengths)}")
print("=" * 60)
# Print per-episode rewards
print("\nPer-episode rewards:")
for i, reward in enumerate(rewards, 1):
print(f" Episode {i}: {reward:.2f}")
# Optional: Try to plot if matplotlib is available
try:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
# Plot 1: Episode rewards
plt.subplot(1, 2, 1)
plt.plot(range(1, len(rewards) + 1), rewards, 'o-', linewidth=2, markersize=6)
plt.axhline(y=np.mean(rewards), color='r', linestyle='--', label=f'Mean: {np.mean(rewards):.2f}')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Episode Rewards')
plt.grid(True, alpha=0.3)
plt.legend()
# Plot 2: Episode lengths
plt.subplot(1, 2, 2)
plt.plot(range(1, len(episode_lengths) + 1), episode_lengths, 's-',
linewidth=2, markersize=6, color='green')
plt.axhline(y=np.mean(episode_lengths), color='r', linestyle='--',
label=f'Mean: {np.mean(episode_lengths):.1f}')
plt.xlabel('Episode')
plt.ylabel('Episode Length')
plt.title('Episode Lengths')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig('eval_results.png', dpi=150, bbox_inches='tight')
print("\n✓ Evaluation plots saved to eval_results.png")
print(" (Close the plot window to continue)")
plt.show(block=False)
plt.pause(2) # Show for 2 seconds
plt.close()
except ImportError:
# Matplotlib not available, skip plotting
pass
except Exception as e:
print(f"\nNote: Could not generate plots: {e}")
if __name__ == "__main__":
main()
|