MyMarioAI2 / MyMarioAI.py
wb-droid's picture
minor fix
b5ae145
# References:
# https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html
# https://github.com/yfeng997/MadMario/
# https://stackoverflow.com/questions/52726475/display-openai-gym-in-jupyter-notebook-only
# https://github.com/uvipen/Super-mario-bros-PPO-pytorch/blob/master/src/env.py
# https://stackoverflow.com/questions/753190/programmatically-generate-video-or-animated-gif-in-python
#%%bash
#pip install gym-super-mario-bros==7.4.0
#pip install tensordict==0.3.0
#pip install torchrl==0.3.0
#pip install torchvision
#pip install matplotlib
#pip install imageio
from mad_mario import *
#from IPython.display import clear_output
import imageio
class MyMario(Mario):
def __init__(self, state_dim, action_dim, save_dir):
super().__init__(state_dim, action_dim, save_dir)
self.exploration_rate_decay = 1 - ((1 - 0.99999975 ) * 10)
self.exploration_rate_min = 0
self.default_chkpoint = "chkpoint.chkpt"
self.burnin = 200
self.learn_from_death_count = 10
def load_mario(env, save_dir):
mario = MyMario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)
# Load from default chkpt.
if Path(mario.default_chkpoint).is_file():
dic = torch.load(mario.default_chkpoint, map_location=torch.device('cpu'))
mario.net.load_state_dict(dic["model"])
mario.exploration_rate = dic["exploration_rate"]
return mario
def train(is_eval = False, episodes = 1000):
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb_array', apply_api_compatibility=True)
# Limit the action-space to
# 0. walk right
# 1. jump right
env = JoypadSpace(env, [["right"], ["right", "A"]])
env.reset()
next_state, reward, done, trunc, info = env.step(action=0)
print(f"{next_state.shape},\n {reward},\n {done},\n {info}")
# Apply Wrappers to environment
env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)
if gym.__version__ < '0.26':
env = FrameStack(env, num_stack=4, new_step_api=True)
else:
env = FrameStack(env, num_stack=4)
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
mario = load_mario(env, save_dir)
logger = MetricLogger(save_dir)
images = []
for e in range(episodes):
state = env.reset()
images = []
# Play the game!
while True:
if is_eval:
#clear_output(wait=True)
img = env.render()
plt.imshow( img )
plt.show()
images.append(img.copy())
# Run agent on the state
action = mario.act(state)
# Agent performs action
next_state, reward, done, trunc, info = env.step(action)
if not is_eval:
# Remember newer actions to learn more on new exploration
if len(logger.moving_avg_ep_lengths) > 0:
pos = (float)(logger.curr_ep_length) / (float)(logger.moving_avg_ep_lengths[-1])
if pos > 0.8: pos = 1.0
if pos < 0.2: pos = 0.2
if np.random.rand() < pos:
mario.cache(state, next_state, action, reward, done)
else:
mario.cache(state, next_state, action, reward, done)
if done == True:
for idx in range(mario.learn_from_death_count):
mario.cache(state, next_state, action, reward, done)
# Learn
q, loss = mario.learn()
# Logging
logger.log_step(reward, loss, q)
# Update state
state = next_state
# Check if end of game
if done or info["flag_get"]:
break
if not is_eval:
logger.log_episode()
if (e % 20 == 0) or (e == episodes - 1):
logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)
if (e % 200 == 0) or (e == episodes - 1):
# Save to timestamped dir
save_dir = Path("checkpoints_save") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
mario.save_dir = save_dir
mario.save()
# Save to default dir
torch.save(
dict(model=mario.net.state_dict(), exploration_rate=mario.exploration_rate),
mario.default_chkpoint,
)
print(f"MarioNet saved to {mario.default_chkpoint}")
if is_eval:
imageio.mimsave('movie.gif', images)
# For training
#train(is_eval = False, episodes = 1000)
# For evaluation
#train(is_eval = True, episodes = 1)