Spaces:
Build error
Build error
| import gym | |
| from stable_baselines3 import DQN | |
| from stable_baselines3.common.evaluation import evaluate_policy | |
| from tetris_env import TetrisEnv | |
| from callbacks import SaveFramesCallback | |
| import os | |
| def main(): | |
| env = TetrisEnv() | |
| model = DQN('MlpPolicy', env, verbose=1, | |
| learning_rate=1e-3, | |
| buffer_size=50000, | |
| learning_starts=1000, | |
| batch_size=32, | |
| gamma=0.99, | |
| target_update_interval=1000, | |
| exploration_fraction=0.1, | |
| exploration_final_eps=0.02) | |
| TIMESTEPS = 550000 | |
| callback = SaveFramesCallback(save_freq=5000, save_path="models/frames", verbose=1) | |
| model.learn(total_timesteps=TIMESTEPS, callback=callback) | |
| os.makedirs("models", exist_ok=True) | |
| model.save("models/dqn_tetris") | |
| print("Model saved to models/dqn_tetris.zip") | |
| mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10) | |
| print(f"Mean Reward: {mean_reward} +/- {std_reward}") | |
| if __name__ == "__main__": | |
| main() | |