Spaces:
Runtime error
Runtime error
| """ | |
| Evaluate a trained PPO MLP baseline on the DroneWindEnv environment. | |
| This script loads a saved PPO model and runs evaluation episodes, | |
| printing statistics about average reward and episode length. | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import numpy as np | |
| from stable_baselines3 import PPO | |
| # Add project root to path | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from env.drone_env import DroneWindEnv | |
| def main(): | |
| """Main evaluation function.""" | |
| parser = argparse.ArgumentParser(description="Evaluate PPO agent on DroneWindEnv") | |
| parser.add_argument( | |
| "--model-path", | |
| type=str, | |
| default="models/mlp_baseline.zip", | |
| help="Path to the trained model (default: models/mlp_baseline.zip)" | |
| ) | |
| parser.add_argument( | |
| "--episodes", | |
| type=int, | |
| default=10, | |
| help="Number of evaluation episodes (default: 10)" | |
| ) | |
| parser.add_argument( | |
| "--render", | |
| action="store_true", | |
| help="Print environment state to console during evaluation" | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=None, | |
| help="Random seed for evaluation (default: None)" | |
| ) | |
| args = parser.parse_args() | |
| print("=" * 60) | |
| print("Evaluating PPO Agent on DroneWindEnv") | |
| print("=" * 60) | |
| print(f"Model path: {args.model_path}") | |
| print(f"Number of episodes: {args.episodes}") | |
| print("=" * 60) | |
| # Check if model file exists | |
| if not os.path.exists(args.model_path): | |
| print(f"\nError: Model file not found at {args.model_path}") | |
| print("Please train a model first using:") | |
| print(" python train/train_mlp_ppo.py") | |
| return | |
| # Create environment | |
| print("\nCreating environment...") | |
| env = DroneWindEnv() | |
| # Load the model | |
| print(f"Loading model from {args.model_path}...") | |
| try: | |
| model = PPO.load(args.model_path, env=env) | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"\nError loading model: {e}") | |
| return | |
| # Run evaluation episodes | |
| print(f"\nRunning {args.episodes} evaluation episodes...") | |
| print("-" * 60) | |
| rewards = [] | |
| episode_lengths = [] | |
| for episode in range(args.episodes): | |
| obs, info = env.reset(seed=args.seed) | |
| done = False | |
| truncated = False | |
| total_reward = 0.0 | |
| step_count = 0 | |
| if args.render: | |
| print(f"\nEpisode {episode + 1}:") | |
| env.render() | |
| while not (done or truncated): | |
| # Get action from the model (deterministic) | |
| action, _ = model.predict(obs, deterministic=True) | |
| # Step the environment | |
| obs, reward, done, truncated, info = env.step(action) | |
| total_reward += reward | |
| step_count += 1 | |
| if args.render: | |
| env.render() | |
| rewards.append(total_reward) | |
| episode_lengths.append(step_count) | |
| status = "terminated" if done else "truncated" | |
| print(f"Episode {episode + 1}: Reward = {total_reward:.2f}, " | |
| f"Length = {step_count} steps ({status})") | |
| # Print statistics | |
| print("\n" + "=" * 60) | |
| print("Evaluation Results") | |
| print("=" * 60) | |
| print(f"Average reward: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}") | |
| print(f"Average episode length: {np.mean(episode_lengths):.1f} ± {np.std(episode_lengths):.1f}") | |
| print(f"Min reward: {np.min(rewards):.2f}") | |
| print(f"Max reward: {np.max(rewards):.2f}") | |
| print(f"Min episode length: {np.min(episode_lengths)}") | |
| print(f"Max episode length: {np.max(episode_lengths)}") | |
| print("=" * 60) | |
| # Print per-episode rewards | |
| print("\nPer-episode rewards:") | |
| for i, reward in enumerate(rewards, 1): | |
| print(f" Episode {i}: {reward:.2f}") | |
| # Optional: Try to plot if matplotlib is available | |
| try: | |
| import matplotlib.pyplot as plt | |
| plt.figure(figsize=(10, 5)) | |
| # Plot 1: Episode rewards | |
| plt.subplot(1, 2, 1) | |
| plt.plot(range(1, len(rewards) + 1), rewards, 'o-', linewidth=2, markersize=6) | |
| plt.axhline(y=np.mean(rewards), color='r', linestyle='--', label=f'Mean: {np.mean(rewards):.2f}') | |
| plt.xlabel('Episode') | |
| plt.ylabel('Total Reward') | |
| plt.title('Episode Rewards') | |
| plt.grid(True, alpha=0.3) | |
| plt.legend() | |
| # Plot 2: Episode lengths | |
| plt.subplot(1, 2, 2) | |
| plt.plot(range(1, len(episode_lengths) + 1), episode_lengths, 's-', | |
| linewidth=2, markersize=6, color='green') | |
| plt.axhline(y=np.mean(episode_lengths), color='r', linestyle='--', | |
| label=f'Mean: {np.mean(episode_lengths):.1f}') | |
| plt.xlabel('Episode') | |
| plt.ylabel('Episode Length') | |
| plt.title('Episode Lengths') | |
| plt.grid(True, alpha=0.3) | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig('eval_results.png', dpi=150, bbox_inches='tight') | |
| print("\n✓ Evaluation plots saved to eval_results.png") | |
| print(" (Close the plot window to continue)") | |
| plt.show(block=False) | |
| plt.pause(2) # Show for 2 seconds | |
| plt.close() | |
| except ImportError: | |
| # Matplotlib not available, skip plotting | |
| pass | |
| except Exception as e: | |
| print(f"\nNote: Could not generate plots: {e}") | |
| if __name__ == "__main__": | |
| main() | |