Lunar_Lander-V3_Discrete / trained_agent.py
privateboss's picture
Update trained_agent.py
1349907 verified
import gymnasium as gym
import numpy as np
import tensorflow as tf
import os
import time
from config import ENV_ID, SEED, SAVE_PATH, TOTAL_TIMESTEPS
from agent import PPOAgent
from reward_shaping import LunarLanderRewardShaping
CHECKPOINT_SUBDIR = 'tf_checkpoints'
CHECKPOINT_ROOT = os.path.join(SAVE_PATH, CHECKPOINT_SUBDIR)
tf.config.run_functions_eagerly(True)
def run_trained_agent(episodes=10, target_step=None):
print(f"--- Running Trained Agent on {ENV_ID} with Human Rendering ---")
print(f"Checking for checkpoints in: {CHECKPOINT_ROOT}")
try:
env = gym.make(ENV_ID, render_mode="human")
env = LunarLanderRewardShaping(env)
except Exception as e:
print(f"ERROR: Could not create environment {ENV_ID} or apply wrapper. Details: {e}")
return
obs_shape = env.observation_space.shape
action_size = env.action_space.n
current_obs, info = env.reset(seed=SEED)
agent = PPOAgent(obs_shape, action_size, TOTAL_TIMESTEPS)
checkpoint_to_load = None
if target_step is not None:
checkpoint_name_prefix = f'ckpt-{target_step}'
potential_path = os.path.join(CHECKPOINT_ROOT, checkpoint_name_prefix)
if os.path.exists(f'{potential_path}.index'):
checkpoint_to_load = potential_path
print(f"\nAttempting to load specific checkpoint: T={target_step}")
else:
print(f"\nERROR: Specified checkpoint prefix '{checkpoint_name_prefix}' was not found in '{CHECKPOINT_ROOT}'.")
print("Falling back to the latest available checkpoint.")
checkpoint_to_load = agent.checkpoint_manager.latest_checkpoint
else:
checkpoint_to_load = agent.checkpoint_manager.latest_checkpoint
if not checkpoint_to_load:
print("\nERROR: Could not find any checkpoint in the designated save path.")
env.close()
return
try:
agent.checkpoint.restore(checkpoint_to_load).expect_partial()
agent.obs_rms.mean = agent.rms_mean_var.numpy()
agent.obs_rms.var = agent.rms_var_var.numpy()
agent.obs_rms.count = agent.rms_count_var.numpy()
checkpoint_name = os.path.basename(checkpoint_to_load)
if 'ckpt-' in checkpoint_name:
loaded_timesteps = int(checkpoint_name.split('-')[-1])
else:
loaded_timesteps = 0
print(f"\nSuccessfully loaded checkpoint trained to T={loaded_timesteps}")
except Exception as e:
print(f"\nERROR: Failed to restore checkpoint at {checkpoint_to_load}. Details: {e}")
print("Suggestion: Check the consistency of your environment setup (wrapper, action size, file names).")
env.close()
return
print(f"\nStarting {episodes} playback episodes...")
total_rewards = []
for i in range(episodes):
done = False
episode_reward = 0
step_count = 0
while not done:
env.render()
obs_to_agent = current_obs.reshape(1, *obs_shape)
actions, _, _ = agent.select_action(obs_to_agent)
action_to_step = actions[0]
current_obs, reward, terminated, truncated, info = env.step(action_to_step)
done = terminated or truncated
episode_reward += reward
step_count += 1
time.sleep(0.01)
total_rewards.append(episode_reward)
print(f"Episode {i+1}: Reward = {episode_reward:7.2f}, Steps = {step_count}")
current_obs, info = env.reset()
env.close()
if total_rewards:
print("-" * 30)
print(f"Average Reward over {episodes} episodes: {np.mean(total_rewards):7.2f}")
print("-" * 30)
if __name__ == "__main__":
run_trained_agent(episodes=15, target_step=None)