Spaces:
Runtime error
Runtime error
| import cv2 | |
| import gradio as gr | |
| import time | |
| from huggingface_sb3 import load_from_hub | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.env_util import make_atari_env | |
| from stable_baselines3.common.vec_env import VecFrameStack | |
| from stable_baselines3.common.env_util import make_atari_env | |
| max_steps = 5000 # Let's try with 5000 steps. | |
| # Loading functions were taken from Edward Beeching code | |
| def load_env(env_name): | |
| env = make_atari_env(env_name, n_envs=1) | |
| env = VecFrameStack(env, n_stack=4) | |
| return env | |
| def load_model(env_name): | |
| custom_objects = { | |
| "learning_rate": 0.0, | |
| "lr_schedule": lambda _: 0.0, | |
| "clip_range": lambda _: 0.0, | |
| } | |
| checkpoint = load_from_hub( | |
| f"ThomasSimonini/ppo-{env_name}", | |
| f"ppo-{env_name}.zip", | |
| ) | |
| model = PPO.load(checkpoint, custom_objects=custom_objects) | |
| return model | |
| def replay(env_name, time_sleep): | |
| max_steps = 500 | |
| env = load_env(env_name) | |
| model = load_model(env_name) | |
| #for i in range(num_episodes): | |
| obs = env.reset() | |
| done = False | |
| i = 0 | |
| while not done: | |
| i+= 1 | |
| if i < max_steps: | |
| frame = env.render(mode="rgb_array") | |
| action, _states = model.predict(obs) | |
| obs, reward, done, info = env.step([action]) | |
| time.sleep(time_sleep) | |
| yield frame | |
| else: | |
| break | |
| demo = gr.Interface( | |
| replay, | |
| [gr.Dropdown(["SpaceInvadersNoFrameskip-v4", | |
| "PongNoFrameskip-v4", | |
| "SeaquestNoFrameskip-v4", | |
| "QbertNoFrameskip-v4", | |
| ]), | |
| #gr.Slider(100, 10000, value=500), | |
| gr.Slider(0.01, 1, value=0.05), | |
| #gr.Slider(1, 20, value=5) | |
| ], | |
| gr.Image(), | |
| title="Watch Agents playing Atari games 🤖", | |
| description="Select an environment to watch a Hugging Face's trained deep reinforcement learning agent.", | |
| article = "time_sleep is the time delay between each frame (0.05 by default)." | |
| ).launch().queue() |