Ivan000 commited on
Commit
927c930
·
verified ·
1 Parent(s): 375aee6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -45
app.py CHANGED
@@ -64,25 +64,25 @@ class Brick:
64
  self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
65
 
66
  class ArkanoidEnv(gym.Env):
67
- def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5, inactivity_penalty=-0.5):
68
  super(ArkanoidEnv, self).__init__()
69
  self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
70
- self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(3,), dtype=np.float32)
71
  self.reward_size = reward_size
72
  self.penalty_size = penalty_size
73
  self.platform_reward = platform_reward
74
- self.inactivity_penalty = inactivity_penalty
75
- self.inactivity_counter = 0
76
  self.reset()
77
 
78
  def reset(self, seed=None, options=None):
 
 
 
79
  self.paddle = Paddle()
80
  self.ball = Ball()
81
  self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
82
  for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
83
  self.done = False
84
  self.score = 0
85
- self.inactivity_counter = 0
86
  return self._get_state(), {}
87
 
88
  def step(self, action):
@@ -93,40 +93,46 @@ class ArkanoidEnv(gym.Env):
93
  elif action == 2:
94
  self.paddle.move(1)
95
 
96
- if action == 0:
97
- self.inactivity_counter += 1 / FPS
98
- else:
99
- self.inactivity_counter = 0
100
-
101
- if self.inactivity_counter >= 1:
102
- reward = self.inactivity_penalty
103
- return self._get_state(), reward, self.done, False, {}
104
-
105
  self.ball.move()
106
 
107
  if self.ball.rect.colliderect(self.paddle.rect):
108
  self.ball.velocity[1] = -self.ball.velocity[1]
109
- self.ball.velocity[0] += random.uniform(-1, 1)
110
  self.score += self.platform_reward
111
 
112
  for brick in self.bricks[:]:
113
  if self.ball.rect.colliderect(brick.rect):
114
  self.bricks.remove(brick)
115
  self.ball.velocity[1] = -self.ball.velocity[1]
116
- self.ball.velocity[0] += random.uniform(-1, 1)
117
  self.score += 1
 
118
  if not self.bricks:
 
119
  self.done = True
120
- return self._get_state(), self.reward_size, self.done, False, {}
 
121
 
122
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
123
  self.done = True
124
- return self._get_state(), self.penalty_size, self.done, False, {}
 
 
 
 
125
 
126
- return self._get_state(), 0, self.done, False, {}
127
 
128
  def _get_state(self):
129
- return np.array([self.ball.rect.x, self.paddle.rect.x, len(self.bricks)], dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
130
 
131
  def render(self, mode='rgb_array'):
132
  surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
@@ -145,32 +151,40 @@ class ArkanoidEnv(gym.Env):
145
  def close(self):
146
  pygame.quit()
147
 
148
- # Training and playing function
149
- def train_and_play(reward_size, penalty_size, platform_reward, inactivity_penalty, iterations):
150
- env = ArkanoidEnv(reward_size, penalty_size, platform_reward, inactivity_penalty)
151
- model = DQN("MlpPolicy", env, verbose=0)
152
- model.learn(total_timesteps=iterations)
153
-
154
- obs, _ = env.reset()
155
- frames = []
156
- while True:
157
- action, _states = model.predict(obs)
158
- obs, _, done, _, _ = env.step(action)
159
- frame = env.render(mode="rgb_array")
160
- frames.append(frame)
161
- if done:
162
- break
163
- env.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- video_path = "/tmp/arkanoid.mp4"
166
- out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
167
- for frame in frames:
168
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
169
- out.write(frame)
170
- out.release()
171
  return video_path
172
 
173
- # Gradio interface
174
  def main():
175
  iface = gr.Interface(
176
  fn=train_and_play,
@@ -178,10 +192,10 @@ def main():
178
  gr.Number(label="Reward Size", value=1),
179
  gr.Number(label="Penalty Size", value=-1),
180
  gr.Number(label="Platform Reward", value=5),
181
- gr.Number(label="Inactivity Penalty", value=-0.5),
182
  gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
183
  ],
184
- outputs="video"
 
185
  )
186
  iface.launch()
