Spaces:
Sleeping
Sleeping
| """ | |
| SubprocVecEnv β N environments running in parallel worker subprocesses. | |
| Each worker owns one RaceEnvironment and runs it in isolation. | |
| The main process sends actions to all workers simultaneously, then | |
| collects results β replacing the sequential env-stepping loop with | |
| a single parallel scatter/gather. | |
| Protocol (over multiprocessing.Pipe): | |
| main β worker : (_CMD_STEP, (accel, steer)) | |
| main β worker : (_CMD_RESET, track_level: int) | |
| main β worker : (_CMD_CLOSE, None) | |
| worker β main : _StepResult namedtuple | |
| On Linux (fork) workers start instantly by inheriting the parent's | |
| memory. SDL_VIDEODRIVER=dummy is set before any pygame import so | |
| every worker gets a headless pygame context. | |
| """ | |
| from __future__ import annotations | |
| import multiprocessing as mp | |
| import os | |
| import sys | |
| from typing import List, NamedTuple, Optional | |
| import numpy as np | |
| # ββ IPC command tokens ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _CMD_STEP = 0 | |
| _CMD_RESET = 1 | |
| _CMD_CLOSE = 2 | |
| class _StepResult(NamedTuple): | |
| """Compact observation sent from worker to main process each step.""" | |
| image: np.ndarray # (64, 64, 3) uint8 | |
| scalars: np.ndarray # (9,) float32: speed, ang_vel, 5 rays, wp_sin, wp_cos | |
| reward: float | |
| done: bool | |
| metadata: dict | |
| # ββ Worker entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _worker_fn(conn: mp.connection.Connection, max_steps: int, laps_target: int) -> None: | |
| """ | |
| Runs inside a subprocess. Owns one RaceEnvironment. | |
| Receives commands from the main process via `conn` and sends | |
| _StepResult objects back. The worker's pygame is headless | |
| (SDL_VIDEODRIVER=dummy) and completely independent of the | |
| main process. | |
| """ | |
| # Must be set before any pygame/game import | |
| os.environ.setdefault("SDL_VIDEODRIVER", "dummy") | |
| os.environ.setdefault("SDL_AUDIODRIVER", "dummy") | |
| # Lazy imports β these happen inside the subprocess so each worker | |
| # gets its own pygame state and pygame Surface objects. | |
| from game.tracks import TRACKS | |
| from env.environment import RaceEnvironment | |
| from env.models import DriveAction | |
| tracks_by_level = {t.level: t for t in TRACKS} | |
| env: Optional[RaceEnvironment] = None | |
| try: | |
| while True: | |
| cmd, data = conn.recv() | |
| if cmd == _CMD_RESET: | |
| track_level: int = data | |
| track = tracks_by_level[track_level] | |
| track.build() # builds pygame.Surface inside this worker | |
| env = RaceEnvironment(track, max_steps, laps_target, use_image=True) | |
| obs = env.reset() | |
| conn.send(_make_result(obs)) | |
| elif cmd == _CMD_STEP: | |
| accel, steer = data | |
| obs = env.step(DriveAction(accel=accel, steer=steer)) | |
| conn.send(_make_result(obs)) | |
| elif cmd == _CMD_CLOSE: | |
| break | |
| except (EOFError, BrokenPipeError, KeyboardInterrupt): | |
| pass | |
| finally: | |
| conn.close() | |
| def _make_result(obs) -> _StepResult: | |
| img = obs.image # (64, 64, 3) uint8 numpy array | |
| scalars = np.array(obs.scalars, dtype=np.float32) | |
| return _StepResult( | |
| image=img, | |
| scalars=scalars, | |
| reward=obs.reward, | |
| done=obs.done, | |
| metadata=obs.metadata if obs.metadata else {}, | |
| ) | |
| # ββ Main-process interface ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SubprocVecEnv: | |
| """ | |
| N environments stepping in parallel subprocesses. | |
| Replaces the sequential ``for n in range(N): envs[n].step(...)`` loop | |
| in the rollout collector with a scatter/gather over worker pipes. | |
| Parameters | |
| ---------- | |
| n_envs : number of worker subprocesses to launch | |
| max_steps : max steps per episode (passed to RaceEnvironment) | |
| laps_target : laps per episode (passed to RaceEnvironment) | |
| """ | |
| def __init__(self, n_envs: int, max_steps: int = 3000, laps_target: int = 3): | |
| self.n_envs = n_envs | |
| # fork is the fastest start method on Linux (workers inherit parent memory). | |
| # Use spawn on Windows/macOS where fork is unavailable or unsafe. | |
| start_method = "fork" if sys.platform.startswith("linux") else "spawn" | |
| ctx = mp.get_context(start_method) | |
| self._remotes: List[mp.connection.Connection] = [] | |
| work_remotes: List[mp.connection.Connection] = [] | |
| for _ in range(n_envs): | |
| main_end, work_end = ctx.Pipe(duplex=True) | |
| self._remotes.append(main_end) | |
| work_remotes.append(work_end) | |
| self._procs: List[mp.Process] = [] | |
| for wr in work_remotes: | |
| p = ctx.Process( | |
| target=_worker_fn, | |
| args=(wr, max_steps, laps_target), | |
| daemon=True, | |
| ) | |
| p.start() | |
| self._procs.append(p) | |
| wr.close() # worker-end not needed in main process | |
| print(f"[SubprocVecEnv] {n_envs} worker processes started " | |
| f"(start_method={start_method!r})") | |
| # ββ Bulk reset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, track_levels: List[int]) -> List[_StepResult]: | |
| """ | |
| Reset all N envs simultaneously. | |
| Parameters | |
| ---------- | |
| track_levels : list of int, length n_envs β one track level per worker | |
| """ | |
| for remote, level in zip(self._remotes, track_levels): | |
| remote.send((_CMD_RESET, level)) | |
| return [r.recv() for r in self._remotes] | |
| # ββ Single-env reset (for episode end during rollout) βββββββββββββββββββββ | |
| def reset_one(self, n: int, track_level: int) -> _StepResult: | |
| """Reset worker n on track_level and return its first observation.""" | |
| self._remotes[n].send((_CMD_RESET, track_level)) | |
| return self._remotes[n].recv() | |
| # ββ Parallel step βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def step_async(self, actions: List[tuple]) -> None: | |
| """ | |
| Broadcast actions to all workers (non-blocking). | |
| actions : list of (accel, steer) float tuples, length n_envs | |
| """ | |
| for remote, (accel, steer) in zip(self._remotes, actions): | |
| remote.send((_CMD_STEP, (float(accel), float(steer)))) | |
| def step_wait(self) -> List[_StepResult]: | |
| """Collect one _StepResult from every worker.""" | |
| return [r.recv() for r in self._remotes] | |
| def step(self, actions: List[tuple]) -> List[_StepResult]: | |
| """Send actions to all workers and wait for all results.""" | |
| self.step_async(actions) | |
| return self.step_wait() | |
| # ββ Lifecycle βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def close(self) -> None: | |
| for remote in self._remotes: | |
| try: | |
| remote.send((_CMD_CLOSE, None)) | |
| except Exception: | |
| pass | |
| for p in self._procs: | |
| p.join(timeout=5) | |
| if p.is_alive(): | |
| p.terminate() | |
| for p in self._procs: | |
| p.join(timeout=2) | |