|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gymnasium as gym |
|
|
import ale_py |
|
|
import numpy as np |
|
|
import os |
|
|
import shutil |
|
|
from stable_baselines3 import DQN |
|
|
from stable_baselines3.common.env_util import make_atari_env |
|
|
from stable_baselines3.common.vec_env import VecFrameStack |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USERNAME = "ImaghT" |
|
|
MODEL_NAME = "dqn-SpaceInvadersNoFrameskip-v4" |
|
|
MODEL_FILE = "dqn-SpaceInvaders.zip" |
|
|
ENV_ID = "ALE/SpaceInvaders-v5" |
|
|
N_EVAL_EPISODES = 10 |
|
|
|
|
|
repo_id = f"{USERNAME}/{MODEL_NAME}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Registering ALE environments...") |
|
|
gym.register_envs(ale_py) |
|
|
print("✅ Environments registered\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading trained model...") |
|
|
if not os.path.exists(MODEL_FILE): |
|
|
print(f"❌ Error: Model file '{MODEL_FILE}' not found!") |
|
|
exit(1) |
|
|
|
|
|
model = DQN.load(MODEL_FILE) |
|
|
print(f"✅ Model loaded from {MODEL_FILE}\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Creating evaluation environment...") |
|
|
eval_env = make_atari_env(ENV_ID, n_envs=1, seed=42) |
|
|
eval_env = VecFrameStack(eval_env, n_stack=4) |
|
|
print("✅ Evaluation environment ready\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("="*60) |
|
|
print(f"Starting Evaluation ({N_EVAL_EPISODES} episodes)...") |
|
|
print("="*60) |
|
|
|
|
|
episode_rewards = [] |
|
|
episode_lengths = [] |
|
|
|
|
|
obs = eval_env.reset() |
|
|
current_episode = 0 |
|
|
|
|
|
|
|
|
while len(episode_rewards) < N_EVAL_EPISODES: |
|
|
action, _states = model.predict(obs, deterministic=True) |
|
|
obs, reward, done, info = eval_env.step(action) |
|
|
|
|
|
|
|
|
if done[0]: |
|
|
|
|
|
if 'episode' in info[0]: |
|
|
ep_reward = info[0]['episode']['r'] |
|
|
ep_length = info[0]['episode']['l'] |
|
|
episode_rewards.append(ep_reward) |
|
|
episode_lengths.append(ep_length) |
|
|
print(f"Episode {len(episode_rewards)}/{N_EVAL_EPISODES}: " |
|
|
f"Reward = {ep_reward:.2f}, Length = {ep_length}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mean_reward = np.mean(episode_rewards) |
|
|
std_reward = np.std(episode_rewards) |
|
|
min_reward = np.min(episode_rewards) |
|
|
max_reward = np.max(episode_rewards) |
|
|
mean_length = np.mean(episode_lengths) |
|
|
score = mean_reward - std_reward |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Evaluation Results:") |
|
|
print(f" Mean Reward: {mean_reward:.2f}") |
|
|
print(f" Std Reward: {std_reward:.2f}") |
|
|
print(f" Min Reward: {min_reward:.2f}") |
|
|
print(f" Max Reward: {max_reward:.2f}") |
|
|
print(f" Mean Length: {mean_length:.2f}") |
|
|
print(f" Score (mean - std): {score:.2f}") |
|
|
print(f" Baseline Required: 200.0") |
|
|
if score >= 200: |
|
|
print(f" Status: ✅ PASSED") |
|
|
else: |
|
|
print(f" Status: ❌ NOT PASSED (need {200 - score:.2f} more points)") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
readme_content = f"""--- |
|
|
library_name: stable-baselines3 |
|
|
tags: |
|
|
- SpaceInvadersNoFrameskip-v4 |
|
|
- deep-reinforcement-learning |
|
|
- reinforcement-learning |
|
|
- stable-baselines3 |
|
|
model-index: |
|
|
- name: DQN |
|
|
results: |
|
|
- task: |
|
|
type: reinforcement-learning |
|
|
name: reinforcement-learning |
|
|
dataset: |
|
|
name: SpaceInvadersNoFrameskip-v4 |
|
|
type: SpaceInvadersNoFrameskip-v4 |
|
|
metrics: |
|
|
- type: mean_reward |
|
|
value: {mean_reward:.2f} +/- {std_reward:.2f} |
|
|
name: mean_reward |
|
|
verified: false |
|
|
--- |
|
|
|
|
|
# **DQN** Agent playing **SpaceInvadersNoFrameskip-v4** |
|
|
|
|
|
This is a trained model of a **DQN** agent playing **SpaceInvadersNoFrameskip-v4** |
|
|
using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3) |
|
|
and the [Deep Reinforcement Learning Course](https://huggingface.co/deep-rl-course/unit3). |
|
|
|
|
|
## Evaluation Results |
|
|
|
|
|
| Metric | Value | |
|
|
|--------|-------| |
|
|
| Mean Reward | {mean_reward:.2f} | |
|
|
| Std Reward | {std_reward:.2f} | |
|
|
| Min Reward | {min_reward:.2f} | |
|
|
| Max Reward | {max_reward:.2f} | |
|
|
| Mean Episode Length | {mean_length:.2f} | |
|
|
| Score (mean - std) | {score:.2f} | |
|
|
| Evaluation Episodes | {N_EVAL_EPISODES} | |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
from stable_baselines3 import DQN |
|
|
from stable_baselines3.common.env_util import make_atari_env |
|
|
from stable_baselines3.common.vec_env import VecFrameStack |
|
|
import gymnasium as gym |
|
|
import ale_py |
|
|
|
|
|
gym.register_envs(ale_py) |
|
|
|
|
|
env = make_atari_env("ALE/SpaceInvaders-v5", n_envs=1, seed=0) |
|
|
env = VecFrameStack(env, n_stack=4) |
|
|
|
|
|
model = DQN.load("dqn-SpaceInvaders") |
|
|
|
|
|
obs = env.reset() |
|
|
for i in range(1000): |
|
|
action, _states = model.predict(obs, deterministic=True) |
|
|
obs, reward, done, info = env.step(action) |
|
|
if done: |
|
|
obs = env.reset() |
|
|
|
|
|
``` |
|
|
## Training Configuration |
|
|
|
|
|
- **Algorithm**: DQN (Deep Q-Network) |
|
|
- **Policy**: CnnPolicy |
|
|
- **Total Timesteps**: 10,000,000 |
|
|
- **Learning Rate**: 1e-4 |
|
|
- **Buffer Size**: 200,000 |
|
|
- **Batch Size**: 32 |
|
|
- **Device**: CUDA |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Preparing files for upload...") |
|
|
upload_folder = "./upload_temp" |
|
|
os.makedirs(upload_folder, exist_ok=True) |
|
|
|
|
|
readme_path = os.path.join(upload_folder, "README.md") |
|
|
with open(readme_path, "w", encoding="utf-8") as f: |
|
|
f.write(readme_content) |
|
|
print(f"✅ Created README.md") |
|
|
|
|
|
model_dest = os.path.join(upload_folder, MODEL_FILE) |
|
|
shutil.copy(MODEL_FILE, model_dest) |
|
|
print(f"✅ Copied {MODEL_FILE}\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Uploading to {repo_id}...") |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
try: |
|
|
create_repo(repo_id, repo_type="model", exist_ok=True) |
|
|
print(f"✅ Repository created/verified") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Repository warning: {e}") |
|
|
|
|
|
try: |
|
|
api.upload_folder( |
|
|
folder_path=upload_folder, |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"DQN SpaceInvaders - Mean: {mean_reward:.2f}, Std: {std_reward:.2f}" |
|
|
) |
|
|
print(f"\n{'='*60}") |
|
|
print("✅ Upload Successful!") |
|
|
print(f"{'='*60}") |
|
|
print(f"🔗 Model Page: https://huggingface.co/{repo_id}") |
|
|
print(f"🏆 Check Progress: https://huggingface.co/spaces/ThomasSimonini/Check-my-progress-Deep-RL-Course") |
|
|
print(f"{'='*60}\n") |
|
|
except Exception as e: |
|
|
print(f"\n❌ Upload failed: {e}\n") |
|
|
finally: |
|
|
shutil.rmtree(upload_folder) |
|
|
print("🧹 Cleaned up temporary files") |
|
|
|
|
|
print("✨ Done!") |
|
|
|