Buckets:

hf-doc-build/doc-dev / trl /main /en /quickstart.html
rtrm's picture
download
raw
17.2 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Quickstart&quot;,&quot;local&quot;:&quot;quickstart&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;How does it work?&quot;,&quot;local&quot;:&quot;how-does-it-work&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Minimal example&quot;,&quot;local&quot;:&quot;minimal-example&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;How to use a trained model&quot;,&quot;local&quot;:&quot;how-to-use-a-trained-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/trl/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/start.183b226a.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/scheduler.85c25b89.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/singletons.98fe034d.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/paths.eb9df337.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/index.c142fe32.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/0.5efac18d.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/28.1faa995a.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/CodeBlock.a5e95a57.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/EditOnGithub.a592e7aa.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Quickstart&quot;,&quot;local&quot;:&quot;quickstart&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;How does it work?&quot;,&quot;local&quot;:&quot;how-does-it-work&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Minimal example&quot;,&quot;local&quot;:&quot;minimal-example&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;How to use a trained model&quot;,&quot;local&quot;:&quot;how-to-use-a-trained-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="quickstart" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quickstart"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quickstart</span></h1> <h2 class="relative group"><a id="how-does-it-work" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-does-it-work"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How does it work?</span></h2> <p data-svelte-h="svelte-79li59">Fine-tuning a language model via PPO consists of roughly three steps:</p> <ol data-svelte-h="svelte-1uekaot"><li><strong>Rollout</strong>: The language model generates a response or continuation based on a query which could be the start of a sentence.</li> <li><strong>Evaluation</strong>: The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. The optimization will aim at maximizing this value.</li> <li><strong>Optimization</strong>: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don’t deviate too far from the reference language model. The active language model is then trained with PPO.</li></ol> <p data-svelte-h="svelte-62xfzx">The full process is illustrated in the following figure:</p> <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png"> <h2 class="relative group"><a id="minimal-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#minimal-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Minimal example</span></h2> <p data-svelte-h="svelte-11erqnm">The following code illustrates the steps above.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 0. imports</span>
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> GPT2Tokenizer
<span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
<span class="hljs-comment"># 1. load a pretrained model</span>
model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">&quot;gpt2&quot;</span>)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">&quot;gpt2&quot;</span>)
tokenizer = GPT2Tokenizer.from_pretrained(<span class="hljs-string">&quot;gpt2&quot;</span>)
tokenizer.pad_token = tokenizer.eos_token
<span class="hljs-comment"># 2. initialize trainer</span>
ppo_config = {<span class="hljs-string">&quot;mini_batch_size&quot;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&quot;batch_size&quot;</span>: <span class="hljs-number">1</span>}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
<span class="hljs-comment"># 3. encode a query</span>
query_txt = <span class="hljs-string">&quot;This morning I went to the &quot;</span>
query_tensor = tokenizer.encode(query_txt, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.pretrained_model.device)
<span class="hljs-comment"># 4. generate model response</span>
generation_kwargs = {
<span class="hljs-string">&quot;min_length&quot;</span>: -<span class="hljs-number">1</span>,
<span class="hljs-string">&quot;top_k&quot;</span>: <span class="hljs-number">0.0</span>,
<span class="hljs-string">&quot;top_p&quot;</span>: <span class="hljs-number">1.0</span>,
<span class="hljs-string">&quot;do_sample&quot;</span>: <span class="hljs-literal">True</span>,
<span class="hljs-string">&quot;pad_token_id&quot;</span>: tokenizer.eos_token_id,
<span class="hljs-string">&quot;max_new_tokens&quot;</span>: <span class="hljs-number">20</span>,
}
response_tensor = ppo_trainer.generate([item <span class="hljs-keyword">for</span> item <span class="hljs-keyword">in</span> query_tensor], return_prompt=<span class="hljs-literal">False</span>, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[<span class="hljs-number">0</span>])
<span class="hljs-comment"># 5. define a reward for response</span>
<span class="hljs-comment"># (this could be any reward such as human feedback or output from another model)</span>
reward = [torch.tensor(<span class="hljs-number">1.0</span>, device=model.pretrained_model.device)]
<span class="hljs-comment"># 6. train model with ppo</span>
train_stats = ppo_trainer.step([query_tensor[<span class="hljs-number">0</span>]], [response_tensor[<span class="hljs-number">0</span>]], reward)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1rp76uw">In general, you would run steps 3-6 in a for-loop and run it on many diverse queries. You can find more realistic examples in the examples section.</p> <h2 class="relative group"><a id="how-to-use-a-trained-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-to-use-a-trained-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How to use a trained model</span></h2> <p data-svelte-h="svelte-znbdcu">After training a <code>AutoModelForCausalLMWithValueHead</code>, you can directly use the model in <code>transformers</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->
<span class="hljs-comment"># .. Let&#x27;s assume we have a trained model using `PPOTrainer` and `AutoModelForCausalLMWithValueHead`</span>
<span class="hljs-comment"># push the model on the Hub</span>
model.push_to_hub(<span class="hljs-string">&quot;my-fine-tuned-model-ppo&quot;</span>)
<span class="hljs-comment"># or save it locally</span>
model.save_pretrained(<span class="hljs-string">&quot;my-fine-tuned-model-ppo&quot;</span>)
<span class="hljs-comment"># load the model from the Hub</span>
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;my-fine-tuned-model-ppo&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-12cg6c8">You can also load your model with <code>AutoModelForCausalLMWithValueHead</code> if you want to use the value head, for example to continue training.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl.model <span class="hljs-keyword">import</span> AutoModelForCausalLMWithValueHead
model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">&quot;my-fine-tuned-model-ppo&quot;</span>)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/quickstart.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_5yobsv = {
assets: "/docs/trl/main/en",
base: "/docs/trl/main/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/trl/main/en/_app/immutable/entry/start.183b226a.js"),
import("/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 28],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
17.2 kB
·
Xet hash:
6af87dfb98883a2809cb414e30f1ea617f3ed8f11c15a00f44a67044d937a678

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.