Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"GPU","local":"gpu","sections":[{"title":"bitsandbytes","local":"bitsandbytes","sections":[],"depth":2},{"title":"Optimum","local":"optimum","sections":[],"depth":2},{"title":"Scaled dot product attention (SDPA)","local":"scaled-dot-product-attention-sdpa","sections":[],"depth":2},{"title":"FlashAttention","local":"flashattention","sections":[{"title":"Benchmarks","local":"benchmarks","sections":[],"depth":3}],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_33892/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/scheduler.31fdf58d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/singletons.9860629f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.252883d5.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/paths.e85c0ec8.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/preload-helper.40847a0e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.2f76fdf0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/0.ca4aafa4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/506.d9e4fd22.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CopyLLMTxtMenu.ff482081.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.71f274cc.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/IconCopy.ac192424.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CodeBlock.ab12f8e1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/HfOption.fb051768.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"GPU","local":"gpu","sections":[{"title":"bitsandbytes","local":"bitsandbytes","sections":[],"depth":2},{"title":"Optimum","local":"optimum","sections":[],"depth":2},{"title":"Scaled dot product attention (SDPA)","local":"scaled-dot-product-attention-sdpa","sections":[],"depth":2},{"title":"FlashAttention","local":"flashattention","sections":[{"title":"Benchmarks","local":"benchmarks","sections":[],"depth":3}],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="gpu" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gpu"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GPU</span></h1> <p data-svelte-h="svelte-1fwjvck">GPUs are the standard hardware for machine learning because they’re optimized for memory bandwidth and parallelism. With the increasing sizes of modern models, it’s more important than ever to make sure GPUs are capable of efficiently handling and delivering the best possible performance.</p> <p data-svelte-h="svelte-1abs0sh">This guide will demonstrate a few ways to optimize inference on a GPU. The optimization methods shown below can be combined with each other to achieve even better performance, and they also work for distributed GPUs.</p> <h2 class="relative group"><a id="bitsandbytes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#bitsandbytes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>bitsandbytes</span></h2> <p data-svelte-h="svelte-1n5togo"><a href="https://hf.co/docs/bitsandbytes/index" rel="nofollow">bitsandbytes</a> is a quantization library that supports 8-bit and 4-bit quantization. Quantization represents weights in a lower precision compared to the original full precision format. It reduces memory requirements and makes it easier to fit large model into memory.</p> <p data-svelte-h="svelte-1b4ixl9">Make sure bitsandbytes and Accelerate are installed first.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install bitsandbytes accelerate<!-- HTML_TAG_END --></pre></div> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">8-bit </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">4-bit </div></div> <div class="language-select"><p data-svelte-h="svelte-1117s4d">For text generation with 8-bit quantization, you should use <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a> instead of the high-level <a href="/docs/transformers/pr_33892/en/main_classes/pipelines#transformers.Pipeline">Pipeline</a> API. The <a href="/docs/transformers/pr_33892/en/main_classes/pipelines#transformers.Pipeline">Pipeline</a> returns slower performance because it isn’t optimized for 8-bit models, and some sampling strategies (nucleus sampling) also aren’t supported.</p> <p data-svelte-h="svelte-p5oeh0">Set up a <a href="/docs/transformers/pr_33892/en/main_classes/quantization#transformers.BitsAndBytesConfig">BitsAndBytesConfig</a> and set <code>load_in_8bit=True</code> to load a model in 8-bit precision. The <a href="/docs/transformers/pr_33892/en/main_classes/quantization#transformers.BitsAndBytesConfig">BitsAndBytesConfig</a> is passed to the <code>quantization_config</code> parameter in <a href="/docs/transformers/pr_33892/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a>.</p> <p data-svelte-h="svelte-189jfcc">Allow Accelerate to automatically distribute the model across your available hardware by setting <a href="https://hf.co/docs/accelerate/concept_guides/big_model_inference#designing-a-device-map" rel="nofollow">device_map=“auto”</a>.</p> <p data-svelte-h="svelte-6y3wld">Place all inputs on the same device as the model.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=<span class="hljs-literal">True</span>) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B"</span>, device_map=<span class="hljs-string">"auto"</span>, quantization_config=quantization_config) | |
| prompt = <span class="hljs-string">"Hello, my llama is cute"</span> | |
| inputs = tokenizer(prompt, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| generated_ids = model.generate(**inputs) | |
| outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-m5ityn">For distributed setups, use the <code>max_memory</code> parameter to create a mapping of the amount of memory to allocate to each GPU. The example below distributes 16GB of memory to the first GPU and 16GB of memory to the second GPU.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->max_memory_mapping = {<span class="hljs-number">0</span>: <span class="hljs-string">"16GB"</span>, <span class="hljs-number">1</span>: <span class="hljs-string">"16GB"</span>} | |
| model_8bit = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.1-8B"</span>, device_map=<span class="hljs-string">"auto"</span>, quantization_config=quantization_config, max_memory=max_memory_mapping | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xjsovs">Learn in more detail the concepts underlying 8-bit quantization in the <a href="https://hf.co/blog/hf-bitsandbytes-integration" rel="nofollow">Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes</a> blog post.</p> </div> <h2 class="relative group"><a id="optimum" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#optimum"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Optimum</span></h2> <p data-svelte-h="svelte-1p718jw"><a href="https://hf.co/docs/optimum/en/index" rel="nofollow">Optimum</a> is a Hugging Face library focused on optimizing model performance across various hardware. It supports <a href="https://onnxruntime.ai/docs/" rel="nofollow">ONNX Runtime</a> (ORT), a model accelerator, for a wide range of hardware and frameworks including NVIDIA GPUs and AMD GPUs that use the <a href="https://www.amd.com/en/products/software/rocm.html" rel="nofollow">ROCm</a> stack.</p> <p data-svelte-h="svelte-pvrbx">ORT uses optimization techniques that fuse common operations into a single node and constant folding to reduce the number of computations. ORT also places the most computationally intensive operations on the GPU and the rest on the CPU to intelligently distribute the workload between the two devices.</p> <p data-svelte-h="svelte-1uimet6">Optimum provides the <code>ORTModel</code> class for loading ONNX models. Set the <code>provider</code> parameter according to the table below.</p> <table data-svelte-h="svelte-1wm1a3t"><thead><tr><th>provider</th> <th>hardware</th></tr></thead> <tbody><tr><td><a href="https://hf.co/docs/optimum/main/en/onnxruntime/usage_guides/gpu#cudaexecutionprovider" rel="nofollow">CUDAExecutionProvider</a></td> <td>CUDA-enabled GPUs</td></tr> <tr><td><a href="https://hf.co/docs/optimum/onnxruntime/usage_guides/amdgpu" rel="nofollow">ROCMExecutionProvider</a></td> <td>AMD Instinct, Radeon Pro, Radeon GPUs</td></tr> <tr><td><a href="https://hf.co/docs/optimum/onnxruntime/usage_guides/gpu#tensorrtexecutionprovider" rel="nofollow">TensorrtExecutionProvider</a></td> <td>TensorRT</td></tr></tbody></table> <p data-svelte-h="svelte-1pvihnl">For example, load the <a href="https://hf.co/optimum/roberta-base-squad2" rel="nofollow">distilbert/distilbert-base-uncased-finetuned-sst-2-english</a> checkpoint for sequence classification. This checkpoint contains a <a href="https://hf.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english/blob/main/onnx/model.onnx" rel="nofollow">model.onnx</a> file. If a checkpoint doesn’t have a <code>model.onnx</code> file, set <code>export=True</code> to convert a checkpoint on the fly to the ONNX format.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> optimum.onnxruntime <span class="hljs-keyword">import</span> ORTModelForSequenceClassification | |
| ort_model = ORTModelForSequenceClassification.from_pretrained( | |
| <span class="hljs-string">"distilbert/distilbert-base-uncased-finetuned-sst-2-english"</span>, | |
| <span class="hljs-comment">#export=True,</span> | |
| provider=<span class="hljs-string">"CUDAExecutionProvider"</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s7pmol">Now you can use the model for inference in a <a href="/docs/transformers/pr_33892/en/main_classes/pipelines#transformers.Pipeline">Pipeline</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> optimum.pipelines <span class="hljs-keyword">import</span> pipeline | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"distilbert/distilbert-base-uncased-finetuned-sst-2-english"</span>) | |
| pipeline = pipeline(task=<span class="hljs-string">"text-classification"</span>, model=ort_model, tokenizer=tokenizer, device=<span class="hljs-string">"cuda:0"</span>) | |
| result = pipeline(<span class="hljs-string">"Both the music and visual were astounding, not to mention the actors performance."</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-tp1950">Learn more details about using ORT with Optimum in the <a href="https://hf.co/docs/optimum/onnxruntime/usage_guides/gpu#accelerated-inference-on-nvidia-gpus" rel="nofollow">Accelerated inference on NVIDIA GPUs</a> and <a href="https://hf.co/docs/optimum/onnxruntime/usage_guides/amdgpu#accelerated-inference-on-amd-gpus" rel="nofollow">Accelerated inference on AMD GPUs</a> guides.</p> <h2 class="relative group"><a id="scaled-dot-product-attention-sdpa" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#scaled-dot-product-attention-sdpa"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Scaled dot product attention (SDPA)</span></h2> <p data-svelte-h="svelte-1vc1lf7">PyTorch’s <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" rel="nofollow">torch.nn.functional.scaled_dot_product_attention</a> (SDPA) is a native implementation of the scaled dot product attention mechanism. SDPA is a more efficient and optimized version of the attention mechanism used in transformer models.</p> <p data-svelte-h="svelte-1wn2wyw">There are three supported implementations available.</p> <ul data-svelte-h="svelte-n7w2td"><li><a href="https://github.com/Dao-AILab/flash-attention" rel="nofollow">FlashAttention2</a> only supports models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate type first.</li> <li><a href="https://github.com/facebookresearch/xformers" rel="nofollow">xFormers</a> or Memory-Efficient Attention is able to support models with the fp32 torch type.</li> <li>C++ implementation of scaled dot product attention</li></ul> <p data-svelte-h="svelte-17qunum">SDPA is used by default for PyTorch v2.1.1. and greater when an implementation is available. You could explicitly enable SDPA by setting <code>attn_implementation="sdpa"</code> in <a href="/docs/transformers/pr_33892/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> though. Certain attention parameters, such as <code>output_attentions=True</code>, are unsupported and returns a warning that Transformers will fall back to the (slower) eager implementation.</p> <p data-svelte-h="svelte-xjpd78">Refer to the <a href="./attention_interface">AttentionInterface</a> guide to learn how to change the attention implementation after loading a model.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B"</span>, device_map=<span class="hljs-string">"auto"</span>, attn_implementation=<span class="hljs-string">"sdpa"</span>) | |
| <span class="hljs-comment"># Change the model's attention dynamically after loading it</span> | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B"</span>, device_map=<span class="hljs-string">"auto"</span>) | |
| model.set_attention_implementation(<span class="hljs-string">"sdpa"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1kdwdpa">SDPA selects the most performant implementation available, but you can also explicitly select an implementation with <a href="https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel" rel="nofollow">torch.nn.attention.sdpa_kernel</a> as a context manager. The example below shows how to enable the FlashAttention2 implementation with <code>enable_flash=True</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> torch.nn.attention <span class="hljs-keyword">import</span> SDPBackend, sdpa_kernel | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B"</span>, device_map=<span class="hljs-string">"auto"</span>) | |
| input_text = <span class="hljs-string">"Hello, my llama is cute"</span> | |
| inputs = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| <span class="hljs-keyword">with</span> sdpa_kernel(SDPBackend.FLASH_ATTENTION): | |
| outputs = model.generate(**inputs) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(outputs[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1r8c7l7">If you encounter the following <code>RuntimeError</code>, try installing the nightly version of PyTorch which has broader coverage for FlashAttention.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->RuntimeError: No available kernel. Aborting execution. | |
| pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="flashattention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#flashattention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>FlashAttention</span></h2> <p data-svelte-h="svelte-1a1ytd4"><a href="https://github.com/Dao-AILab/flash-attention" rel="nofollow">FlashAttention</a> is also available as a standalone package. It can significantly speed up inference by:</p> <ol data-svelte-h="svelte-1t56p9w"><li>additionally parallelizing the attention computation over sequence length</li> <li>partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them</li></ol> <p data-svelte-h="svelte-1b1vtgy">Install FlashAttention first for the hardware you’re using.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">NVIDIA </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">AMD </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install flash-attn --no-build-isolation<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-12r58pl">Enable FlashAttention2 by setting <code>attn_implementation="flash_attention_2"</code> in <a href="/docs/transformers/pr_33892/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> or by setting <code>model.set_attention_implementation("flash_attention_2")</code> to dynamically update the <a href="./attention_interface">attention interface</a>. FlashAttention2 is only supported for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate data type first.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B"</span>, device_map=<span class="hljs-string">"auto"</span>, dtype=torch.bfloat16, attn_implementation=<span class="hljs-string">"flash_attention_2"</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="benchmarks" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#benchmarks"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Benchmarks</span></h3> <p data-svelte-h="svelte-1dcldcy">FlashAttention2 speeds up inference considerably especially for inputs with long sequences. However, since FlashAttention2 doesn’t support computing attention scores with padding tokens, you must manually pad and unpad the attention scores for batched inference if a sequence contains padding tokens. The downside is batched generation is slower with padding tokens.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">short sequence length </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">long sequence length </div></div> <div class="language-select"><p data-svelte-h="svelte-1c98iqw">With a relatively small sequence length, a single forward pass creates overhead leading to a small speed up. The graph below shows the expected speed up for a single forward pass with <a href="https://hf.co/meta-llama/Llama-7b-hf" rel="nofollow">meta-llama/Llama-7b-hf</a> with padding.</p> <div class="flex justify-center" data-svelte-h="svelte-yereht"><img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-small-seqlen-padding.png"></div> </div> <p data-svelte-h="svelte-ox9wk">To avoid this slowdown, use FlashAttention2 without padding tokens in the sequence during training. Pack the dataset or concatenate sequences until reaching the maximum sequence length.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">tiiuae/falcon-7b </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">meta-llama/Llama-7b-hf </div></div> <div class="language-select"><p data-svelte-h="svelte-1dnv8z">The graph below shows the expected speed up for a single forward pass with <a href="https://hf.co/tiiuae/falcon-7b" rel="nofollow">tiiuae/falcon-7b</a> with a sequence length of 4096 and various batch sizes without padding tokens.</p> <div class="flex justify-center" data-svelte-h="svelte-s5vof8"><img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png"></div> </div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/perf_infer_gpu_one.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_16tnnm8 = { | |
| assets: "/docs/transformers/pr_33892/en", | |
| base: "/docs/transformers/pr_33892/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"), | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 506], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 43.7 kB
- Xet hash:
- f2fb279e8aceb1f3373fce9b404724e19c3a1a9a32f55f6248ea898be7599fa3
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.