Aaron Brown
Cleanup: fix bugs, remove dead code, add missing packages
3ea4118
"""Rollout function for TRL GRPOTrainer integration.
DEFERRED -- this is a stub. The environment must work first; anyone can
plug in TRL/Unsloth/SkyRL later via this rollout_func.
Usage with GRPOTrainer::
from open_range.training.rollout import rollout_func
trainer = GRPOTrainer(
...,
rollout_func=rollout_func,
)
"""
from __future__ import annotations
from typing import Any, Protocol
class AgentCallable(Protocol):
"""Minimal agent interface for rollout."""
def __call__(self, observation: Any) -> Any: ...
def rollout_func(
env: Any,
agent: AgentCallable,
num_steps: int = 100,
mode: str = "red",
) -> dict[str, Any]:
"""Run a single episode rollout.
Args:
env: An OpenRange environment (RangeEnvironment or EnvClient).
agent: Callable that takes an observation and returns an action.
num_steps: Maximum steps per episode.
mode: Agent mode ("red" or "blue").
Returns:
Dictionary with episode summary: observations, actions, rewards,
total_reward, steps, done.
"""
obs = env.reset()
trajectory: list[dict[str, Any]] = []
total_reward = 0.0
for step in range(num_steps):
action = agent(obs)
# Ensure mode is set
if hasattr(action, "mode"):
action.mode = mode
obs = env.step(action)
reward = getattr(obs, "reward", 0.0) or 0.0
total_reward += reward
trajectory.append({
"step": step,
"action": action,
"observation": obs,
"reward": reward,
"done": getattr(obs, "done", False),
})
if getattr(obs, "done", False):
break
return {
"trajectory": trajectory,
"total_reward": total_reward,
"steps": len(trajectory),
"done": getattr(obs, "done", False),
}
def rollout_func_sync(
env: Any,
agent: AgentCallable,
num_steps: int = 100,
mode: str = "red",
) -> dict[str, Any]:
"""Synchronous wrapper — now just delegates to rollout_func directly.
Kept for backward compatibility with callers that import this name.
"""
return rollout_func(env, agent, num_steps, mode)