shoyebb26 commited on
Commit
bfb0d21
·
verified ·
1 Parent(s): 4508cff

Update generate_clips.py

Browse files
Files changed (1) hide show
  1. generate_clips.py +87 -53
generate_clips.py CHANGED
@@ -1,41 +1,48 @@
 
 
 
1
  import gym_super_mario_bros
2
  from gym_super_mario_bros.actions import RIGHT_ONLY
3
  from nes_py.wrappers import JoypadSpace
4
- from agent import Agent
5
 
6
  from gym import Wrapper
7
  from gym.wrappers import GrayScaleObservation, ResizeObservation, FrameStack
8
 
9
- import os
10
- from PIL import Image
11
 
12
- # Modified SkipFrame wrapper to log frames and actions
13
  class SkipFrame(Wrapper):
14
- def __init__(self, env, skip):
15
  super().__init__(env)
16
  self._skip = skip
17
- self.counter = 0
18
  self.frames_log = []
19
  self.actions_log = []
20
-
21
  def step(self, action):
22
  total_reward = 0.0
23
  done = False
 
 
 
24
  for _ in range(self._skip):
25
  next_state, reward, done, trunc, info = self.env.step(action)
 
26
  self.frames_log.append(next_state.copy())
27
  self.actions_log.append(action)
28
  total_reward += reward
29
- if done:
30
  break
 
31
  return next_state, total_reward, done, trunc, info
32
-
33
  def reset(self, **kwargs):
34
  state, info = self.env.reset(**kwargs)
35
  self.frames_log = [state.copy()]
36
  self.actions_log = [0]
37
  return state, info
38
 
 
39
  def apply_wrappers(env):
40
  env = SkipFrame(env, skip=4)
41
  env = ResizeObservation(env, shape=84)
@@ -44,58 +51,85 @@ def apply_wrappers(env):
44
  return env
45
 
46
 
47
- ENV_NAME = 'SuperMarioBros-1-1-v0'
48
- NUM_OF_EPISODES = 1_000
49
- controllers = [Image.open(f"controllers/{i}.png") for i in range(5)]
50
-
51
- env = gym_super_mario_bros.make(ENV_NAME, render_mode='rgb_array', apply_api_compatibility=True)
52
- env = JoypadSpace(env, RIGHT_ONLY)
53
- env = apply_wrappers(env)
54
 
55
- agent = Agent(input_dims=env.observation_space.shape, num_actions=env.action_space.n)
56
 
57
- # agent.load_model("models/folder_name/ckpt_name")
 
58
 
59
- for i in range(NUM_OF_EPISODES):
60
- done = False
61
- state, _ = env.reset()
62
- rewards = 0
63
- while not done:
64
- action = agent.choose_action(state)
65
- frame = env.render()
66
- new_state, reward, done, truncated, info = env.step(action)
67
- rewards += reward
68
 
69
- # agent.store_in_memory(state, action, reward, new_state, done)
70
- # agent.learn()
 
 
 
 
 
 
71
 
72
- state = new_state
 
 
73
 
74
- if done:
75
- print(f"Episode: {i}, Reward: {rewards}")
76
- if info["flag_get"]:
77
- os.makedirs(os.path.join("games" f"game_{i}"), exist_ok=True)
78
- frame_skip_env = env.env.env.env # Unwrapping the environment to get the SkipFrame wrapper
79
- frames_log = frame_skip_env.frames_log
80
- actions_log = frame_skip_env.actions_log
81
- for j, (frame, action) in enumerate(zip(frames_log, actions_log)):
82
- # upscale frame
83
- scaling_factor = 10
84
- new_dims = (frame.shape[1] * scaling_factor, frame.shape[0] * scaling_factor)
85
- frame = Image.fromarray(frame).resize(new_dims, Image.NEAREST)
86
 
87
- frame.save(os.path.join("games" f"game_{i}", f"frame_{j}.png"))
88
- controllers[action].save(os.path.join("games" f"game_{i}", f"controller_{j}.png"))
89
-
90
- # if i % 5000 == 0 and i > 0:
91
- # agent.save_model(os.path.join("models", f"model_{i}_iter.pt"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- env.close()
94
 
95
- # Bash commands to convert frames to video (Assuming you're in the games folder)
 
96
 
97
- # Video game frames
98
- # ffmpeg -framerate 60 -i frame_%d.png game.mp4
99
 
100
- # Controller frames (For transparent background)
101
- # ffmpeg -framerate 60 -i controller_%d.png -c:v prores_ks -profile:v 4444 -pix_fmt yuva444p10le -alpha_bits 16 controller.mov
 
1
+ import os
2
+ from PIL import Image
3
+
4
  import gym_super_mario_bros
5
  from gym_super_mario_bros.actions import RIGHT_ONLY
6
  from nes_py.wrappers import JoypadSpace
 
7
 
8
  from gym import Wrapper
