File size: 2,987 Bytes
298814a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649ce69
298814a
 
 
fa29ef9
 
 
298814a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa29ef9
 
 
 
 
 
2ee197d
fa29ef9
 
298814a
 
 
 
 
 
fa29ef9
 
 
 
298814a
fa29ef9
 
298814a
fa29ef9
 
 
 
 
298814a
 
 
fa29ef9
298814a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import os
from MyMarioAI import *


def play():

    is_eval = True
    episodes = 5

    env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb_array', apply_api_compatibility=True)
    # Limit the action-space to
    #   0. walk right
    #   1. jump right
    env = JoypadSpace(env, [["right"], ["right", "A"]])

    env.reset()
    next_state, reward, done, trunc, info = env.step(action=0)
    print(f"{next_state.shape},\n {reward},\n {done},\n {info}")


    # Apply Wrappers to environment
    env = SkipFrame(env, skip=4)
    env = GrayScaleObservation(env)
    env = ResizeObservation(env, shape=84)
    if gym.__version__ < '0.26':
        env = FrameStack(env, num_stack=4, new_step_api=True)
    else:
        env = FrameStack(env, num_stack=4)

    use_cuda = torch.cuda.is_available()
    print(f"Using CUDA: {use_cuda}")
    print()

    save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    save_dir.mkdir(parents=True)

    mario = load_mario(env, save_dir)

    logger = MetricLogger(save_dir)

    images = []

    for e in range(episodes):

        state = env.reset()
        images = []

        # Play the game!
        while True:

            if is_eval:
                #clear_output(wait=True)
                img = env.render()
                plt.imshow( img )
                plt.show()
                #time.sleep(0.1)
                images.append(img.copy())
                yield (img.copy()), None

            # Run agent on the state
            with torch.no_grad():
                action = mario.act(state)

            # Agent performs action
            next_state, reward, done, trunc, info = env.step(action)

            # Update state
            state = next_state

            # Check if end of game
            if done or info["flag_get"]:
                break
            
        if info["flag_get"]:
            #break
            imageio.mimsave('movie_new.gif', images)
            time.sleep(5)
            return (img.copy()), 'movie_new.gif'

def refresh_playback():
    return 'movie_new.gif'


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">Mario AI</h1>""")
    gr.HTML("""<h1 align="center">(May take a few re-plays to pass full scenario. )</h1>""")
    session_data = gr.State([])

    with gr.Row():
        #with gr.Column(scale=1):
        with gr.Column(scale=1):
            play_mario = gr.Button("Let AI Play")
            mario_image = gr.Image(height=400,width=400, label="New play.") 
        with gr.Column(scale=1):
            refresh = gr.Button("Refresh")
            mario_gif = gr.Image(height=400,width=400, value='movie.gif', label="Playback previous AI run.")
   
    refresh.click(
        refresh_playback,
        [],
        [mario_gif]
    )
    play_mario.click(
        play,
        [],
        [mario_image, mario_gif],
        #show_progress=True,

    )

demo.queue().launch(share=False, inbrowser=True)