Spaces:
Running
scripts/ β Parallel Rollout Architecture
This directory holds the helper modules that make 8 concurrent multi-turn rollouts against the AWS RL environment possible β the scaling trick that turns GRPO from a thought experiment into something you can actually train on a single GPU.
If you only read one section, read Β§2 β Three coordinated pool layers. It explains the architecture in one page.
Table of contents
- Why parallel rollouts matter
- Three coordinated pool layers
- Walking through one GRPO step
- The all-or-nothing connect protocol
- Concurrency-safety guarantees
- Configuration
- Running the multi-connection demo
- Files in this directory
1. Why parallel rollouts matter
GRPO computes group-relative advantages: every gradient step needs G rollouts on the same prompt so the algorithm can normalize rewards within the group. With G = 8, multi-turn episodes (β€ 6 turns), and an env step that round-trips an AWS CLI invocation through MiniStack (~50 ms), the math is:
Serial: 8 rollouts Γ 6 turns Γ 50 ms = 2,400 ms env-time per GRPO step
Parallel: max(8 envs) Γ 6 turns Γ 50 ms = 300 ms env-time per GRPO step
That's an 8Γ speedup on the env side. The model forward pass still serialises (single GPU), so the practical end-to-end gain depends on the env/compute ratio β but for an env that takes ~50 ms per step, parallelism is the difference between a tractable training run and a 24-hour one.
The parallelism isn't free: each rollout needs state isolation. If two rollouts share an AWS world, rollout 1's S3 buckets bleed into rollout 2's view, the curriculum mastery numbers go to garbage, and the agent can hack the reward by piggy-backing off siblings. The three coordinated pools below exist to make state isolation cheap and automatic.
2. Three coordinated pool layers
The system has three pools that work together. They look similar at first glance β all of them deal with N concurrent envs β but each operates at a different layer of the stack:
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Layer 3 β Trainer-process pool β
β MultiTurnEnvPool (train_grpo.py) β
β β’ owns a background asyncio loop β
β β’ exposes a sync run_group() that the GRPO trainer can call β
β β’ used by the in-process trainer (CLI: python train_grpo.py) β
ββββββββββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββββββ
β N WebSocket clients
ββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββ
β Layer 3 alt β Notebook-friendly pool β
β GrpoPool (scripts/grpo_pool.py) β
β β’ async-native API (async with GrpoPool(...) as pool: ...) β
β β’ used by Colab notebooks where the cell IS the asyncio loop β
β β’ simpler interface (no background thread) β
ββββββββββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββββββ
β N WebSocket clients
ββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββ
β Layer 2 β OpenEnv max_concurrent_envs β
β create_app(env_factory, ..., max_concurrent_envs=POOL_SIZE) β
β β’ OpenEnv reserves up to N env instances at once β
β β’ returns 503 if a 9th client tries to connect when POOL_SIZE=8 β
ββββββββββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββββββ
β env_factory() invoked per session
ββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββ
β Layer 1 β Server-side MiniStack pool β
β MiniStackPool (server/app.py) β
β β’ free-list of MiniStack ports (BASE..BASE+POOL_SIZE-1) β
β β’ acquire()/release() under a threading.Lock β
β β’ each WS session binds to ONE port for its lifetime β state isolation β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
βΌ
N independent MiniStack processes
(started by Dockerfile / Makefile)
Layer 1 β Server-side MiniStackPool
Lives in server/app.py:75β138. Documented in detail in server/README.md Β§6.
- A
threading.Lock-guarded free list of port numbers acquire()returns a port;release(port)puts it backRuntimeError("MiniStack pool exhausted")if depleted- The Dockerfile launches
POOL_SIZEMiniStack processes on consecutive ports before the FastAPI server starts accepting connections
Layer 2 β OpenEnv max_concurrent_envs
When create_app() is called with max_concurrent_envs=POOL_SIZE, OpenEnv enforces the cap upstream β clients beyond the cap get a clean 503 instead of RuntimeError. Defence in depth.
Layer 3 β Client pools
Two flavours, same parallelism model, different ergonomics:
MultiTurnEnvPool (train_grpo.py) |
GrpoPool (scripts/grpo_pool.py) |
|
|---|---|---|
| API | Sync β pool.run_group(task, ...) |
Async β await pool.run_group(rollout_fn) |
| Loop | Owns a background thread + asyncio loop | Caller is the asyncio loop (Colab cell) |
| Use case | In-process trainer (python train_grpo.py) |
Notebooks driving training from Colab |
| Connection | await asyncio.gather(*(e.connect() for e in envs)) on background thread |
Same, but on the caller's loop |
record_result() |
Trainer calls Curriculum.record_result() directly |
pool.record_group_result(task, rewards) helper baked in |
Both share the all-or-nothing connect protocol described in Β§4.
Why two client pools?
Real life: the trainer process (python train_grpo.py) runs synchronously β TRL's GRPOTrainer.train() blocks. To use await asyncio.gather from inside that, we need a background asyncio loop on a separate thread. That's MultiTurnEnvPool.
Colab cells, on the other hand, are the asyncio loop (Jupyter β₯ 7 ships nest_asyncio under the hood). Running a background thread + loop there is overkill and creates ordering bugs. GrpoPool is the simpler async-native variant for that case.
The two pools share semantic invariants β same N, same all-or-nothing connect, same task scoping β so behaviour is identical regardless of which entry point you use.
3. Walking through one GRPO step
1. trainer picks one task from the Curriculum (1 task)
2. pool.run_group(task) (asyncio.gather over N envs)
3. for turn in 0..MAX_TURNS:
prompts = build_prompts(observations) (CPU)
completions = policy.generate(prompts) (1 batched fwd, GPU)
actions = parse_completions(completions) (CPU; extract `aws ...` line)
observations = await pool.run_group_step(actions) (N concurrent env.step)
4. rewards = sum_per_episode(rewards_lists) (N floats)
5. GRPO computes group-relative advantages, KL, loss (1 backward, GPU)
6. Curriculum.record_result(task, mean(rewards)) (1 update)
A couple of subtleties:
Generation is serialised, env-step is not
train_grpo.py:_GENERATE_LOCK β a threading.Lock around model.generate(). The model lives on a single GPU; concurrent generate() calls would clobber each other. We let env step calls run concurrently (the slow part β WebSocket round-trip + MiniStack execution); only generation serialises.
Per-turn token accumulation
rollout_one_episode() accumulates prompt_ids, completion_ids, and logprobs across turns into a single sequence. GRPO then assigns the episode-level reward to that full sequence. This matches the multi-turn structure of the underlying decision problem.
Why every rollout in a group runs the same task
GRPO's group-relative advantage is (reward_i β group_mean) / group_std. If different rollouts ran different tasks, group statistics would mean nothing. The curriculum picks one task per GRPO step; the pool's reset_group(task) forces every env to that task; only then can the group statistics be meaningful.
4. The all-or-nothing connect protocol
scripts/grpo_pool.py:58-82 β the most non-obvious correctness detail in the whole pool stack.
async def connect(self) -> None:
if self.envs:
return
envs = [AwsRlEnv(base_url=self.base_url) for _ in range(self.size)]
try:
await asyncio.gather(*(e.connect() for e in envs))
except BaseException:
# Roll back: close every env (successful or not). return_exceptions
# so a close() failure doesn't mask the original connect error.
await asyncio.gather(
*(e.close() for e in envs),
return_exceptions=True,
)
raise
# Only publish the pool after the entire group connected successfully.
self.envs = envs
What makes this important:
asyncio.gatherraises on the first failure. If 3 of 8 connects succeed and the 4th raises, the other 4 may or may not have connected yet. Their state is undefined.- Server-side state matters. Each successful connect acquired a MiniStack port from the server pool. If we just
raisewithout cleanup, those ports stay held until the WebSocket times out β typically minutes. The next training run hits "pool exhausted". self.envsis published only after success. If any partial state were exposed, callers might callpool.run_group()on a half-initialised pool and get N/M valid results.return_exceptions=Trueon the rollback. A close error must not mask the original connect error β the user needs to know the real reason connect failed, not a downstream cleanup failure.
These four invariants are the difference between "training reliably resumes after a flake" and "every flake leaks 7 ports and you're rebuilding the container at 3 AM".
MultiTurnEnvPool._connect_all() in train_grpo.py:473-480 implements the same pattern.
5. Concurrency-safety guarantees
| Concern | Guarantee | Where enforced |
|---|---|---|
| Cross-rollout state isolation | Each WebSocket session holds its own MiniStack port for its lifetime | MiniStackPool.acquire/release (server/app.py) |
| Curriculum coherence | One curriculum instance per training run; record_result() is the only mutation point |
make_rollout_func in train_grpo.py |
| GPU contention | model.generate() calls serialised behind _GENERATE_LOCK |
train_grpo.py:_GENERATE_LOCK |
| Pool slot leakage on flake | All-or-nothing connect with rollback close | GrpoPool.connect, MultiTurnEnvPool._connect_all |
| Hung shutdown | Pool close runs asyncio.gather(..., return_exceptions=True) then stops the loop with timeout |
MultiTurnEnvPool.close() |
| Web playground vs pool collisions | Web routes refuse to mount when POOL_SIZE > 1 |
server/app.py:171 |
Tests covering these:
- tests/test_pool.py β server-side
MiniStackPoolacquire/release, exhaustion behaviour - tests/test_grpo_pool.py β
GrpoPoolconnect/close lifecycle, partial-connect rollback, group-result aggregation
6. Configuration
| Variable | Default | Purpose |
|---|---|---|
AWS_RL_ENV_POOL_SIZE |
1 |
Server-side MiniStack pool size. Set to 8 for GRPO training. Must be β₯ training-time num_generations. |
AWS_RL_ENV_MINISTACK_BASE_PORT |
4566 |
First MiniStack port; the pool covers [BASE, BASE + POOL_SIZE) |
BACKEND_TYPE |
simulator |
simulator (default; pool is meaningful) or aws (real AWS; pool disabled) |
NUM_GENERATIONS (in trainer cfg) |
8 |
Number of WebSocket clients the pool opens. Should equal AWS_RL_ENV_POOL_SIZE for full parallelism. |
MAX_TURNS (in trainer cfg) |
6 |
Per-rollout episode length cap |
MAX_TOTAL_TOKENS (in trainer cfg) |
4096 |
Per-episode token budget (anti-OOM) |
When deploying to HuggingFace Spaces, pool size is constrained by container memory β each MiniStack process is ~50β100 MB resident.
7. Running the multi-connection demo
scripts/TestMultipleConnects.ipynb is a hands-on notebook that proves all 8 sessions stay isolated.
# 1. Start the env server with pool size 8
AWS_RL_ENV_POOL_SIZE=8 make run
# 2. Run the notebook
jupyter notebook scripts/TestMultipleConnects.ipynb
Expected output: 8 simultaneous "connection open" lines, 8 independent reset/step traces, no resource bleed across sessions.
The screenshot at docs/figures/env_init_screenshot.png captures one such run.
See also
- Main README β project overview
- server/README.md β environment internals (server-side pool detail in Β§6)
- train/README.md β SFT + GRPO training pipeline (this pool plugs into the GRPO loop)
- tests/test_pool.py β server-side pool acquire/release tests
- tests/test_grpo_pool.py β client-side pool lifecycle tests

