Spaces:
Sleeping
Sleeping
File size: 5,064 Bytes
298814a 979ff77 298814a b5ae145 298814a 979ff77 298814a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# References:
# https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html
# https://github.com/yfeng997/MadMario/
# https://stackoverflow.com/questions/52726475/display-openai-gym-in-jupyter-notebook-only
# https://github.com/uvipen/Super-mario-bros-PPO-pytorch/blob/master/src/env.py
# https://stackoverflow.com/questions/753190/programmatically-generate-video-or-animated-gif-in-python
#%%bash
#pip install gym-super-mario-bros==7.4.0
#pip install tensordict==0.3.0
#pip install torchrl==0.3.0
#pip install torchvision
#pip install matplotlib
#pip install imageio
from mad_mario import *
#from IPython.display import clear_output
import imageio
class MyMario(Mario):
def __init__(self, state_dim, action_dim, save_dir):
super().__init__(state_dim, action_dim, save_dir)
self.exploration_rate_decay = 1 - ((1 - 0.99999975 ) * 10)
self.exploration_rate_min = 0
self.default_chkpoint = "chkpoint.chkpt"
self.burnin = 200
self.learn_from_death_count = 10
def load_mario(env, save_dir):
mario = MyMario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)
# Load from default chkpt.
if Path(mario.default_chkpoint).is_file():
dic = torch.load(mario.default_chkpoint, map_location=torch.device('cpu'))
mario.net.load_state_dict(dic["model"])
mario.exploration_rate = dic["exploration_rate"]
return mario
def train(is_eval = False, episodes = 1000):
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb_array', apply_api_compatibility=True)
# Limit the action-space to
# 0. walk right
# 1. jump right
env = JoypadSpace(env, [["right"], ["right", "A"]])
env.reset()
next_state, reward, done, trunc, info = env.step(action=0)
print(f"{next_state.shape},\n {reward},\n {done},\n {info}")
# Apply Wrappers to environment
env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)
if gym.__version__ < '0.26':
env = FrameStack(env, num_stack=4, new_step_api=True)
else:
env = FrameStack(env, num_stack=4)
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
mario = load_mario(env, save_dir)
logger = MetricLogger(save_dir)
images = []
for e in range(episodes):
state = env.reset()
images = []
# Play the game!
while True:
if is_eval:
#clear_output(wait=True)
img = env.render()
plt.imshow( img )
plt.show()
images.append(img.copy())
# Run agent on the state
action = mario.act(state)
# Agent performs action
next_state, reward, done, trunc, info = env.step(action)
if not is_eval:
# Remember newer actions to learn more on new exploration
if len(logger.moving_avg_ep_lengths) > 0:
pos = (float)(logger.curr_ep_length) / (float)(logger.moving_avg_ep_lengths[-1])
if pos > 0.8: pos = 1.0
if pos < 0.2: pos = 0.2
if np.random.rand() < pos:
mario.cache(state, next_state, action, reward, done)
else:
mario.cache(state, next_state, action, reward, done)
if done == True:
for idx in range(mario.learn_from_death_count):
mario.cache(state, next_state, action, reward, done)
# Learn
q, loss = mario.learn()
# Logging
logger.log_step(reward, loss, q)
# Update state
state = next_state
# Check if end of game
if done or info["flag_get"]:
break
if not is_eval:
logger.log_episode()
if (e % 20 == 0) or (e == episodes - 1):
logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)
if (e % 200 == 0) or (e == episodes - 1):
# Save to timestamped dir
save_dir = Path("checkpoints_save") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
mario.save_dir = save_dir
mario.save()
# Save to default dir
torch.save(
dict(model=mario.net.state_dict(), exploration_rate=mario.exploration_rate),
mario.default_chkpoint,
)
print(f"MarioNet saved to {mario.default_chkpoint}")
if is_eval:
imageio.mimsave('movie.gif', images)
# For training
#train(is_eval = False, episodes = 1000)
# For evaluation
#train(is_eval = True, episodes = 1) |