Update app.py
Browse files
app.py
CHANGED
|
@@ -162,28 +162,12 @@ def evaluate_model(model, env):
|
|
| 162 |
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
|
| 163 |
return mean_reward
|
| 164 |
|
| 165 |
-
# Gradio interface
|
| 166 |
-
def play_game():
|
| 167 |
-
env = ArkanoidEnv()
|
| 168 |
-
model = DQN.load("arkanoid_model")
|
| 169 |
-
obs = env.reset()[0]
|
| 170 |
-
done = False
|
| 171 |
-
frames = []
|
| 172 |
-
while not done:
|
| 173 |
-
action, _states = model.predict(obs, deterministic=True)
|
| 174 |
-
obs, reward, done, truncated, info = env.step(action)
|
| 175 |
-
env.render()
|
| 176 |
-
pygame.image.save(screen, "frame.png")
|
| 177 |
-
frames.append(gr.Image(value="frame.png"))
|
| 178 |
-
return frames
|
| 179 |
-
|
| 180 |
# Real-time training function
|
| 181 |
def train_and_play():
|
| 182 |
env = ArkanoidEnv()
|
| 183 |
model = DQN('MlpPolicy', env, verbose=1)
|
| 184 |
total_timesteps = 10000
|
| 185 |
timesteps_per_update = 1000
|
| 186 |
-
frames = []
|
| 187 |
video_frames = []
|
| 188 |
|
| 189 |
for i in range(0, total_timesteps, timesteps_per_update):
|
|
@@ -191,7 +175,6 @@ def train_and_play():
|
|
| 191 |
obs = env.reset()[0]
|
| 192 |
done = False
|
| 193 |
truncated = False
|
| 194 |
-
episode_frames = []
|
| 195 |
while not done and not truncated:
|
| 196 |
action, _states = model.predict(obs, deterministic=True)
|
| 197 |
obs, reward, done, truncated, info = env.step(action)
|
|
@@ -200,9 +183,6 @@ def train_and_play():
|
|
| 200 |
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
| 201 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 202 |
video_frames.append(frame)
|
| 203 |
-
episode_frames.append(gr.Image(value="frame.png"))
|
| 204 |
-
frames.extend(episode_frames)
|
| 205 |
-
yield frames
|
| 206 |
|
| 207 |
# Save the video
|
| 208 |
video_path = "arkanoid_training.mp4"
|
|
@@ -213,7 +193,7 @@ def train_and_play():
|
|
| 213 |
video_writer.release()
|
| 214 |
|
| 215 |
# Return the video path
|
| 216 |
-
return
|
| 217 |
|
| 218 |
# Main function
|
| 219 |
def main():
|
|
|
|
| 162 |
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
|
| 163 |
return mean_reward
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
# Real-time training function
|
| 166 |
def train_and_play():
|
| 167 |
env = ArkanoidEnv()
|
| 168 |
model = DQN('MlpPolicy', env, verbose=1)
|
| 169 |
total_timesteps = 10000
|
| 170 |
timesteps_per_update = 1000
|
|
|
|
| 171 |
video_frames = []
|
| 172 |
|
| 173 |
for i in range(0, total_timesteps, timesteps_per_update):
|
|
|
|
| 175 |
obs = env.reset()[0]
|
| 176 |
done = False
|
| 177 |
truncated = False
|
|
|
|
| 178 |
while not done and not truncated:
|
| 179 |
action, _states = model.predict(obs, deterministic=True)
|
| 180 |
obs, reward, done, truncated, info = env.step(action)
|
|
|
|
| 183 |
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
| 184 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 185 |
video_frames.append(frame)
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
# Save the video
|
| 188 |
video_path = "arkanoid_training.mp4"
|
|
|
|
| 193 |
video_writer.release()
|
| 194 |
|
| 195 |
# Return the video path
|
| 196 |
+
return video_path
|
| 197 |
|
| 198 |
# Main function
|
| 199 |
def main():
|