Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from MyMarioAI import * | |
| def play(): | |
| is_eval = True | |
| episodes = 5 | |
| env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb_array', apply_api_compatibility=True) | |
| # Limit the action-space to | |
| # 0. walk right | |
| # 1. jump right | |
| env = JoypadSpace(env, [["right"], ["right", "A"]]) | |
| env.reset() | |
| next_state, reward, done, trunc, info = env.step(action=0) | |
| print(f"{next_state.shape},\n {reward},\n {done},\n {info}") | |
| # Apply Wrappers to environment | |
| env = SkipFrame(env, skip=4) | |
| env = GrayScaleObservation(env) | |
| env = ResizeObservation(env, shape=84) | |
| if gym.__version__ < '0.26': | |
| env = FrameStack(env, num_stack=4, new_step_api=True) | |
| else: | |
| env = FrameStack(env, num_stack=4) | |
| use_cuda = torch.cuda.is_available() | |
| print(f"Using CUDA: {use_cuda}") | |
| print() | |
| save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
| save_dir.mkdir(parents=True) | |
| mario = load_mario(env, save_dir) | |
| logger = MetricLogger(save_dir) | |
| images = [] | |
| for e in range(episodes): | |
| state = env.reset() | |
| images = [] | |
| # Play the game! | |
| while True: | |
| if is_eval: | |
| #clear_output(wait=True) | |
| img = env.render() | |
| plt.imshow( img ) | |
| plt.show() | |
| #time.sleep(0.1) | |
| images.append(img.copy()) | |
| yield (img.copy()), None | |
| # Run agent on the state | |
| with torch.no_grad(): | |
| action = mario.act(state) | |
| # Agent performs action | |
| next_state, reward, done, trunc, info = env.step(action) | |
| # Update state | |
| state = next_state | |
| # Check if end of game | |
| if done or info["flag_get"]: | |
| break | |
| if info["flag_get"]: | |
| #break | |
| imageio.mimsave('movie_new.gif', images) | |
| time.sleep(5) | |
| return (img.copy()), 'movie_new.gif' | |
| def refresh_playback(): | |
| return 'movie_new.gif' | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1 align="center">Mario AI</h1>""") | |
| gr.HTML("""<h1 align="center">(May take a few re-plays to pass full scenario. )</h1>""") | |
| session_data = gr.State([]) | |
| with gr.Row(): | |
| #with gr.Column(scale=1): | |
| with gr.Column(scale=1): | |
| play_mario = gr.Button("Let AI Play") | |
| mario_image = gr.Image(height=400,width=400, label="New play.") | |
| with gr.Column(scale=1): | |
| refresh = gr.Button("Refresh") | |
| mario_gif = gr.Image(height=400,width=400, value='movie.gif', label="Playback previous AI run.") | |
| refresh.click( | |
| refresh_playback, | |
| [], | |
| [mario_gif] | |
| ) | |
| play_mario.click( | |
| play, | |
| [], | |
| [mario_image, mario_gif], | |
| #show_progress=True, | |
| ) | |
| demo.queue().launch(share=False, inbrowser=True) |