pabloramesc commited on
Commit
8097aca
·
verified ·
1 Parent(s): 584e7e3

Upload unit1.py

Browse files
Files changed (1) hide show
  1. unit1.py +102 -0
unit1.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ # Import required packages
3
+ import gymnasium as gym
4
+
5
+ from huggingface_sb3 import package_to_hub
6
+ from stable_baselines3 import PPO
7
+ from stable_baselines3.common.evaluation import evaluate_policy
8
+ from stable_baselines3.common.monitor import Monitor
9
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
10
+
11
+ # %%
12
+ # Test random environment
13
+ env_id = "LunarLander-v3"
14
+ env = gym.make(env_id)
15
+
16
+ observation, info = env.reset()
17
+
18
+ for _ in range(20):
19
+
20
+ action = env.action_space.sample()
21
+ print("Action taken:", action)
22
+
23
+ observation, reward, terminated, truncated, info = env.step(action)
24
+
25
+ if terminated or truncated:
26
+ print("Environment is reset")
27
+ observation, info = env.reset()
28
+
29
+ env.close()
30
+
31
+
32
+ # %%
33
+ # Check observation and action spaces
34
+ env.reset()
35
+
36
+ print("_____OBSERVATION SPACE_____ \n")
37
+ print("Observation Space Shape", env.observation_space.shape)
38
+ print("Sample observation", env.observation_space.sample()) # Get a random observation
39
+
40
+ print("\n _____ACTION SPACE_____ \n")
41
+ print("Action Space Shape", env.action_space.n)
42
+ print("Action Space Sample", env.action_space.sample()) # Take a random action
43
+
44
+
45
+ # %%
46
+ # Check SB3 model device
47
+ model = PPO("MlpPolicy", env, device="auto")
48
+ print(model.device)
49
+
50
+
51
+ # %%
52
+ # Train PPO agent
53
+ model = PPO(
54
+ policy="MlpPolicy",
55
+ env=env,
56
+ n_steps=1024,
57
+ batch_size=64,
58
+ n_epochs=4,
59
+ gamma=0.999,
60
+ gae_lambda=0.98,
61
+ ent_coef=0.01,
62
+ verbose=1,
63
+ )
64
+ model.learn(total_timesteps=500_000)
65
+
66
+
67
+ # %%
68
+ # Train agent for 1M timesteps
69
+ model.learn(total_timesteps=1_000_000)
70
+ model.save("ppo-lunar-lander")
71
+
72
+ # %%
73
+ # Evaluate the agent
74
+ model = PPO.load("ppo-lunar-lander", env=env)
75
+ eval_env = Monitor(gym.make(env_id))
76
+ mean_reward, std_reward = evaluate_policy(
77
+ model, eval_env, n_eval_episodes=100, deterministic=True
78
+ )
79
+ print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
80
+
81
+ # %%
82
+ # Publish the trained agent
83
+ eval_env = DummyVecEnv(
84
+ [lambda: Monitor(gym.make(env_id, render_mode="rgb_array"))]
85
+ )
86
+ eval_env = VecVideoRecorder(
87
+ eval_env,
88
+ "videos/",
89
+ record_video_trigger=lambda x: x == 0,
90
+ video_length=1000,
91
+ name_prefix="ppo-lunar-lander-demo",
92
+ )
93
+ package_to_hub(
94
+ model=model,
95
+ model_name="ppo-lunar-lander-v2",
96
+ model_architecture="PPO",
97
+ env_id=env_id,
98
+ eval_env=eval_env,
99
+ repo_id="pabloramesc/ppo-lunar-lander-v2",
100
+ commit_message="Upload PPO agent for LunarLander-v2",
101
+ )
102
+ # %%