9
  from gym.wrappers import GrayScaleObservation, ResizeObservation, FrameStack
10
 
11
+ from agent import Agent
12
+
13
 
14
+ # ----- Custom SkipFrame wrapper -----
15
  class SkipFrame(Wrapper):
16
+ def __init__(self, env, skip: int):
17
  super().__init__(env)
18
  self._skip = skip
 
19
  self.frames_log = []
20
  self.actions_log = []
21
+
22
  def step(self, action):
23
  total_reward = 0.0
24
  done = False
25
+ trunc = False
26
+ info = {}
27
+
28
  for _ in range(self._skip):
29
  next_state, reward, done, trunc, info = self.env.step(action)
30
+ # log frames + actions
31
  self.frames_log.append(next_state.copy())
32
  self.actions_log.append(action)
33
  total_reward += reward
34
+ if done or trunc:
35
  break
36
+
37
  return next_state, total_reward, done, trunc, info
38
+
39
  def reset(self, **kwargs):
40
  state, info = self.env.reset(**kwargs)
41
  self.frames_log = [state.copy()]
42
  self.actions_log = [0]
43
  return state, info
44
 
45
+
46
  def apply_wrappers(env):
47
  env = SkipFrame(env, skip=4)
48
  env = ResizeObservation(env, shape=84)
 
51
  return env
52
 
53
 
54
+ # -------- CONFIG --------
55
+ ENV_NAME = "SuperMarioBros-1-1-v0"
56
+ NUM_OF_EPISODES = 5 # just a few episodes to record
57
+ GAMES_ROOT = "games"
58
+ CONTROLLERS_DIR = "controllers" # folder with 0.png ... 4.png
59
+ SCALING_FACTOR = 10 # upscale frames for nicer video
60
+ # ------------------------
61
 
 
62
 
63
+ def main():
64
+ os.makedirs(GAMES_ROOT, exist_ok=True)
65
 
66
+ # load controller images (optional, only for separate controller frames)
67
+ controllers = [Image.open(os.path.join(CONTROLLERS_DIR, f"{i}.png")) for i in range(5)]
 
 
 
 
 
 
 
68
 
69
+ # ---- create env ----
70
+ env = gym_super_mario_bros.make(
71
+ ENV_NAME,
72
+ render_mode="rgb_array",
73
+ apply_api_compatibility=True,
74
+ )
75
+ env = JoypadSpace(env, RIGHT_ONLY)
76
+ env = apply_wrappers(env)
77
 
78
+ # ---- agent ----
79
+ agent = Agent(input_dims=env.observation_space.shape,
80
+ num_actions=env.action_space.n)
81
 
82
+ # TODO: change this path to your trained checkpoint
83
+ # e.g. "models/best_model.pt" or "models/mario/model_500000.pt"
84
+ # If you don't have a trained model yet, comment this line.
85
+ agent.load_model("models/your_model_checkpoint.pt")
 
 
 
 
 
 
 
 
86
 
87
+ for i in range(NUM_OF_EPISODES):
88
+ done = False
89
+ state, _ = env.reset()
90
+ rewards = 0
91
+
92
+ while not done:
93
+ action = agent.choose_action(state)
94
+ frame = env.render() # rgb_array frame (not used directly, but ok)
95
+ new_state, reward, done, truncated, info = env.step(action)
96
+ rewards += reward
97
+ state = new_state
98
+
99
+ if done or truncated:
100
+ print(f"Episode: {i}, Reward: {rewards}")
101
+
102
+ # Only save a game if Mario reached the flag
103
+ if info.get("flag_get", False):
104
+ game_dir = os.path.join(GAMES_ROOT, f"game_{i}")
105
+ os.makedirs(game_dir, exist_ok=True)
106
+
107
+ # unwrap env to get SkipFrame wrapper
108
+ frame_skip_env = env.env.env.env
109
+ frames_log = frame_skip_env.frames_log
110
+ actions_log = frame_skip_env.actions_log
111
+
112
+ for j, (frame_array, action_taken) in enumerate(zip(frames_log, actions_log)):
113
+ # upscale frame for nicer video
114
+ new_dims = (
115
+ frame_array.shape[1] * SCALING_FACTOR,
116
+ frame_array.shape[0] * SCALING_FACTOR,
117
+ )
118
+ frame_img = Image.fromarray(frame_array).resize(
119
+ new_dims, Image.NEAREST
120
+ )
121
+ frame_img.save(os.path.join(game_dir, f"frame_{j}.png"))
122
+
123
+ # save controller overlay frames (optional)
124
+ controllers[action_taken].save(
125
+ os.path.join(game_dir, f"controller_{j}.png")
126
+ )
127
 
128
+ break
129
 
130
+ env.close()
131
+ print("Finished generating frame sequences in 'games/'.")
132
 
 
 
133
 
134
+ if __name__ == "__main__":
135
+ main()