ImaghT's picture
Upload Unit_3_upload.py with huggingface_hub
f56d328 verified
# ============================================================
# 评估模型并上传到 Hugging Face(修复版)
# ============================================================
import gymnasium as gym
import ale_py
import numpy as np
import os
import shutil
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from huggingface_hub import HfApi, create_repo
# ============================================================
# 配置参数(修改这里)
# ============================================================
USERNAME = "ImaghT"
MODEL_NAME = "dqn-SpaceInvadersNoFrameskip-v4"
MODEL_FILE = "dqn-SpaceInvaders.zip"
ENV_ID = "ALE/SpaceInvaders-v5"
N_EVAL_EPISODES = 10
repo_id = f"{USERNAME}/{MODEL_NAME}"
# ============================================================
# 1. 注册 ALE 环境
# ============================================================
print("Registering ALE environments...")
gym.register_envs(ale_py)
print("✅ Environments registered\n")
# ============================================================
# 2. 加载训练好的模型
# ============================================================
print("Loading trained model...")
if not os.path.exists(MODEL_FILE):
print(f"❌ Error: Model file '{MODEL_FILE}' not found!")
exit(1)
model = DQN.load(MODEL_FILE)
print(f"✅ Model loaded from {MODEL_FILE}\n")
# ============================================================
# 3. 创建评估环境
# ============================================================
print("Creating evaluation environment...")
eval_env = make_atari_env(ENV_ID, n_envs=1, seed=42)
eval_env = VecFrameStack(eval_env, n_stack=4)
print("✅ Evaluation environment ready\n")
# ============================================================
# 4. 运行评估(正确方法)
# ============================================================
print("="*60)
print(f"Starting Evaluation ({N_EVAL_EPISODES} episodes)...")
print("="*60)
episode_rewards = []
episode_lengths = []
obs = eval_env.reset()
current_episode = 0
# 持续运行直到收集到足够的 episode
while len(episode_rewards) < N_EVAL_EPISODES:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = eval_env.step(action)
# 🔥 关键修复:在 VecEnv 中,episode 结束时真实 reward 在 info 中
if done[0]:
# info 是一个列表,info[0] 包含第一个环境的信息
if 'episode' in info[0]:
ep_reward = info[0]['episode']['r']
ep_length = info[0]['episode']['l']
episode_rewards.append(ep_reward)
episode_lengths.append(ep_length)
print(f"Episode {len(episode_rewards)}/{N_EVAL_EPISODES}: "
f"Reward = {ep_reward:.2f}, Length = {ep_length}")
# ============================================================
# 5. 计算统计数据
# ============================================================
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
min_reward = np.min(episode_rewards)
max_reward = np.max(episode_rewards)
mean_length = np.mean(episode_lengths)
score = mean_reward - std_reward
print("\n" + "="*60)
print("Evaluation Results:")
print(f" Mean Reward: {mean_reward:.2f}")
print(f" Std Reward: {std_reward:.2f}")
print(f" Min Reward: {min_reward:.2f}")
print(f" Max Reward: {max_reward:.2f}")
print(f" Mean Length: {mean_length:.2f}")
print(f" Score (mean - std): {score:.2f}")
print(f" Baseline Required: 200.0")
if score >= 200:
print(f" Status: ✅ PASSED")
else:
print(f" Status: ❌ NOT PASSED (need {200 - score:.2f} more points)")
print("="*60 + "\n")
# ============================================================
# 6. 创建 README.md
# ============================================================
readme_content = f"""---
library_name: stable-baselines3
tags:
- SpaceInvadersNoFrameskip-v4
- deep-reinforcement-learning
- reinforcement-learning
- stable-baselines3
model-index:
- name: DQN
results:
- task:
type: reinforcement-learning
name: reinforcement-learning
dataset:
name: SpaceInvadersNoFrameskip-v4
type: SpaceInvadersNoFrameskip-v4
metrics:
- type: mean_reward
value: {mean_reward:.2f} +/- {std_reward:.2f}
name: mean_reward
verified: false
---
# **DQN** Agent playing **SpaceInvadersNoFrameskip-v4**
This is a trained model of a **DQN** agent playing **SpaceInvadersNoFrameskip-v4**
using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3)
and the [Deep Reinforcement Learning Course](https://huggingface.co/deep-rl-course/unit3).
## Evaluation Results
| Metric | Value |
|--------|-------|
| Mean Reward | {mean_reward:.2f} |
| Std Reward | {std_reward:.2f} |
| Min Reward | {min_reward:.2f} |
| Max Reward | {max_reward:.2f} |
| Mean Episode Length | {mean_length:.2f} |
| Score (mean - std) | {score:.2f} |
| Evaluation Episodes | {N_EVAL_EPISODES} |
## Usage
```python
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)
env = make_atari_env("ALE/SpaceInvaders-v5", n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)
model = DQN.load("dqn-SpaceInvaders")
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
if done:
obs = env.reset()
```
## Training Configuration
- **Algorithm**: DQN (Deep Q-Network)
- **Policy**: CnnPolicy
- **Total Timesteps**: 10,000,000
- **Learning Rate**: 1e-4
- **Buffer Size**: 200,000
- **Batch Size**: 32
- **Device**: CUDA
"""
# ============================================================
# 7. 准备上传文件
# ============================================================
print("Preparing files for upload...")
upload_folder = "./upload_temp"
os.makedirs(upload_folder, exist_ok=True)
readme_path = os.path.join(upload_folder, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)
print(f"✅ Created README.md")
model_dest = os.path.join(upload_folder, MODEL_FILE)
shutil.copy(MODEL_FILE, model_dest)
print(f"✅ Copied {MODEL_FILE}\n")
# ============================================================
# 8. 上传到 Hugging Face
# ============================================================
print(f"Uploading to {repo_id}...")
api = HfApi()
try:
create_repo(repo_id, repo_type="model", exist_ok=True)
print(f"✅ Repository created/verified")
except Exception as e:
print(f"⚠️ Repository warning: {e}")
try:
api.upload_folder(
folder_path=upload_folder,
repo_id=repo_id,
repo_type="model",
commit_message=f"DQN SpaceInvaders - Mean: {mean_reward:.2f}, Std: {std_reward:.2f}"
)
print(f"\n{'='*60}")
print("✅ Upload Successful!")
print(f"{'='*60}")
print(f"🔗 Model Page: https://huggingface.co/{repo_id}")
print(f"🏆 Check Progress: https://huggingface.co/spaces/ThomasSimonini/Check-my-progress-Deep-RL-Course")
print(f"{'='*60}\n")
except Exception as e:
print(f"\n❌ Upload failed: {e}\n")
finally:
shutil.rmtree(upload_folder)
print("🧹 Cleaned up temporary files")
print("✨ Done!")