File size: 5,064 Bytes
298814a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979ff77
298814a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5ae145
298814a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979ff77
298814a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# References:
# https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html
# https://github.com/yfeng997/MadMario/
# https://stackoverflow.com/questions/52726475/display-openai-gym-in-jupyter-notebook-only
# https://github.com/uvipen/Super-mario-bros-PPO-pytorch/blob/master/src/env.py
# https://stackoverflow.com/questions/753190/programmatically-generate-video-or-animated-gif-in-python


#%%bash
#pip install gym-super-mario-bros==7.4.0
#pip install tensordict==0.3.0
#pip install torchrl==0.3.0
#pip install torchvision
#pip install matplotlib
#pip install imageio

from mad_mario import *
#from IPython.display import clear_output
import imageio


class MyMario(Mario):
    def __init__(self, state_dim, action_dim, save_dir):
        super().__init__(state_dim, action_dim, save_dir)
        self.exploration_rate_decay = 1 - ((1 - 0.99999975 ) * 10)
        self.exploration_rate_min = 0
        self.default_chkpoint = "chkpoint.chkpt"
        self.burnin = 200
        
        self.learn_from_death_count = 10

def load_mario(env, save_dir):
    mario = MyMario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)

    # Load from default chkpt.
    if Path(mario.default_chkpoint).is_file():
        dic = torch.load(mario.default_chkpoint, map_location=torch.device('cpu'))
        mario.net.load_state_dict(dic["model"])
        mario.exploration_rate = dic["exploration_rate"]

    return mario

def train(is_eval = False, episodes = 1000):
    env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb_array', apply_api_compatibility=True)
    # Limit the action-space to
    #   0. walk right
    #   1. jump right
    env = JoypadSpace(env, [["right"], ["right", "A"]])

    env.reset()
    next_state, reward, done, trunc, info = env.step(action=0)
    print(f"{next_state.shape},\n {reward},\n {done},\n {info}")


    # Apply Wrappers to environment
    env = SkipFrame(env, skip=4)
    env = GrayScaleObservation(env)
    env = ResizeObservation(env, shape=84)
    if gym.__version__ < '0.26':
        env = FrameStack(env, num_stack=4, new_step_api=True)
    else:
        env = FrameStack(env, num_stack=4)

    use_cuda = torch.cuda.is_available()
    print(f"Using CUDA: {use_cuda}")
    print()

    save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    save_dir.mkdir(parents=True)

    mario = load_mario(env, save_dir)

    logger = MetricLogger(save_dir)

    images = []

    for e in range(episodes):

        state = env.reset()
        images = []

        # Play the game!
        while True:

            if is_eval:
                #clear_output(wait=True)
                img = env.render()
                plt.imshow( img )
                plt.show()
                images.append(img.copy())

            # Run agent on the state
            action = mario.act(state)

            # Agent performs action
            next_state, reward, done, trunc, info = env.step(action)

            if not is_eval:
                # Remember newer actions to learn more on new exploration
                if len(logger.moving_avg_ep_lengths) > 0:
                    pos = (float)(logger.curr_ep_length) / (float)(logger.moving_avg_ep_lengths[-1])
                    if pos > 0.8: pos = 1.0
                    if pos < 0.2: pos = 0.2

                    if np.random.rand() < pos:
                        mario.cache(state, next_state, action, reward, done)
                else:
                    mario.cache(state, next_state, action, reward, done)

                if done == True:
                    for idx in range(mario.learn_from_death_count):
                        mario.cache(state, next_state, action, reward, done)

                # Learn
                q, loss = mario.learn()
                
                # Logging
                logger.log_step(reward, loss, q)

            # Update state
            state = next_state

            # Check if end of game
            if done or info["flag_get"]:
                break

        if not is_eval:
            logger.log_episode()

            if (e % 20 == 0) or (e == episodes - 1):
                logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)    

            if (e % 200 == 0) or (e == episodes - 1):
                # Save to timestamped dir
                save_dir = Path("checkpoints_save") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
                save_dir.mkdir(parents=True)
                mario.save_dir = save_dir
                mario.save()

                # Save to default dir
                torch.save(
                    dict(model=mario.net.state_dict(), exploration_rate=mario.exploration_rate),
                    mario.default_chkpoint,
                )
                print(f"MarioNet saved to {mario.default_chkpoint}")
    
    if is_eval:
        imageio.mimsave('movie.gif', images)


# For training
#train(is_eval = False, episodes = 1000)

# For evaluation
#train(is_eval = True, episodes = 1)