Update app.py
Browse files
app.py
CHANGED
|
@@ -106,14 +106,18 @@ class ArkanoidEnv(gym.Env):
|
|
| 106 |
self.bricks.remove(brick)
|
| 107 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
| 108 |
self.score += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
| 111 |
self.done = True
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
if not self.bricks:
|
| 114 |
-
self.done = True
|
| 115 |
-
|
| 116 |
-
reward = 1 if self.score > 0 else -1
|
| 117 |
return self._get_state(), reward, self.done, {}
|
| 118 |
|
| 119 |
def _get_state(self):
|
|
@@ -142,16 +146,14 @@ class ArkanoidEnv(gym.Env):
|
|
| 142 |
pygame.quit()
|
| 143 |
|
| 144 |
# Training function
|
| 145 |
-
def train_model():
|
| 146 |
-
env = ArkanoidEnv()
|
| 147 |
model = DQN('MlpPolicy', env, verbose=1)
|
| 148 |
-
model.learn(total_timesteps=
|
| 149 |
model.save("arkanoid_model")
|
| 150 |
return model
|
| 151 |
|
| 152 |
# Evaluation function
|
| 153 |
-
def evaluate_model(model):
|
| 154 |
-
env = ArkanoidEnv()
|
| 155 |
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
|
| 156 |
return mean_reward
|
| 157 |
|
|
@@ -170,22 +172,33 @@ def play_game():
|
|
| 170 |
frames.append(gr.Image(value="frame.png"))
|
| 171 |
return frames
|
| 172 |
|
| 173 |
-
#
|
| 174 |
-
def
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
|
|
|
|
|
|
| 186 |
# Gradio interface
|
| 187 |
iface = gr.Interface(
|
| 188 |
-
fn=
|
| 189 |
inputs=None,
|
| 190 |
outputs="image",
|
| 191 |
live=True
|
|
|
|
| 106 |
self.bricks.remove(brick)
|
| 107 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
| 108 |
self.score += 1
|
| 109 |
+
reward = 1
|
| 110 |
+
if not self.bricks:
|
| 111 |
+
reward += 10 # Bonus reward for breaking all bricks
|
| 112 |
+
self.done = True
|
| 113 |
+
return self._get_state(), reward, self.done, {}
|
| 114 |
|
| 115 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
| 116 |
self.done = True
|
| 117 |
+
reward = -1
|
| 118 |
+
else:
|
| 119 |
+
reward = 0
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
return self._get_state(), reward, self.done, {}
|
| 122 |
|
| 123 |
def _get_state(self):
|
|
|
|
| 146 |
pygame.quit()
|
| 147 |
|
| 148 |
# Training function
|
| 149 |
+
def train_model(env, total_timesteps=10000):
|
|
|
|
| 150 |
model = DQN('MlpPolicy', env, verbose=1)
|
| 151 |
+
model.learn(total_timesteps=total_timesteps)
|
| 152 |
model.save("arkanoid_model")
|
| 153 |
return model
|
| 154 |
|
| 155 |
# Evaluation function
|
| 156 |
+
def evaluate_model(model, env):
|
|
|
|
| 157 |
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
|
| 158 |
return mean_reward
|
| 159 |
|
|
|
|
| 172 |
frames.append(gr.Image(value="frame.png"))
|
| 173 |
return frames
|
| 174 |
|
| 175 |
+
# Real-time training function
|
| 176 |
+
def train_and_play():
|
| 177 |
+
env = ArkanoidEnv()
|
| 178 |
+
model = DQN('MlpPolicy', env, verbose=1)
|
| 179 |
+
total_timesteps = 10000
|
| 180 |
+
timesteps_per_update = 1000
|
| 181 |
+
frames = []
|
| 182 |
|
| 183 |
+
for i in range(0, total_timesteps, timesteps_per_update):
|
| 184 |
+
model.learn(total_timesteps=timesteps_per_update)
|
| 185 |
+
obs = env.reset()[0]
|
| 186 |
+
done = False
|
| 187 |
+
episode_frames = []
|
| 188 |
+
while not done:
|
| 189 |
+
action, _states = model.predict(obs, deterministic=True)
|
| 190 |
+
obs, reward, done, info = env.step(action)
|
| 191 |
+
env.render()
|
| 192 |
+
pygame.image.save(screen, "frame.png")
|
| 193 |
+
episode_frames.append(gr.Image(value="frame.png"))
|
| 194 |
+
frames.extend(episode_frames)
|
| 195 |
+
yield frames
|
| 196 |
|
| 197 |
+
# Main function
|
| 198 |
+
def main():
|
| 199 |
# Gradio interface
|
| 200 |
iface = gr.Interface(
|
| 201 |
+
fn=train_and_play,
|
| 202 |
inputs=None,
|
| 203 |
outputs="image",
|
| 204 |
live=True
|