Spaces:
Running
Running
File size: 6,555 Bytes
5ff442a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
import sys
import numpy as np
import torch
import torch.multiprocessing as mp
# Ensure project root is in path
sys.path.append(os.getcwd())
from ai.vec_env_adapter import VectorEnvAdapter
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecEnv
# Worker function to run in a separate process
def worker_process(remote, parent_remote, num_envs):
parent_remote.close()
# Initialize the Numba-optimized vector environment
env = VectorEnvAdapter(num_envs=num_envs)
try:
while True:
cmd, data = remote.recv()
if cmd == "step":
# data is actions
obs, rewards, dones, infos = env.step(data)
remote.send((obs, rewards, dones, infos))
elif cmd == "reset":
obs = env.reset()
remote.send(obs)
elif cmd == "close":
env.close()
remote.close()
break
elif cmd == "get_attr":
remote.send(getattr(env, data))
else:
raise NotImplementedError(f"Worker received unknown command: {cmd}")
except KeyboardInterrupt:
print("Worker interrupt.")
finally:
env.close()
class DistributedVectorEnv(VecEnv):
"""
A distributed Vector Environment that manages multiple worker processes,
each running a Numba-optimized VectorEnvAdapter.
Structure:
Main Process (PPO) -> DistributedVectorEnv
-> Worker Process 1 -> VectorEnvAdapter (N=1024) -> Numba
-> Worker Process 2 -> VectorEnvAdapter (N=1024) -> Numba
...
"""
def __init__(self, num_workers: int, envs_per_worker: int):
self.num_workers = num_workers
self.envs_per_worker = envs_per_worker
self.total_envs = num_workers * envs_per_worker
# Define spaces (assuming consistent across all envs)
# We create a dummy adapter just to get the spaces
dummy = VectorEnvAdapter(num_envs=1)
observation_space = dummy.observation_space
action_space = dummy.action_space
dummy.close()
del dummy
super().__init__(self.total_envs, observation_space, action_space)
self.closed = False
self.waiting = False
self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in range(num_workers)])
self.processes = []
for work_remote, remote in zip(self.work_remotes, self.remotes):
p = mp.Process(target=worker_process, args=(work_remote, remote, envs_per_worker))
p.daemon = True # Kill if main process dies
p.start()
self.processes.append(p)
work_remote.close()
def step_async(self, actions):
# Split actions into chunks for each worker
chunks = np.array_split(actions, self.num_workers)
for remote, action_chunk in zip(self.remotes, chunks):
remote.send(("step", action_chunk))
self.waiting = True
def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
# Aggregate results
obs_list, rews_list, dones_list, infos_list = zip(*results)
return (
np.concatenate(obs_list),
np.concatenate(rews_list),
np.concatenate(dones_list),
# Infos are lists of dicts, so we just add them
sum(infos_list, []),
)
def reset(self):
for remote in self.remotes:
remote.send(("reset", None))
results = [remote.recv() for remote in self.remotes]
return np.concatenate(results)
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(("close", None))
for p in self.processes:
p.join()
self.closed = True
def get_attr(self, attr_name, indices=None):
# Simplified: return from first worker
self.remotes[0].send(("get_attr", attr_name))
return self.remotes[0].recv()
def set_attr(self, attr_name, value, indices=None):
pass
def env_method(self, method_name, *method_args, **method_kwargs):
pass
def env_is_wrapped(self, wrapper_class, indices=None):
return [False] * self.total_envs
def run_training():
print("========================================================")
print(" LovecaSim - DISTRIBUTED GPU TRAINING (Async Workers) ")
print("========================================================")
# Configuration
TRAIN_ENVS = int(os.getenv("TRAIN_ENVS", "16384")) # Increased default
NUM_WORKERS = int(os.getenv("NUM_WORKERS", "4"))
ENVS_PER_WORKER = TRAIN_ENVS // NUM_WORKERS
TRAIN_STEPS = int(os.getenv("TRAIN_STEPS", "100_000_000"))
BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "32768")) # Increased batch size for GPU
print(f" [Config] Total Envs: {TRAIN_ENVS}")
print(f" [Config] Workers: {NUM_WORKERS} (Envs/Worker: {ENVS_PER_WORKER})")
print(f" [Config] Batch Size: {BATCH_SIZE}")
print(f" [Config] Architecture: Main(PPO) <-> {NUM_WORKERS} Workers <-> Numba(Vectors)")
print(f" [Init] Launching {NUM_WORKERS} distributed worker processes...")
vec_env = DistributedVectorEnv(NUM_WORKERS, ENVS_PER_WORKER)
print(" [Init] Creating PPO Model...")
model = PPO(
"MlpPolicy",
vec_env,
verbose=1,
learning_rate=3e-4,
n_steps=128,
batch_size=BATCH_SIZE,
n_epochs=4,
gamma=0.99,
gae_lambda=0.95,
ent_coef=0.01,
tensorboard_log="./logs/gpu_workers_tensorboard/",
device="cuda" if torch.cuda.is_available() else "cpu",
)
print(f" [Init] Model Device: {model.device}")
try:
print(" [Train] Starting Distributed Training...")
model.learn(total_timesteps=TRAIN_STEPS, progress_bar=True)
except KeyboardInterrupt:
print("\n [Stop] Interrupted by user.")
finally:
print(" [Done] Saving model and closing workers...")
model.save("./checkpoints/gpu_workers_final")
vec_env.close()
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
run_training()
|