File size: 4,880 Bytes
e56d042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""GRPO rollout pool helper — designed to run from a Google Colab notebook.

Opens N persistent WebSocket sessions against a single server deployed with
AWS_RL_ENV_POOL_SIZE=N. All rollouts in a group share the same task (picked by
one central Curriculum) and run concurrently via asyncio.gather.

Usage (Colab cell):
    from scripts.grpo_pool import GrpoPool

    async def rollout(env, task):
        res = await env.reset(task=task)
        done = False
        total = 0.0
        while not done:
            action = AwsRlAction(command=policy(res.observation))
            res = await env.step(action)
            total += res.reward
            done = res.done
        return total

    async with GrpoPool(base_url="https://tunnel.example.com", size=8) as pool:
        for _ in range(num_grpo_steps):
            task = pool.curriculum.next_task()
            rewards = await pool.run_group(lambda e: rollout(e, task))
            pool.record_group_result(task, rewards)
"""

from __future__ import annotations

import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Awaitable, Callable, List, Optional, Sequence

from client import AwsRlEnv
from models import Task
from server.services.curriculum import Curriculum

logger = logging.getLogger(__name__)


class GrpoPool:
    """Manages N AwsRlEnv clients against a pooled server for GRPO rollouts."""

    def __init__(
        self,
        base_url: str,
        size: int = 8,
        curriculum: Optional[Curriculum] = None,
    ) -> None:
        if size < 1:
            raise ValueError("size must be >= 1")
        self.base_url = base_url
        self.size = size
        self.curriculum = curriculum or Curriculum()
        self.envs: List[AwsRlEnv] = []

    async def connect(self) -> None:
        """Open N persistent WebSocket sessions. Each binds to its own MiniStack.

        All-or-nothing: if any single session fails to connect, every already
        opened session is closed before re-raising, so the server's pool does
        not leak slots and callers never see a half-initialised pool.
        """
        if self.envs:
            return
        envs = [AwsRlEnv(base_url=self.base_url) for _ in range(self.size)]
        try:
            await asyncio.gather(*(e.connect() for e in envs))
        except BaseException:
            # Roll back: close every env (successful or not). return_exceptions
            # so a close() failure doesn't mask the original connect error.
            await asyncio.gather(
                *(e.close() for e in envs),
                return_exceptions=True,
            )
            raise
        # Only publish the pool after the entire group connected successfully.
        self.envs = envs
        logger.info(
            "GrpoPool connected: %d sessions against %s", self.size, self.base_url
        )

    async def close(self) -> None:
        """Close all WebSocket sessions. Server releases MiniStacks back to pool."""
        if not self.envs:
            return
        await asyncio.gather(*(e.close() for e in self.envs), return_exceptions=True)
        self.envs = []

    async def reset_group(self, task: Task) -> None:
        """Reset all N envs onto the same task. Runs concurrently.

        The full Task is serialised to the server, so envs do not have to
        look the task up through their own curriculum.
        """
        await asyncio.gather(*(e.reset(task=task) for e in self.envs))

    async def run_group(
        self,
        rollout_fn: Callable[[AwsRlEnv], Awaitable[float]],
    ) -> List[float]:
        """Run `rollout_fn` on each of the N envs concurrently, return rewards.

        The caller is responsible for calling reset_group() beforehand (or
        doing the reset inside rollout_fn with the same task_id).
        """
        return list(await asyncio.gather(*(rollout_fn(e) for e in self.envs)))

    def record_group_result(
        self,
        task: Task,
        rewards: Sequence[float],
        success_threshold: float = 0.99,
    ) -> None:
        """Feed one group-level result back to the central curriculum.

        A group is considered "achieved" if at least one rollout scored above
        the success threshold. The recorded reward is the group mean.
        """
        achieved = any(r >= success_threshold for r in rewards)
        mean_reward = sum(rewards) / len(rewards) if rewards else 0.0
        self.curriculum.record_result(task, achieved=achieved, reward=mean_reward)

    @asynccontextmanager
    async def session(self):
        try:
            await self.connect()
            yield self
        finally:
            await self.close()

    async def __aenter__(self) -> "GrpoPool":
        await self.connect()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
        await self.close()