Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Fine-tune BERT for Text Classification on AWS Trainium","local":"fine-tune-bert-for-text-classification-on-aws-trainium","sections":[{"title":"Quick intro: AWS Trainium","local":"quick-intro-aws-trainium","sections":[],"depth":2},{"title":"1. Setup AWS environment","local":"1-setup-aws-environment","sections":[],"depth":2},{"title":"2. Load and process the dataset","local":"2-load-and-process-the-dataset","sections":[],"depth":2},{"title":"3. Fine-tune BERT using Hugging Face Transformers","local":"3-fine-tune-bert-using-hugging-face-transformers","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/optimum.neuron/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/entry/start.ae7452d0.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/chunks/scheduler.a2b4ca8e.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/chunks/singletons.afcc50b4.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/chunks/paths.c3d1ecd8.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/entry/app.f4665957.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/chunks/index.d2f673cc.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/nodes/0.35a7fce3.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/nodes/25.38c7d538.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/chunks/CodeBlock.792343a6.js"> | |
| <link rel="modulepreload" href="/docs/optimum.neuron/main/en/_app/immutable/chunks/Heading.675d4c1e.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Fine-tune BERT for Text Classification on AWS Trainium","local":"fine-tune-bert-for-text-classification-on-aws-trainium","sections":[{"title":"Quick intro: AWS Trainium","local":"quick-intro-aws-trainium","sections":[],"depth":2},{"title":"1. Setup AWS environment","local":"1-setup-aws-environment","sections":[],"depth":2},{"title":"2. Load and process the dataset","local":"2-load-and-process-the-dataset","sections":[],"depth":2},{"title":"3. Fine-tune BERT using Hugging Face Transformers","local":"3-fine-tune-bert-using-hugging-face-transformers","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="fine-tune-bert-for-text-classification-on-aws-trainium" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tune-bert-for-text-classification-on-aws-trainium"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tune BERT for Text Classification on AWS Trainium</span></h1> <p data-svelte-h="svelte-zlj320"><em>There is a notebook version of that tutorial <a href="https://github.com/huggingface/optimum-neuron/blob/main/notebooks/text-classification/notebook.ipynb" rel="nofollow">here</a></em>.</p> <p data-svelte-h="svelte-fq2j94">This tutorial will help you to get started with <a href="https://aws.amazon.com/machine-learning/trainium/?nc1=h_ls" rel="nofollow">AWS Trainium</a> and Hugging Face Transformers. It will cover how to set up a Trainium instance on AWS, load & fine-tune a transformers model for text-classification</p> <p data-svelte-h="svelte-1hahfn0">You will learn how to:</p> <ol data-svelte-h="svelte-i2dg6s"><li><a href="#1-setup-aws-environment">Setup AWS environment</a></li> <li><a href="#2-load-and-process-the-dataset">Load and process the dataset</a></li> <li><a href="#3-fine-tune-bert-using-hugging-face-transformers">Fine-tune BERT using Hugging Face Transformers and Optimum Neuron</a></li></ol> <p data-svelte-h="svelte-1dtnb6s">Before we can start, make sure you have a <a href="https://huggingface.co/join" rel="nofollow">Hugging Face Account</a> to save artifacts and experiments.</p> <h2 class="relative group"><a id="quick-intro-aws-trainium" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quick-intro-aws-trainium"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quick intro: AWS Trainium</span></h2> <p data-svelte-h="svelte-i1bq8"><a href="https://aws.amazon.com/de/ec2/instance-types/trn1/" rel="nofollow">AWS Trainium (Trn1)</a> is a purpose-built EC2 for deep learning (DL) training workloads. Trainium is the successor of <a href="https://aws.amazon.com/ec2/instance-types/inf1/?nc1=h_ls" rel="nofollow">AWS Inferentia</a> focused on high-performance training workloads claiming up to 50% cost-to-train savings over comparable GPU-based instances.</p> <p data-svelte-h="svelte-1xmmjdv">Trainium has been optimized for training natural language processing, computer vision, and recommender models used. The accelerator supports a wide range of data types, including FP32, TF32, BF16, FP16, UINT8, and configurable FP8.</p> <p data-svelte-h="svelte-f3jgz1">The biggest Trainium instance, the <code>trn1.32xlarge</code> comes with over 500GB of memory, making it easy to fine-tune ~10B parameter models on a single instance. Below you will find an overview of the available instance types. More details <a href="https://aws.amazon.com/de/ec2/instance-types/trn1/#Product_details" rel="nofollow">here</a>:</p> <table data-svelte-h="svelte-1ch8aud"><thead><tr><th>instance size</th> <th>accelerators</th> <th>accelerator memory</th> <th>vCPU</th> <th>CPU Memory</th> <th>price per hour</th></tr></thead> <tbody><tr><td>trn1.2xlarge</td> <td>1</td> <td>32</td> <td>8</td> <td>32</td> <td>$1.34</td></tr> <tr><td>trn1.32xlarge</td> <td>16</td> <td>512</td> <td>128</td> <td>512</td> <td>$21.50</td></tr> <tr><td>trn1n.32xlarge (2x bandwidth)</td> <td>16</td> <td>512</td> <td>128</td> <td>512</td> <td>$24.78</td></tr></tbody></table> <hr> <p data-svelte-h="svelte-6n93f4">Now we know what Trainium offers, let’s get started. 🚀</p> <p data-svelte-h="svelte-xlynlq"><em>Note: This tutorial was created on a trn1.2xlarge AWS EC2 Instance.</em></p> <h2 class="relative group"><a id="1-setup-aws-environment" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#1-setup-aws-environment"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>1. Setup AWS environment</span></h2> <p data-svelte-h="svelte-6ygvsn">In this example, we will use the <code>trn1.2xlarge</code> instance on AWS with 1 Accelerator, including two Neuron Cores and the <a href="https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2" rel="nofollow">Hugging Face Neuron Deep Learning AMI</a>.</p> <p data-svelte-h="svelte-1p685c9">This blog post doesn’t cover how to create the instance in detail. You can check out my previous blog about <a href="https://www.philschmid.de/setup-aws-trainium" rel="nofollow">“Setting up AWS Trainium for Hugging Face Transformers”</a>, which includes a step-by-step guide on setting up the environment.</p> <p data-svelte-h="svelte-1w4fiht">Once the instance is up and running, we can ssh into it. But instead of developing inside a terminal we want to use a <code>Jupyter</code> environment, which we can use for preparing our dataset and launching the training. For this, we need to add a port for forwarding in the <code>ssh</code> command, which will tunnel our localhost traffic to the Trainium instance.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->PUBLIC_DNS=<span class="hljs-string">""</span> <span class="hljs-comment"># IP address, e.g. ec2-3-80-....</span> | |
| KEY_PATH=<span class="hljs-string">""</span> <span class="hljs-comment"># local path to key, e.g. ssh/trn.pem</span> | |
| ssh -L 8080:localhost:8080 -i <span class="hljs-variable">${KEY_NAME}</span>.pem ubuntu@<span class="hljs-variable">$PUBLIC_DNS</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cmik5i">We can now start our <strong><code>jupyter</code></strong> server.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->python -m notebook --allow-root --port=8080<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-si97e0">You should see a familiar <strong><code>jupyter</code></strong> output with a URL to the notebook.</p> <p data-svelte-h="svelte-7s5jat"><strong><code>http://localhost:8080/?token=8c1739aff1755bd7958c4cfccc8d08cb5da5234f61f129a9</code></strong></p> <p data-svelte-h="svelte-yorl3n">We can click on it, and a <strong><code>jupyter</code></strong> environment opens in our local browser.</p> <p data-svelte-h="svelte-krn90s"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/neuron/tutorial-fine-tune-bert-jupyter.png" alt="jupyter.webp"></p> <p data-svelte-h="svelte-1op3czr">We are going to use the Jupyter environment only for preparing the dataset and then <code>torchrun</code> for launching our training script on both neuron cores for distributed training. Lets create a new notebook and get started.</p> <h2 class="relative group"><a id="2-load-and-process-the-dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#2-load-and-process-the-dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>2. Load and process the dataset</span></h2> <p data-svelte-h="svelte-nodub0">We are training a Text Classification model on the <a href="https://huggingface.co/datasets/philschmid/emotion" rel="nofollow">emotion</a> dataset to keep the example straightforward. The <code>emotion</code> is a dataset of English Twitter messages with six basic emotions: anger, fear, joy, love, sadness, and surprise.</p> <p data-svelte-h="svelte-1ab0i1h">We will use the <code>load_dataset()</code> method from the <a href="https://huggingface.co/docs/datasets/index" rel="nofollow">🤗 Datasets</a> library to load the <code>emotion</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-comment"># Dataset id from huggingface.co/dataset</span> | |
| dataset_id = <span class="hljs-string">"philschmid/emotion"</span> | |
| <span class="hljs-comment"># Load raw dataset</span> | |
| raw_dataset = load_dataset(dataset_id) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Train dataset size: <span class="hljs-subst">{<span class="hljs-built_in">len</span>(raw_dataset[<span class="hljs-string">'train'</span>])}</span>"</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Test dataset size: <span class="hljs-subst">{<span class="hljs-built_in">len</span>(raw_dataset[<span class="hljs-string">'test'</span>])}</span>"</span>) | |
| <span class="hljs-comment"># Train dataset size: 16000</span> | |
| <span class="hljs-comment"># Test dataset size: 2000</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-udg7sq">Let’s check out an example of the dataset.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> random <span class="hljs-keyword">import</span> randrange | |
| random_id = randrange(<span class="hljs-built_in">len</span>(raw_dataset[<span class="hljs-string">'train'</span>])) | |
| raw_dataset[<span class="hljs-string">'train'</span>][random_id] | |
| <span class="hljs-comment"># {'text': 'i feel isolated and alone in my trade', 'label': 0}</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1vv5g70">We must convert our “Natural Language” to token IDs to train our model. This is done by a Tokenizer, which tokenizes the inputs (including converting the tokens to their corresponding IDs in the pre-trained vocabulary). if you want to learn more about this, out <a href="https://huggingface.co/course/chapter6/1?fw=pt" rel="nofollow">chapter 6</a> of the <a href="https://huggingface.co/course/chapter1/1" rel="nofollow">Hugging Face Course</a>.</p> <p data-svelte-h="svelte-1iixhur">Our Neuron Accelerator expects a fixed shape of inputs. We need to truncate or pad all samples to the same length.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| <span class="hljs-keyword">import</span> os | |
| <span class="hljs-comment"># Model id to load the tokenizer</span> | |
| model_id = <span class="hljs-string">"bert-base-uncased"</span> | |
| save_dataset_path = <span class="hljs-string">"lm_dataset"</span> | |
| <span class="hljs-comment"># Load Tokenizer</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| <span class="hljs-comment"># Tokenize helper function</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize</span>(<span class="hljs-params">batch</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(batch[<span class="hljs-string">'text'</span>], padding=<span class="hljs-string">'max_length'</span>, truncation=<span class="hljs-literal">True</span>,return_tensors=<span class="hljs-string">"pt"</span>) | |
| <span class="hljs-comment"># Tokenize dataset</span> | |
| raw_dataset = raw_dataset.rename_column(<span class="hljs-string">"label"</span>, <span class="hljs-string">"labels"</span>) <span class="hljs-comment"># to match Trainer</span> | |
| tokenized_dataset = raw_dataset.<span class="hljs-built_in">map</span>(tokenize, batched=<span class="hljs-literal">True</span>, remove_columns=[<span class="hljs-string">"text"</span>]) | |
| tokenized_dataset = tokenized_dataset.with_format(<span class="hljs-string">"torch"</span>) | |
| <span class="hljs-comment"># save dataset to disk</span> | |
| tokenized_dataset[<span class="hljs-string">"train"</span>].save_to_disk(os.path.join(save_dataset_path,<span class="hljs-string">"train"</span>)) | |
| tokenized_dataset[<span class="hljs-string">"test"</span>].save_to_disk(os.path.join(save_dataset_path,<span class="hljs-string">"eval"</span>))<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="3-fine-tune-bert-using-hugging-face-transformers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#3-fine-tune-bert-using-hugging-face-transformers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>3. Fine-tune BERT using Hugging Face Transformers</span></h2> <p data-svelte-h="svelte-spmis1">Normally you would use the <a href="https://huggingface.co/docs/transformers/v4.19.4/en/main_classes/trainer#transformers.Trainer" rel="nofollow">Trainer</a> and <a href="https://huggingface.co/docs/transformers/v4.19.4/en/main_classes/trainer#transformers.TrainingArguments" rel="nofollow">TrainingArguments</a> to fine-tune PyTorch-based transformer models.</p> <p data-svelte-h="svelte-1ctxk4j">But together with AWS, we have developed a <a href="https://huggingface.co/docs/optimum-neuron/package_reference/trainer" rel="nofollow">NeuronTrainer</a> to improve performance, robustness, and safety when training on Trainium or Inferentia2 instances. The <code>NeuronTrainer</code> also comes with a <a href="https://www.notion.so/Getting-started-with-AWS-Trainium-and-Hugging-Face-Transformers-8428c72556194aed9c393de101229dcf" rel="nofollow">model cache</a>, which allows us to use precompiled models and configuration from Hugging Face Hub to skip the compilation step, which would be needed at the beginning of training. This can reduce the training time by ~3x.</p> <p data-svelte-h="svelte-c0yi33">The <code>NeuronTrainer</code> is part of the <code>optimum-neuron</code> library and can be used as a 1-to-1 replacement for the <code>Trainer</code>. You only have to adjust the import in your training script.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-deletion">- from transformers import Trainer, TrainingArguments</span> | |
| <span class="hljs-addition">+ from optimum.neuron import NeuronTrainer as Trainer</span> | |
| <span class="hljs-addition">+ from optimum.neuron import NeuronTrainingArguments as TrainingArguments</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-5i5vsb">We prepared a simple <a href="https://github.com/huggingface/optimum-neuron/blob/main/notebooks/text-classification/scripts/train.py" rel="nofollow">train.py</a> training script based on the <a href="https://www.philschmid.de/getting-started-pytorch-2-0-transformers#3-fine-tune--evaluate-bert-model-with-the-hugging-face-trainer" rel="nofollow">“Getting started with Pytorch 2.0 and Hugging Face Transformers”</a> blog post with the <code>NeuronTrainer</code>. Below is an excerpt</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| <span class="hljs-keyword">from</span> optimum.neuron <span class="hljs-keyword">import</span> NeuronTrainer <span class="hljs-keyword">as</span> Trainer | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">parse_args</span>(): | |
| ... | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">training_function</span>(<span class="hljs-params">args</span>): | |
| <span class="hljs-comment"># load dataset from disk and tokenizer</span> | |
| train_dataset = load_from_disk(os.path.join(args.dataset_path, <span class="hljs-string">"train"</span>)) | |
| ... | |
| <span class="hljs-comment"># Download the model from huggingface.co/models</span> | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| args.model_id, num_labels=num_labels, label2id=label2id, id2label=id2label | |
| ) | |
| training_args = TrainingArguments( | |
| ... | |
| ) | |
| <span class="hljs-comment"># Create Trainer instance</span> | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| compute_metrics=compute_metrics, | |
| ) | |
| <span class="hljs-comment"># Start training</span> | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-z4rrh7">We can load the training script into our environment using the <code>wget</code> command or manually copy it into the notebook from <a href="https://github.com/huggingface/optimum-neuron/blob/notebooks/text-classification/scripts/train.py" rel="nofollow">here</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!wget https://raw.githubusercontent.com/huggingface/optimum-neuron/main/notebooks/text-classification/scripts/train.py<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1tr68b8">We will use <code>torchrun</code> to launch our training script on both neuron cores for distributed training. <code>torchrun</code> is a tool that automatically distributes a PyTorch model across multiple accelerators. We can pass the number of accelerators as <code>nproc_per_node</code> arguments alongside our hyperparameters.</p> <p data-svelte-h="svelte-folbj9">We’ll use the following command to launch training:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!torchrun --nproc_per_node=2 train.py \ | |
| --model_id bert-base-uncased \ | |
| --dataset_path lm_dataset \ | |
| --lr 5e-5 \ | |
| --per_device_train_batch_size 16 \ | |
| --bf16 True \ | |
| --epochs 3<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-161fdea"><em><strong>Note</strong>: If you see bad, bad accuracy, you might want to deactivate <code>bf16</code> for now.</em></p> <p data-svelte-h="svelte-dotc0h">After 9 minutes the training was completed and achieved an excellent f1 score of <code>0.914</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->***** train metrics ***** | |
| epoch = 3.0 | |
| train_runtime = 0:08:30 | |
| train_samples = 16000 | |
| train_samples_per_second = 96.337 | |
| ***** <span class="hljs-built_in">eval</span> metrics ***** | |
| eval_f1 = 0.914 | |
| eval_runtime = 0:00:08<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1iox2bz">Last but not least, terminate the EC2 instance to avoid unnecessary charges. Looking at the price-performance, our training only cost <strong><code>20ct</code></strong> (<strong><code>1.34$/h * 0.15h = 0.20$</code></strong>)</p> <p></p> | |
| <script> | |
| { | |
| __sveltekit_w1gmpk = { | |
| assets: "/docs/optimum.neuron/main/en", | |
| base: "/docs/optimum.neuron/main/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/optimum.neuron/main/en/_app/immutable/entry/start.ae7452d0.js"), | |
| import("/docs/optimum.neuron/main/en/_app/immutable/entry/app.f4665957.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 25], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 37.4 kB
- Xet hash:
- f95dba81eeabe0fa05da30304832645c17ba166e643aff38cde893b6dc93b053
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.