Buckets:

rtrm's picture
download
raw
68.1 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Fine-tune and Test Llama-3 8B on AWS Trainium&quot;,&quot;local&quot;:&quot;fine-tune-and-test-llama-3-8b-on-aws-trainium&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;1. Setup AWS Environment&quot;,&quot;local&quot;:&quot;1-setup-aws-environment&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;2. Load and prepare the dataset&quot;,&quot;local&quot;:&quot;2-load-and-prepare-the-dataset&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;3. Fine-tune Llama on AWS Trainium using the NeuronTrainer&quot;,&quot;local&quot;:&quot;3-fine-tune-llama-on-aws-trainium-using-the-neurontrainer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;4. Launch Training&quot;,&quot;local&quot;:&quot;4-launch-training&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Precompilation&quot;,&quot;local&quot;:&quot;precompilation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Actual Training&quot;,&quot;local&quot;:&quot;actual-training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Consolidate the Checkpoint&quot;,&quot;local&quot;:&quot;consolidate-the-checkpoint&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;5. Evaluate and test fine-tuned Llama model&quot;,&quot;local&quot;:&quot;5-evaluate-and-test-fine-tuned-llama-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/entry/start.13c1f5a3.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/chunks/scheduler.9039eef2.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/chunks/singletons.a8905f53.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/chunks/paths.f5262881.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/entry/app.a71d5dce.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/chunks/index.cdcc3d35.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/nodes/0.68fa8611.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/nodes/27.d60f0b03.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/chunks/Tip.6f74db41.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/chunks/CodeBlock.e3ac94d9.js">
<link rel="modulepreload" href="/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/chunks/Heading.96ce3702.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Fine-tune and Test Llama-3 8B on AWS Trainium&quot;,&quot;local&quot;:&quot;fine-tune-and-test-llama-3-8b-on-aws-trainium&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;1. Setup AWS Environment&quot;,&quot;local&quot;:&quot;1-setup-aws-environment&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;2. Load and prepare the dataset&quot;,&quot;local&quot;:&quot;2-load-and-prepare-the-dataset&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;3. Fine-tune Llama on AWS Trainium using the NeuronTrainer&quot;,&quot;local&quot;:&quot;3-fine-tune-llama-on-aws-trainium-using-the-neurontrainer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;4. Launch Training&quot;,&quot;local&quot;:&quot;4-launch-training&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Precompilation&quot;,&quot;local&quot;:&quot;precompilation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Actual Training&quot;,&quot;local&quot;:&quot;actual-training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Consolidate the Checkpoint&quot;,&quot;local&quot;:&quot;consolidate-the-checkpoint&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;5. Evaluate and test fine-tuned Llama model&quot;,&quot;local&quot;:&quot;5-evaluate-and-test-fine-tuned-llama-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="fine-tune-and-test-llama-3-8b-on-aws-trainium" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tune-and-test-llama-3-8b-on-aws-trainium"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tune and Test Llama-3 8B on AWS Trainium</span></h1> <p data-svelte-h="svelte-1dtzs79"><em>Note: The complete script for this tutorial can be downloaded <a href="https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/finetune_llm.py" rel="nofollow">here</a>.</em></p> <p data-svelte-h="svelte-1w2bqvx">This tutorial will teach you how to fine-tune open source LLMs like <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B" rel="nofollow">Llama 3</a> on AWS Trainium. In our example, we are going to leverage the <a href="https://huggingface.co/docs/optimum-neuron/index" rel="nofollow">Optimum Neuron</a>, <a href="https://huggingface.co/docs/transformers/index" rel="nofollow">Transformers</a> and <a href="https://huggingface.co/docs/datasets/index" rel="nofollow">Datasets</a> libraries.</p> <p data-svelte-h="svelte-1hahfn0">You will learn how to:</p> <ol data-svelte-h="svelte-1yf8yg7"><li><a href="#1-setup-aws-environment">Setup AWS Environment</a></li> <li><a href="#2-load-and-prepare-the-dataset">Load and process the dataset</a></li> <li><a href="#3-fine-tune-llama-on-aws-trainium-using-the-neurontrainer">Fine-tune Llama on AWS Trainium using the <code>NeuronTrainer</code></a></li> <li><a href="#4-launch-training">Launch Training</a></li> <li><a href="#5-evaluate-and-test-fine-tuned-llama-model">Evaluate and test fine-tuned Llama model</a></li></ol> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1f24nyx">While we will use <code>Llama-3 8B</code> in this tutorial, it is completely possible to use other models, simply by swtiching the <code>model_id</code>.
For instance, it is possible to fine-tune:</p> <ul data-svelte-h="svelte-1fsabqd"><li>Mistral models, such as <a href="https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2" rel="nofollow">Mistral 7b (<code>mistralai/Mistral-7B-Instruct-v0.3</code>)</a></li> <li>Llama-2 models, such as <a href="https://huggingface.co/meta-llama/Llama-2-7b-hf" rel="nofollow">Llama-2 7b (<code>meta-llama/Llama-2-7b-hf</code>)</a></li></ul> <p data-svelte-h="svelte-yii95m">And many others!</p></div> <h2 class="relative group"><a id="1-setup-aws-environment" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#1-setup-aws-environment"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>1. Setup AWS Environment</span></h2> <p data-svelte-h="svelte-w52dpm">Before starting this tutorial, you will need to setup your environment:</p> <ol data-svelte-h="svelte-15ega2l"><li>Create an AWS Trainium instance. <strong>You will need a <code>trn1.32xlarge</code>, which contains 16 Neuron Devices.</strong> You can follow this <a href="https://huggingface.co/docs/optimum-neuron/guides/setup_aws_instance" rel="nofollow">guide</a> to create one.</li> <li>Make sure you are logged in on the Hugging Face Hub:</li></ol> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->huggingface-cli login --token YOUR_TOKEN<!-- HTML_TAG_END --></pre></div> <ol start="3" data-svelte-h="svelte-ksb30n"><li>Check that you have access to the model. Some open source models are gated, meaning that users need to apply to the model owner to be able to use the model weights. Here we will be training Llama-3 8B, for which there are two possibilities:</li></ol> <ul data-svelte-h="svelte-1qiwkz0"><li>The official gated repo: <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B" rel="nofollow"><code>meta-llama/Meta-Llama-3-8B</code></a></li> <li>The non-official un-gated repo: <a href="https://huggingface.co/NousResearch/Meta-Llama-3-8B" rel="nofollow"><code>NousResearch/Meta-Llama-3-8B</code></a></li></ul> <ol start="4" data-svelte-h="svelte-ch5yfe"><li>Clone the Optimum Neuron repository, <strong>which contains the <a href="https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/finetune_llm.py" rel="nofollow">complete script</a> described in this tutorial:</strong></li></ol> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->git <span class="hljs-built_in">clone</span> https://github.com/huggingface/optimum-neuron.git<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="2-load-and-prepare-the-dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#2-load-and-prepare-the-dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>2. Load and prepare the dataset</span></h2> <p data-svelte-h="svelte-s6kf7j">For this tutorial, we will use <a href="https://huggingface.co/datasets/databricks/databricks-dolly-15k" rel="nofollow">Dolly</a>, an open source dataset of instruction-following records on categories outlined in the <a href="https://arxiv.org/abs/2203.02155" rel="nofollow">InstructGPT paper</a>, including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.</p> <p data-svelte-h="svelte-11lpom8">Example:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{
<span class="hljs-string">&quot;instruction&quot;</span>: <span class="hljs-string">&quot;What is world of warcraft&quot;</span>,
<span class="hljs-string">&quot;context&quot;</span>: <span class="hljs-string">&quot;&quot;</span>,
<span class="hljs-string">&quot;response&quot;</span>: (
<span class="hljs-string">&quot;World of warcraft is a massive online multi player role playing game. &quot;</span>
<span class="hljs-string">&quot;It was released in 2004 by blizarre entertainment&quot;</span>
)
}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1snwaj0">We can use the <code>load_dataset()</code> method from the 🤗 Datasets library to load the <code>dolly</code> dataset very easily.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-keyword">from</span> random <span class="hljs-keyword">import</span> randrange
<span class="hljs-comment"># Load dataset from the hub</span>
dataset = load_dataset(<span class="hljs-string">&quot;databricks/databricks-dolly-15k&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;dataset size: <span class="hljs-subst">{<span class="hljs-built_in">len</span>(dataset)}</span>&quot;</span>)
<span class="hljs-built_in">print</span>(dataset[randrange(<span class="hljs-built_in">len</span>(dataset))])
<span class="hljs-comment"># dataset size: 15011</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-voh434">To instruct tune our model we need to convert our structured examples into a collection of tasks described via instructions. We define a <code>format_dolly</code> that takes a raw sample and returns a string with our format instruction.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">format_dolly</span>(<span class="hljs-params">sample</span>):
instruction = <span class="hljs-string">f&quot;### Instruction\n<span class="hljs-subst">{sample[<span class="hljs-string">&#x27;instruction&#x27;</span>]}</span>&quot;</span>
context = <span class="hljs-string">f&quot;### Context\n<span class="hljs-subst">{sample[<span class="hljs-string">&#x27;context&#x27;</span>]}</span>&quot;</span> <span class="hljs-keyword">if</span> <span class="hljs-built_in">len</span>(sample[<span class="hljs-string">&quot;context&quot;</span>]) &gt; <span class="hljs-number">0</span> <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span>
response = <span class="hljs-string">f&quot;### Answer\n<span class="hljs-subst">{sample[<span class="hljs-string">&#x27;response&#x27;</span>]}</span>&quot;</span>
<span class="hljs-comment"># join all the parts together</span>
prompt = <span class="hljs-string">&quot;\n\n&quot;</span>.join([i <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> [instruction, context, response] <span class="hljs-keyword">if</span> i <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>])
<span class="hljs-keyword">return</span> prompt<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1qw7fs2">In addition to formatting our samples, we also want to pack multiple samples to one sequence to have a more efficient training. In other words, we are stacking multiple samples to one sequence and split them with an EOS Token. Packing/stacking samples can be done during training or before.</p> <p data-svelte-h="svelte-1tlab3t">The following function <code>pack_dataset</code> takes a <code>dataset</code> and a <code>chunk_length</code> and returns a packed dataset:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> functools <span class="hljs-keyword">import</span> partial
<span class="hljs-keyword">from</span> itertools <span class="hljs-keyword">import</span> chain
<span class="hljs-comment"># empty list to save remainder from batches to use in next batch</span>
remainder = {<span class="hljs-string">&quot;input_ids&quot;</span>: [], <span class="hljs-string">&quot;attention_mask&quot;</span>: [], <span class="hljs-string">&quot;token_type_ids&quot;</span>: []}
<span class="hljs-keyword">def</span> <span class="hljs-title function_">pack_dataset</span>(<span class="hljs-params">dataset, chunk_length=<span class="hljs-number">2048</span></span>):
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Chunking dataset into chunks of <span class="hljs-subst">{chunk_length}</span> tokens.&quot;</span>)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">chunk</span>(<span class="hljs-params">sample, chunk_length=chunk_length</span>):
<span class="hljs-comment"># define global remainder variable to save remainder from batches to use in next batch</span>
<span class="hljs-keyword">global</span> remainder
<span class="hljs-comment"># Concatenate all texts and add remainder from previous batch</span>
concatenated_examples = {k: <span class="hljs-built_in">list</span>(chain(*sample[k])) <span class="hljs-keyword">for</span> k <span class="hljs-keyword">in</span> sample.keys()}
concatenated_examples = {k: remainder[k] + concatenated_examples[k] <span class="hljs-keyword">for</span> k <span class="hljs-keyword">in</span> concatenated_examples.keys()}
<span class="hljs-comment"># get total number of tokens for batch</span>
batch_total_length = <span class="hljs-built_in">len</span>(concatenated_examples[<span class="hljs-built_in">list</span>(sample.keys())[<span class="hljs-number">0</span>]])
<span class="hljs-comment"># get max number of chunks for batch</span>
<span class="hljs-keyword">if</span> batch_total_length &gt;= chunk_length:
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
<span class="hljs-comment"># Split by chunks of max_len.</span>
result = {
k: [t[i : i + chunk_length] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">0</span>, batch_chunk_length, chunk_length)]
<span class="hljs-keyword">for</span> k, t <span class="hljs-keyword">in</span> concatenated_examples.items()
}
<span class="hljs-comment"># add remainder to global variable for next batch</span>
remainder = {k: concatenated_examples[k][batch_chunk_length:] <span class="hljs-keyword">for</span> k <span class="hljs-keyword">in</span> concatenated_examples.keys()}
<span class="hljs-comment"># prepare labels</span>
result[<span class="hljs-string">&quot;labels&quot;</span>] = result[<span class="hljs-string">&quot;input_ids&quot;</span>].copy()
<span class="hljs-keyword">return</span> result
<span class="hljs-comment"># tokenize and chunk dataset</span>
lm_dataset = dataset.<span class="hljs-built_in">map</span>(
partial(chunk, chunk_length=chunk_length),
batched=<span class="hljs-literal">True</span>,
)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Total number of samples: <span class="hljs-subst">{<span class="hljs-built_in">len</span>(lm_dataset)}</span>&quot;</span>)
<span class="hljs-keyword">return</span> lm_dataset<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-vnokz2">To summarize to prepare our dataset we will:</p> <ol data-svelte-h="svelte-v6jbub"><li>Format our samples using the template method and add an EOS token at the end of each sample</li> <li>Tokenize our dataset to convert it from text to tokens</li> <li>Pack our dataset to 2048 tokens</li></ol> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer
<span class="hljs-keyword">from</span> random <span class="hljs-keyword">import</span> randint
<span class="hljs-comment"># Hugging Face Hub model id </span>
<span class="hljs-comment"># model_id = &quot;meta-llama/Meta-Llama-3-8B&quot; # gated</span>
model_id = <span class="hljs-string">&quot;NousResearch/Meta-Llama-3-8B&quot;</span> <span class="hljs-comment"># ungated</span>
tokenizer = AutoTokenizer.from_pretrained(model_id)
<span class="hljs-comment"># template dataset to add prompt to each sample</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">template_dataset</span>(<span class="hljs-params">sample</span>):
sample[<span class="hljs-string">&quot;text&quot;</span>] = <span class="hljs-string">f&quot;<span class="hljs-subst">{format_dolly(sample)}</span><span class="hljs-subst">{tokenizer.eos_token}</span>&quot;</span>
<span class="hljs-keyword">return</span> sample
<span class="hljs-comment"># apply prompt template per sample</span>
dataset = dataset.<span class="hljs-built_in">map</span>(template_dataset, remove_columns=<span class="hljs-built_in">list</span>(dataset.features))
<span class="hljs-comment"># print random sample</span>
<span class="hljs-built_in">print</span>(dataset[randint(<span class="hljs-number">0</span>, <span class="hljs-built_in">len</span>(dataset))][<span class="hljs-string">&quot;text&quot;</span>])
<span class="hljs-comment"># tokenize dataset</span>
dataset = dataset.<span class="hljs-built_in">map</span>(
<span class="hljs-keyword">lambda</span> sample: tokenizer(sample[<span class="hljs-string">&quot;text&quot;</span>]), batched=<span class="hljs-literal">True</span>, remove_columns=<span class="hljs-built_in">list</span>(dataset.features)
)
<span class="hljs-comment"># chunk dataset</span>
lm_dataset = pack_dataset(dataset, chunk_length=<span class="hljs-number">2048</span>) <span class="hljs-comment"># We use 2048 as the maximum length for packing</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="3-fine-tune-llama-on-aws-trainium-using-the-neurontrainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#3-fine-tune-llama-on-aws-trainium-using-the-neurontrainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>3. Fine-tune Llama on AWS Trainium using the NeuronTrainer</span></h2> <p data-svelte-h="svelte-14g3or9">Normally you would use the <strong><a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer" rel="nofollow">Trainer</a></strong> and <strong><a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments" rel="nofollow">TrainingArguments</a></strong> classes to fine-tune PyTorch-based transformer models.</p> <p data-svelte-h="svelte-b4fdk9">But together with AWS, we have developed the [~<code>optimum.neuron.NeuronTrainer</code>] to improve performance, robustness, and ease-of-use when training on Trainium instances. It can be used as a 1-to-1 replacement for the <code>Trainer</code>.</p> <p data-svelte-h="svelte-lw48ct">Since Llama-3 8B is a big model it will not fit on a single Neuron core, we need distributed training. In Optimum Neuron we support:</p> <ol data-svelte-h="svelte-xm927i"><li><a href="https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/zero1_gpt2.html" rel="nofollow">ZeRO-1</a>: It is an optimization of data-parallelism which consists in sharding the optimizer state (which usually represents half or more of the memory needed on the device) over the data-parallel ranks.</li> <li><a href="https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html" rel="nofollow">Tensor Parallelism</a>: It is a technique which consists in sharding each of your model matrix-multiplications along a given axis (row or column) on multiple devices. It also known as intra-layer model parallelism. The number of devices to shard your parameters on is called the <code>tensor_parallel_size</code>.</li> <li><a href="https://arxiv.org/pdf/2205.05198.pdf" rel="nofollow">Sequence parallelism</a>: It is an optimization over Tensor Parallelism which shards the activations on the sequence axis outside of the tensor parallel regions. It is useful because it saves memory by sharding the activations.</li> <li><a href="https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html" rel="nofollow">Pipeline Parallelism</a>: It consists in sharding the model block layers on multiple devices. It is also known as inter-layer model parallelism. The number of devices to shard your layers on is called the <code>pipeline_parallel_size</code>.</li></ol> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-yp02re">If you want to know more about distributed training you can take a look at the <a href="https://huggingface.co/docs/optimum-neuron/guides/distributed_training" rel="nofollow">documentation</a>.</p></div> <p data-svelte-h="svelte-1lnowtd">Here, since we want to fine-tune an 8B model, we will not need to use pipeline parallelism.
Our training code will look as follows:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> optimum.neuron <span class="hljs-keyword">import</span> NeuronTrainer <span class="hljs-keyword">as</span> Trainer
<span class="hljs-keyword">from</span> optimum.neuron.distributed <span class="hljs-keyword">import</span> lazy_load_for_parallelism
<span class="hljs-comment"># Define the tensor_parallel_size</span>
tensor_parallel_size = <span class="hljs-number">8</span>
<span class="hljs-comment"># Load model from the Hugging face Hub </span>
<span class="hljs-keyword">with</span> lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size):
model = AutoModelForCausalLM.from_pretrained(model_id)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset,
data_collator=default_data_collator, <span class="hljs-comment"># no special collator needed since we stacked the dataset</span>
)
<span class="hljs-comment"># Start training</span>
trainer.train()
trainer.save_model() <span class="hljs-comment"># saves the tokenizer too for easy upload</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-t16ze3">The key points here are:</p> <ul data-svelte-h="svelte-1ou6zrr"><li>We use the <code>lazy_load_for_parallelism</code> context manager to lazily load the model. This will not load the full model weights on each worker, but instead only load the required weights (sharded or full). <strong>This is much more memory efficient, and often mandatory to use.</strong></li> <li>We use the [~<code>optimum.neuron.NeuronTrainer</code>] to perform training. It will take the lazily loaded model, along with the <code>training_args</code>, which are an instance of [~<code>optimum.neuron.NeuronTrainingArguments</code>], and will handle all the parallelization and training on the Neuron cores.</li></ul> <h2 class="relative group"><a id="4-launch-training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#4-launch-training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>4. Launch Training</span></h2> <p data-svelte-h="svelte-618yad">We prepared a script called <a href="https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/finetune_llm.py" rel="nofollow">finetune_llm.py</a> summing up everything mentioned in this tutorial.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1hrw5ii">This script is a minimalistic version of our official example training script to run causal language modeling fine-tuning, called <a href="https://github.com/huggingface/optimum-neuron/blob/main/examples/language-modeling/run_clm.py" rel="nofollow">run_clm.py</a>. For the sake of this tutorial, we tried to get rid of anything that is not necessary, and added the formatting step necessary for fine-tuning, but if you want to do more custom things, maybe the solution is already implemented in <code>run_clm.py</code>!</p> <p data-svelte-h="svelte-1k8wfq8">Also, these scripts are more designed as templates than final scripts. Feel free to take <code>finetune_llm.py</code> or <code>run_clm.py</code> and adapt them to your own needs!</p></div> <p data-svelte-h="svelte-bne5rp">PyTorch Neuron uses <code>torch_xla</code>. It evaluates operations lazily during execution of the training loops, which means it builds a symbolic graph in the background and the graph is executed on the hardware only when the tensor is printed, transfered to CPU, or <code>xm.mark_step()</code> is called. During execution, multiple graphs can be build depending on control-flow and it can take time to compile each graph sequentially. To alleviate that, the Neuron SDK provides <code>neuron_parallel_compile</code>, a tool which performs a fast trial run that builds all the graphs and compile them in parallel. This step is usually called precompilation.</p> <h3 class="relative group"><a id="precompilation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#precompilation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Precompilation</span></h3> <p data-svelte-h="svelte-1u0hmu2">When training models on AWS Trainium we first need to compile our model with our training arguments.</p> <p data-svelte-h="svelte-1gsk3gh">To overcome this, we added a <a href="https://huggingface.co/docs/optimum-neuron/guides/cache_system" rel="nofollow">model cache repository</a>, which allows us to use precompiled models from the Hugging Face Hub to skip the compilation step. But be careful: every change in the model configuration might lead to a new compilation, which could result in some cache misses.</p> <p data-svelte-h="svelte-jypm7b"><em>Note: If your model configuration is not cached please open an issue on <a href="https://github.com/huggingface/optimum-neuron/issues" rel="nofollow">Github</a>, we are happy to include it.</em></p> <p data-svelte-h="svelte-t9z3bd">The compilation command simply consists in calling your script as an input to the <code>neuron_parallel_compile</code> utility:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->MALLOC_ARENA_MAX=64 XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node=32 finetune_llm.py \
--model_id meta-llama/Meta-Llama-3-8B \
--bf16 True \
--learning_rate 5e-5 \
--output_dir dolly_llama \
--overwrite_output_dir True \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--gradient_checkpointing True \
--tensor_parallel_size 8 \
--max_steps 10 \
--logging_steps 10<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1k7qt99">Make sure to run this precompilation phase for around 10 training steps. It is usually enough to accumulate and compile all the graphs that will be needed during the actual training.</p></div> <p data-svelte-h="svelte-11pd151"><em>Note: Compiling without a cache can take a while. It will also create dummy files in the <code>dolly_llama</code> directory during compilation you will have to remove them afterwards. We also need to add <code>MALLOC_ARENA_MAX=64</code> to limit the CPU allocation to avoid potential crashes, don’t remove it for now.</em></p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># remove dummy artifacts which are created by the precompilation command</span>
<span class="hljs-built_in">rm</span> -rf dolly_llama<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="actual-training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#actual-training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Actual Training</span></h3> <p data-svelte-h="svelte-1yiqpri">After compilation is done we can start our actual training with a similar command, we just need to remove the use of <code>neuron_parallel_compile</code>.</p> <p data-svelte-h="svelte-1anw499">We will use <code>torchrun</code> to launch our training script. <code>torchrun</code> is a tool that automatically distributes a PyTorch model across multiple accelerators. We can pass the number of accelerators as <code>nproc_per_node</code> arguments alongside our hyperparameters.</p> <p data-svelte-h="svelte-ctfnvw">The difference to the compilation command is that we changed from <code>max_steps=10</code> to <code>num_train_epochs=3</code>.</p> <p data-svelte-h="svelte-17lv8z9">Launch the training, with the following command.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->MALLOC_ARENA_MAX=64 XLA_USE_BF16=1 torchrun --nproc_per_node=32 finetune_llm.py \
--model_id meta-llama/Meta-Llama-3-8B \
--bf16 True \
--learning_rate 5e-5 \
--output_dir dolly_llama \
--overwrite_output_dir True \
--skip_cache_push True \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--gradient_checkpointing True \
--tensor_parallel_size 8 \
--num_train_epochs 3 \
--logging_steps 10<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1v2czzn">That’s it, we successfully trained Llama-3 8B on AWS Trainium!</p> <p data-svelte-h="svelte-12iaeqv">But before we can share and test our model we need to consolidate our model. Since we used Tensor Parallelism during training, we saved sharded versions of the checkpoints. We need to consolidate them now.</p> <h3 class="relative group"><a id="consolidate-the-checkpoint" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#consolidate-the-checkpoint"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Consolidate the Checkpoint</span></h3> <p data-svelte-h="svelte-66io75">The Optimum CLI provides a way of doing that very easily via the <code>optimum neuron consolidate [sharded_checkpoint] [output_dir]</code> command:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->optimum-cli neuron consolidate dolly_llama dolly_llama<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="5-evaluate-and-test-fine-tuned-llama-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#5-evaluate-and-test-fine-tuned-llama-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>5. Evaluate and test fine-tuned Llama model</span></h2> <p data-svelte-h="svelte-emhotm">As for training, to be able to run inference on AWS Trainium or AWS Inferentia2 we need to compile our model. In this case, we will use our Trainium instance for the inference test, but we recommend customer to switch to Inferentia2 for inference.</p> <p data-svelte-h="svelte-1pgrxrm">Optimum Neuron implements similar to Transformers AutoModel classes for easy inference use. We will use the <code>NeuronModelForCausalLM</code> class to load our vanilla transformers checkpoint and convert it to neuron.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> optimum.neuron <span class="hljs-keyword">import</span> NeuronModelForCausalLM
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer
compiler_args = {<span class="hljs-string">&quot;num_cores&quot;</span>: <span class="hljs-number">2</span>, <span class="hljs-string">&quot;auto_cast_type&quot;</span>: <span class="hljs-string">&#x27;fp16&#x27;</span>}
input_shapes = {<span class="hljs-string">&quot;batch_size&quot;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&quot;sequence_length&quot;</span>: <span class="hljs-number">2048</span>}
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;dolly_llama&quot;</span>)
model = NeuronModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;dolly_llama&quot;</span>,
export=<span class="hljs-literal">True</span>,
**compiler_args,
**input_shapes)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1sinvz7"><em>Note: Inference compilation can take ~25minutes. Luckily, you need to only run this onces. Since you can save the model afterwards. If you are going to run on Inferentia2 you need to recompile again. The compilation is parameter and hardware specific.</em></p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># COMMENT IN if you want to save the compiled model</span>
<span class="hljs-comment"># model.save_pretrained(&quot;compiled_dolly_llama&quot;)</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jx2yqv">We can now test inference, but have to make sure we format our input to our prompt format we used for fine-tuning. Therefore we created a helper method, which accepts a <code>dict</code> with our <code>instruction</code> and optionally a <code>context</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">format_dolly_inference</span>(<span class="hljs-params">sample</span>):
instruction = <span class="hljs-string">f&quot;### Instruction\n<span class="hljs-subst">{sample[<span class="hljs-string">&#x27;instruction&#x27;</span>]}</span>&quot;</span>
context = <span class="hljs-string">f&quot;### Context\n<span class="hljs-subst">{sample[<span class="hljs-string">&#x27;context&#x27;</span>]}</span>&quot;</span> <span class="hljs-keyword">if</span> <span class="hljs-string">&quot;context&quot;</span> <span class="hljs-keyword">in</span> sample <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span>
response = <span class="hljs-string">f&quot;### Answer\n&quot;</span>
prompt = <span class="hljs-string">&quot;\n\n&quot;</span>.join([i <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> [instruction, context, response] <span class="hljs-keyword">if</span> i <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>])
<span class="hljs-keyword">return</span> prompt
<span class="hljs-keyword">def</span> <span class="hljs-title function_">generate</span>(<span class="hljs-params">sample</span>):
prompt = format_dolly_inference(sample)
inputs = tokenizer(prompt, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>)
outputs = model.generate(
**inputs,
max_new_tokens=<span class="hljs-number">512</span>,
do_sample=<span class="hljs-literal">True</span>,
temperature=<span class="hljs-number">0.9</span>,
top_k=<span class="hljs-number">50</span>,
top_p=<span class="hljs-number">0.9</span>
)
<span class="hljs-keyword">return</span> tokenizer.decode(outputs[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">False</span>)[<span class="hljs-built_in">len</span>(prompt):]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ytr5g6">Let’s test inference. First we test without a context.</p> <p data-svelte-h="svelte-1j8w2o0"><em>Note: Inference is not expected to be super fast on AWS Trainium using 2 cores. For Inference we recommend using Inferentia2.</em></p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->prompt = {
<span class="hljs-string">&quot;instruction&quot;</span>: <span class="hljs-string">&quot;Can you tell me something about AWS?&quot;</span>
}
res = generate(prompt)
<span class="hljs-built_in">print</span>(res)<!-- HTML_TAG_END --></pre></div> <blockquote data-svelte-h="svelte-6l4k0q"><p>AWS stands for Amazon Web Services. AWS is a suite of remote computing services offered by Amazon. The most widely used of these include Amazon Elastic Compute Cloud (Amazon EC2), which provides resizable compute capacity in the cloud; Amazon Simple Storage Service (Amazon S3), which is an object storage service; and Amazon Elastic Block Store (Amazon EBS), which is designed to provide high performance, durable block storage volumes for use with AWS instances. AWS also provides other services, such as AWS Identity and Access Management (IAM), a service that enables organizations to control access to their AWS resources, and AWS Key Management Service (AWS KMS), which helps customers create and control the use of encryption keys.</p></blockquote> <p data-svelte-h="svelte-ovzcvi">That looks correct. Now, lets add some context, e.g. as you would do for RAG applications:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->prompt = {
<span class="hljs-string">&quot;instruction&quot;</span>: <span class="hljs-string">&quot;How can I train models on AWS Trainium?&quot;</span>,
<span class="hljs-string">&quot;context&quot;</span>: <span class="hljs-string">&quot;🤗 Optimum Neuron is the interface between the 🤗 Transformers library and AWS Accelerators including [AWS Trainium](https://aws.amazon.com/machine-learning/trainium/?nc1=h_ls) and [AWS Inferentia](https://aws.amazon.com/machine-learning/inferentia/?nc1=h_ls). It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks.&quot;</span>
}
res = generate(prompt)
<span class="hljs-built_in">print</span>(res)<!-- HTML_TAG_END --></pre></div> <blockquote data-svelte-h="svelte-15xrtpx"><p>You can use the Optimum Neuron interface to train models on AWS Trainium.</p></blockquote> <p data-svelte-h="svelte-q9f4rf">Awesome, our model also correctly uses the provided context. We are done. Congrats on fine-tuning Llama on AWS Trainium.</p> <p></p>
<script>
{
__sveltekit_1shq1jv = {
assets: "/docs/optimum.neuron/v0.0.28.dev2/en",
base: "/docs/optimum.neuron/v0.0.28.dev2/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/entry/start.13c1f5a3.js"),
import("/docs/optimum.neuron/v0.0.28.dev2/en/_app/immutable/entry/app.a71d5dce.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 27],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
68.1 kB
·
Xet hash:
4d812eecb5cd9940d1e83022e19314a3aac2052955687f0013bbe7f2fe2801a3

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.