File size: 3,678 Bytes
fc2f931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Minimal Open Reward Standard environment for testing TRL Γ— OpenReward.

Deployed at:
    https://huggingface.co/spaces/trl-internal-testing/openreward-echo-env

The model is given a target string and must call ``echo(text=...)`` with
exactly that string. Reward is 1.0 on match, 0.0 otherwise; the episode
finishes on a correct echo.

Pure Python β€” no sandbox, no external state β€” so the response is
deterministic and the env can run thousands of concurrent sessions on
free-tier hardware. Useful as:

  1. The fixture target for `tests/experimental/test_openreward.py` in TRL.
  2. A reference implementation for "what does a minimal ORS env look like?"

Run locally:

    pip install fastapi uvicorn openreward
    python server.py            # listens on :8000

Then:

    from trl.experimental.openreward import OpenRewardSpec
    spec = OpenRewardSpec("http://localhost:8000")
    print(spec.train_dataset)
"""

from pydantic import BaseModel

from openreward.environments import (
    Environment,
    JSONObject,
    Server,
    TextBlock,
    ToolOutput,
    tool,
)


# ── Tasks ────────────────────────────────────────────────────────────

# Each task gives the model a target string to echo back. Keep this list
# small and deterministic β€” it's a CI fixture, not a benchmark.
TRAIN_TASKS: list[JSONObject] = [
    {"id": "echo-0", "target": "hello"},
    {"id": "echo-1", "target": "world"},
    {"id": "echo-2", "target": "trl"},
    {"id": "echo-3", "target": "openreward"},
    {"id": "echo-4", "target": "spec"},
    {"id": "echo-5", "target": "factory"},
    {"id": "echo-6", "target": "dataset"},
    {"id": "echo-7", "target": "reward"},
]


class EchoTaskSpec(BaseModel):
    id: str
    target: str


class EchoParams(BaseModel):
    text: str


class EchoEnvironment(Environment):
    """A tiny ORS env: echo the target string to win.

    The single ``echo`` tool returns reward=1.0 with finished=True iff the
    submitted ``text`` exactly matches the task's target, else reward=0.0
    with finished=False so the model can keep trying.
    """

    def __init__(self, task_spec: JSONObject = {}, secrets: dict[str, str] = {}):
        super().__init__(task_spec)
        self.config = EchoTaskSpec.model_validate(task_spec)

    @classmethod
    def list_splits(cls) -> list[str]:
        return ["train"]

    @classmethod
    def list_tasks(cls, split: str) -> list[JSONObject]:
        if split != "train":
            raise ValueError(f"unknown split: {split}")
        return TRAIN_TASKS

    def get_prompt(self) -> list[TextBlock]:
        return [TextBlock(
            type="text",
            text=(
                f"Call the `echo` tool with text='{self.config.target}' to win. "
                f"You get reward=1.0 on an exact match and the episode finishes."
            ),
        )]

    @tool
    async def echo(self, params: EchoParams) -> ToolOutput:
        """Submit a string. Reward 1.0 + finished if it matches the target.

        Args:
            text: The string to echo back.
        """
        correct = params.text == self.config.target
        return ToolOutput(
            blocks=[TextBlock(
                type="text",
                text="match" if correct else f"no match (got {params.text!r}, expected {self.config.target!r})",
            )],
            reward=1.0 if correct else 0.0,
            finished=correct,
        )


if __name__ == "__main__":
    import os
    port = int(os.environ.get("PORT", "8080"))
    Server([EchoEnvironment]).run(host="0.0.0.0", port=port)