Spaces:
Running
Running
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| import torch.multiprocessing as mp | |
| # Ensure project root is in path | |
| sys.path.append(os.getcwd()) | |
| from ai.vec_env_adapter import VectorEnvAdapter | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.vec_env import VecEnv | |
| # Worker function to run in a separate process | |
| def worker_process(remote, parent_remote, num_envs): | |
| parent_remote.close() | |
| # Initialize the Numba-optimized vector environment | |
| env = VectorEnvAdapter(num_envs=num_envs) | |
| try: | |
| while True: | |
| cmd, data = remote.recv() | |
| if cmd == "step": | |
| # data is actions | |
| obs, rewards, dones, infos = env.step(data) | |
| remote.send((obs, rewards, dones, infos)) | |
| elif cmd == "reset": | |
| obs = env.reset() | |
| remote.send(obs) | |
| elif cmd == "close": | |
| env.close() | |
| remote.close() | |
| break | |
| elif cmd == "get_attr": | |
| remote.send(getattr(env, data)) | |
| else: | |
| raise NotImplementedError(f"Worker received unknown command: {cmd}") | |
| except KeyboardInterrupt: | |
| print("Worker interrupt.") | |
| finally: | |
| env.close() | |
| class DistributedVectorEnv(VecEnv): | |
| """ | |
| A distributed Vector Environment that manages multiple worker processes, | |
| each running a Numba-optimized VectorEnvAdapter. | |
| Structure: | |
| Main Process (PPO) -> DistributedVectorEnv | |
| -> Worker Process 1 -> VectorEnvAdapter (N=1024) -> Numba | |
| -> Worker Process 2 -> VectorEnvAdapter (N=1024) -> Numba | |
| ... | |
| """ | |
| def __init__(self, num_workers: int, envs_per_worker: int): | |
| self.num_workers = num_workers | |
| self.envs_per_worker = envs_per_worker | |
| self.total_envs = num_workers * envs_per_worker | |
| # Define spaces (assuming consistent across all envs) | |
| # We create a dummy adapter just to get the spaces | |
| dummy = VectorEnvAdapter(num_envs=1) | |
| observation_space = dummy.observation_space | |
| action_space = dummy.action_space | |
| dummy.close() | |
| del dummy | |
| super().__init__(self.total_envs, observation_space, action_space) | |
| self.closed = False | |
| self.waiting = False | |
| self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in range(num_workers)]) | |
| self.processes = [] | |
| for work_remote, remote in zip(self.work_remotes, self.remotes): | |
| p = mp.Process(target=worker_process, args=(work_remote, remote, envs_per_worker)) | |
| p.daemon = True # Kill if main process dies | |
| p.start() | |
| self.processes.append(p) | |
| work_remote.close() | |
| def step_async(self, actions): | |
| # Split actions into chunks for each worker | |
| chunks = np.array_split(actions, self.num_workers) | |
| for remote, action_chunk in zip(self.remotes, chunks): | |
| remote.send(("step", action_chunk)) | |
| self.waiting = True | |
| def step_wait(self): | |
| results = [remote.recv() for remote in self.remotes] | |
| self.waiting = False | |
| # Aggregate results | |
| obs_list, rews_list, dones_list, infos_list = zip(*results) | |
| return ( | |
| np.concatenate(obs_list), | |
| np.concatenate(rews_list), | |
| np.concatenate(dones_list), | |
| # Infos are lists of dicts, so we just add them | |
| sum(infos_list, []), | |
| ) | |
| def reset(self): | |
| for remote in self.remotes: | |
| remote.send(("reset", None)) | |
| results = [remote.recv() for remote in self.remotes] | |
| return np.concatenate(results) | |
| def close(self): | |
| if self.closed: | |
| return | |
| if self.waiting: | |
| for remote in self.remotes: | |
| remote.recv() | |
| for remote in self.remotes: | |
| remote.send(("close", None)) | |
| for p in self.processes: | |
| p.join() | |
| self.closed = True | |
| def get_attr(self, attr_name, indices=None): | |
| # Simplified: return from first worker | |
| self.remotes[0].send(("get_attr", attr_name)) | |
| return self.remotes[0].recv() | |
| def set_attr(self, attr_name, value, indices=None): | |
| pass | |
| def env_method(self, method_name, *method_args, **method_kwargs): | |
| pass | |
| def env_is_wrapped(self, wrapper_class, indices=None): | |
| return [False] * self.total_envs | |
| def run_training(): | |
| print("========================================================") | |
| print(" LovecaSim - DISTRIBUTED GPU TRAINING (Async Workers) ") | |
| print("========================================================") | |
| # Configuration | |
| TRAIN_ENVS = int(os.getenv("TRAIN_ENVS", "16384")) # Increased default | |
| NUM_WORKERS = int(os.getenv("NUM_WORKERS", "4")) | |
| ENVS_PER_WORKER = TRAIN_ENVS // NUM_WORKERS | |
| TRAIN_STEPS = int(os.getenv("TRAIN_STEPS", "100_000_000")) | |
| BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "32768")) # Increased batch size for GPU | |
| print(f" [Config] Total Envs: {TRAIN_ENVS}") | |
| print(f" [Config] Workers: {NUM_WORKERS} (Envs/Worker: {ENVS_PER_WORKER})") | |
| print(f" [Config] Batch Size: {BATCH_SIZE}") | |
| print(f" [Config] Architecture: Main(PPO) <-> {NUM_WORKERS} Workers <-> Numba(Vectors)") | |
| print(f" [Init] Launching {NUM_WORKERS} distributed worker processes...") | |
| vec_env = DistributedVectorEnv(NUM_WORKERS, ENVS_PER_WORKER) | |
| print(" [Init] Creating PPO Model...") | |
| model = PPO( | |
| "MlpPolicy", | |
| vec_env, | |
| verbose=1, | |
| learning_rate=3e-4, | |
| n_steps=128, | |
| batch_size=BATCH_SIZE, | |
| n_epochs=4, | |
| gamma=0.99, | |
| gae_lambda=0.95, | |
| ent_coef=0.01, | |
| tensorboard_log="./logs/gpu_workers_tensorboard/", | |
| device="cuda" if torch.cuda.is_available() else "cpu", | |
| ) | |
| print(f" [Init] Model Device: {model.device}") | |
| try: | |
| print(" [Train] Starting Distributed Training...") | |
| model.learn(total_timesteps=TRAIN_STEPS, progress_bar=True) | |
| except KeyboardInterrupt: | |
| print("\n [Stop] Interrupted by user.") | |
| finally: | |
| print(" [Done] Saving model and closing workers...") | |
| model.save("./checkpoints/gpu_workers_final") | |
| vec_env.close() | |
| if __name__ == "__main__": | |
| mp.set_start_method("spawn", force=True) | |
| run_training() | |