jatinror commited on
Commit
5ab0a29
·
verified ·
1 Parent(s): f9f2735

Upload train_pixelcopter.py

Browse files
Files changed (1) hide show
  1. train_pixelcopter.py +88 -0
train_pixelcopter.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium import spaces
3
+ import numpy as np
4
+ from stable_baselines3 import PPO
5
+
6
+ # -------------------------
7
+ # Side-Scrolling PixelCopter Environment (Medium/Certification Friendly)
8
+ # -------------------------
9
+ class PixelCopterCertEnv(gym.Env):
10
+ def __init__(self, screen_width=50, screen_height=10, gap_size=6):
11
+ super().__init__()
12
+ self.screen_width = screen_width
13
+ self.screen_height = screen_height
14
+ self.copter_y = self.screen_height // 2
15
+ self.copter_velocity = 0
16
+ self.gravity = 0.25
17
+ self.lift = -0.9
18
+ self.done = False
19
+ self.timestep = 0
20
+ self.max_timesteps = 500
21
+ self.gap_size = gap_size
22
+
23
+ self.wall_gap_positions = [np.random.randint(1, self.screen_height - self.gap_size -1)
24
+ for _ in range(screen_width)]
25
+
26
+ self.action_space = spaces.Discrete(2)
27
+ self.observation_space = spaces.Box(
28
+ low=0, high=self.screen_height, shape=(self.screen_width + 1,), dtype=np.float32
29
+ )
30
+
31
+ def reset(self, seed=None, options=None):
32
+ self.copter_y = self.screen_height // 2
33
+ self.copter_velocity = 0
34
+ self.done = False
35
+ self.timestep = 0
36
+ self.wall_gap_positions = [np.random.randint(1, self.screen_height - self.gap_size -1)
37
+ for _ in range(self.screen_width)]
38
+ obs = np.array([self.copter_y] + self.wall_gap_positions, dtype=np.float32)
39
+ return obs, {}
40
+
41
+ def step(self, action):
42
+ # Apply action
43
+ if action == 1:
44
+ self.copter_velocity += self.lift
45
+ self.copter_velocity += self.gravity
46
+ self.copter_y += self.copter_velocity
47
+ self.copter_y = np.clip(self.copter_y, 0, self.screen_height)
48
+
49
+ # Move walls left
50
+ self.wall_gap_positions = self.wall_gap_positions[1:]
51
+ last_gap = self.wall_gap_positions[-1]
52
+ new_gap = last_gap + np.random.choice([-1,0,1])
53
+ new_gap = np.clip(new_gap, 1, self.screen_height - self.gap_size -1)
54
+ self.wall_gap_positions.append(new_gap)
55
+
56
+ # Check collision with first wall
57
+ gap_top = self.wall_gap_positions[0]
58
+ gap_bottom = gap_top + self.gap_size
59
+ if self.copter_y <= gap_top or self.copter_y >= gap_bottom:
60
+ self.done = True
61
+ reward = -5
62
+ else:
63
+ reward = 1
64
+
65
+ self.timestep += 1
66
+ if self.timestep >= self.max_timesteps:
67
+ self.done = True
68
+
69
+ obs = np.array([self.copter_y] + self.wall_gap_positions, dtype=np.float32)
70
+ return obs, reward, self.done, False, {}
71
+
72
+ # -------------------------
73
+ # Training
74
+ # -------------------------
75
+ env = PixelCopterCertEnv(screen_width=80, screen_height=10, gap_size=6)
76
+ model = PPO("MlpPolicy", env, verbose=1)
77
+
78
+ print("Training started...")
79
+ model.learn(total_timesteps=500_000) # Enough for certification
80
+ print("Training finished!")
81
+
82
+ model.save("ppo_pixelcopter_cert")
83
+ print("Model saved as 'ppo_pixelcopter_cert.zip'")
84
+
85
+
86
+
87
+
88
+