LovecaSim / ai /training /train_gpu_workers.py
trioskosmos's picture
Upload ai/training/train_gpu_workers.py with huggingface_hub
5ff442a verified
import os
import sys
import numpy as np
import torch
import torch.multiprocessing as mp
# Ensure project root is in path
sys.path.append(os.getcwd())
from ai.vec_env_adapter import VectorEnvAdapter
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecEnv
# Worker function to run in a separate process
def worker_process(remote, parent_remote, num_envs):
parent_remote.close()
# Initialize the Numba-optimized vector environment
env = VectorEnvAdapter(num_envs=num_envs)
try:
while True:
cmd, data = remote.recv()
if cmd == "step":
# data is actions
obs, rewards, dones, infos = env.step(data)
remote.send((obs, rewards, dones, infos))
elif cmd == "reset":
obs = env.reset()
remote.send(obs)
elif cmd == "close":
env.close()
remote.close()
break
elif cmd == "get_attr":
remote.send(getattr(env, data))
else:
raise NotImplementedError(f"Worker received unknown command: {cmd}")
except KeyboardInterrupt:
print("Worker interrupt.")
finally:
env.close()
class DistributedVectorEnv(VecEnv):
"""
A distributed Vector Environment that manages multiple worker processes,
each running a Numba-optimized VectorEnvAdapter.
Structure:
Main Process (PPO) -> DistributedVectorEnv
-> Worker Process 1 -> VectorEnvAdapter (N=1024) -> Numba
-> Worker Process 2 -> VectorEnvAdapter (N=1024) -> Numba
...
"""
def __init__(self, num_workers: int, envs_per_worker: int):
self.num_workers = num_workers
self.envs_per_worker = envs_per_worker
self.total_envs = num_workers * envs_per_worker
# Define spaces (assuming consistent across all envs)
# We create a dummy adapter just to get the spaces
dummy = VectorEnvAdapter(num_envs=1)
observation_space = dummy.observation_space
action_space = dummy.action_space
dummy.close()
del dummy
super().__init__(self.total_envs, observation_space, action_space)
self.closed = False
self.waiting = False
self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in range(num_workers)])
self.processes = []
for work_remote, remote in zip(self.work_remotes, self.remotes):
p = mp.Process(target=worker_process, args=(work_remote, remote, envs_per_worker))
p.daemon = True # Kill if main process dies
p.start()
self.processes.append(p)
work_remote.close()
def step_async(self, actions):
# Split actions into chunks for each worker
chunks = np.array_split(actions, self.num_workers)
for remote, action_chunk in zip(self.remotes, chunks):
remote.send(("step", action_chunk))
self.waiting = True
def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
# Aggregate results
obs_list, rews_list, dones_list, infos_list = zip(*results)
return (
np.concatenate(obs_list),
np.concatenate(rews_list),
np.concatenate(dones_list),
# Infos are lists of dicts, so we just add them
sum(infos_list, []),
)
def reset(self):
for remote in self.remotes:
remote.send(("reset", None))
results = [remote.recv() for remote in self.remotes]
return np.concatenate(results)
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(("close", None))
for p in self.processes:
p.join()
self.closed = True
def get_attr(self, attr_name, indices=None):
# Simplified: return from first worker
self.remotes[0].send(("get_attr", attr_name))
return self.remotes[0].recv()
def set_attr(self, attr_name, value, indices=None):
pass
def env_method(self, method_name, *method_args, **method_kwargs):
pass
def env_is_wrapped(self, wrapper_class, indices=None):
return [False] * self.total_envs
def run_training():
print("========================================================")
print(" LovecaSim - DISTRIBUTED GPU TRAINING (Async Workers) ")
print("========================================================")
# Configuration
TRAIN_ENVS = int(os.getenv("TRAIN_ENVS", "16384")) # Increased default
NUM_WORKERS = int(os.getenv("NUM_WORKERS", "4"))
ENVS_PER_WORKER = TRAIN_ENVS // NUM_WORKERS
TRAIN_STEPS = int(os.getenv("TRAIN_STEPS", "100_000_000"))
BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "32768")) # Increased batch size for GPU
print(f" [Config] Total Envs: {TRAIN_ENVS}")
print(f" [Config] Workers: {NUM_WORKERS} (Envs/Worker: {ENVS_PER_WORKER})")
print(f" [Config] Batch Size: {BATCH_SIZE}")
print(f" [Config] Architecture: Main(PPO) <-> {NUM_WORKERS} Workers <-> Numba(Vectors)")
print(f" [Init] Launching {NUM_WORKERS} distributed worker processes...")
vec_env = DistributedVectorEnv(NUM_WORKERS, ENVS_PER_WORKER)
print(" [Init] Creating PPO Model...")
model = PPO(
"MlpPolicy",
vec_env,
verbose=1,
learning_rate=3e-4,
n_steps=128,
batch_size=BATCH_SIZE,
n_epochs=4,
gamma=0.99,
gae_lambda=0.95,
ent_coef=0.01,
tensorboard_log="./logs/gpu_workers_tensorboard/",
device="cuda" if torch.cuda.is_available() else "cpu",
)
print(f" [Init] Model Device: {model.device}")
try:
print(" [Train] Starting Distributed Training...")
model.learn(total_timesteps=TRAIN_STEPS, progress_bar=True)
except KeyboardInterrupt:
print("\n [Stop] Interrupted by user.")
finally:
print(" [Done] Saving model and closing workers...")
model.save("./checkpoints/gpu_workers_final")
vec_env.close()
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
run_training()