Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"XLA Integration for TensorFlow Models","local":"xla-integration-for-tensorflow-models","sections":[{"title":"Running TF functions with XLA","local":"running-tf-functions-with-xla","sections":[],"depth":2},{"title":"Running a TF text generation model with XLA from 🤗 Transformers","local":"running-a-tf-text-generation-model-with-xla-from--transformers","sections":[],"depth":2},{"title":"Gotchas to be aware of","local":"gotchas-to-be-aware-of","sections":[],"depth":2},{"title":"Additional Resources","local":"additional-resources","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/scheduler.25b97de1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/singletons.0f2b7d5f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.e188933d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/paths.3d04d2c6.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.d9030fc9.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/0.026d2fdd.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/430.0d6dbdc3.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/DocNotebookDropdown.5ea6cb78.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/globals.7f7f1b26.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/EditOnGithub.91d95064.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"XLA Integration for TensorFlow Models","local":"xla-integration-for-tensorflow-models","sections":[{"title":"Running TF functions with XLA","local":"running-tf-functions-with-xla","sections":[],"depth":2},{"title":"Running a TF text generation model with XLA from 🤗 Transformers","local":"running-a-tf-text-generation-model-with-xla-from--transformers","sections":[],"depth":2},{"title":"Gotchas to be aware of","local":"gotchas-to-be-aware-of","sections":[],"depth":2},{"title":"Additional Resources","local":"additional-resources","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="xla-integration-for-tensorflow-models" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#xla-integration-for-tensorflow-models"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>XLA Integration for TensorFlow Models</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"> </button> </div> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"> </button> </div></div> <p data-svelte-h="svelte-afit4t">Accelerated Linear Algebra, dubbed XLA, is a compiler for accelerating the runtime of TensorFlow Models. From the <a href="https://www.tensorflow.org/xla" rel="nofollow">official documentation</a>:</p> <p data-svelte-h="svelte-1rjg69l">XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that can accelerate TensorFlow models with potentially no source code changes.</p> <p data-svelte-h="svelte-8rwpb7">Using XLA in TensorFlow is simple – it comes packaged inside the <code>tensorflow</code> library, and it can be triggered with the <code>jit_compile</code> argument in any graph-creating function such as <a href="https://www.tensorflow.org/guide/intro_to_graphs" rel="nofollow"><code>tf.function</code></a>. When using Keras methods like <code>fit()</code> and <code>predict()</code>, you can enable XLA simply by passing the <code>jit_compile</code> argument to <code>model.compile()</code>. However, XLA is not limited to these methods - it can also be used to accelerate any arbitrary <code>tf.function</code>.</p> <p data-svelte-h="svelte-1sup5o3">Several TensorFlow methods in 🤗 Transformers have been rewritten to be XLA-compatible, including text generation for models such as <a href="https://huggingface.co/docs/transformers/model_doc/gpt2" rel="nofollow">GPT2</a>, <a href="https://huggingface.co/docs/transformers/model_doc/t5" rel="nofollow">T5</a> and <a href="https://huggingface.co/docs/transformers/model_doc/opt" rel="nofollow">OPT</a>, as well as speech processing for models such as <a href="https://huggingface.co/docs/transformers/model_doc/whisper" rel="nofollow">Whisper</a>.</p> <p data-svelte-h="svelte-v22hyz">While the exact amount of speed-up is very much model-dependent, for TensorFlow text generation models inside 🤗 Transformers, we noticed a speed-up of ~100x. This document will explain how you can use XLA for these models to get the maximum amount of performance. We’ll also provide links to additional resources if you’re interested to learn more about the benchmarks and our design philosophy behind the XLA integration.</p> <h2 class="relative group"><a id="running-tf-functions-with-xla" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#running-tf-functions-with-xla"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Running TF functions with XLA</span></h2> <p data-svelte-h="svelte-1n8d5xl">Let us consider the following model in TensorFlow:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| model = tf.keras.Sequential( | |
| [tf.keras.layers.Dense(<span class="hljs-number">10</span>, input_shape=(<span class="hljs-number">10</span>,), activation=<span class="hljs-string">"relu"</span>), tf.keras.layers.Dense(<span class="hljs-number">5</span>, activation=<span class="hljs-string">"softmax"</span>)] | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-137qaqi">The above model accepts inputs having a dimension of <code>(10, )</code>. We can use the model for running a forward pass like so:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Generate random inputs for the model.</span> | |
| batch_size = <span class="hljs-number">16</span> | |
| input_vector_dim = <span class="hljs-number">10</span> | |
| random_inputs = tf.random.normal((batch_size, input_vector_dim)) | |
| <span class="hljs-comment"># Run a forward pass.</span> | |
| _ = model(random_inputs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xkaplb">In order to run the forward pass with an XLA-compiled function, we’d need to do:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->xla_fn = tf.function(model, jit_compile=<span class="hljs-literal">True</span>) | |
| _ = xla_fn(random_inputs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3tnenm">The default <code>call()</code> function of the <code>model</code> is used for compiling the XLA graph. But if there’s any other model function you want to compile into XLA that’s also possible with:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->my_xla_fn = tf.function(model.my_xla_fn, jit_compile=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="running-a-tf-text-generation-model-with-xla-from--transformers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#running-a-tf-text-generation-model-with-xla-from--transformers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Running a TF text generation model with XLA from 🤗 Transformers</span></h2> <p data-svelte-h="svelte-1tqra4o">To enable XLA-accelerated generation within 🤗 Transformers, you need to have a recent version of <code>transformers</code> installed. You can install it by running:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install transformers --upgrade<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-5mrt6v">And then you can run the following code:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, TFAutoModelForCausalLM | |
| <span class="hljs-comment"># Will error if the minimal version of Transformers is not installed.</span> | |
| <span class="hljs-keyword">from</span> transformers.utils <span class="hljs-keyword">import</span> check_min_version | |
| check_min_version(<span class="hljs-string">"4.21.0"</span>) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>, padding_side=<span class="hljs-string">"left"</span>, pad_token=<span class="hljs-string">"</s>"</span>) | |
| model = TFAutoModelForCausalLM.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>) | |
| input_string = [<span class="hljs-string">"TensorFlow is"</span>] | |
| <span class="hljs-comment"># One line to create an XLA generation function</span> | |
| xla_generate = tf.function(model.generate, jit_compile=<span class="hljs-literal">True</span>) | |
| tokenized_input = tokenizer(input_string, return_tensors=<span class="hljs-string">"tf"</span>) | |
| generated_tokens = xla_generate(**tokenized_input, num_beams=<span class="hljs-number">2</span>) | |
| decoded_text = tokenizer.decode(generated_tokens[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Generated -- <span class="hljs-subst">{decoded_text}</span>"</span>) | |
| <span class="hljs-comment"># Generated -- TensorFlow is an open-source, open-source, distributed-source application # framework for the</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xz55ih">As you can notice, enabling XLA on <code>generate()</code> is just a single line of code. The rest of the code remains unchanged. However, there are a couple of gotchas in the above code snippet that are specific to XLA. You need to be aware of those to realize the speed-ups that XLA can bring in. We discuss these in the following section.</p> <h2 class="relative group"><a id="gotchas-to-be-aware-of" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gotchas-to-be-aware-of"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Gotchas to be aware of</span></h2> <p data-svelte-h="svelte-1tw7ufk">When you are executing an XLA-enabled function (like <code>xla_generate()</code> above) for the first time, it will internally try to infer the computation graph, which is time-consuming. This process is known as <a href="https://www.tensorflow.org/guide/intro_to_graphs#when_is_a_function_tracing" rel="nofollow">“tracing”</a>.</p> <p data-svelte-h="svelte-1v603pg">You might notice that the generation time is not fast. Successive calls of <code>xla_generate()</code> (or any other XLA-enabled function) won’t have to infer the computation graph, given the inputs to the function follow the same shape with which the computation graph was initially built. While this is not a problem for modalities with fixed input shapes (e.g., images), you must pay attention if you are working with variable input shape modalities (e.g., text).</p> <p data-svelte-h="svelte-1s9agvy">To ensure <code>xla_generate()</code> always operates with the same input shapes, you can specify the <code>padding</code> arguments when calling the tokenizer.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, TFAutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>, padding_side=<span class="hljs-string">"left"</span>, pad_token=<span class="hljs-string">"</s>"</span>) | |
| model = TFAutoModelForCausalLM.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>) | |
| input_string = [<span class="hljs-string">"TensorFlow is"</span>] | |
| xla_generate = tf.function(model.generate, jit_compile=<span class="hljs-literal">True</span>) | |
| <span class="hljs-comment"># Here, we call the tokenizer with padding options.</span> | |
| tokenized_input = tokenizer(input_string, pad_to_multiple_of=<span class="hljs-number">8</span>, padding=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"tf"</span>) | |
| generated_tokens = xla_generate(**tokenized_input, num_beams=<span class="hljs-number">2</span>) | |
| decoded_text = tokenizer.decode(generated_tokens[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Generated -- <span class="hljs-subst">{decoded_text}</span>"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1nbul0x">This way, you can ensure that the inputs to <code>xla_generate()</code> will always receive inputs with the shape it was traced with and thus leading to speed-ups in the generation time. You can verify this with the code below:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> time | |
| <span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, TFAutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>, padding_side=<span class="hljs-string">"left"</span>, pad_token=<span class="hljs-string">"</s>"</span>) | |
| model = TFAutoModelForCausalLM.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>) | |
| xla_generate = tf.function(model.generate, jit_compile=<span class="hljs-literal">True</span>) | |
| <span class="hljs-keyword">for</span> input_string <span class="hljs-keyword">in</span> [<span class="hljs-string">"TensorFlow is"</span>, <span class="hljs-string">"TensorFlow is a"</span>, <span class="hljs-string">"TFLite is a"</span>]: | |
| tokenized_input = tokenizer(input_string, pad_to_multiple_of=<span class="hljs-number">8</span>, padding=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"tf"</span>) | |
| start = time.time_ns() | |
| generated_tokens = xla_generate(**tokenized_input, num_beams=<span class="hljs-number">2</span>) | |
| end = time.time_ns() | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Execution time -- <span class="hljs-subst">{(end - start) / <span class="hljs-number">1e6</span>:<span class="hljs-number">.1</span>f}</span> ms\n"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zm3wt3">On a Tesla T4 GPU, you can expect the outputs like so:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Execution time -- 30819.6 ms | |
| Execution time -- 79.0 ms | |
| Execution time -- 78.9 ms<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hsy3qe">The first call to <code>xla_generate()</code> is time-consuming because of tracing, but the successive calls are orders of magnitude faster. Keep in mind that any change in the generation options at any point will trigger re-tracing and thus leading to slow-downs in the generation time.</p> <p data-svelte-h="svelte-149gqik">We didn’t cover all the text generation options 🤗 Transformers provides in this document. We encourage you to read the documentation for advanced use cases.</p> <h2 class="relative group"><a id="additional-resources" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#additional-resources"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Additional Resources</span></h2> <p data-svelte-h="svelte-k2oxa5">Here, we leave you with some additional resources if you want to delve deeper into XLA in 🤗 Transformers and in general.</p> <ul data-svelte-h="svelte-1eyvtjn"><li><a href="https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/91_tf_xla_generate.ipynb" rel="nofollow">This Colab Notebook</a> provides an interactive demonstration if you want to fiddle with the XLA-compatible encoder-decoder (like <a href="https://huggingface.co/docs/transformers/model_doc/t5" rel="nofollow">T5</a>) and decoder-only (like <a href="https://huggingface.co/docs/transformers/model_doc/gpt2" rel="nofollow">GPT2</a>) text generation models.</li> <li><a href="https://huggingface.co/blog/tf-xla-generate" rel="nofollow">This blog post</a> provides an overview of the comparison benchmarks for XLA-compatible models along with a friendly introduction to XLA in TensorFlow.</li> <li><a href="https://blog.tensorflow.org/2022/11/how-hugging-face-improved-text-generation-performance-with-xla.html" rel="nofollow">This blog post</a> discusses our design philosophy behind adding XLA support to the TensorFlow models in 🤗 Transformers.</li> <li>Recommended posts for learning more about XLA and TensorFlow graphs in general:<ul><li><a href="https://www.tensorflow.org/xla" rel="nofollow">XLA: Optimizing Compiler for Machine Learning</a></li> <li><a href="https://www.tensorflow.org/guide/intro_to_graphs" rel="nofollow">Introduction to graphs and tf.function</a></li> <li><a href="https://www.tensorflow.org/guide/function" rel="nofollow">Better performance with tf.function</a></li></ul></li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/tf_xla.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1xexzbk = { | |
| assets: "/docs/transformers/main/en", | |
| base: "/docs/transformers/main/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js"), | |
| import("/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 430], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 34.6 kB
- Xet hash:
- bf6256cc080db74134194f9ccd37c6db8c9f4ba67e9daa7f0b5c4db53e695d8b
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.