Spaces:
Running
Running
| """GRPO rollout pool helper — designed to run from a Google Colab notebook. | |
| Opens N persistent WebSocket sessions against a single server deployed with | |
| AWS_RL_ENV_POOL_SIZE=N. All rollouts in a group share the same task (picked by | |
| one central Curriculum) and run concurrently via asyncio.gather. | |
| Usage (Colab cell): | |
| from scripts.grpo_pool import GrpoPool | |
| async def rollout(env, task): | |
| res = await env.reset(task=task) | |
| done = False | |
| total = 0.0 | |
| while not done: | |
| action = AwsRlAction(command=policy(res.observation)) | |
| res = await env.step(action) | |
| total += res.reward | |
| done = res.done | |
| return total | |
| async with GrpoPool(base_url="https://tunnel.example.com", size=8) as pool: | |
| for _ in range(num_grpo_steps): | |
| task = pool.curriculum.next_task() | |
| rewards = await pool.run_group(lambda e: rollout(e, task)) | |
| pool.record_group_result(task, rewards) | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from typing import Awaitable, Callable, List, Optional, Sequence | |
| from client import AwsRlEnv | |
| from models import Task | |
| from server.services.curriculum import Curriculum | |
| logger = logging.getLogger(__name__) | |
| class GrpoPool: | |
| """Manages N AwsRlEnv clients against a pooled server for GRPO rollouts.""" | |
| def __init__( | |
| self, | |
| base_url: str, | |
| size: int = 8, | |
| curriculum: Optional[Curriculum] = None, | |
| ) -> None: | |
| if size < 1: | |
| raise ValueError("size must be >= 1") | |
| self.base_url = base_url | |
| self.size = size | |
| self.curriculum = curriculum or Curriculum() | |
| self.envs: List[AwsRlEnv] = [] | |
| async def connect(self) -> None: | |
| """Open N persistent WebSocket sessions. Each binds to its own MiniStack. | |
| All-or-nothing: if any single session fails to connect, every already | |
| opened session is closed before re-raising, so the server's pool does | |
| not leak slots and callers never see a half-initialised pool. | |
| """ | |
| if self.envs: | |
| return | |
| envs = [AwsRlEnv(base_url=self.base_url) for _ in range(self.size)] | |
| try: | |
| await asyncio.gather(*(e.connect() for e in envs)) | |
| except BaseException: | |
| # Roll back: close every env (successful or not). return_exceptions | |
| # so a close() failure doesn't mask the original connect error. | |
| await asyncio.gather( | |
| *(e.close() for e in envs), | |
| return_exceptions=True, | |
| ) | |
| raise | |
| # Only publish the pool after the entire group connected successfully. | |
| self.envs = envs | |
| logger.info( | |
| "GrpoPool connected: %d sessions against %s", self.size, self.base_url | |
| ) | |
| async def close(self) -> None: | |
| """Close all WebSocket sessions. Server releases MiniStacks back to pool.""" | |
| if not self.envs: | |
| return | |
| await asyncio.gather(*(e.close() for e in self.envs), return_exceptions=True) | |
| self.envs = [] | |
| async def reset_group(self, task: Task) -> None: | |
| """Reset all N envs onto the same task. Runs concurrently. | |
| The full Task is serialised to the server, so envs do not have to | |
| look the task up through their own curriculum. | |
| """ | |
| await asyncio.gather(*(e.reset(task=task) for e in self.envs)) | |
| async def run_group( | |
| self, | |
| rollout_fn: Callable[[AwsRlEnv], Awaitable[float]], | |
| ) -> List[float]: | |
| """Run `rollout_fn` on each of the N envs concurrently, return rewards. | |
| The caller is responsible for calling reset_group() beforehand (or | |
| doing the reset inside rollout_fn with the same task_id). | |
| """ | |
| return list(await asyncio.gather(*(rollout_fn(e) for e in self.envs))) | |
| def record_group_result( | |
| self, | |
| task: Task, | |
| rewards: Sequence[float], | |
| success_threshold: float = 0.99, | |
| ) -> None: | |
| """Feed one group-level result back to the central curriculum. | |
| A group is considered "achieved" if at least one rollout scored above | |
| the success threshold. The recorded reward is the group mean. | |
| """ | |
| achieved = any(r >= success_threshold for r in rewards) | |
| mean_reward = sum(rewards) / len(rewards) if rewards else 0.0 | |
| self.curriculum.record_result(task, achieved=achieved, reward=mean_reward) | |
| async def session(self): | |
| try: | |
| await self.connect() | |
| yield self | |
| finally: | |
| await self.close() | |
| async def __aenter__(self) -> "GrpoPool": | |
| await self.connect() | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: | |
| await self.close() | |