ImaghT commited on
Commit
f56d328
·
verified ·
1 Parent(s): 1b99688

Upload Unit_3_upload.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. Unit_3_upload.py +233 -0
Unit_3_upload.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # 评估模型并上传到 Hugging Face(修复版)
3
+ # ============================================================
4
+
5
+ import gymnasium as gym
6
+ import ale_py
7
+ import numpy as np
8
+ import os
9
+ import shutil
10
+ from stable_baselines3 import DQN
11
+ from stable_baselines3.common.env_util import make_atari_env
12
+ from stable_baselines3.common.vec_env import VecFrameStack
13
+ from huggingface_hub import HfApi, create_repo
14
+
15
+ # ============================================================
16
+ # 配置参数(修改这里)
17
+ # ============================================================
18
+ USERNAME = "ImaghT"
19
+ MODEL_NAME = "dqn-SpaceInvadersNoFrameskip-v4"
20
+ MODEL_FILE = "dqn-SpaceInvaders.zip"
21
+ ENV_ID = "ALE/SpaceInvaders-v5"
22
+ N_EVAL_EPISODES = 10
23
+
24
+ repo_id = f"{USERNAME}/{MODEL_NAME}"
25
+
26
+ # ============================================================
27
+ # 1. 注册 ALE 环境
28
+ # ============================================================
29
+ print("Registering ALE environments...")
30
+ gym.register_envs(ale_py)
31
+ print("✅ Environments registered\n")
32
+
33
+ # ============================================================
34
+ # 2. 加载训练好的模型
35
+ # ============================================================
36
+ print("Loading trained model...")
37
+ if not os.path.exists(MODEL_FILE):
38
+ print(f"❌ Error: Model file '{MODEL_FILE}' not found!")
39
+ exit(1)
40
+
41
+ model = DQN.load(MODEL_FILE)
42
+ print(f"✅ Model loaded from {MODEL_FILE}\n")
43
+
44
+ # ============================================================
45
+ # 3. 创建评估环境
46
+ # ============================================================
47
+ print("Creating evaluation environment...")
48
+ eval_env = make_atari_env(ENV_ID, n_envs=1, seed=42)
49
+ eval_env = VecFrameStack(eval_env, n_stack=4)
50
+ print("✅ Evaluation environment ready\n")
51
+
52
+ # ============================================================
53
+ # 4. 运行评估(正确方法)
54
+ # ============================================================
55
+ print("="*60)
56
+ print(f"Starting Evaluation ({N_EVAL_EPISODES} episodes)...")
57
+ print("="*60)
58
+
59
+ episode_rewards = []
60
+ episode_lengths = []
61
+
62
+ obs = eval_env.reset()
63
+ current_episode = 0
64
+
65
+ # 持续运行直到收集到足够的 episode
66
+ while len(episode_rewards) < N_EVAL_EPISODES:
67
+ action, _states = model.predict(obs, deterministic=True)
68
+ obs, reward, done, info = eval_env.step(action)
69
+
70
+ # 🔥 关键修复:在 VecEnv 中,episode 结束时真实 reward 在 info 中
71
+ if done[0]:
72
+ # info 是一个列表,info[0] 包含第一个环境的信息
73
+ if 'episode' in info[0]:
74
+ ep_reward = info[0]['episode']['r']
75
+ ep_length = info[0]['episode']['l']
76
+ episode_rewards.append(ep_reward)
77
+ episode_lengths.append(ep_length)
78
+ print(f"Episode {len(episode_rewards)}/{N_EVAL_EPISODES}: "
79
+ f"Reward = {ep_reward:.2f}, Length = {ep_length}")
80
+
81
+ # ============================================================
82
+ # 5. 计算统计数据
83
+ # ============================================================
84
+ mean_reward = np.mean(episode_rewards)
85
+ std_reward = np.std(episode_rewards)
86
+ min_reward = np.min(episode_rewards)
87
+ max_reward = np.max(episode_rewards)
88
+ mean_length = np.mean(episode_lengths)
89
+ score = mean_reward - std_reward
90
+
91
+ print("\n" + "="*60)
92
+ print("Evaluation Results:")
93
+ print(f" Mean Reward: {mean_reward:.2f}")
94
+ print(f" Std Reward: {std_reward:.2f}")
95
+ print(f" Min Reward: {min_reward:.2f}")
96
+ print(f" Max Reward: {max_reward:.2f}")
97
+ print(f" Mean Length: {mean_length:.2f}")
98
+ print(f" Score (mean - std): {score:.2f}")
99
+ print(f" Baseline Required: 200.0")
100
+ if score >= 200:
101
+ print(f" Status: ✅ PASSED")
102
+ else:
103
+ print(f" Status: ❌ NOT PASSED (need {200 - score:.2f} more points)")
104
+ print("="*60 + "\n")
105
+
106
+ # ============================================================
107
+ # 6. 创建 README.md
108
+ # ============================================================
109
+ readme_content = f"""---
110
+ library_name: stable-baselines3
111
+ tags:
112
+ - SpaceInvadersNoFrameskip-v4
113
+ - deep-reinforcement-learning
114
+ - reinforcement-learning
115
+ - stable-baselines3
116
+ model-index:
117
+ - name: DQN
118
+ results:
119
+ - task:
120
+ type: reinforcement-learning
121
+ name: reinforcement-learning
122
+ dataset:
123
+ name: SpaceInvadersNoFrameskip-v4
124
+ type: SpaceInvadersNoFrameskip-v4
125
+ metrics:
126
+ - type: mean_reward
127
+ value: {mean_reward:.2f} +/- {std_reward:.2f}
128
+ name: mean_reward
129
+ verified: false
130
+ ---
131
+
132
+ # **DQN** Agent playing **SpaceInvadersNoFrameskip-v4**
133
+
134
+ This is a trained model of a **DQN** agent playing **SpaceInvadersNoFrameskip-v4**
135
+ using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3)
136
+ and the [Deep Reinforcement Learning Course](https://huggingface.co/deep-rl-course/unit3).
137
+
138
+ ## Evaluation Results
139
+
140
+ | Metric | Value |
141
+ |--------|-------|
142
+ | Mean Reward | {mean_reward:.2f} |
143
+ | Std Reward | {std_reward:.2f} |
144
+ | Min Reward | {min_reward:.2f} |
145
+ | Max Reward | {max_reward:.2f} |
146
+ | Mean Episode Length | {mean_length:.2f} |
147
+ | Score (mean - std) | {score:.2f} |
148
+ | Evaluation Episodes | {N_EVAL_EPISODES} |
149
+
150
+ ## Usage
151
+
152
+ ```python
153
+ from stable_baselines3 import DQN
154
+ from stable_baselines3.common.env_util import make_atari_env
155
+ from stable_baselines3.common.vec_env import VecFrameStack
156
+ import gymnasium as gym
157
+ import ale_py
158
+
159
+ gym.register_envs(ale_py)
160
+
161
+ env = make_atari_env("ALE/SpaceInvaders-v5", n_envs=1, seed=0)
162
+ env = VecFrameStack(env, n_stack=4)
163
+
164
+ model = DQN.load("dqn-SpaceInvaders")
165
+
166
+ obs = env.reset()
167
+ for i in range(1000):
168
+ action, _states = model.predict(obs, deterministic=True)
169
+ obs, reward, done, info = env.step(action)
170
+ if done:
171
+ obs = env.reset()
172
+
173
+ ```
174
+ ## Training Configuration
175
+
176
+ - **Algorithm**: DQN (Deep Q-Network)
177
+ - **Policy**: CnnPolicy
178
+ - **Total Timesteps**: 10,000,000
179
+ - **Learning Rate**: 1e-4
180
+ - **Buffer Size**: 200,000
181
+ - **Batch Size**: 32
182
+ - **Device**: CUDA
183
+ """
184
+
185
+ # ============================================================
186
+ # 7. 准备上传文件
187
+ # ============================================================
188
+ print("Preparing files for upload...")
189
+ upload_folder = "./upload_temp"
190
+ os.makedirs(upload_folder, exist_ok=True)
191
+
192
+ readme_path = os.path.join(upload_folder, "README.md")
193
+ with open(readme_path, "w", encoding="utf-8") as f:
194
+ f.write(readme_content)
195
+ print(f"✅ Created README.md")
196
+
197
+ model_dest = os.path.join(upload_folder, MODEL_FILE)
198
+ shutil.copy(MODEL_FILE, model_dest)
199
+ print(f"✅ Copied {MODEL_FILE}\n")
200
+
201
+ # ============================================================
202
+ # 8. 上传到 Hugging Face
203
+ # ============================================================
204
+ print(f"Uploading to {repo_id}...")
205
+
206
+ api = HfApi()
207
+
208
+ try:
209
+ create_repo(repo_id, repo_type="model", exist_ok=True)
210
+ print(f"✅ Repository created/verified")
211
+ except Exception as e:
212
+ print(f"⚠️ Repository warning: {e}")
213
+
214
+ try:
215
+ api.upload_folder(
216
+ folder_path=upload_folder,
217
+ repo_id=repo_id,
218
+ repo_type="model",
219
+ commit_message=f"DQN SpaceInvaders - Mean: {mean_reward:.2f}, Std: {std_reward:.2f}"
220
+ )
221
+ print(f"\n{'='*60}")
222
+ print("✅ Upload Successful!")
223
+ print(f"{'='*60}")
224
+ print(f"🔗 Model Page: https://huggingface.co/{repo_id}")
225
+ print(f"🏆 Check Progress: https://huggingface.co/spaces/ThomasSimonini/Check-my-progress-Deep-RL-Course")
226
+ print(f"{'='*60}\n")
227
+ except Exception as e:
228
+ print(f"\n❌ Upload failed: {e}\n")
229
+ finally:
230
+ shutil.rmtree(upload_folder)
231
+ print("🧹 Cleaned up temporary files")
232
+
233
+ print("✨ Done!")