File size: 2,241 Bytes
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ea4118
8c486a8
 
 
 
 
 
 
 
3ea4118
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ea4118
8c486a8
3ea4118
8c486a8
3ea4118
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Rollout function for TRL GRPOTrainer integration.

DEFERRED -- this is a stub. The environment must work first; anyone can
plug in TRL/Unsloth/SkyRL later via this rollout_func.

Usage with GRPOTrainer::

    from open_range.training.rollout import rollout_func
    trainer = GRPOTrainer(
        ...,
        rollout_func=rollout_func,
    )
"""

from __future__ import annotations

from typing import Any, Protocol


class AgentCallable(Protocol):
    """Minimal agent interface for rollout."""

    def __call__(self, observation: Any) -> Any: ...


def rollout_func(
    env: Any,
    agent: AgentCallable,
    num_steps: int = 100,
    mode: str = "red",
) -> dict[str, Any]:
    """Run a single episode rollout.

    Args:
        env: An OpenRange environment (RangeEnvironment or EnvClient).
        agent: Callable that takes an observation and returns an action.
        num_steps: Maximum steps per episode.
        mode: Agent mode ("red" or "blue").

    Returns:
        Dictionary with episode summary: observations, actions, rewards,
        total_reward, steps, done.
    """
    obs = env.reset()
    trajectory: list[dict[str, Any]] = []
    total_reward = 0.0

    for step in range(num_steps):
        action = agent(obs)

        # Ensure mode is set
        if hasattr(action, "mode"):
            action.mode = mode

        obs = env.step(action)

        reward = getattr(obs, "reward", 0.0) or 0.0
        total_reward += reward

        trajectory.append({
            "step": step,
            "action": action,
            "observation": obs,
            "reward": reward,
            "done": getattr(obs, "done", False),
        })

        if getattr(obs, "done", False):
            break

    return {
        "trajectory": trajectory,
        "total_reward": total_reward,
        "steps": len(trajectory),
        "done": getattr(obs, "done", False),
    }


def rollout_func_sync(
    env: Any,
    agent: AgentCallable,
    num_steps: int = 100,
    mode: str = "red",
) -> dict[str, Any]:
    """Synchronous wrapper — now just delegates to rollout_func directly.

    Kept for backward compatibility with callers that import this name.
    """
    return rollout_func(env, agent, num_steps, mode)