openreward-echo-env / server.py
AdithyaSK's picture
AdithyaSK HF Staff
Initial echo env (TRL Γ— OpenReward test fixture)
fc2f931 verified
"""Minimal Open Reward Standard environment for testing TRL Γ— OpenReward.
Deployed at:
https://huggingface.co/spaces/trl-internal-testing/openreward-echo-env
The model is given a target string and must call ``echo(text=...)`` with
exactly that string. Reward is 1.0 on match, 0.0 otherwise; the episode
finishes on a correct echo.
Pure Python β€” no sandbox, no external state β€” so the response is
deterministic and the env can run thousands of concurrent sessions on
free-tier hardware. Useful as:
1. The fixture target for `tests/experimental/test_openreward.py` in TRL.
2. A reference implementation for "what does a minimal ORS env look like?"
Run locally:
pip install fastapi uvicorn openreward
python server.py # listens on :8000
Then:
from trl.experimental.openreward import OpenRewardSpec
spec = OpenRewardSpec("http://localhost:8000")
print(spec.train_dataset)
"""
from pydantic import BaseModel
from openreward.environments import (
Environment,
JSONObject,
Server,
TextBlock,
ToolOutput,
tool,
)
# ── Tasks ────────────────────────────────────────────────────────────
# Each task gives the model a target string to echo back. Keep this list
# small and deterministic β€” it's a CI fixture, not a benchmark.
TRAIN_TASKS: list[JSONObject] = [
{"id": "echo-0", "target": "hello"},
{"id": "echo-1", "target": "world"},
{"id": "echo-2", "target": "trl"},
{"id": "echo-3", "target": "openreward"},
{"id": "echo-4", "target": "spec"},
{"id": "echo-5", "target": "factory"},
{"id": "echo-6", "target": "dataset"},
{"id": "echo-7", "target": "reward"},
]
class EchoTaskSpec(BaseModel):
id: str
target: str
class EchoParams(BaseModel):
text: str
class EchoEnvironment(Environment):
"""A tiny ORS env: echo the target string to win.
The single ``echo`` tool returns reward=1.0 with finished=True iff the
submitted ``text`` exactly matches the task's target, else reward=0.0
with finished=False so the model can keep trying.
"""
def __init__(self, task_spec: JSONObject = {}, secrets: dict[str, str] = {}):
super().__init__(task_spec)
self.config = EchoTaskSpec.model_validate(task_spec)
@classmethod
def list_splits(cls) -> list[str]:
return ["train"]
@classmethod
def list_tasks(cls, split: str) -> list[JSONObject]:
if split != "train":
raise ValueError(f"unknown split: {split}")
return TRAIN_TASKS
def get_prompt(self) -> list[TextBlock]:
return [TextBlock(
type="text",
text=(
f"Call the `echo` tool with text='{self.config.target}' to win. "
f"You get reward=1.0 on an exact match and the episode finishes."
),
)]
@tool
async def echo(self, params: EchoParams) -> ToolOutput:
"""Submit a string. Reward 1.0 + finished if it matches the target.
Args:
text: The string to echo back.
"""
correct = params.text == self.config.target
return ToolOutput(
blocks=[TextBlock(
type="text",
text="match" if correct else f"no match (got {params.text!r}, expected {self.config.target!r})",
)],
reward=1.0 if correct else 0.0,
finished=correct,
)
if __name__ == "__main__":
import os
port = int(os.environ.get("PORT", "8080"))
Server([EchoEnvironment]).run(host="0.0.0.0", port=port)