Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"IA3","local":"ia3","sections":[{"title":"Dataset","local":"dataset","sections":[],"depth":2},{"title":"Model","local":"model","sections":[{"title":"PEFT configuration and model","local":"peft-configuration-and-model","sections":[],"depth":3},{"title":"Training","local":"training","sections":[],"depth":3}],"depth":2},{"title":"Share your model","local":"share-your-model","sections":[],"depth":2},{"title":"Inference","local":"inference","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/peft/pr_3207/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/entry/start.dbacf4b8.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/scheduler.78382b47.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/singletons.2397427a.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/index.fadd215c.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/paths.6fea2ad6.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/entry/app.cb03e0d4.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/preload-helper.2f407600.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/index.6dd35eb6.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/nodes/0.18735116.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/nodes/70.08a42cf2.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.d25d6883.js"> | |
| <link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/CodeBlock.147ab5db.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"IA3","local":"ia3","sections":[{"title":"Dataset","local":"dataset","sections":[],"depth":2},{"title":"Model","local":"model","sections":[{"title":"PEFT configuration and model","local":"peft-configuration-and-model","sections":[],"depth":3},{"title":"Training","local":"training","sections":[],"depth":3}],"depth":2},{"title":"Share your model","local":"share-your-model","sections":[],"depth":2},{"title":"Inference","local":"inference","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="ia3" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ia3"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>IA3</span></h1> <p data-svelte-h="svelte-s2qd7"><a href="../conceptual_guides/ia3">IA3</a> multiplies the model’s activations (the keys and values in the self-attention and encoder-decoder attention blocks, and the intermediate activation of the position-wise feedforward network) by three learned vectors. This PEFT method introduces an even smaller number of trainable parameters than LoRA which introduces weight matrices instead of vectors. The original model’s parameters are kept frozen and only these vectors are updated. As a result, it is faster, cheaper and more efficient to finetune for a new downstream task.</p> <p data-svelte-h="svelte-xoc01f">This guide will show you how to train a sequence-to-sequence model with IA3 to <em>generate a sentiment</em> given some financial news.</p> <blockquote class="tip" data-svelte-h="svelte-teq66z"><p>Some familiarity with the general process of training a sequence-to-sequence would be really helpful and allow you to focus on how to apply IA3. If you’re new, we recommend taking a look at the <a href="https://huggingface.co/docs/transformers/tasks/translation" rel="nofollow">Translation</a> and <a href="https://huggingface.co/docs/transformers/tasks/summarization" rel="nofollow">Summarization</a> guides first from the Transformers documentation. When you’re ready, come back and see how easy it is to drop PEFT in to your training!</p></blockquote> <h2 class="relative group"><a id="dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Dataset</span></h2> <p data-svelte-h="svelte-1nk73qc">You’ll use the <a href="https://huggingface.co/datasets/zeroshot/twitter-financial-news-sentiment" rel="nofollow">zeroshot/twitter-financial-news-sentiment</a> dataset. This dataset contains financial tweets labeled with sentiment (bearish, bullish, or neutral). Take a look at the <a href="https://huggingface.co/datasets/zeroshot/twitter-financial-news-sentiment/viewer" rel="nofollow">dataset viewer</a> for a better idea of the data and sentences you’ll be working with.</p> <p data-svelte-h="svelte-hq1ruh">Load the dataset with the <a href="https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset" rel="nofollow">load_dataset</a> function. This dataset only contains a train split, so use the <code>train_test_split</code> function to create a train and validation split. Create a new <code>text_label</code> column so it is easier to understand what the <code>label</code> values <code>0</code>, <code>1</code>, and <code>2</code> mean.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| ds = load_dataset(<span class="hljs-string">"zeroshot/twitter-financial-news-sentiment"</span>) | |
| ds = ds[<span class="hljs-string">"train"</span>].train_test_split(test_size=<span class="hljs-number">0.1</span>) | |
| ds[<span class="hljs-string">"validation"</span>] = ds[<span class="hljs-string">"test"</span>] | |
| <span class="hljs-keyword">del</span> ds[<span class="hljs-string">"test"</span>] | |
| classes = ds[<span class="hljs-string">"train"</span>].features[<span class="hljs-string">"label"</span>].names | |
| ds = ds.<span class="hljs-built_in">map</span>( | |
| <span class="hljs-keyword">lambda</span> x: {<span class="hljs-string">"text_label"</span>: [classes[label] <span class="hljs-keyword">for</span> label <span class="hljs-keyword">in</span> x[<span class="hljs-string">"label"</span>]]}, | |
| batched=<span class="hljs-literal">True</span>, | |
| num_proc=<span class="hljs-number">1</span>, | |
| ) | |
| ds[<span class="hljs-string">"train"</span>][<span class="hljs-number">0</span>] | |
| {<span class="hljs-string">'text'</span>: <span class="hljs-string">'Morrisons reports first sales rise in four years'</span>, | |
| <span class="hljs-string">'label'</span>: <span class="hljs-number">1</span>, | |
| <span class="hljs-string">'text_label'</span>: <span class="hljs-string">'bullish'</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-aqyjx7">Load a tokenizer and create a preprocessing function that:</p> <ol data-svelte-h="svelte-1xhy766"><li>tokenizes the inputs, pads and truncates the sequence to the <code>max_length</code></li> <li>apply the same tokenizer to the labels but with a shorter <code>max_length</code> that corresponds to the label</li> <li>mask the padding tokens</li></ol> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| text_column = <span class="hljs-string">"text"</span> | |
| label_column = <span class="hljs-string">"text_label"</span> | |
| max_length = <span class="hljs-number">128</span> | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"bigscience/mt0-large"</span>) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_function</span>(<span class="hljs-params">examples</span>): | |
| inputs = examples[text_column] | |
| targets = examples[label_column] | |
| model_inputs = tokenizer(inputs, max_length=max_length, padding=<span class="hljs-string">"max_length"</span>, truncation=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"pt"</span>) | |
| labels = tokenizer(targets, max_length=<span class="hljs-number">3</span>, padding=<span class="hljs-string">"max_length"</span>, truncation=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"pt"</span>) | |
| labels = labels[<span class="hljs-string">"input_ids"</span>] | |
| labels[labels == tokenizer.pad_token_id] = -<span class="hljs-number">100</span> | |
| model_inputs[<span class="hljs-string">"labels"</span>] = labels | |
| <span class="hljs-keyword">return</span> model_inputs<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ww7yef">Use the <a href="https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.map" rel="nofollow">map</a> function to apply the preprocessing function to the entire dataset.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->processed_ds = ds.<span class="hljs-built_in">map</span>( | |
| preprocess_function, | |
| batched=<span class="hljs-literal">True</span>, | |
| num_proc=<span class="hljs-number">1</span>, | |
| remove_columns=ds[<span class="hljs-string">"train"</span>].column_names, | |
| load_from_cache_file=<span class="hljs-literal">False</span>, | |
| desc=<span class="hljs-string">"Running tokenizer on dataset"</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-kwwkxd">Create a training and evaluation <a href="https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader" rel="nofollow"><code>DataLoader</code></a>, and set <code>pin_memory=True</code> to speed up data transfer to the accelerator during training if your dataset samples are on a CPU.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> default_data_collator | |
| train_ds = processed_ds[<span class="hljs-string">"train"</span>] | |
| eval_ds = processed_ds[<span class="hljs-string">"validation"</span>] | |
| batch_size = <span class="hljs-number">8</span> | |
| train_dataloader = DataLoader( | |
| train_ds, shuffle=<span class="hljs-literal">True</span>, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=<span class="hljs-literal">True</span> | |
| ) | |
| eval_dataloader = DataLoader(eval_ds, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model</span></h2> <p data-svelte-h="svelte-1d06z94">Now you can load a pretrained model to use as the base model for IA3. This guide uses the <a href="https://huggingface.co/bigscience/mt0-large" rel="nofollow">bigscience/mt0-large</a> model, but you can use any sequence-to-sequence model you like.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSeq2SeqLM | |
| model = AutoModelForSeq2SeqLM.from_pretrained(<span class="hljs-string">"bigscience/mt0-large"</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="peft-configuration-and-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#peft-configuration-and-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PEFT configuration and model</span></h3> <p data-svelte-h="svelte-rlstu1">All PEFT methods need a configuration that contains and specifies all the parameters for how the PEFT method should be applied. Create an <a href="/docs/peft/pr_3207/en/package_reference/ia3#peft.IA3Config">IA3Config</a> with the task type and set the inference mode to <code>False</code>. You can find additional parameters for this configuration in the <a href="../package_reference/ia3#ia3config">API reference</a>.</p> <blockquote class="tip" data-svelte-h="svelte-1gmi3mo"><p>Call the <a href="/docs/peft/pr_3207/en/package_reference/peft_model#peft.PeftModel.print_trainable_parameters">print_trainable_parameters()</a> method to compare the number of trainable parameters of <a href="/docs/peft/pr_3207/en/package_reference/peft_model#peft.PeftModel">PeftModel</a> versus the number of parameters in the base model!</p></blockquote> <p data-svelte-h="svelte-1ie59t8">Once the configuration is setup, pass it to the <a href="/docs/peft/pr_3207/en/package_reference/peft_model#peft.get_peft_model">get_peft_model()</a> function along with the base model to create a trainable <a href="/docs/peft/pr_3207/en/package_reference/peft_model#peft.PeftModel">PeftModel</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> IA3Config, get_peft_model | |
| peft_config = IA3Config(task_type=<span class="hljs-string">"SEQ_2_SEQ_LM"</span>) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| <span class="hljs-string">"trainable params: 282,624 || all params: 1,229,863,936 || trainable%: 0.022980103060766553"</span><!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training</span></h3> <p data-svelte-h="svelte-tlkvop">Set up an optimizer and learning rate scheduler.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_linear_schedule_with_warmup | |
| lr = <span class="hljs-number">8e-3</span> | |
| num_epochs = <span class="hljs-number">3</span> | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr) | |
| lr_scheduler = get_linear_schedule_with_warmup( | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=(<span class="hljs-built_in">len</span>(train_dataloader) * num_epochs), | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-166cxn8">Move the model to the accelerator and create a training loop that reports the loss and perplexity for each epoch.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tqdm <span class="hljs-keyword">import</span> tqdm | |
| device = torch.accelerator.current_accelerator().<span class="hljs-built_in">type</span> <span class="hljs-keyword">if</span> <span class="hljs-built_in">hasattr</span>(torch, <span class="hljs-string">"accelerator"</span>) <span class="hljs-keyword">else</span> <span class="hljs-string">"cuda"</span> | |
| model = model.to(device) | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| model.train() | |
| total_loss = <span class="hljs-number">0</span> | |
| <span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(tqdm(train_dataloader)): | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| total_loss += loss.detach().<span class="hljs-built_in">float</span>() | |
| loss.backward() | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| model.<span class="hljs-built_in">eval</span>() | |
| eval_loss = <span class="hljs-number">0</span> | |
| eval_preds = [] | |
| <span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(tqdm(eval_dataloader)): | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| eval_loss += loss.detach().<span class="hljs-built_in">float</span>() | |
| eval_preds.extend( | |
| tokenizer.batch_decode(torch.argmax(outputs.logits, -<span class="hljs-number">1</span>).detach().cpu().numpy(), skip_special_tokens=<span class="hljs-literal">True</span>) | |
| ) | |
| eval_epoch_loss = eval_loss / <span class="hljs-built_in">len</span>(eval_dataloader) | |
| eval_ppl = torch.exp(eval_epoch_loss) | |
| train_epoch_loss = total_loss / <span class="hljs-built_in">len</span>(train_dataloader) | |
| train_ppl = torch.exp(train_epoch_loss) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"<span class="hljs-subst">{epoch=}</span>: <span class="hljs-subst">{train_ppl=}</span> <span class="hljs-subst">{train_epoch_loss=}</span> <span class="hljs-subst">{eval_ppl=}</span> <span class="hljs-subst">{eval_epoch_loss=}</span>"</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="share-your-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#share-your-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Share your model</span></h2> <p data-svelte-h="svelte-q5liq1">After training is complete, you can upload your model to the Hub with the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.push_to_hub" rel="nofollow">push_to_hub</a> method. You’ll need to login to your Hugging Face account first and enter your token when prompted.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login | |
| account = <your-hf-account-name> | |
| peft_model_id = <span class="hljs-string">f"<span class="hljs-subst">{account}</span>/mt0-large-ia3"</span> | |
| model.push_to_hub(peft_model_id)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Inference</span></h2> <p data-svelte-h="svelte-uypnpm">To load the model for inference, use the <a href="/docs/peft/pr_3207/en/package_reference/auto_class#peft.AutoPeftModel.from_pretrained">from_pretrained()</a> method. Let’s also load a sentence of financial news from the dataset to generate a sentiment for.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> AutoPeftModelForSeq2SeqLM | |
| device = torch.accelerator.current_accelerator().<span class="hljs-built_in">type</span> <span class="hljs-keyword">if</span> <span class="hljs-built_in">hasattr</span>(torch, <span class="hljs-string">"accelerator"</span>) <span class="hljs-keyword">else</span> <span class="hljs-string">"cuda"</span> | |
| model = AutoPeftModelForSeq2SeqLM.from_pretrained(<span class="hljs-string">"<your-hf-account-name>/mt0-large-ia3"</span>).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"bigscience/mt0-large"</span>) | |
| i = <span class="hljs-number">15</span> | |
| inputs = tokenizer(ds[<span class="hljs-string">"validation"</span>][text_column][i], return_tensors=<span class="hljs-string">"pt"</span>) | |
| <span class="hljs-built_in">print</span>(ds[<span class="hljs-string">"validation"</span>][text_column][i]) | |
| <span class="hljs-string">"The robust growth was the result of the inclusion of clothing chain Lindex in the Group in December 2007 ."</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1dkx89x">Call the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate" rel="nofollow">generate</a> method to generate the predicted sentiment label.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">with</span> torch.no_grad(): | |
| inputs = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> inputs.items()} | |
| outputs = model.generate(input_ids=inputs[<span class="hljs-string">"input_ids"</span>], max_new_tokens=<span class="hljs-number">10</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=<span class="hljs-literal">True</span>)) | |
| [<span class="hljs-string">'positive'</span>]<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/peft/blob/main/docs/source/task_guides/ia3.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1wz1iqk = { | |
| assets: "/docs/peft/pr_3207/en", | |
| base: "/docs/peft/pr_3207/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/peft/pr_3207/en/_app/immutable/entry/start.dbacf4b8.js"), | |
| import("/docs/peft/pr_3207/en/_app/immutable/entry/app.cb03e0d4.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 70], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 43.5 kB
- Xet hash:
- 2bbc1a2d9cc8dd3ecf06a00e4e421c283a7321660a021108fb0fe2fad14b5a3e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.