Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from moviepy.editor import * | |
| def replay(option): | |
| path = "" | |
| # Get the correct model | |
| if (option == "LunarLander-v2 ππ©βπ"): | |
| path = "./LunarLander-v2.mp4" | |
| elif(option == "CartPole-v1 πΉοΈ"): | |
| path = "./CartPole-v1.mp4" | |
| elif(option == "Atari Space Invaders πΎ"): | |
| path = "./SpaceInvadersNoFrameskip-v4.mp4" | |
| # The only turnaround I found (Base64 video pb) | |
| videoclip = VideoFileClip(path) | |
| videoclip.write_videofile("new_filename.mp4") | |
| return 'new_filename.mp4' | |
| iface = gr.Interface( | |
| replay, | |
| [ | |
| gr.inputs.Dropdown(["Atari Space Invaders πΎ", "CartPole-v1 πΉοΈ", "LunarLander-v2 ππ©βπ"]), | |
| ], | |
| "video", | |
| title = 'Stable Baselines 3 with π€', | |
| description = '', | |
| article = | |
| '''<div> | |
| <p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p> | |
| <p style="text-align: center"> Select the trained agent you want to watch perform. | |
| These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p> | |
| <p> | |
| There are currently 3 models: | |
| <ul> | |
| <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li> | |
| <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li> | |
| <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li> | |
| </ul> | |
| </div>''' | |
| ) | |
| iface.launch() | |
| """ | |
| TODO: Next version with live video generation | |
| import gradio as gr | |
| import os | |
| from Recorder import Recorder | |
| from stable_baselines3 import PPO | |
| #The Agent plays and we generate the video | |
| def replay(option): | |
| video_path = "" | |
| # Get the correct model | |
| if (option == "LunarLander-v2 ππ©βπ"): | |
| env_name = "Lunar Lander v2" | |
| agent_name = "PPO" | |
| print("TEST") | |
| hf_model_filename = "LunarLander-v2" | |
| hf_model_id = "ThomasSimonini/stable-baselines3-ppo-LunarLander-v2" | |
| video_path = replay_gym(hf_model_filename, hf_model_id) | |
| elif(option == "CartPole-v1 πΉοΈ"): | |
| hf_model_filename = "CartPole-v1" | |
| hf_model_id = "ThomasSimonini/stable-baselines3-ppo-CartPole-v1" | |
| video_path = replay_gym(hf_model_filename, hf_model_id) | |
| elif(option == "Atari Space Invaders πΎ"): | |
| hf_model_filename = "SpaceInvadersNoFrameskip-v4" | |
| hf_model_id = "ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4" | |
| video_path = replay_atari(hf_model_filename, hf_model_id) | |
| #video_path = "./SpaceInvadersNoFrameskip-v4.mp4" | |
| return video_path | |
| def replay_gym(hf_model_filename, hf_model_id): | |
| import gym | |
| from stable_baselines3.common.evaluation import evaluate_policy | |
| model = PPO.load_from_huggingface(hf_model_id,hf_model_filename) | |
| eval_env = gym.make(hf_model_filename) | |
| directory = './video' | |
| env = Recorder(eval_env, directory) | |
| obs = env.reset() | |
| done = False | |
| while not done: | |
| action, _state = model.predict(obs) | |
| obs, reward, done, info = env.step(action) | |
| clip = env.play() | |
| return clip | |
| def replay_atari(hf_model_filename, hf_model_id): | |
| os.system("python -m atari_py.import_roms \"content/atari_roms\"") | |
| import gym | |
| from stable_baselines3.common.env_util import make_atari_env | |
| from stable_baselines3.common.vec_env import VecFrameStack | |
| from stable_baselines3.common.evaluation import evaluate_policy | |
| model = PPO.load_from_huggingface(hf_model_id, hf_model_filename) | |
| eval_env = make_atari_env(hf_model_filename, n_envs=1, seed=0) | |
| eval_env = VecFrameStack(eval_env, n_stack=4) | |
| model = PPO.load_from_huggingface(hf_model_id, hf_model_filename) | |
| import gym | |
| directory = './video' | |
| env = Recorder(eval_env, directory) | |
| obs = env.reset() | |
| done = False | |
| while not done: | |
| action, _state = model.predict(obs) | |
| obs, reward, done, info = env.step(action) | |
| clip = env.play() | |
| return clip | |
| iface = gr.Interface( | |
| replay, | |
| [ | |
| gr.inputs.Dropdown(["Atari Space Invaders πΎ", "CartPole-v1 πΉοΈ", "LunarLander-v2 ππ©βπ"]), | |
| ], | |
| "video", | |
| title = 'Stable Baselines 3 with π€', | |
| description = '', | |
| article = | |
| '''<div> | |
| <p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p> | |
| <p style="text-align: center"> Select the trained agent you want to watch perform. We record your agent playing. | |
| <p style="text-align: center"> Don't forget to <b>click on clear between each record.</b> </p> | |
| These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p> | |
| <p> | |
| There are currently 3 models: | |
| <ul> | |
| <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li> | |
| <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li> | |
| <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li> | |
| </ul> | |
| </div>''' | |
| ) | |
| iface.launch() | |
| """ | |