MyMarioAI2 / app.py
wb-droid's picture
minor change
2ee197d
import gradio as gr
import os
from MyMarioAI import *
def play():
is_eval = True
episodes = 5
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb_array', apply_api_compatibility=True)
# Limit the action-space to
# 0. walk right
# 1. jump right
env = JoypadSpace(env, [["right"], ["right", "A"]])
env.reset()
next_state, reward, done, trunc, info = env.step(action=0)
print(f"{next_state.shape},\n {reward},\n {done},\n {info}")
# Apply Wrappers to environment
env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)
if gym.__version__ < '0.26':
env = FrameStack(env, num_stack=4, new_step_api=True)
else:
env = FrameStack(env, num_stack=4)
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
mario = load_mario(env, save_dir)
logger = MetricLogger(save_dir)
images = []
for e in range(episodes):
state = env.reset()
images = []
# Play the game!
while True:
if is_eval:
#clear_output(wait=True)
img = env.render()
plt.imshow( img )
plt.show()
#time.sleep(0.1)
images.append(img.copy())
yield (img.copy()), None
# Run agent on the state
with torch.no_grad():
action = mario.act(state)
# Agent performs action
next_state, reward, done, trunc, info = env.step(action)
# Update state
state = next_state
# Check if end of game
if done or info["flag_get"]:
break
if info["flag_get"]:
#break
imageio.mimsave('movie_new.gif', images)
time.sleep(5)
return (img.copy()), 'movie_new.gif'
def refresh_playback():
return 'movie_new.gif'
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">Mario AI</h1>""")
gr.HTML("""<h1 align="center">(May take a few re-plays to pass full scenario. )</h1>""")
session_data = gr.State([])
with gr.Row():
#with gr.Column(scale=1):
with gr.Column(scale=1):
play_mario = gr.Button("Let AI Play")
mario_image = gr.Image(height=400,width=400, label="New play.")
with gr.Column(scale=1):
refresh = gr.Button("Refresh")
mario_gif = gr.Image(height=400,width=400, value='movie.gif', label="Playback previous AI run.")
refresh.click(
refresh_playback,
[],
[mario_gif]
)
play_mario.click(
play,
[],
[mario_image, mario_gif],
#show_progress=True,
)
demo.queue().launch(share=False, inbrowser=True)