"""Rollout function for TRL GRPOTrainer integration. DEFERRED -- this is a stub. The environment must work first; anyone can plug in TRL/Unsloth/SkyRL later via this rollout_func. Usage with GRPOTrainer:: from open_range.training.rollout import rollout_func trainer = GRPOTrainer( ..., rollout_func=rollout_func, ) """ from __future__ import annotations from typing import Any, Protocol class AgentCallable(Protocol): """Minimal agent interface for rollout.""" def __call__(self, observation: Any) -> Any: ... def rollout_func( env: Any, agent: AgentCallable, num_steps: int = 100, mode: str = "red", ) -> dict[str, Any]: """Run a single episode rollout. Args: env: An OpenRange environment (RangeEnvironment or EnvClient). agent: Callable that takes an observation and returns an action. num_steps: Maximum steps per episode. mode: Agent mode ("red" or "blue"). Returns: Dictionary with episode summary: observations, actions, rewards, total_reward, steps, done. """ obs = env.reset() trajectory: list[dict[str, Any]] = [] total_reward = 0.0 for step in range(num_steps): action = agent(obs) # Ensure mode is set if hasattr(action, "mode"): action.mode = mode obs = env.step(action) reward = getattr(obs, "reward", 0.0) or 0.0 total_reward += reward trajectory.append({ "step": step, "action": action, "observation": obs, "reward": reward, "done": getattr(obs, "done", False), }) if getattr(obs, "done", False): break return { "trajectory": trajectory, "total_reward": total_reward, "steps": len(trajectory), "done": getattr(obs, "done", False), } def rollout_func_sync( env: Any, agent: AgentCallable, num_steps: int = 100, mode: str = "red", ) -> dict[str, Any]: """Synchronous wrapper — now just delegates to rollout_func directly. Kept for backward compatibility with callers that import this name. """ return rollout_func(env, agent, num_steps, mode)