| """Minimal Open Reward Standard environment for testing TRL Γ OpenReward. |
| |
| Deployed at: |
| https://huggingface.co/spaces/trl-internal-testing/openreward-echo-env |
| |
| The model is given a target string and must call ``echo(text=...)`` with |
| exactly that string. Reward is 1.0 on match, 0.0 otherwise; the episode |
| finishes on a correct echo. |
| |
| Pure Python β no sandbox, no external state β so the response is |
| deterministic and the env can run thousands of concurrent sessions on |
| free-tier hardware. Useful as: |
| |
| 1. The fixture target for `tests/experimental/test_openreward.py` in TRL. |
| 2. A reference implementation for "what does a minimal ORS env look like?" |
| |
| Run locally: |
| |
| pip install fastapi uvicorn openreward |
| python server.py # listens on :8000 |
| |
| Then: |
| |
| from trl.experimental.openreward import OpenRewardSpec |
| spec = OpenRewardSpec("http://localhost:8000") |
| print(spec.train_dataset) |
| """ |
|
|
| from pydantic import BaseModel |
|
|
| from openreward.environments import ( |
| Environment, |
| JSONObject, |
| Server, |
| TextBlock, |
| ToolOutput, |
| tool, |
| ) |
|
|
|
|
| |
|
|
| |
| |
| TRAIN_TASKS: list[JSONObject] = [ |
| {"id": "echo-0", "target": "hello"}, |
| {"id": "echo-1", "target": "world"}, |
| {"id": "echo-2", "target": "trl"}, |
| {"id": "echo-3", "target": "openreward"}, |
| {"id": "echo-4", "target": "spec"}, |
| {"id": "echo-5", "target": "factory"}, |
| {"id": "echo-6", "target": "dataset"}, |
| {"id": "echo-7", "target": "reward"}, |
| ] |
|
|
|
|
| class EchoTaskSpec(BaseModel): |
| id: str |
| target: str |
|
|
|
|
| class EchoParams(BaseModel): |
| text: str |
|
|
|
|
| class EchoEnvironment(Environment): |
| """A tiny ORS env: echo the target string to win. |
| |
| The single ``echo`` tool returns reward=1.0 with finished=True iff the |
| submitted ``text`` exactly matches the task's target, else reward=0.0 |
| with finished=False so the model can keep trying. |
| """ |
|
|
| def __init__(self, task_spec: JSONObject = {}, secrets: dict[str, str] = {}): |
| super().__init__(task_spec) |
| self.config = EchoTaskSpec.model_validate(task_spec) |
|
|
| @classmethod |
| def list_splits(cls) -> list[str]: |
| return ["train"] |
|
|
| @classmethod |
| def list_tasks(cls, split: str) -> list[JSONObject]: |
| if split != "train": |
| raise ValueError(f"unknown split: {split}") |
| return TRAIN_TASKS |
|
|
| def get_prompt(self) -> list[TextBlock]: |
| return [TextBlock( |
| type="text", |
| text=( |
| f"Call the `echo` tool with text='{self.config.target}' to win. " |
| f"You get reward=1.0 on an exact match and the episode finishes." |
| ), |
| )] |
|
|
| @tool |
| async def echo(self, params: EchoParams) -> ToolOutput: |
| """Submit a string. Reward 1.0 + finished if it matches the target. |
| |
| Args: |
| text: The string to echo back. |
| """ |
| correct = params.text == self.config.target |
| return ToolOutput( |
| blocks=[TextBlock( |
| type="text", |
| text="match" if correct else f"no match (got {params.text!r}, expected {self.config.target!r})", |
| )], |
| reward=1.0 if correct else 0.0, |
| finished=correct, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| import os |
| port = int(os.environ.get("PORT", "8080")) |
| Server([EchoEnvironment]).run(host="0.0.0.0", port=port) |
|
|