trioskosmos commited on
Commit
5ff442a
·
verified ·
1 Parent(s): e4accf1

Upload ai/training/train_gpu_workers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/training/train_gpu_workers.py +194 -0
ai/training/train_gpu_workers.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.multiprocessing as mp
7
+
8
+ # Ensure project root is in path
9
+ sys.path.append(os.getcwd())
10
+
11
+ from ai.vec_env_adapter import VectorEnvAdapter
12
+ from stable_baselines3 import PPO
13
+ from stable_baselines3.common.vec_env import VecEnv
14
+
15
+
16
+ # Worker function to run in a separate process
17
+ def worker_process(remote, parent_remote, num_envs):
18
+ parent_remote.close()
19
+
20
+ # Initialize the Numba-optimized vector environment
21
+ env = VectorEnvAdapter(num_envs=num_envs)
22
+
23
+ try:
24
+ while True:
25
+ cmd, data = remote.recv()
26
+ if cmd == "step":
27
+ # data is actions
28
+ obs, rewards, dones, infos = env.step(data)
29
+ remote.send((obs, rewards, dones, infos))
30
+ elif cmd == "reset":
31
+ obs = env.reset()
32
+ remote.send(obs)
33
+ elif cmd == "close":
34
+ env.close()
35
+ remote.close()
36
+ break
37
+ elif cmd == "get_attr":
38
+ remote.send(getattr(env, data))
39
+ else:
40
+ raise NotImplementedError(f"Worker received unknown command: {cmd}")
41
+ except KeyboardInterrupt:
42
+ print("Worker interrupt.")
43
+ finally:
44
+ env.close()
45
+
46
+
47
+ class DistributedVectorEnv(VecEnv):
48
+ """
49
+ A distributed Vector Environment that manages multiple worker processes,
50
+ each running a Numba-optimized VectorEnvAdapter.
51
+
52
+ Structure:
53
+ Main Process (PPO) -> DistributedVectorEnv
54
+ -> Worker Process 1 -> VectorEnvAdapter (N=1024) -> Numba
55
+ -> Worker Process 2 -> VectorEnvAdapter (N=1024) -> Numba
56
+ ...
57
+ """
58
+
59
+ def __init__(self, num_workers: int, envs_per_worker: int):
60
+ self.num_workers = num_workers
61
+ self.envs_per_worker = envs_per_worker
62
+ self.total_envs = num_workers * envs_per_worker
63
+
64
+ # Define spaces (assuming consistent across all envs)
65
+ # We create a dummy adapter just to get the spaces
66
+ dummy = VectorEnvAdapter(num_envs=1)
67
+ observation_space = dummy.observation_space
68
+ action_space = dummy.action_space
69
+ dummy.close()
70
+ del dummy
71
+
72
+ super().__init__(self.total_envs, observation_space, action_space)
73
+
74
+ self.closed = False
75
+ self.waiting = False
76
+ self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in range(num_workers)])
77
+ self.processes = []
78
+
79
+ for work_remote, remote in zip(self.work_remotes, self.remotes):
80
+ p = mp.Process(target=worker_process, args=(work_remote, remote, envs_per_worker))
81
+ p.daemon = True # Kill if main process dies
82
+ p.start()
83
+ self.processes.append(p)
84
+ work_remote.close()
85
+
86
+ def step_async(self, actions):
87
+ # Split actions into chunks for each worker
88
+ chunks = np.array_split(actions, self.num_workers)
89
+ for remote, action_chunk in zip(self.remotes, chunks):
90
+ remote.send(("step", action_chunk))
91
+ self.waiting = True
92
+
93
+ def step_wait(self):
94
+ results = [remote.recv() for remote in self.remotes]
95
+ self.waiting = False
96
+
97
+ # Aggregate results
98
+ obs_list, rews_list, dones_list, infos_list = zip(*results)
99
+
100
+ return (
101
+ np.concatenate(obs_list),
102
+ np.concatenate(rews_list),
103
+ np.concatenate(dones_list),
104
+ # Infos are lists of dicts, so we just add them
105
+ sum(infos_list, []),
106
+ )
107
+
108
+ def reset(self):
109
+ for remote in self.remotes:
110
+ remote.send(("reset", None))
111
+
112
+ results = [remote.recv() for remote in self.remotes]
113
+ return np.concatenate(results)
114
+
115
+ def close(self):
116
+ if self.closed:
117
+ return
118
+ if self.waiting:
119
+ for remote in self.remotes:
120
+ remote.recv()
121
+ for remote in self.remotes:
122
+ remote.send(("close", None))
123
+ for p in self.processes:
124
+ p.join()
125
+ self.closed = True
126
+
127
+ def get_attr(self, attr_name, indices=None):
128
+ # Simplified: return from first worker
129
+ self.remotes[0].send(("get_attr", attr_name))
130
+ return self.remotes[0].recv()
131
+
132
+ def set_attr(self, attr_name, value, indices=None):
133
+ pass
134
+
135
+ def env_method(self, method_name, *method_args, **method_kwargs):
136
+ pass
137
+
138
+ def env_is_wrapped(self, wrapper_class, indices=None):
139
+ return [False] * self.total_envs
140
+
141
+
142
+ def run_training():
143
+ print("========================================================")
144
+ print(" LovecaSim - DISTRIBUTED GPU TRAINING (Async Workers) ")
145
+ print("========================================================")
146
+
147
+ # Configuration
148
+ TRAIN_ENVS = int(os.getenv("TRAIN_ENVS", "16384")) # Increased default
149
+ NUM_WORKERS = int(os.getenv("NUM_WORKERS", "4"))
150
+ ENVS_PER_WORKER = TRAIN_ENVS // NUM_WORKERS
151
+
152
+ TRAIN_STEPS = int(os.getenv("TRAIN_STEPS", "100_000_000"))
153
+ BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "32768")) # Increased batch size for GPU
154
+
155
+ print(f" [Config] Total Envs: {TRAIN_ENVS}")
156
+ print(f" [Config] Workers: {NUM_WORKERS} (Envs/Worker: {ENVS_PER_WORKER})")
157
+ print(f" [Config] Batch Size: {BATCH_SIZE}")
158
+ print(f" [Config] Architecture: Main(PPO) <-> {NUM_WORKERS} Workers <-> Numba(Vectors)")
159
+
160
+ print(f" [Init] Launching {NUM_WORKERS} distributed worker processes...")
161
+ vec_env = DistributedVectorEnv(NUM_WORKERS, ENVS_PER_WORKER)
162
+
163
+ print(" [Init] Creating PPO Model...")
164
+ model = PPO(
165
+ "MlpPolicy",
166
+ vec_env,
167
+ verbose=1,
168
+ learning_rate=3e-4,
169
+ n_steps=128,
170
+ batch_size=BATCH_SIZE,
171
+ n_epochs=4,
172
+ gamma=0.99,
173
+ gae_lambda=0.95,
174
+ ent_coef=0.01,
175
+ tensorboard_log="./logs/gpu_workers_tensorboard/",
176
+ device="cuda" if torch.cuda.is_available() else "cpu",
177
+ )
178
+
179
+ print(f" [Init] Model Device: {model.device}")
180
+
181
+ try:
182
+ print(" [Train] Starting Distributed Training...")
183
+ model.learn(total_timesteps=TRAIN_STEPS, progress_bar=True)
184
+ except KeyboardInterrupt:
185
+ print("\n [Stop] Interrupted by user.")
186
+ finally:
187
+ print(" [Done] Saving model and closing workers...")
188
+ model.save("./checkpoints/gpu_workers_final")
189
+ vec_env.close()
190
+
191
+
192
+ if __name__ == "__main__":
193
+ mp.set_start_method("spawn", force=True)
194
+ run_training()