187
 
 
64
  self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
65
 
66
  class ArkanoidEnv(gym.Env):
67
+ def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5):
68
  super(ArkanoidEnv, self).__init__()
69
  self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
70
+ self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
71
  self.reward_size = reward_size
72
  self.penalty_size = penalty_size
73
  self.platform_reward = platform_reward
 
 
74
  self.reset()
75
 
76
  def reset(self, seed=None, options=None):
77
+ if seed is not None:
78
+ random.seed(seed)
79
+ np.random.seed(seed)
80
  self.paddle = Paddle()
81
  self.ball = Ball()
82
  self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
83
  for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
84
  self.done = False
85
  self.score = 0
 
86
  return self._get_state(), {}
87
 
88
  def step(self, action):
 
93
  elif action == 2:
94
  self.paddle.move(1)
95
 
 
 
 
 
 
 
 
 
 
96
  self.ball.move()
97
 
98
  if self.ball.rect.colliderect(self.paddle.rect):
99
  self.ball.velocity[1] = -self.ball.velocity[1]
 
100
  self.score += self.platform_reward
101
 
102
  for brick in self.bricks[:]:
103
  if self.ball.rect.colliderect(brick.rect):
104
  self.bricks.remove(brick)
105
  self.ball.velocity[1] = -self.ball.velocity[1]
 
106
  self.score += 1
107
+ reward = self.reward_size
108
  if not self.bricks:
109
+ reward += self.reward_size * 10 # Bonus reward for breaking all bricks
110
  self.done = True
111
+ truncated = False
112
+ return self._get_state(), reward, self.done, truncated, {}
113
 
114
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
115
  self.done = True
116
+ reward = self.penalty_size
117
+ truncated = False
118
+ else:
119
+ reward = 0
120
+ truncated = False
121
 
122
+ return self._get_state(), reward, self.done, truncated, {}
123
 
124
  def _get_state(self):
125
+ state = [
126
+ self.paddle.rect.x,
127
+ self.ball.rect.x,
128
+ self.ball.rect.y,
129
+ self.ball.velocity[0],
130
+ self.ball.velocity[1]
131
+ ]
132
+ for brick in self.bricks:
133
+ state.extend([brick.rect.x, brick.rect.y])
134
+ state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) # Padding for missing bricks
135
+ return np.array(state, dtype=np.float32)
136
 
137
  def render(self, mode='rgb_array'):
138
  surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
 
151
  def close(self):
152
  pygame.quit()
153
 
154
+ # Training and playing with custom parameters
155
+ def train_and_play(reward_size, penalty_size, platform_reward, iterations):
156
+ env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size, platform_reward=platform_reward)
157
+ model = DQN('MlpPolicy', env, verbose=1)
158
+ timesteps_per_update = min(1000, iterations)
159
+ video_frames = []
160
+
161
+ completed_iterations = 0
162
+ while completed_iterations < iterations:
163
+ steps = min(timesteps_per_update, iterations - completed_iterations)
164
+ model.learn(total_timesteps=steps)
165
+ completed_iterations += steps
166
+
167
+ obs, _ = env.reset()
168
+ done = False
169
+ while not done:
170
+ action, _states = model.predict(obs, deterministic=True)
171
+ obs, reward, done, truncated, _ = env.step(action)
172
+
173
+ frame = env.render(mode='rgb_array')
174
+ frame = np.rot90(frame)
175
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
176
+ video_frames.append(frame)
177
+
178
+ video_path = "arkanoid_training.mp4"
179
+ video_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
180
+ for frame in video_frames:
181
+ video_writer.write(frame)
182
+ video_writer.release()
183
 
184
+ env.close()
 
 
 
 
 
185
  return video_path
186
 
187
+ # Main function with Gradio interface
188
  def main():
189
  iface = gr.Interface(
190
  fn=train_and_play,
 
192
  gr.Number(label="Reward Size", value=1),
193
  gr.Number(label="Penalty Size", value=-1),
194
  gr.Number(label="Platform Reward", value=5),
 
195
  gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
196
  ],
197
+ outputs="video",
198
+ live=False # Disable auto-generation on slider changes
199
  )
200
  iface.launch()
201