aws_rl_env / scripts /grpo_pool.py
Sizzing's picture
Upload folder using huggingface_hub
e56d042 verified
"""GRPO rollout pool helper — designed to run from a Google Colab notebook.
Opens N persistent WebSocket sessions against a single server deployed with
AWS_RL_ENV_POOL_SIZE=N. All rollouts in a group share the same task (picked by
one central Curriculum) and run concurrently via asyncio.gather.
Usage (Colab cell):
from scripts.grpo_pool import GrpoPool
async def rollout(env, task):
res = await env.reset(task=task)
done = False
total = 0.0
while not done:
action = AwsRlAction(command=policy(res.observation))
res = await env.step(action)
total += res.reward
done = res.done
return total
async with GrpoPool(base_url="https://tunnel.example.com", size=8) as pool:
for _ in range(num_grpo_steps):
task = pool.curriculum.next_task()
rewards = await pool.run_group(lambda e: rollout(e, task))
pool.record_group_result(task, rewards)
"""
from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Awaitable, Callable, List, Optional, Sequence
from client import AwsRlEnv
from models import Task
from server.services.curriculum import Curriculum
logger = logging.getLogger(__name__)
class GrpoPool:
"""Manages N AwsRlEnv clients against a pooled server for GRPO rollouts."""
def __init__(
self,
base_url: str,
size: int = 8,
curriculum: Optional[Curriculum] = None,
) -> None:
if size < 1:
raise ValueError("size must be >= 1")
self.base_url = base_url
self.size = size
self.curriculum = curriculum or Curriculum()
self.envs: List[AwsRlEnv] = []
async def connect(self) -> None:
"""Open N persistent WebSocket sessions. Each binds to its own MiniStack.
All-or-nothing: if any single session fails to connect, every already
opened session is closed before re-raising, so the server's pool does
not leak slots and callers never see a half-initialised pool.
"""
if self.envs:
return
envs = [AwsRlEnv(base_url=self.base_url) for _ in range(self.size)]
try:
await asyncio.gather(*(e.connect() for e in envs))
except BaseException:
# Roll back: close every env (successful or not). return_exceptions
# so a close() failure doesn't mask the original connect error.
await asyncio.gather(
*(e.close() for e in envs),
return_exceptions=True,
)
raise
# Only publish the pool after the entire group connected successfully.
self.envs = envs
logger.info(
"GrpoPool connected: %d sessions against %s", self.size, self.base_url
)
async def close(self) -> None:
"""Close all WebSocket sessions. Server releases MiniStacks back to pool."""
if not self.envs:
return
await asyncio.gather(*(e.close() for e in self.envs), return_exceptions=True)
self.envs = []
async def reset_group(self, task: Task) -> None:
"""Reset all N envs onto the same task. Runs concurrently.
The full Task is serialised to the server, so envs do not have to
look the task up through their own curriculum.
"""
await asyncio.gather(*(e.reset(task=task) for e in self.envs))
async def run_group(
self,
rollout_fn: Callable[[AwsRlEnv], Awaitable[float]],
) -> List[float]:
"""Run `rollout_fn` on each of the N envs concurrently, return rewards.
The caller is responsible for calling reset_group() beforehand (or
doing the reset inside rollout_fn with the same task_id).
"""
return list(await asyncio.gather(*(rollout_fn(e) for e in self.envs)))
def record_group_result(
self,
task: Task,
rewards: Sequence[float],
success_threshold: float = 0.99,
) -> None:
"""Feed one group-level result back to the central curriculum.
A group is considered "achieved" if at least one rollout scored above
the success threshold. The recorded reward is the group mean.
"""
achieved = any(r >= success_threshold for r in rewards)
mean_reward = sum(rewards) / len(rewards) if rewards else 0.0
self.curriculum.record_result(task, achieved=achieved, reward=mean_reward)
@asynccontextmanager
async def session(self):
try:
await self.connect()
yield self
finally:
await self.close()
async def __aenter__(self) -> "GrpoPool":
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close()