Buckets:

hf-doc-build/doc-dev / peft /pr_3207 /en /task_guides /lora_based_methods.html
download
raw
52.5 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;LoRA methods&quot;,&quot;local&quot;:&quot;lora-methods&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Dataset&quot;,&quot;local&quot;:&quot;dataset&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Model&quot;,&quot;local&quot;:&quot;model&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;PEFT configuration and model&quot;,&quot;local&quot;:&quot;peft-configuration-and-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Training&quot;,&quot;local&quot;:&quot;training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Share your model&quot;,&quot;local&quot;:&quot;share-your-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/peft/pr_3207/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/entry/start.dbacf4b8.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/scheduler.78382b47.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/singletons.2397427a.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/index.fadd215c.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/paths.6fea2ad6.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/entry/app.cb03e0d4.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/preload-helper.2f407600.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/index.6dd35eb6.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/nodes/0.18735116.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/nodes/71.cae5f1db.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.d25d6883.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/CodeBlock.147ab5db.js">
<link rel="modulepreload" href="/docs/peft/pr_3207/en/_app/immutable/chunks/HfOption.a1db6210.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;LoRA methods&quot;,&quot;local&quot;:&quot;lora-methods&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Dataset&quot;,&quot;local&quot;:&quot;dataset&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Model&quot;,&quot;local&quot;:&quot;model&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;PEFT configuration and model&quot;,&quot;local&quot;:&quot;peft-configuration-and-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Training&quot;,&quot;local&quot;:&quot;training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Share your model&quot;,&quot;local&quot;:&quot;share-your-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="lora-methods" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#lora-methods"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>LoRA methods</span></h1> <p data-svelte-h="svelte-3c4mx2">A popular way to efficiently train large models is to insert (typically in the attention blocks) smaller trainable matrices that are a low-rank decomposition of the delta weight matrix to be learnt during finetuning. The pretrained model’s original weight matrix is frozen and only the smaller matrices are updated during training. This reduces the number of trainable parameters, reducing memory usage and training time which can be very expensive for large models.</p> <p data-svelte-h="svelte-1mec8a1">There are several different ways to express the weight matrix as a low-rank decomposition, but <a href="../conceptual_guides/adapter#low-rank-adaptation-lora">Low-Rank Adaptation (LoRA)</a> is the most common method. The PEFT library supports several other LoRA variants, such as <a href="../conceptual_guides/adapter#low-rank-hadamard-product-loha">Low-Rank Hadamard Product (LoHa)</a>, <a href="../conceptual_guides/adapter#low-rank-kronecker-product-lokr">Low-Rank Kronecker Product (LoKr)</a>, and <a href="../conceptual_guides/adapter#adaptive-low-rank-adaptation-adalora">Adaptive Low-Rank Adaptation (AdaLoRA)</a>. You can learn more about how these methods work conceptually in the <a href="../conceptual_guides/adapter">Adapters</a> guide. If you’re interested in applying these methods to other tasks and use cases like semantic segmentation, token classification, take a look at our <a href="https://huggingface.co/collections/PEFT/notebooks-6573b28b33e5a4bf5b157fc1" rel="nofollow">notebook collection</a>!</p> <p data-svelte-h="svelte-wawyyc">Additionally, PEFT supports the <a href="../conceptual_guides/adapter#mixture-of-lora-experts-x-lora">X-LoRA</a> Mixture of LoRA Experts method.</p> <p data-svelte-h="svelte-3yvnz9">This guide will show you how to quickly train an image classification model - with a low-rank decomposition method - to identify the class of food shown in an image.</p> <blockquote class="tip" data-svelte-h="svelte-bg3xwr"><p>Some familiarity with the general process of training an image classification model would be really helpful and allow you to focus on the low-rank decomposition methods. If you’re new, we recommend taking a look at the <a href="https://huggingface.co/docs/transformers/tasks/image_classification" rel="nofollow">Image classification</a> guide first from the Transformers documentation. When you’re ready, come back and see how easy it is to drop PEFT in to your training!</p></blockquote> <p data-svelte-h="svelte-1rdzhb1">Before you begin, make sure you have all the necessary libraries installed.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install -q peft transformers datasets<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Dataset</span></h2> <p data-svelte-h="svelte-l33v6t">In this guide, you’ll use the <a href="https://huggingface.co/datasets/food101" rel="nofollow">Food-101</a> dataset which contains images of 101 food classes (take a look at the <a href="https://huggingface.co/datasets/food101/viewer/default/train" rel="nofollow">dataset viewer</a> to get a better idea of what the dataset looks like).</p> <p data-svelte-h="svelte-zwwqy7">Load the dataset with the <a href="https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset" rel="nofollow">load_dataset</a> function.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
ds = load_dataset(<span class="hljs-string">&quot;food101&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1a5q7dy">Each food class is labeled with an integer, so to make it easier to understand what these integers represent, you’ll create a <code>label2id</code> and <code>id2label</code> dictionary to map the integer to its class label.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->labels = ds[<span class="hljs-string">&quot;train&quot;</span>].features[<span class="hljs-string">&quot;label&quot;</span>].names
label2id, id2label = <span class="hljs-built_in">dict</span>(), <span class="hljs-built_in">dict</span>()
<span class="hljs-keyword">for</span> i, label <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(labels):
label2id[label] = i
id2label[i] = label
id2label[<span class="hljs-number">2</span>]
<span class="hljs-string">&quot;baklava&quot;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-13cytaa">Load an image processor to properly resize and normalize the pixel values of the training and evaluation images.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoImageProcessor
image_processor = AutoImageProcessor.from_pretrained(<span class="hljs-string">&quot;google/vit-base-patch16-224-in21k&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1p1x6w4">You can also use the image processor to prepare some transformation functions for data augmentation and pixel scaling.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torchvision.transforms <span class="hljs-keyword">import</span> (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
[
RandomResizedCrop(image_processor.size[<span class="hljs-string">&quot;height&quot;</span>]),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
val_transforms = Compose(
[
Resize(image_processor.size[<span class="hljs-string">&quot;height&quot;</span>]),
CenterCrop(image_processor.size[<span class="hljs-string">&quot;height&quot;</span>]),
ToTensor(),
normalize,
]
)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_train</span>(<span class="hljs-params">example_batch</span>):
example_batch[<span class="hljs-string">&quot;pixel_values&quot;</span>] = [train_transforms(image.convert(<span class="hljs-string">&quot;RGB&quot;</span>)) <span class="hljs-keyword">for</span> image <span class="hljs-keyword">in</span> example_batch[<span class="hljs-string">&quot;image&quot;</span>]]
<span class="hljs-keyword">return</span> example_batch
<span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_val</span>(<span class="hljs-params">example_batch</span>):
example_batch[<span class="hljs-string">&quot;pixel_values&quot;</span>] = [val_transforms(image.convert(<span class="hljs-string">&quot;RGB&quot;</span>)) <span class="hljs-keyword">for</span> image <span class="hljs-keyword">in</span> example_batch[<span class="hljs-string">&quot;image&quot;</span>]]
<span class="hljs-keyword">return</span> example_batch<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xbqbtr">Define the training and validation datasets, and use the <a href="https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.set_transform" rel="nofollow">set_transform</a> function to apply the transformations on-the-fly.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->train_ds = ds[<span class="hljs-string">&quot;train&quot;</span>]
val_ds = ds[<span class="hljs-string">&quot;validation&quot;</span>]
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mznco7">Finally, you’ll need a data collator to create a batch of training and evaluation data and convert the labels to <code>torch.tensor</code> objects.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">def</span> <span class="hljs-title function_">collate_fn</span>(<span class="hljs-params">examples</span>):
pixel_values = torch.stack([example[<span class="hljs-string">&quot;pixel_values&quot;</span>] <span class="hljs-keyword">for</span> example <span class="hljs-keyword">in</span> examples])
labels = torch.tensor([example[<span class="hljs-string">&quot;label&quot;</span>] <span class="hljs-keyword">for</span> example <span class="hljs-keyword">in</span> examples])
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;pixel_values&quot;</span>: pixel_values, <span class="hljs-string">&quot;labels&quot;</span>: labels}<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model</span></h2> <p data-svelte-h="svelte-akpunr">Now let’s load a pretrained model to use as the base model. This guide uses the <a href="https://huggingface.co/google/vit-base-patch16-224-in21k" rel="nofollow">google/vit-base-patch16-224-in21k</a> model, but you can use any image classification model you want. Pass the <code>label2id</code> and <code>id2label</code> dictionaries to the model so it knows how to map the integer labels to their class labels, and you can optionally pass the <code>ignore_mismatched_sizes=True</code> parameter if you’re finetuning a checkpoint that has already been finetuned.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForImageClassification, TrainingArguments, Trainer
model = AutoModelForImageClassification.from_pretrained(
<span class="hljs-string">&quot;google/vit-base-patch16-224-in21k&quot;</span>,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="peft-configuration-and-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#peft-configuration-and-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PEFT configuration and model</span></h3> <p data-svelte-h="svelte-19j1yy8">Every PEFT method requires a configuration that holds all the parameters specifying how the PEFT method should be applied. Once the configuration is setup, pass it to the <a href="/docs/peft/pr_3207/en/package_reference/peft_model#peft.get_peft_model">get_peft_model()</a> function along with the base model to create a trainable <a href="/docs/peft/pr_3207/en/package_reference/peft_model#peft.PeftModel">PeftModel</a>.</p> <blockquote class="tip" data-svelte-h="svelte-135jo6"><p>Call the <a href="/docs/peft/pr_3207/en/package_reference/peft_model#peft.PeftModel.print_trainable_parameters">print_trainable_parameters()</a> method to compare the number of parameters of <a href="/docs/peft/pr_3207/en/package_reference/peft_model#peft.PeftModel">PeftModel</a> versus the number of parameters in the base model!</p></blockquote> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">LoRA </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">LoHa </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">LoKr </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">AdaLoRA </div></div> <div class="language-select"><p data-svelte-h="svelte-rdh6sn"><a href="../conceptual_guides/adapter#low-rank-adaptation-lora">LoRA</a> decomposes the weight update matrix into <em>two</em> smaller matrices. The size of these low-rank matrices is determined by its <em>rank</em> or <code>r</code>. A higher rank means the model has more parameters to train, but it also means the model has more learning capacity. You’ll also want to specify the <code>target_modules</code> which determine where the smaller matrices are inserted. For this guide, you’ll target the <em>query</em> and <em>value</em> matrices of the attention blocks. Other important parameters to set are <code>lora_alpha</code> (scaling factor), <code>bias</code> (whether <code>none</code>, <code>all</code> or only the LoRA bias parameters should be trained), and <code>modules_to_save</code> (the modules apart from the LoRA layers to be trained and saved). All of these parameters - and more - are found in the <a href="/docs/peft/pr_3207/en/package_reference/lora#peft.LoraConfig">LoraConfig</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> LoraConfig, get_peft_model
config = LoraConfig(
r=<span class="hljs-number">16</span>,
lora_alpha=<span class="hljs-number">16</span>,
target_modules=[<span class="hljs-string">&quot;query&quot;</span>, <span class="hljs-string">&quot;value&quot;</span>],
lora_dropout=<span class="hljs-number">0.1</span>,
bias=<span class="hljs-string">&quot;none&quot;</span>,
modules_to_save=[<span class="hljs-string">&quot;classifier&quot;</span>],
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
<span class="hljs-string">&quot;trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.7712775047664294&quot;</span><!-- HTML_TAG_END --></pre></div> </div> <h3 class="relative group"><a id="training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training</span></h3> <p data-svelte-h="svelte-wcjxic">For training, let’s use the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer" rel="nofollow">Trainer</a> class from Transformers. The <code>Trainer</code> contains a PyTorch training loop, and when you’re ready, call <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.train" rel="nofollow">train</a> to start training. To customize the training run, configure the training hyperparameters in the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments" rel="nofollow">TrainingArguments</a> class. With LoRA-like methods, you can afford to use a higher batch size and learning rate.</p> <blockquote class="warning" data-svelte-h="svelte-1324k35"><p>AdaLoRA has an <a href="/docs/peft/pr_3207/en/package_reference/adalora#peft.AdaLoraModel.update_and_allocate">update_and_allocate()</a> method that should be called at each training step to update the parameter budget and mask, otherwise the adaptation step is not performed. This requires writing a custom training loop or subclassing the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer" rel="nofollow">Trainer</a> to incorporate this method. As an example, take a look at this <a href="https://github.com/huggingface/peft/blob/912ad41e96e03652cabf47522cd876076f7a0c4f/examples/conditional_generation/peft_adalora_seq2seq.py#L120" rel="nofollow">custom training loop</a>.</p></blockquote> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, Trainer
account = <span class="hljs-string">&quot;stevhliu&quot;</span>
peft_model_id = <span class="hljs-string">f&quot;<span class="hljs-subst">{account}</span>/google/vit-base-patch16-224-in21k-lora&quot;</span>
batch_size = <span class="hljs-number">128</span>
args = TrainingArguments(
peft_model_id,
remove_unused_columns=<span class="hljs-literal">False</span>,
eval_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
save_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
learning_rate=<span class="hljs-number">5e-3</span>,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=<span class="hljs-number">4</span>,
per_device_eval_batch_size=batch_size,
fp16=<span class="hljs-literal">True</span>,
num_train_epochs=<span class="hljs-number">5</span>,
logging_steps=<span class="hljs-number">10</span>,
load_best_model_at_end=<span class="hljs-literal">True</span>,
label_names=[<span class="hljs-string">&quot;labels&quot;</span>],
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ntqszh">Begin training with <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.train" rel="nofollow">train</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer = Trainer(
model,
args,
train_dataset=train_ds,
eval_dataset=val_ds,
processing_class=image_processor,
data_collator=collate_fn,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="share-your-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#share-your-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Share your model</span></h2> <p data-svelte-h="svelte-vaxac6">Once training is complete, you can upload your model to the Hub with the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.push_to_hub" rel="nofollow">push_to_hub</a> method. You’ll need to login to your Hugging Face account first and enter your token when prompted.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login
notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1qc4exu">Call <a href="https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.push_to_hub" rel="nofollow">push_to_hub</a> to save your model to your repositoy.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model.push_to_hub(peft_model_id)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Inference</span></h2> <p data-svelte-h="svelte-1whrqzq">Let’s load the model from the Hub and test it out on a food image.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> PeftConfig, PeftModel
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoImageProcessor
<span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image
<span class="hljs-keyword">import</span> requests
config = PeftConfig.from_pretrained(<span class="hljs-string">&quot;stevhliu/vit-base-patch16-224-in21k-lora&quot;</span>)
model = AutoModelForImageClassification.from_pretrained(
config.base_model_name_or_path,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=<span class="hljs-literal">True</span>,
)
model = PeftModel.from_pretrained(model, <span class="hljs-string">&quot;stevhliu/vit-base-patch16-224-in21k-lora&quot;</span>)
url = <span class="hljs-string">&quot;https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/beignets.jpeg&quot;</span>
image = Image.<span class="hljs-built_in">open</span>(requests.get(url, stream=<span class="hljs-literal">True</span>).raw)
image<!-- HTML_TAG_END --></pre></div> <div class="flex justify-center" data-svelte-h="svelte-sucph9"><img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/beignets.jpeg"></div> <p data-svelte-h="svelte-1nlb27h">Convert the image to RGB and return the underlying PyTorch tensors.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->encoding = image_processor(image.convert(<span class="hljs-string">&quot;RGB&quot;</span>), return_tensors=<span class="hljs-string">&quot;pt&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1mvhoi">Now run the model and return the predicted class!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">with</span> torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-<span class="hljs-number">1</span>).item()
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Predicted class:&quot;</span>, model.config.id2label[predicted_class_idx])
<span class="hljs-string">&quot;Predicted class: beignets&quot;</span><!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/peft/blob/main/docs/source/task_guides/lora_based_methods.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1wz1iqk = {
assets: "/docs/peft/pr_3207/en",
base: "/docs/peft/pr_3207/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/peft/pr_3207/en/_app/immutable/entry/start.dbacf4b8.js"),
import("/docs/peft/pr_3207/en/_app/immutable/entry/app.cb03e0d4.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 71],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
52.5 kB
·
Xet hash:
62b85b17feb5f89e46384adde4d99a97b8961c1aa9afc4394b4c6bf52ec6ac5b

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.