Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"GPU inference","local":"gpu-inference","sections":[{"title":"FlashAttention-2","local":"flashattention-2","sections":[{"title":"Expected speedups","local":"expected-speedups","sections":[],"depth":3}],"depth":2},{"title":"PyTorch scaled dot product attention","local":"pytorch-scaled-dot-product-attention","sections":[],"depth":2},{"title":"BetterTransformer","local":"bettertransformer","sections":[],"depth":2},{"title":"bitsandbytes","local":"bitsandbytes","sections":[{"title":"4-bit","local":"4-bit","sections":[],"depth":3},{"title":"8-bit","local":"8-bit","sections":[],"depth":3}],"depth":2},{"title":"🤗 Optimum","local":"-optimum","sections":[],"depth":2},{"title":"Combine optimizations","local":"combine-optimizations","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/scheduler.25b97de1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/singletons.0f2b7d5f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.e188933d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/paths.3d04d2c6.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.d9030fc9.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/0.026d2fdd.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/367.6f219082.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/Tip.baa67368.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/EditOnGithub.91d95064.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/HfOption.1e589c90.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/stores.c3f24f16.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"GPU inference","local":"gpu-inference","sections":[{"title":"FlashAttention-2","local":"flashattention-2","sections":[{"title":"Expected speedups","local":"expected-speedups","sections":[],"depth":3}],"depth":2},{"title":"PyTorch scaled dot product attention","local":"pytorch-scaled-dot-product-attention","sections":[],"depth":2},{"title":"BetterTransformer","local":"bettertransformer","sections":[],"depth":2},{"title":"bitsandbytes","local":"bitsandbytes","sections":[{"title":"4-bit","local":"4-bit","sections":[],"depth":3},{"title":"8-bit","local":"8-bit","sections":[],"depth":3}],"depth":2},{"title":"🤗 Optimum","local":"-optimum","sections":[],"depth":2},{"title":"Combine optimizations","local":"combine-optimizations","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="gpu-inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gpu-inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GPU inference</span></h1> <p data-svelte-h="svelte-1rqqrlq">GPUs are the standard choice of hardware for machine learning, unlike CPUs, because they are optimized for memory bandwidth and parallelism. To keep up with the larger sizes of modern models or to run these large models on existing and older hardware, there are several optimizations you can use to speed up GPU inference. In this guide, you’ll learn how to use FlashAttention-2 (a more memory-efficient attention mechanism), BetterTransformer (a PyTorch native fastpath execution), and bitsandbytes to quantize your model to a lower precision. Finally, learn how to use 🤗 Optimum to accelerate inference with ONNX Runtime on Nvidia and AMD GPUs.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1s2laj2">The majority of the optimizations described here also apply to multi-GPU setups!</p></div> <h2 class="relative group"><a id="flashattention-2" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#flashattention-2"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>FlashAttention-2</span></h2> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-f638sf">FlashAttention-2 is experimental and may change considerably in future versions.</p></div> <p data-svelte-h="svelte-1uzubb3"><a href="https://huggingface.co/papers/2205.14135" rel="nofollow">FlashAttention-2</a> is a faster and more efficient implementation of the standard attention mechanism that can significantly speedup inference by:</p> <ol data-svelte-h="svelte-1t56p9w"><li>additionally parallelizing the attention computation over sequence length</li> <li>partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them</li></ol> <p data-svelte-h="svelte-ivbulk">FlashAttention-2 is currently supported for the following architectures:</p> <ul data-svelte-h="svelte-paac0g"><li><a href="https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel" rel="nofollow">Bark</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel" rel="nofollow">Bart</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon" rel="nofollow">Chameleon</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel" rel="nofollow">CLIP</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel" rel="nofollow">Cohere</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel" rel="nofollow">Dbrx</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel" rel="nofollow">DistilBert</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel" rel="nofollow">Gemma</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model" rel="nofollow">Gemma2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gpt2" rel="nofollow">GPT2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel" rel="nofollow">GPTBigCode</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel" rel="nofollow">GPTNeo</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel" rel="nofollow">GPTNeoX</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel" rel="nofollow">GPT-J</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel" rel="nofollow">Granite</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model" rel="nofollow">Idefics2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel" rel="nofollow">Falcon</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel" rel="nofollow">JetMoe</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel" rel="nofollow">Jamba</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel" rel="nofollow">Llama</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/llava" rel="nofollow">Llava</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/llava_next" rel="nofollow">Llava-NeXT</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/llava_next_video" rel="nofollow">Llava-NeXT-Video</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/llava_onevision" rel="nofollow">LLaVA-Onevision</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/vipllava" rel="nofollow">VipLlava</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/video_llava" rel="nofollow">VideoLlava</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/m2m_100" rel="nofollow">M2M100</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel" rel="nofollow">MBart</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel" rel="nofollow">Mistral</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel" rel="nofollow">Mixtral</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel" rel="nofollow">Musicgen</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel" rel="nofollow">MusicGen Melody</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/nemotron" rel="nofollow">Nemotron</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/nllb" rel="nofollow">NLLB</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel" rel="nofollow">OLMo</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel" rel="nofollow">OLMoE</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel" rel="nofollow">OPT</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel" rel="nofollow">Phi</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model" rel="nofollow">Phi3</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel" rel="nofollow">StableLm</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model" rel="nofollow">Starcoder2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model" rel="nofollow">Qwen2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder" rel="nofollow">Qwen2Audio</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel" rel="nofollow">Qwen2MoE</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel" rel="nofollow">Qwen2VL</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel" rel="nofollow">Whisper</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model" rel="nofollow">Wav2Vec2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel" rel="nofollow">Hubert</a></li> <li><a href="https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel" rel="nofollow">data2vec_audio</a></li> <li><a href="https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel" rel="nofollow">Sew</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/siglip" rel="nofollow">SigLIP</a></li> <li><a href="https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel" rel="nofollow">UniSpeech</a></li> <li><a href="https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel" rel="nofollow">unispeech_sat</a></li></ul> <p data-svelte-h="svelte-1n0m8yy">You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.</p> <p data-svelte-h="svelte-1u2y52a">Before you begin, make sure you have FlashAttention-2 installed.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">NVIDIA </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">AMD </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install flash-attn --no-build-isolation<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-188nv">We strongly suggest referring to the detailed <a href="https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features" rel="nofollow">installation instructions</a> to learn more about supported hardware and data types!</p> </div> <p data-svelte-h="svelte-115y5mk">To enable FlashAttention-2, pass the argument <code>attn_implementation="flash_attention_2"</code> to <a href="/docs/transformers/main/en/model_doc/auto#transformers.AutoModel.from_pretrained">from_pretrained()</a>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | |
| model_id = <span class="hljs-string">"tiiuae/falcon-7b"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation=<span class="hljs-string">"flash_attention_2"</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1npttsz">FlashAttention-2 can only be used when the model’s dtype is <code>fp16</code> or <code>bf16</code>. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.</p> <br> <p data-svelte-h="svelte-19zw97w">You can also set <code>use_flash_attention_2=True</code> to enable FlashAttention-2 but it is deprecated in favor of <code>attn_implementation="flash_attention_2"</code>.</p></div> <p data-svelte-h="svelte-g96vmn">FlashAttention-2 can be combined with other optimization techniques like quantization to further speedup inference. For example, you can combine FlashAttention-2 with 8-bit or 4-bit quantization:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | |
| model_id = <span class="hljs-string">"tiiuae/falcon-7b"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| <span class="hljs-comment"># load in 8bit</span> | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| load_in_8bit=<span class="hljs-literal">True</span>, | |
| attn_implementation=<span class="hljs-string">"flash_attention_2"</span>, | |
| ) | |
| <span class="hljs-comment"># load in 4bit</span> | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| load_in_4bit=<span class="hljs-literal">True</span>, | |
| attn_implementation=<span class="hljs-string">"flash_attention_2"</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="expected-speedups" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#expected-speedups"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Expected speedups</span></h3> <p data-svelte-h="svelte-16n4x7g">You can benefit from considerable speedups for inference, especially for inputs with long sequences. However, since FlashAttention-2 does not support computing attention scores with padding tokens, you must manually pad/unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens.</p> <p data-svelte-h="svelte-1pddegu">To overcome this, you should use FlashAttention-2 without padding tokens in the sequence during training (by packing a dataset or <a href="https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L516" rel="nofollow">concatenating sequences</a> until reaching the maximum sequence length).</p> <p data-svelte-h="svelte-19bwxdm">For a single forward pass on <a href="https://hf.co/tiiuae/falcon-7b" rel="nofollow">tiiuae/falcon-7b</a> with a sequence length of 4096 and various batch sizes without padding tokens, the expected speedup is:</p> <div style="text-align: center" data-svelte-h="svelte-u3wzwi"><img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png"></div> <p data-svelte-h="svelte-vlhl7y">For a single forward pass on <a href="https://hf.co/meta-llama/Llama-7b-hf" rel="nofollow">meta-llama/Llama-7b-hf</a> with a sequence length of 4096 and various batch sizes without padding tokens, the expected speedup is:</p> <div style="text-align: center" data-svelte-h="svelte-1yuov6e"><img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-7b-inference-large-seqlen.png"></div> <p data-svelte-h="svelte-hsv6vu">For sequences with padding tokens (generating with padding tokens), you need to unpad/pad the input sequences to correctly compute the attention scores. With a relatively small sequence length, a single forward pass creates overhead leading to a small speedup (in the example below, 30% of the input is filled with padding tokens):</p> <div style="text-align: center" data-svelte-h="svelte-cixhj1"><img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-small-seqlen-padding.png"></div> <p data-svelte-h="svelte-lalk4m">But for larger sequence lengths, you can expect even more speedup benefits:</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-ww329q">FlashAttention is more memory efficient, meaning you can train on much larger sequence lengths without running into out-of-memory issues. You can potentially reduce memory usage up to 20x for larger sequence lengths. Take a look at the <a href="https://github.com/Dao-AILab/flash-attention" rel="nofollow">flash-attention</a> repository for more details.</p></div> <div style="text-align: center" data-svelte-h="svelte-13f0ql9"><img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png"></div> <h2 class="relative group"><a id="pytorch-scaled-dot-product-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pytorch-scaled-dot-product-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PyTorch scaled dot product attention</span></h2> <p data-svelte-h="svelte-vmq8tc">PyTorch’s <a href="https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html" rel="nofollow"><code>torch.nn.functional.scaled_dot_product_attention</code></a> (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for <code>torch>=2.1.1</code> when an implementation is available. You may also set <code>attn_implementation="sdpa"</code> in <code>from_pretrained()</code> to explicitly request SDPA to be used.</p> <p data-svelte-h="svelte-17lnkhg">For now, Transformers supports SDPA inference and training for the following architectures:</p> <ul data-svelte-h="svelte-1vtzs3j"><li><a href="https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel" rel="nofollow">Albert</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel" rel="nofollow">Audio Spectrogram Transformer</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel" rel="nofollow">Bart</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel" rel="nofollow">Bert</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel" rel="nofollow">CamemBERT</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon" rel="nofollow">Chameleon</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel" rel="nofollow">CLIP</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel" rel="nofollow">Cohere</a></li> <li><a href="https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel" rel="nofollow">data2vec_audio</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel" rel="nofollow">Dbrx</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel" rel="nofollow">DeiT</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader" rel="nofollow">Dpr</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel" rel="nofollow">Falcon</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel" rel="nofollow">Gemma</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model" rel="nofollow">Gemma2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gpt2" rel="nofollow">GPT2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel" rel="nofollow">GPTBigCode</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel" rel="nofollow">GPTNeoX</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel" rel="nofollow">Hubert</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel" rel="nofollow">Idefics</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel" rel="nofollow">Granite</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel" rel="nofollow">JetMoe</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel" rel="nofollow">Jamba</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel" rel="nofollow">Llama</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/llava_onevision" rel="nofollow">LLaVA-Onevision</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel" rel="nofollow">Mistral</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel" rel="nofollow">Mixtral</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel" rel="nofollow">Musicgen</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel" rel="nofollow">MusicGen Melody</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel" rel="nofollow">OLMo</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel" rel="nofollow">OLMoE</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration" rel="nofollow">PaliGemma</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel" rel="nofollow">Phi</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model" rel="nofollow">Phi3</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel" rel="nofollow">Idefics</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel" rel="nofollow">Whisper</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel" rel="nofollow">mBart</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel" rel="nofollow">Mistral</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel" rel="nofollow">Mixtral</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel" rel="nofollow">StableLm</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model" rel="nofollow">Starcoder2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model" rel="nofollow">Qwen2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder" rel="nofollow">Qwen2Audio</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel" rel="nofollow">Qwen2MoE</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel" rel="nofollow">RoBERTa</a></li> <li><a href="https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel" rel="nofollow">Sew</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/siglip" rel="nofollow">SigLIP</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel" rel="nofollow">StableLm</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model" rel="nofollow">Starcoder2</a></li> <li><a href="https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel" rel="nofollow">UniSpeech</a></li> <li><a href="https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel" rel="nofollow">unispeech_sat</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel" rel="nofollow">RoBERTa</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel" rel="nofollow">Qwen2VL</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel" rel="nofollow">Musicgen</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel" rel="nofollow">MusicGen Melody</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/nemotron" rel="nofollow">Nemotron</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel" rel="nofollow">ViT</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel" rel="nofollow">ViTHybrid</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel" rel="nofollow">ViTMAE</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel" rel="nofollow">ViTMSN</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell" rel="nofollow">VideoMAE</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model" rel="nofollow">wav2vec2</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel" rel="nofollow">Whisper</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel" rel="nofollow">XLM-RoBERTa</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel" rel="nofollow">XLM-RoBERTa-XL</a></li> <li><a href="https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel" rel="nofollow">YOLOS</a></li></ul> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-ygpnc8">FlashAttention can only be used for models with the <code>fp16</code> or <code>bf16</code> torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle <code>fp32</code> models.</p></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-tm3gu3">SDPA does not support certain sets of attention parameters, such as <code>head_mask</code> and <code>output_attentions=True</code>. | |
| In that case, you should see a warning message and we will fall back to the (slower) eager implementation.</p></div> <p data-svelte-h="svelte-10n5yk">By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with <a href="https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel" rel="nofollow"><code>torch.backends.cuda.sdp_kernel</code></a> as a context manager:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") | |
| model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda") | |
| input_text = "Hello my dog is cute and" | |
| inputs = tokenizer(input_text, return_tensors="pt").to("cuda") | |
| <span class="hljs-addition">+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):</span> | |
| outputs = model.generate(**inputs) | |
| print(tokenizer.decode(outputs[0], skip_special_tokens=True))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1pg2hal">If you see a bug with the traceback below, try using the nightly version of PyTorch which may have broader coverage for FlashAttention:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->RuntimeError: No available kernel. Aborting execution. | |
| <span class="hljs-comment"># install PyTorch nightly</span> | |
| pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="bettertransformer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#bettertransformer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>BetterTransformer</span></h2> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1fq80om">Some BetterTransformer features are being upstreamed to Transformers with default support for native <code>torch.nn.scaled_dot_product_attention</code>. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to natively support SDPA in Transformers.</p></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1gyhh6">Check out our benchmarks with BetterTransformer and scaled dot product attention in the <a href="https://pytorch.org/blog/out-of-the-box-acceleration/" rel="nofollow">Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0</a> and learn more about the fastpath execution in the <a href="https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2" rel="nofollow">BetterTransformer</a> blog post.</p></div> <p data-svelte-h="svelte-mytrcu">BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:</p> <ol data-svelte-h="svelte-1b2ln7l"><li>fusion, which combines multiple sequential operations into a single “kernel” to reduce the number of computation steps</li> <li>skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors</li></ol> <p data-svelte-h="svelte-kdzh4e">BetterTransformer also converts all attention operations to use the more memory-efficient <a href="https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention" rel="nofollow">scaled dot product attention (SDPA)</a>, and it calls optimized kernels like <a href="https://huggingface.co/papers/2205.14135" rel="nofollow">FlashAttention</a> under the hood.</p> <p data-svelte-h="svelte-1qwvxdv">Before you start, make sure you have 🤗 Optimum <a href="https://huggingface.co/docs/optimum/installation" rel="nofollow">installed</a>.</p> <p data-svelte-h="svelte-uvzvcj">Then you can enable BetterTransformer with the <a href="/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.to_bettertransformer">PreTrainedModel.to_bettertransformer()</a> method:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = model.to_bettertransformer()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ofzbkg">You can return the original Transformers model with the <a href="/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.reverse_bettertransformer">reverse_bettertransformer()</a> method. You should use this before saving your model to use the canonical Transformers modeling:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = model.reverse_bettertransformer() | |
| model.save_pretrained(<span class="hljs-string">"saved_model"</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="bitsandbytes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#bitsandbytes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>bitsandbytes</span></h2> <p data-svelte-h="svelte-b0kgtg">bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory.</p> <p data-svelte-h="svelte-5yw5rr">Make sure you have bitsandbytes and 🤗 Accelerate installed:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># these versions support 8-bit and 4-bit</span> | |
| pip install bitsandbytes>=0.39.0 accelerate>=0.20.0 | |
| <span class="hljs-comment"># install Transformers</span> | |
| pip install transformers<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="4-bit" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#4-bit"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>4-bit</span></h3> <p data-svelte-h="svelte-85jp6k">To load a model in 4-bit for inference, use the <code>load_in_4bit</code> parameter. The <code>device_map</code> parameter is optional, but we recommend setting it to <code>"auto"</code> to allow 🤗 Accelerate to automatically and efficiently allocate the model given the available resources in the environment.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| model_name = <span class="hljs-string">"bigscience/bloom-2b5"</span> | |
| model_4bit = AutoModelForCausalLM.from_pretrained(model_name, device_map=<span class="hljs-string">"auto"</span>, load_in_4bit=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6x6hw6">To load a model in 4-bit for inference with multiple GPUs, you can control how much GPU RAM you want to allocate to each GPU. For example, to distribute 600MB of memory to the first GPU and 1GB of memory to the second GPU:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->max_memory_mapping = {<span class="hljs-number">0</span>: <span class="hljs-string">"600MB"</span>, <span class="hljs-number">1</span>: <span class="hljs-string">"1GB"</span>} | |
| model_name = <span class="hljs-string">"bigscience/bloom-3b"</span> | |
| model_4bit = AutoModelForCausalLM.from_pretrained( | |
| model_name, device_map=<span class="hljs-string">"auto"</span>, load_in_4bit=<span class="hljs-literal">True</span>, max_memory=max_memory_mapping | |
| )<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="8-bit" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#8-bit"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>8-bit</span></h3> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-elufrf">If you’re curious and interested in learning more about the concepts underlying 8-bit quantization, read the <a href="https://huggingface.co/blog/hf-bitsandbytes-integration" rel="nofollow">Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes</a> blog post.</p></div> <p data-svelte-h="svelte-jt61vs">To load a model in 8-bit for inference, use the <code>load_in_8bit</code> parameter. The <code>device_map</code> parameter is optional, but we recommend setting it to <code>"auto"</code> to allow 🤗 Accelerate to automatically and efficiently allocate the model given the available resources in the environment:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, BitsAndBytesConfig | |
| model_name = <span class="hljs-string">"bigscience/bloom-2b5"</span> | |
| model_8bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=BitsAndBytesConfig(load_in_8bit=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-9zkweh">If you’re loading a model in 8-bit for text generation, you should use the <a href="/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a> method instead of the <a href="/docs/transformers/main/en/main_classes/pipelines#transformers.Pipeline">Pipeline</a> function which is not optimized for 8-bit models and will be slower. Some sampling strategies, like nucleus sampling, are also not supported by the <a href="/docs/transformers/main/en/main_classes/pipelines#transformers.Pipeline">Pipeline</a> for 8-bit models. You should also place all inputs on the same device as the model:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| model_name = <span class="hljs-string">"bigscience/bloom-2b5"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model_8bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=BitsAndBytesConfig(load_in_8bit=<span class="hljs-literal">True</span>)) | |
| prompt = <span class="hljs-string">"Hello, my llama is cute"</span> | |
| inputs = tokenizer(prompt, return_tensors=<span class="hljs-string">"pt"</span>).to(<span class="hljs-string">"cuda"</span>) | |
| generated_ids = model.generate(**inputs) | |
| outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cgkjhe">To load a model in 4-bit for inference with multiple GPUs, you can control how much GPU RAM you want to allocate to each GPU. For example, to distribute 1GB of memory to the first GPU and 2GB of memory to the second GPU:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->max_memory_mapping = {<span class="hljs-number">0</span>: <span class="hljs-string">"1GB"</span>, <span class="hljs-number">1</span>: <span class="hljs-string">"2GB"</span>} | |
| model_name = <span class="hljs-string">"bigscience/bloom-3b"</span> | |
| model_8bit = AutoModelForCausalLM.from_pretrained( | |
| model_name, device_map=<span class="hljs-string">"auto"</span>, load_in_8bit=<span class="hljs-literal">True</span>, max_memory=max_memory_mapping | |
| )<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-w7bqaw">Feel free to try running a 11 billion parameter <a href="https://colab.research.google.com/drive/1YORPWx4okIHXnjW7MSAidXN29mPVNT7F?usp=sharing" rel="nofollow">T5 model</a> or the 3 billion parameter <a href="https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing" rel="nofollow">BLOOM model</a> for inference on Google Colab’s free tier GPUs!</p></div> <h2 class="relative group"><a id="-optimum" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#-optimum"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>🤗 Optimum</span></h2> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1wgg2ae">Learn more details about using ORT with 🤗 Optimum in the <a href="https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu#accelerated-inference-on-nvidia-gpus" rel="nofollow">Accelerated inference on NVIDIA GPUs</a> and <a href="https://huggingface.co/docs/optimum/onnxruntime/usage_guides/amdgpu#accelerated-inference-on-amd-gpus" rel="nofollow">Accelerated inference on AMD GPUs</a> guides. This section only provides a brief and simple example.</p></div> <p data-svelte-h="svelte-dwwncr">ONNX Runtime (ORT) is a model accelerator that supports accelerated inference on Nvidia GPUs, and AMD GPUs that use <a href="https://www.amd.com/en/products/software/rocm.html" rel="nofollow">ROCm</a> stack. ORT uses optimization techniques like fusing common operations into a single node and constant folding to reduce the number of computations performed and speedup inference. ORT also places the most computationally intensive operations on the GPU and the rest on the CPU to intelligently distribute the workload between the two devices.</p> <p data-svelte-h="svelte-bcmzuc">ORT is supported by 🤗 Optimum which can be used in 🤗 Transformers. You’ll need to use an <a href="https://huggingface.co/docs/optimum/main/en/onnxruntime/package_reference/modeling_ort#optimum.onnxruntime.ORTModel" rel="nofollow">ORTModel</a> for the task you’re solving, and specify the <code>provider</code> parameter which can be set to either <a href="https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu#cudaexecutionprovider" rel="nofollow"><code>CUDAExecutionProvider</code></a>, <a href="https://huggingface.co/docs/optimum/onnxruntime/usage_guides/amdgpu" rel="nofollow"><code>ROCMExecutionProvider</code></a> or <a href="https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu#tensorrtexecutionprovider" rel="nofollow"><code>TensorrtExecutionProvider</code></a>. If you want to load a model that was not yet exported to ONNX, you can set <code>export=True</code> to convert your model on-the-fly to the ONNX format:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> optimum.onnxruntime <span class="hljs-keyword">import</span> ORTModelForSequenceClassification | |
| ort_model = ORTModelForSequenceClassification.from_pretrained( | |
| <span class="hljs-string">"distilbert/distilbert-base-uncased-finetuned-sst-2-english"</span>, | |
| export=<span class="hljs-literal">True</span>, | |
| provider=<span class="hljs-string">"CUDAExecutionProvider"</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-90y3yl">Now you’re free to use the model for inference:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> optimum.pipelines <span class="hljs-keyword">import</span> pipeline | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"distilbert/distilbert-base-uncased-finetuned-sst-2-english"</span>) | |
| pipeline = pipeline(task=<span class="hljs-string">"text-classification"</span>, model=ort_model, tokenizer=tokenizer, device=<span class="hljs-string">"cuda:0"</span>) | |
| result = pipeline(<span class="hljs-string">"Both the music and visual were astounding, not to mention the actors performance."</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="combine-optimizations" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#combine-optimizations"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Combine optimizations</span></h2> <p data-svelte-h="svelte-mv8bmm">It is often possible to combine several of the optimization techniques described above to get the best inference performance possible for your model. For example, you can load a model in 4-bit, and then enable BetterTransformer with FlashAttention:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| <span class="hljs-comment"># load model in 4-bit</span> | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=<span class="hljs-literal">True</span>, | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"facebook/opt-350m"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"facebook/opt-350m"</span>, quantization_config=quantization_config) | |
| <span class="hljs-comment"># enable BetterTransformer</span> | |
| model = model.to_bettertransformer() | |
| input_text = <span class="hljs-string">"Hello my dog is cute and"</span> | |
| inputs = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(<span class="hljs-string">"cuda"</span>) | |
| <span class="hljs-comment"># enable FlashAttention</span> | |
| <span class="hljs-keyword">with</span> torch.backends.cuda.sdp_kernel(enable_flash=<span class="hljs-literal">True</span>, enable_math=<span class="hljs-literal">False</span>, enable_mem_efficient=<span class="hljs-literal">False</span>): | |
| outputs = model.generate(**inputs) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(outputs[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/perf_infer_gpu_one.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1xexzbk = { | |
| assets: "/docs/transformers/main/en", | |
| base: "/docs/transformers/main/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js"), | |
| import("/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 367], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 78.3 kB
- Xet hash:
- b6b39030d6842b1a50077a6ff418d1b3fba46bc46de94e10b050d5952f6ad445
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.