Spaces:
Sleeping
Sleeping
| # References: | |
| # https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html | |
| # https://github.com/yfeng997/MadMario/ | |
| # https://stackoverflow.com/questions/52726475/display-openai-gym-in-jupyter-notebook-only | |
| # https://github.com/uvipen/Super-mario-bros-PPO-pytorch/blob/master/src/env.py | |
| # https://stackoverflow.com/questions/753190/programmatically-generate-video-or-animated-gif-in-python | |
| #%%bash | |
| #pip install gym-super-mario-bros==7.4.0 | |
| #pip install tensordict==0.3.0 | |
| #pip install torchrl==0.3.0 | |
| #pip install torchvision | |
| #pip install matplotlib | |
| #pip install imageio | |
| from mad_mario import * | |
| #from IPython.display import clear_output | |
| import imageio | |
| class MyMario(Mario): | |
| def __init__(self, state_dim, action_dim, save_dir): | |
| super().__init__(state_dim, action_dim, save_dir) | |
| self.exploration_rate_decay = 1 - ((1 - 0.99999975 ) * 10) | |
| self.exploration_rate_min = 0 | |
| self.default_chkpoint = "chkpoint.chkpt" | |
| self.burnin = 200 | |
| self.learn_from_death_count = 10 | |
| def load_mario(env, save_dir): | |
| mario = MyMario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir) | |
| # Load from default chkpt. | |
| if Path(mario.default_chkpoint).is_file(): | |
| dic = torch.load(mario.default_chkpoint, map_location=torch.device('cpu')) | |
| mario.net.load_state_dict(dic["model"]) | |
| mario.exploration_rate = dic["exploration_rate"] | |
| return mario | |
| def train(is_eval = False, episodes = 1000): | |
| env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb_array', apply_api_compatibility=True) | |
| # Limit the action-space to | |
| # 0. walk right | |
| # 1. jump right | |
| env = JoypadSpace(env, [["right"], ["right", "A"]]) | |
| env.reset() | |
| next_state, reward, done, trunc, info = env.step(action=0) | |
| print(f"{next_state.shape},\n {reward},\n {done},\n {info}") | |
| # Apply Wrappers to environment | |
| env = SkipFrame(env, skip=4) | |
| env = GrayScaleObservation(env) | |
| env = ResizeObservation(env, shape=84) | |
| if gym.__version__ < '0.26': | |
| env = FrameStack(env, num_stack=4, new_step_api=True) | |
| else: | |
| env = FrameStack(env, num_stack=4) | |
| use_cuda = torch.cuda.is_available() | |
| print(f"Using CUDA: {use_cuda}") | |
| print() | |
| save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
| save_dir.mkdir(parents=True) | |
| mario = load_mario(env, save_dir) | |
| logger = MetricLogger(save_dir) | |
| images = [] | |
| for e in range(episodes): | |
| state = env.reset() | |
| images = [] | |
| # Play the game! | |
| while True: | |
| if is_eval: | |
| #clear_output(wait=True) | |
| img = env.render() | |
| plt.imshow( img ) | |
| plt.show() | |
| images.append(img.copy()) | |
| # Run agent on the state | |
| action = mario.act(state) | |
| # Agent performs action | |
| next_state, reward, done, trunc, info = env.step(action) | |
| if not is_eval: | |
| # Remember newer actions to learn more on new exploration | |
| if len(logger.moving_avg_ep_lengths) > 0: | |
| pos = (float)(logger.curr_ep_length) / (float)(logger.moving_avg_ep_lengths[-1]) | |
| if pos > 0.8: pos = 1.0 | |
| if pos < 0.2: pos = 0.2 | |
| if np.random.rand() < pos: | |
| mario.cache(state, next_state, action, reward, done) | |
| else: | |
| mario.cache(state, next_state, action, reward, done) | |
| if done == True: | |
| for idx in range(mario.learn_from_death_count): | |
| mario.cache(state, next_state, action, reward, done) | |
| # Learn | |
| q, loss = mario.learn() | |
| # Logging | |
| logger.log_step(reward, loss, q) | |
| # Update state | |
| state = next_state | |
| # Check if end of game | |
| if done or info["flag_get"]: | |
| break | |
| if not is_eval: | |
| logger.log_episode() | |
| if (e % 20 == 0) or (e == episodes - 1): | |
| logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step) | |
| if (e % 200 == 0) or (e == episodes - 1): | |
| # Save to timestamped dir | |
| save_dir = Path("checkpoints_save") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
| save_dir.mkdir(parents=True) | |
| mario.save_dir = save_dir | |
| mario.save() | |
| # Save to default dir | |
| torch.save( | |
| dict(model=mario.net.state_dict(), exploration_rate=mario.exploration_rate), | |
| mario.default_chkpoint, | |
| ) | |
| print(f"MarioNet saved to {mario.default_chkpoint}") | |
| if is_eval: | |
| imageio.mimsave('movie.gif', images) | |
| # For training | |
| #train(is_eval = False, episodes = 1000) | |
| # For evaluation | |
| #train(is_eval = True, episodes = 1) |