Spaces:
Runtime error
Runtime error
File size: 2,241 Bytes
8c486a8 3ea4118 8c486a8 3ea4118 8c486a8 3ea4118 8c486a8 3ea4118 8c486a8 3ea4118 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """Rollout function for TRL GRPOTrainer integration.
DEFERRED -- this is a stub. The environment must work first; anyone can
plug in TRL/Unsloth/SkyRL later via this rollout_func.
Usage with GRPOTrainer::
from open_range.training.rollout import rollout_func
trainer = GRPOTrainer(
...,
rollout_func=rollout_func,
)
"""
from __future__ import annotations
from typing import Any, Protocol
class AgentCallable(Protocol):
"""Minimal agent interface for rollout."""
def __call__(self, observation: Any) -> Any: ...
def rollout_func(
env: Any,
agent: AgentCallable,
num_steps: int = 100,
mode: str = "red",
) -> dict[str, Any]:
"""Run a single episode rollout.
Args:
env: An OpenRange environment (RangeEnvironment or EnvClient).
agent: Callable that takes an observation and returns an action.
num_steps: Maximum steps per episode.
mode: Agent mode ("red" or "blue").
Returns:
Dictionary with episode summary: observations, actions, rewards,
total_reward, steps, done.
"""
obs = env.reset()
trajectory: list[dict[str, Any]] = []
total_reward = 0.0
for step in range(num_steps):
action = agent(obs)
# Ensure mode is set
if hasattr(action, "mode"):
action.mode = mode
obs = env.step(action)
reward = getattr(obs, "reward", 0.0) or 0.0
total_reward += reward
trajectory.append({
"step": step,
"action": action,
"observation": obs,
"reward": reward,
"done": getattr(obs, "done", False),
})
if getattr(obs, "done", False):
break
return {
"trajectory": trajectory,
"total_reward": total_reward,
"steps": len(trajectory),
"done": getattr(obs, "done", False),
}
def rollout_func_sync(
env: Any,
agent: AgentCallable,
num_steps: int = 100,
mode: str = "red",
) -> dict[str, Any]:
"""Synchronous wrapper — now just delegates to rollout_func directly.
Kept for backward compatibility with callers that import this name.
"""
return rollout_func(env, agent, num_steps, mode)
|