import os from PIL import Image import gym_super_mario_bros from gym_super_mario_bros.actions import RIGHT_ONLY from nes_py.wrappers import JoypadSpace from gym import Wrapper from gym.wrappers import GrayScaleObservation, ResizeObservation, FrameStack from agent import Agent # ----- Custom SkipFrame wrapper ----- class SkipFrame(Wrapper): def __init__(self, env, skip: int): super().__init__(env) self._skip = skip self.frames_log = [] self.actions_log = [] def step(self, action): total_reward = 0.0 done = False trunc = False info = {} for _ in range(self._skip): next_state, reward, done, trunc, info = self.env.step(action) # log frames + actions self.frames_log.append(next_state.copy()) self.actions_log.append(action) total_reward += reward if done or trunc: break return next_state, total_reward, done, trunc, info def reset(self, **kwargs): state, info = self.env.reset(**kwargs) self.frames_log = [state.copy()] self.actions_log = [0] return state, info def apply_wrappers(env): env = SkipFrame(env, skip=4) env = ResizeObservation(env, shape=84) env = GrayScaleObservation(env) env = FrameStack(env, num_stack=4, lz4_compress=True) return env # -------- CONFIG -------- ENV_NAME = "SuperMarioBros-1-1-v0" NUM_OF_EPISODES = 5 # just a few episodes to record GAMES_ROOT = "games" CONTROLLERS_DIR = "controllers" # folder with 0.png ... 4.png SCALING_FACTOR = 10 # upscale frames for nicer video # ------------------------ def main(): os.makedirs(GAMES_ROOT, exist_ok=True) # load controller images (optional, only for separate controller frames) controllers = [Image.open(os.path.join(CONTROLLERS_DIR, f"{i}.png")) for i in range(5)] # ---- create env ---- env = gym_super_mario_bros.make( ENV_NAME, render_mode="rgb_array", apply_api_compatibility=True, ) env = JoypadSpace(env, RIGHT_ONLY) env = apply_wrappers(env) # ---- agent ---- agent = Agent(input_dims=env.observation_space.shape, num_actions=env.action_space.n) # TODO: change this path to your trained checkpoint # e.g. "models/best_model.pt" or "models/mario/model_500000.pt" # If you don't have a trained model yet, comment this line. agent.load_model("models/your_model_checkpoint.pt") for i in range(NUM_OF_EPISODES): done = False state, _ = env.reset() rewards = 0 while not done: action = agent.choose_action(state) frame = env.render() # rgb_array frame (not used directly, but ok) new_state, reward, done, truncated, info = env.step(action) rewards += reward state = new_state if done or truncated: print(f"Episode: {i}, Reward: {rewards}") # Only save a game if Mario reached the flag if info.get("flag_get", False): game_dir = os.path.join(GAMES_ROOT, f"game_{i}") os.makedirs(game_dir, exist_ok=True) # unwrap env to get SkipFrame wrapper frame_skip_env = env.env.env.env frames_log = frame_skip_env.frames_log actions_log = frame_skip_env.actions_log for j, (frame_array, action_taken) in enumerate(zip(frames_log, actions_log)): # upscale frame for nicer video new_dims = ( frame_array.shape[1] * SCALING_FACTOR, frame_array.shape[0] * SCALING_FACTOR, ) frame_img = Image.fromarray(frame_array).resize( new_dims, Image.NEAREST ) frame_img.save(os.path.join(game_dir, f"frame_{j}.png")) # save controller overlay frames (optional) controllers[action_taken].save( os.path.join(game_dir, f"controller_{j}.png") ) break env.close() print("Finished generating frame sequences in 'games/'.") if __name__ == "__main__": main()