Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"OpenEnv Integration for Training LLMs with Environments","local":"openenv-integration-for-training-llms-with-environments","sections":[{"title":"Overview","local":"overview","sections":[],"depth":2},{"title":"Installation","local":"installation","sections":[],"depth":2},{"title":"Using rollout_func with OpenEnv environments","local":"using-rolloutfunc-with-openenv-environments","sections":[{"title":"Rollout Function Signature","local":"rollout-function-signature","sections":[],"depth":3},{"title":"Integration pattern","local":"integration-pattern","sections":[],"depth":3}],"depth":2},{"title":"A simple example","local":"a-simple-example","sections":[{"title":"Running the Example","local":"running-the-example","sections":[],"depth":3}],"depth":2},{"title":"Advanced Example","local":"advanced-example","sections":[{"title":"The TextArena Environment","local":"the-textarena-environment","sections":[],"depth":3},{"title":"Wordle","local":"wordle","sections":[],"depth":3},{"title":"Rollout Function","local":"rollout-function","sections":[],"depth":3},{"title":"Reward Functions","local":"reward-functions","sections":[],"depth":3},{"title":"Training the Model","local":"training-the-model","sections":[],"depth":3},{"title":"Running the Example","local":"running-the-example","sections":[],"depth":3},{"title":"Results","local":"results","sections":[],"depth":3}],"depth":2}],"depth":1}"> | |
| <link href="/docs/trl/pr_4331/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/entry/start.6bbbc54b.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/scheduler.7b731bd4.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/singletons.55eb59f9.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/index.ac28c20f.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/paths.677b038d.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/entry/app.b003256e.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/preload-helper.71df5523.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/index.cc268345.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/nodes/0.c996cd3a.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/nodes/34.d235322e.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.d403d039.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/CodeBlock.17bc4142.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"OpenEnv Integration for Training LLMs with Environments","local":"openenv-integration-for-training-llms-with-environments","sections":[{"title":"Overview","local":"overview","sections":[],"depth":2},{"title":"Installation","local":"installation","sections":[],"depth":2},{"title":"Using rollout_func with OpenEnv environments","local":"using-rolloutfunc-with-openenv-environments","sections":[{"title":"Rollout Function Signature","local":"rollout-function-signature","sections":[],"depth":3},{"title":"Integration pattern","local":"integration-pattern","sections":[],"depth":3}],"depth":2},{"title":"A simple example","local":"a-simple-example","sections":[{"title":"Running the Example","local":"running-the-example","sections":[],"depth":3}],"depth":2},{"title":"Advanced Example","local":"advanced-example","sections":[{"title":"The TextArena Environment","local":"the-textarena-environment","sections":[],"depth":3},{"title":"Wordle","local":"wordle","sections":[],"depth":3},{"title":"Rollout Function","local":"rollout-function","sections":[],"depth":3},{"title":"Reward Functions","local":"reward-functions","sections":[],"depth":3},{"title":"Training the Model","local":"training-the-model","sections":[],"depth":3},{"title":"Running the Example","local":"running-the-example","sections":[],"depth":3},{"title":"Results","local":"results","sections":[],"depth":3}],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="openenv-integration-for-training-llms-with-environments" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#openenv-integration-for-training-llms-with-environments"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>OpenEnv Integration for Training LLMs with Environments</span></h1> <h2 class="relative group"><a id="overview" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#overview"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Overview</span></h2> <p data-svelte-h="svelte-mz6ziy"><a href="https://github.com/meta-pytorch/OpenEnv" rel="nofollow">OpenEnv</a> is an open-source framework from Meta’s PyTorch team for defining, deploying, and interacting with environments in reinforcement learning (RL) and agentic workflows. It offers <a href="https://gymnasium.farama.org" rel="nofollow">Gymnasium-style APIs</a> (e.g., <code>reset()</code> and <code>step()</code>) to interface with environments in a standard manner, and supports running these environments as backend servers (for example via HTTP or containerised execution). You can find a collection of ready-to-use OpenEnv environments on the <a href="https://huggingface.co/collections/openenv/environment-hub" rel="nofollow">Hugging Face Hub</a>.</p> <p data-svelte-h="svelte-n1ld35">In this guide, we’ll focus on <strong>how to integrate OpenEnv with TRL</strong>, but feel free to explore the links above to dive deeper into OpenEnv itself.</p> <h2 class="relative group"><a id="installation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#installation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Installation</span></h2> <p data-svelte-h="svelte-jlmwtz">To use OpenEnv with TRL, install the framework:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install openenv-core<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="using-rolloutfunc-with-openenv-environments" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-rolloutfunc-with-openenv-environments"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using rollout_func with OpenEnv environments</span></h2> <p data-svelte-h="svelte-xv3ktu">TRL’s <a href="/docs/trl/pr_4331/en/grpo_trainer#trl.GRPOTrainer">GRPOTrainer</a> supports <em>custom rollout logic</em> through the <code>rollout_func</code> argument. This lets you override the trainer’s default text-generation loop and directly interact with OpenEnv environments — for instance, to compute environment-driven rewards instead of relying solely on model-based signals.</p> <h3 class="relative group"><a id="rollout-function-signature" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#rollout-function-signature"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Rollout Function Signature</span></h3> <p data-svelte-h="svelte-xffp71">A rollout function must have the following signature:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">rollout_func</span>(<span class="hljs-params"> | |
| prompts: <span class="hljs-built_in">list</span>[<span class="hljs-built_in">str</span>], | |
| args: GRPOConfig, | |
| processing_class | |
| </span>) -> <span class="hljs-built_in">dict</span>[<span class="hljs-built_in">str</span>, <span class="hljs-built_in">list</span>]: | |
| <span class="hljs-string">""" | |
| Custom rollout function for generation and reward computation. | |
| Args: | |
| prompts: List of prompts to generate from | |
| args: GRPOConfig containing sampling parameters (temperature, top_p, etc.) | |
| processing_class: Tokenizer/processor for encoding/decoding | |
| Returns: | |
| Dictionary containing: | |
| - prompt_ids: List of token IDs for each prompt | |
| - completion_ids: List of token IDs for each completion | |
| - logprobs: List of log probabilities for each token | |
| - Any additional fields are forwarded to reward functions as kwargs | |
| """</span> | |
| <span class="hljs-keyword">pass</span><!-- HTML_TAG_END --></pre></div> <blockquote class="note" data-svelte-h="svelte-dj7sqq"><p>Any extra fields in the returned dictionary (beyond the required three) are automatically forwarded to your reward functions. This makes it easy to propagate signals such as environment rewards or auxiliary metrics from the rollout step.</p></blockquote> <h3 class="relative group"><a id="integration-pattern" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#integration-pattern"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Integration pattern</span></h3> <p data-svelte-h="svelte-1r3xw2z">The typical pattern when combining OpenEnv with TRL looks like this:</p> <ol data-svelte-h="svelte-ksntjq"><li>Start or connect to an OpenEnv environment (e.g., an HTTP endpoint or Dockerized env).</li> <li>Generate completions from your model — for example, via a vLLM inference server (<code>use_vllm=True</code>, <code>vllm_mode="server"</code>).</li> <li>Step through the environment using each completion to compute rewards or metrics.</li> <li>Add environment results (e.g., <code>env_reward</code>) to the rollout result dict.</li> <li>Access those rewards inside your reward function via <code>**kwargs</code>.</li></ol> <p data-svelte-h="svelte-1o9klxq">By using OpenEnv in this loop, you can:</p> <ul data-svelte-h="svelte-1fcm4d"><li>Train with realistic or interactive feedback (not just static reward functions).</li> <li>Plug in custom simulators, web APIs, or evaluators as environments.</li> <li>Pass structured reward signals back into RL training seamlessly.</li></ul> <h2 class="relative group"><a id="a-simple-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#a-simple-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>A simple example</span></h2> <p data-svelte-h="svelte-vpa6kv">The <a href="https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py" rel="nofollow">echo.py</a> script demonstrates a minimal, end-to-end integration between TRL and OpenEnv. In this example, the Echo environment rewards completions based on their text length, encouraging the model to generate longer outputs. This pattern can be extended to any custom environment that provides structured feedback or task-based rewards:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> envs.echo_env <span class="hljs-keyword">import</span> EchoEnv, EchoAction | |
| <span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> GRPOConfig, GRPOTrainer | |
| <span class="hljs-comment"># Create HTTP client for Echo Environment</span> | |
| client = EchoEnv.from_docker_image(<span class="hljs-string">"echo-env:latest"</span>) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">rollout_func</span>(<span class="hljs-params">prompts, args, processing_class</span>): | |
| <span class="hljs-comment"># 1. Generate completions via vLLM inference server (running on port 8000)</span> | |
| payload = { | |
| <span class="hljs-string">"prompts"</span>: prompts, | |
| <span class="hljs-string">"n"</span>: args.num_generations, | |
| <span class="hljs-string">"temperature"</span>: args.temperature, | |
| <span class="hljs-string">"max_tokens"</span>: args.max_completion_length, | |
| } | |
| response = requests.post(<span class="hljs-string">"http://0.0.0.0:8000/generate/"</span>, json=payload) | |
| result = response.json() | |
| completions_text = processing_class.batch_decode( | |
| result[<span class="hljs-string">"completion_ids"</span>], | |
| skip_special_tokens=<span class="hljs-literal">True</span> | |
| ) | |
| <span class="hljs-comment"># 2. Step through the environment to get rewards</span> | |
| client.reset() | |
| env_rewards = [] | |
| <span class="hljs-keyword">for</span> msg <span class="hljs-keyword">in</span> completions_text: | |
| env_result = client.step(EchoAction(message=msg)) | |
| env_rewards.append(env_result.reward) | |
| <span class="hljs-comment"># 3. Add environment rewards as extra field</span> | |
| result[<span class="hljs-string">"env_reward"</span>] = env_rewards | |
| <span class="hljs-keyword">return</span> result | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_from_env</span>(<span class="hljs-params">completions, **kwargs</span>): | |
| <span class="hljs-string">"""Extract environment rewards passed via rollout_func kwargs."""</span> | |
| env_rewards = kwargs.get(<span class="hljs-string">"env_reward"</span>, []) | |
| <span class="hljs-keyword">return</span> [<span class="hljs-built_in">float</span>(reward) <span class="hljs-keyword">for</span> reward <span class="hljs-keyword">in</span> env_rewards] <span class="hljs-keyword">if</span> env_rewards <span class="hljs-keyword">else</span> [<span class="hljs-number">0.0</span>] * <span class="hljs-built_in">len</span>(completions) | |
| dataset = Dataset.from_dict({<span class="hljs-string">"prompt"</span>: [<span class="hljs-string">"You are an AI that interacts with an *Echo* environment. Word to echo:"</span>] * <span class="hljs-number">64</span>}) | |
| <span class="hljs-comment"># Setup trainer with custom rollout</span> | |
| trainer = GRPOTrainer( | |
| model=<span class="hljs-string">"Qwen/Qwen2.5-0.5B-Instruct"</span>, | |
| reward_funcs=reward_from_env, | |
| train_dataset=dataset, | |
| rollout_func=rollout_func, <span class="hljs-comment"># Use custom rollout</span> | |
| args=GRPOConfig( | |
| vllm_mode=<span class="hljs-string">"server"</span>, | |
| use_vllm=<span class="hljs-literal">True</span>, | |
| num_train_epochs=<span class="hljs-number">1</span>, | |
| num_generations=<span class="hljs-number">8</span>, | |
| max_completion_length=<span class="hljs-number">2048</span>, | |
| per_device_train_batch_size=<span class="hljs-number">8</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">4</span>, | |
| ), | |
| ) | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1m5nu2f">That’s it! Now that you’ve seen the full example, let’s unpack how the main pieces fit together.</p> <ol data-svelte-h="svelte-17lxckh"><li><strong>Environment Client:</strong> <code>EchoEnv</code> implements an HTTP interface to interact with the environment server.</li> <li><strong>Custom rollout:</strong> The <code>rollout_func</code> generates completions and steps through the environment to collect rewards.</li> <li><strong>Extra fields:</strong> The rollout adds <code>env_reward</code> to the result dictionary, which is automatically passed to reward functions.</li> <li><strong>Reward function:</strong> Extracts <code>env_reward</code> from <code>kwargs</code> to apply environment-computed rewards during training.</li></ol> <blockquote class="warning" data-svelte-h="svelte-10dmcfe"><p>The <code>rollout_func</code> is currently only supported when using vLLM in server mode (<code>use_vllm=True</code>, <code>vllm_mode="server"</code>).</p></blockquote> <h3 class="relative group"><a id="running-the-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#running-the-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Running the Example</span></h3> <p data-svelte-h="svelte-1k72py">The example requires two GPUs:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Terminal 1: Start vLLM inference server</span> | |
| CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 | |
| <span class="hljs-comment"># Terminal 2: Run GRPO training with OpenEnv</span> | |
| CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-klksjg">Below is the reward curve from training:</p> <iframe src="https://trl-lib-trackio.hf.space?project=openenv&metrics=train/rewards/reward_from_env/mean&runs=qgallouedec-1761202871&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe> <p data-svelte-h="svelte-y71kuk">To learn more about how to create custom environments, see the <a href="https://github.com/meta-pytorch/OpenEnv/blob/main/src/envs/README.md" rel="nofollow">OpenEnv documentation</a>.</p> <h2 class="relative group"><a id="advanced-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#advanced-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Advanced Example</span></h2> <p data-svelte-h="svelte-1uypb7y">Let’s level this up a bit by training a model to interact with a more complex environment. We’ll use the game word guessing game <a href="https://www.nytimes.com/games/wordle/index.html" rel="nofollow">wordle</a> from the <code>textarena</code> environment.</p> <h3 class="relative group"><a id="the-textarena-environment" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-textarena-environment"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The TextArena Environment</span></h3> <p data-svelte-h="svelte-amaete"><a href="https://huggingface.co/papers/2504.11442" rel="nofollow">TextArena</a> is an open-source collection of competitive text-based games designed to evaluate reasoning skills in LLMs using textual games like Wordle, Snake, Tic-Tac-Toe, and more. Research has shown that such games improve model performance on reasoning tasks.</p> <p data-svelte-h="svelte-14dq4rw"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/text_arena_evals.png" alt="image of textarena"></p> <p data-svelte-h="svelte-b68nqy">We will use the <code>textarena</code> environment to train a model to play Wordle. The environment is a simple text based response environment that allows the model to interact with the game by making guesses and receive feedback on them.</p> <h3 class="relative group"><a id="wordle" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#wordle"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Wordle</span></h3> <p data-svelte-h="svelte-1moxxfn">Wordle is a useful game to train a model on because it requires the model to reason about the word and the feedback provided by the environment. Also, it is a purely language based game that requires no external tools or knowledge. Furthermore, we found that models from 1 billion parameters and up are able to improve on wordle and only require 8 tokens to generate a guess, which makes the game a good benchmark to experiment with Reinforcement Learning environments without significant compute requirements.</p> <blockquote class="note"><p data-svelte-h="svelte-4qgqnr">How does Wordle work? | |
| Wordle is a word guessing game where the player has to guess a 5-letter word. The player can make 6 guesses, and for each guess, the environment will provide feedback on the correctness of the guess. The player wins if they guess the word in 6 guesses or less. It challenges the model to generate words that are likely to be correct, and to learn from the feedback provided by the environment.</p> <p data-svelte-h="svelte-u6inub">For example, if the wordle environment returns the following feedback:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->G U E S S | |
| <span class="hljs-keyword">X</span> G <span class="hljs-keyword">Y</span> <span class="hljs-keyword">X</span> <span class="hljs-keyword">X</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lo3rxv">The model has guessed the word “GUESS” and the environment has provided feedback as the letters X, G, and Y. Referring to colors in the original game blank, green, and yellow. From this feedback, the model should learn that the word is “GUESS” is incorrect. The letter “E” is in the word, but in the wrong position. The letter “U” is correct and in the correct position.</p></blockquote> <p data-svelte-h="svelte-11t6784">In the TextArena environment, reward is only given when the model wins the game. The reward is 1.0 if the model wins, and 0.0 otherwise. This is not a very efficient reward signal for the model, so we have added a number of custom reward functions to the script to help the model learn to play the game. The extensible nature of <code>reward_funcs</code> and <code>rollout_func</code> allows you to add any custom reward function you want to the script.</p> <h3 class="relative group"><a id="rollout-function" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#rollout-function"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Rollout Function</span></h3> <p data-svelte-h="svelte-11hv3wq">The rollout function runs one full Wordle episode, prompting the model for a guess each turn and capturing both environment rewards and auxiliary signals such as letter coverage and repetition penalties.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">rollout_once</span>(<span class="hljs-params"> | |
| env: TextArenaEnv, | |
| tokenizer: AutoTokenizer, | |
| args: GRPOConfig, | |
| dataset_prompt: <span class="hljs-built_in">str</span>, | |
| cli_args: argparse.Namespace, | |
| system_prompt: <span class="hljs-built_in">str</span>, | |
| </span>) -> <span class="hljs-built_in">dict</span>[<span class="hljs-built_in">str</span>, <span class="hljs-built_in">list</span>]: | |
| result = env.reset() | |
| observation = result.observation | |
| prompt_ids: <span class="hljs-built_in">list</span>[<span class="hljs-built_in">int</span>] = [] | |
| completion_ids: <span class="hljs-built_in">list</span>[<span class="hljs-built_in">int</span>] = [] | |
| logprobs: <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>] = [] | |
| raw_rewards: <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>] = [] | |
| green_scores: <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>] = [] | |
| yellow_scores: <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>] = [] | |
| repetition_scores: <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>] = [] | |
| correct_scores: <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>] = [] | |
| guess_counts: <span class="hljs-built_in">dict</span>[<span class="hljs-built_in">str</span>, <span class="hljs-built_in">int</span>] = {} | |
| <span class="hljs-keyword">for</span> _turn <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(cli_args.max_turns): | |
| <span class="hljs-comment"># when the game is over the environment will return a done=True</span> | |
| <span class="hljs-keyword">if</span> result.done: | |
| <span class="hljs-keyword">break</span> | |
| <span class="hljs-comment"># set up the prompt for the model</span> | |
| base_prompt = observation.prompt <span class="hljs-keyword">or</span> dataset_prompt | |
| user_prompt = make_user_prompt(base_prompt, observation.messages) | |
| messages = [ | |
| {<span class="hljs-string">"role"</span>: <span class="hljs-string">"system"</span>, <span class="hljs-string">"content"</span>: system_prompt}, | |
| {<span class="hljs-string">"role"</span>: <span class="hljs-string">"user"</span>, <span class="hljs-string">"content"</span>: user_prompt}, | |
| ] | |
| prompt_text = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=<span class="hljs-literal">True</span>, | |
| tokenize=<span class="hljs-literal">False</span>, | |
| enable_thinking=<span class="hljs-literal">False</span>, | |
| ) | |
| <span class="hljs-comment"># generate the completion from the model using vLLM</span> | |
| vllm_result = request_vllm_completion( | |
| prompt_text, | |
| args, | |
| endpoint=cli_args.vllm_endpoint, | |
| timeout=cli_args.request_timeout, | |
| fallback=cli_args, | |
| ) | |
| prompt_ids.extend(vllm_result[<span class="hljs-string">"prompt_ids"</span>]) | |
| completion_ids.extend(vllm_result[<span class="hljs-string">"completion_ids"</span>]) | |
| logprobs.extend(vllm_result[<span class="hljs-string">"logprobs"</span>]) | |
| completion_text = vllm_result.get(<span class="hljs-string">"text"</span>) <span class="hljs-keyword">or</span> tokenizer.decode( | |
| vllm_result[<span class="hljs-string">"completion_ids"</span>], skip_special_tokens=<span class="hljs-literal">True</span> | |
| ) | |
| <span class="hljs-comment"># extract the guess from the completion</span> | |
| guess = extract_guess(completion_text) | |
| <span class="hljs-comment"># step the environment with the guess</span> | |
| result = env.step(TextArenaAction(message=guess)) | |
| raw_rewards.append(<span class="hljs-built_in">float</span>(result.reward <span class="hljs-keyword">or</span> <span class="hljs-number">0.0</span>)) | |
| observation = result.observation | |
| correct_score = <span class="hljs-built_in">float</span>(result.reward <span class="hljs-keyword">or</span> <span class="hljs-number">0.0</span>) | |
| feedback = extract_wordle_feedback(observation) | |
| <span class="hljs-comment"># Update guess counts</span> | |
| previous_occurrences = guess_counts[guess] | |
| repetition_score = scale_repetition_score(previous_occurrences, <span class="hljs-built_in">len</span>(guess_counts)) | |
| guess_counts[guess] += <span class="hljs-number">1</span> | |
| <span class="hljs-comment"># calculate custom reward signals from the feedback</span> | |
| <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> feedback: | |
| green_score = <span class="hljs-number">0.0</span> | |
| yellow_score = <span class="hljs-number">0.0</span> | |
| <span class="hljs-keyword">else</span>: | |
| green_count, yellow_count = extract_feedback_counts(feedback) | |
| green_score = green_count / <span class="hljs-number">5.0</span> | |
| yellow_score = yellow_count / <span class="hljs-number">5.0</span> | |
| repetition_scores.append(repetition_score) | |
| green_scores.append(green_score) | |
| yellow_scores.append(yellow_score) | |
| correct_scores.append(correct_score) | |
| correct_reward_value = correct_scores[-<span class="hljs-number">1</span>] <span class="hljs-keyword">if</span> correct_scores <span class="hljs-keyword">else</span> (raw_rewards[-<span class="hljs-number">1</span>] <span class="hljs-keyword">if</span> raw_rewards <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span>) | |
| <span class="hljs-keyword">return</span> { | |
| <span class="hljs-string">"prompt_ids"</span>: prompt_ids, | |
| <span class="hljs-string">"completion_ids"</span>: completion_ids, | |
| <span class="hljs-string">"logprobs"</span>: logprobs, | |
| <span class="hljs-string">"raw_rewards"</span>: raw_rewards, | |
| <span class="hljs-string">"correct_reward"</span>: correct_reward_value, | |
| <span class="hljs-string">"green_reward"</span>: green_scores[-<span class="hljs-number">1</span>] <span class="hljs-keyword">if</span> green_scores <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span>, | |
| <span class="hljs-string">"yellow_reward"</span>: yellow_scores[-<span class="hljs-number">1</span>] <span class="hljs-keyword">if</span> yellow_scores <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span>, | |
| <span class="hljs-string">"repetition_reward"</span>: repetition_scores[-<span class="hljs-number">1</span>] <span class="hljs-keyword">if</span> repetition_scores <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span>, | |
| }<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1c2qvit">The environment has a reward signal based on the completion of the game. We found that most models struggle to ever win the game, so we have added a number of custom reward functions to the script to help the model learn to play the game more iteratively. At first, the model will learn to cover new letters and avoid repeating guesses. As it improves, it will learn to win the game.</p> <h3 class="relative group"><a id="reward-functions" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#reward-functions"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Reward Functions</span></h3> <p data-svelte-h="svelte-modcgz">We log four reward streams that encourage the model to solve the puzzle, cover new letters, and avoid repeating guesses:</p> <ul data-svelte-h="svelte-1tq1s16"><li><code>reward_correct</code>: final win/loss signal from the environment.</li> <li><code>reward_greens</code>: density of green letters in the last feedback.</li> <li><code>reward_yellows</code>: density of yellow letters in the last feedback.</li> <li><code>reward_repetition</code>: penalty for guessing the same token multiple times.</li></ul> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_correct</span>(<span class="hljs-params">completions: <span class="hljs-type">List</span>[<span class="hljs-built_in">str</span>], **kwargs: <span class="hljs-type">Optional</span>[<span class="hljs-type">Dict</span>]</span>) -> <span class="hljs-type">List</span>[<span class="hljs-built_in">float</span>]: | |
| rewards = kwargs.get(<span class="hljs-string">"correct_reward"</span>) <span class="hljs-keyword">if</span> kwargs <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span> | |
| <span class="hljs-keyword">return</span> [<span class="hljs-built_in">float</span>(r) <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> rewards] <span class="hljs-keyword">if</span> rewards <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">else</span> [<span class="hljs-number">0.0</span>] * <span class="hljs-built_in">len</span>(completions) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_greens</span>(<span class="hljs-params">completions: <span class="hljs-type">List</span>[<span class="hljs-built_in">str</span>], **kwargs: <span class="hljs-type">Optional</span>[<span class="hljs-type">Dict</span>]</span>) -> <span class="hljs-type">List</span>[<span class="hljs-built_in">float</span>]: | |
| rewards = kwargs.get(<span class="hljs-string">"green_reward"</span>) <span class="hljs-keyword">if</span> kwargs <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span> | |
| <span class="hljs-keyword">return</span> [<span class="hljs-built_in">float</span>(r) <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> rewards] <span class="hljs-keyword">if</span> rewards <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">else</span> [<span class="hljs-number">0.0</span>] * <span class="hljs-built_in">len</span>(completions) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_yellows</span>(<span class="hljs-params">completions: <span class="hljs-type">List</span>[<span class="hljs-built_in">str</span>], **kwargs: <span class="hljs-type">Optional</span>[<span class="hljs-type">Dict</span>]</span>) -> <span class="hljs-type">List</span>[<span class="hljs-built_in">float</span>]: | |
| rewards = kwargs.get(<span class="hljs-string">"yellow_reward"</span>) <span class="hljs-keyword">if</span> kwargs <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span> | |
| <span class="hljs-keyword">return</span> [<span class="hljs-built_in">float</span>(r) <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> rewards] <span class="hljs-keyword">if</span> rewards <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">else</span> [<span class="hljs-number">0.0</span>] * <span class="hljs-built_in">len</span>(completions) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_repetition</span>(<span class="hljs-params">completions: <span class="hljs-type">List</span>[<span class="hljs-built_in">str</span>], **kwargs: <span class="hljs-type">Optional</span>[<span class="hljs-type">Dict</span>]</span>) -> <span class="hljs-type">List</span>[<span class="hljs-built_in">float</span>]: | |
| rewards = kwargs.get(<span class="hljs-string">"repetition_reward"</span>) <span class="hljs-keyword">if</span> kwargs <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span> | |
| <span class="hljs-keyword">return</span> [<span class="hljs-built_in">float</span>(r) <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> rewards] <span class="hljs-keyword">if</span> rewards <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">else</span> [<span class="hljs-number">0.0</span>] * <span class="hljs-built_in">len</span>(completions)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="training-the-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training-the-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training the Model</span></h3> <p data-svelte-h="svelte-3udbcx">The training script wires the custom rollout and rewards into <code>GRPOTrainer</code>. The CLI exposes the configuration used during development as defaults, so you can override endpoints or hyperparameters at launch time.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->parser = argparse.ArgumentParser() | |
| <span class="hljs-comment"># ... add CLI arguments with sensible defaults ...</span> | |
| cli_args = parser.parse_args() | |
| trainer = GRPOTrainer( | |
| model=cli_args.model_id, | |
| processing_class=tokenizer, | |
| reward_funcs=[ | |
| reward_correct, | |
| reward_greens, | |
| reward_yellows, | |
| reward_repetition, | |
| ], | |
| train_dataset=dataset, | |
| args=grpo_config, | |
| rollout_func=<span class="hljs-keyword">lambda</span> prompts, args, processing_class: rollout_func( | |
| env=env, | |
| tokenizer=tokenizer, | |
| prompts=prompts, | |
| args=args, | |
| cli_args=cli_args, | |
| system_prompt=system_prompt, | |
| ), | |
| ) | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="running-the-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#running-the-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Running the Example</span></h3> <p data-svelte-h="svelte-1k72py">The example requires two GPUs:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Terminal 1: Start vLLM inference server</span> | |
| CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000 | |
| <span class="hljs-comment"># Terminal 2: Run GRPO training with OpenEnv</span> | |
| CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="results" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#results"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Results</span></h3> <p data-svelte-h="svelte-clpkc9">The resulting model improves it’s performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model’s guesses and the coverage of correct Y and G letters.</p> <iframe src="https://burtenshaw-wordle-grpo.hf.space?project=group-Qwen-Qwen3-17B&metrics=reward&runs=run-2025-10-26_09-39-49,run-2025-10-26_08-04-49&sidebar=hidden&navbar=hidden" style="width:1600px; height:500px; border:0;"></iframe> <p data-svelte-h="svelte-xgom2y">We experimented larger models like <code>gpt-oss-20b</code> and found that model was able to consistently win the game. However, this requires a lot of compute to train and the model. Why not try this out yourself?</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/openenv.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_l5mord = { | |
| assets: "/docs/trl/pr_4331/en", | |
| base: "/docs/trl/pr_4331/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/trl/pr_4331/en/_app/immutable/entry/start.6bbbc54b.js"), | |
| import("/docs/trl/pr_4331/en/_app/immutable/entry/app.b003256e.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 34], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 66 kB
- Xet hash:
- a4fb34952ee6aacc07e1838845df647e26f7ae55c8a75d6ab754344f45f48efd
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.