Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"用于 TensorFlow 模型的 XLA 集成","local":"用于-tensorflow-模型的-xla-集成","sections":[{"title":"使用 XLA 运行 TensorFlow 函数","local":"使用-xla-运行-tensorflow-函数","sections":[],"depth":2},{"title":"在🤗 Transformers库中使用XLA运行TensorFlow文本生成模型","local":"在-transformers库中使用xla运行tensorflow文本生成模型","sections":[],"depth":2},{"title":"需要关注的注意事项","local":"需要关注的注意事项","sections":[],"depth":2},{"title":"附加资源","local":"附加资源","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/main/zh/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/entry/start.a61b9c50.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/scheduler.9991993c.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/singletons.2822fe91.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/index.02cfeb18.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/paths.d66588b4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/entry/app.99775688.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/index.7fc9a5e7.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/nodes/0.f4c5a5c1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/nodes/59.ff730382.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/CodeBlock.e11cba92.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/DocNotebookDropdown.a0cb4c0f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/globals.7f7f1b26.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/zh/_app/immutable/chunks/EditOnGithub.84ab7f0e.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"用于 TensorFlow 模型的 XLA 集成","local":"用于-tensorflow-模型的-xla-集成","sections":[{"title":"使用 XLA 运行 TensorFlow 函数","local":"使用-xla-运行-tensorflow-函数","sections":[],"depth":2},{"title":"在🤗 Transformers库中使用XLA运行TensorFlow文本生成模型","local":"在-transformers库中使用xla运行tensorflow文本生成模型","sections":[],"depth":2},{"title":"需要关注的注意事项","local":"需要关注的注意事项","sections":[],"depth":2},{"title":"附加资源","local":"附加资源","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="用于-tensorflow-模型的-xla-集成" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#用于-tensorflow-模型的-xla-集成"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>用于 TensorFlow 模型的 XLA 集成</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"> </button> </div> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"> </button> </div></div> <p data-svelte-h="svelte-1ab31dt">加速线性代数,也称为XLA,是一个用于加速TensorFlow模型运行时间的编译器。从<a href="https://www.tensorflow.org/xla" rel="nofollow">官方文档</a>中可以看到:</p> <p data-svelte-h="svelte-1pvm1ip">XLA(加速线性代数)是一种针对线性代数的特定领域编译器,可以在可能不需要更改源代码的情况下加速TensorFlow模型。</p> <p data-svelte-h="svelte-r8rbc5">在TensorFlow中使用XLA非常简单——它包含在<code>tensorflow</code>库中,并且可以使用任何图创建函数中的<code>jit_compile</code>参数来触发,例如<a href="https://www.tensorflow.org/guide/intro_to_graphs" rel="nofollow"><code>tf.function</code></a>。在使用Keras方法如<code>fit()</code>和<code>predict()</code>时,只需将<code>jit_compile</code>参数传递给<code>model.compile()</code>即可启用XLA。然而,XLA不仅限于这些方法 - 它还可以用于加速任何任意的<code>tf.function</code>。</p> <p data-svelte-h="svelte-p7rlbq">在🤗 Transformers中,几个TensorFlow方法已经被重写为与XLA兼容,包括<a href="https://huggingface.co/docs/transformers/model_doc/gpt2" rel="nofollow">GPT2</a>、<a href="https://huggingface.co/docs/transformers/model_doc/t5" rel="nofollow">T5</a>和<a href="https://huggingface.co/docs/transformers/model_doc/opt" rel="nofollow">OPT</a>等文本生成模型,以及<a href="https://huggingface.co/docs/transformers/model_doc/whisper" rel="nofollow">Whisper</a>等语音处理模型。</p> <p data-svelte-h="svelte-1xuz7ym">虽然确切的加速倍数很大程度上取决于模型,但对于🤗 Transformers中的TensorFlow文本生成模型,我们注意到速度提高了约100倍。本文档将解释如何在这些模型上使用XLA获得最大的性能。如果您有兴趣了解更多关于基准测试和我们在XLA集成背后的设计哲学的信息,我们还将提供额外的资源链接。</p> <h2 class="relative group"><a id="使用-xla-运行-tensorflow-函数" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#使用-xla-运行-tensorflow-函数"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>使用 XLA 运行 TensorFlow 函数</span></h2> <p data-svelte-h="svelte-1vn6z9m">让我们考虑以下TensorFlow 中的模型:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| model = tf.keras.Sequential( | |
| [tf.keras.layers.Dense(<span class="hljs-number">10</span>, input_shape=(<span class="hljs-number">10</span>,), activation=<span class="hljs-string">"relu"</span>), tf.keras.layers.Dense(<span class="hljs-number">5</span>, activation=<span class="hljs-string">"softmax"</span>)] | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-edbqcr">上述模型接受维度为 <code>(10,)</code> 的输入。我们可以像下面这样使用模型进行前向传播:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Generate random inputs for the model.</span> | |
| batch_size = <span class="hljs-number">16</span> | |
| input_vector_dim = <span class="hljs-number">10</span> | |
| random_inputs = tf.random.normal((batch_size, input_vector_dim)) | |
| <span class="hljs-comment"># Run a forward pass.</span> | |
| _ = model(random_inputs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1567hs3">为了使用 XLA 编译的函数运行前向传播,我们需要执行以下操作:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->xla_fn = tf.function(model, jit_compile=<span class="hljs-literal">True</span>) | |
| _ = xla_fn(random_inputs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1d2xanc"><code>model</code>的默认<code>call()</code>函数用于编译XLA图。但如果你想将其他模型函数编译成XLA,也是可以的,如下所示:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->my_xla_fn = tf.function(model.my_xla_fn, jit_compile=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="在-transformers库中使用xla运行tensorflow文本生成模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#在-transformers库中使用xla运行tensorflow文本生成模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>在🤗 Transformers库中使用XLA运行TensorFlow文本生成模型</span></h2> <p data-svelte-h="svelte-1jrv2bf">要在🤗 Transformers中启用XLA加速生成,您需要安装最新版本的<code>transformers</code>。您可以通过运行以下命令来安装它:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install transformers --upgrade<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xbae7q">然后您可以运行以下代码:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, TFAutoModelForCausalLM | |
| <span class="hljs-comment"># Will error if the minimal version of Transformers is not installed.</span> | |
| <span class="hljs-keyword">from</span> transformers.utils <span class="hljs-keyword">import</span> check_min_version | |
| check_min_version(<span class="hljs-string">"4.21.0"</span>) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>, padding_side=<span class="hljs-string">"left"</span>, pad_token=<span class="hljs-string">"</s>"</span>) | |
| model = TFAutoModelForCausalLM.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>) | |
| input_string = [<span class="hljs-string">"TensorFlow is"</span>] | |
| <span class="hljs-comment"># One line to create an XLA generation function</span> | |
| xla_generate = tf.function(model.generate, jit_compile=<span class="hljs-literal">True</span>) | |
| tokenized_input = tokenizer(input_string, return_tensors=<span class="hljs-string">"tf"</span>) | |
| generated_tokens = xla_generate(**tokenized_input, num_beams=<span class="hljs-number">2</span>) | |
| decoded_text = tokenizer.decode(generated_tokens[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Generated -- <span class="hljs-subst">{decoded_text}</span>"</span>) | |
| <span class="hljs-comment"># Generated -- TensorFlow is an open-source, open-source, distributed-source application # framework for the</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-k1gawp">正如您所注意到的,在<code>generate()</code>上启用XLA只需要一行代码。其余部分代码保持不变。然而,上面的代码片段中有一些与XLA相关的注意事项。您需要了解这些注意事项,以充分利用XLA可能带来的性能提升。我们将在下面的部分讨论这些内容。</p> <h2 class="relative group"><a id="需要关注的注意事项" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#需要关注的注意事项"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>需要关注的注意事项</span></h2> <p data-svelte-h="svelte-1ccko4b">当您首次执行启用XLA的函数(如上面的<code>xla_generate()</code>)时,它将在内部尝试推断计算图,这是一个耗时的过程。这个过程被称为<a href="https://www.tensorflow.org/guide/intro_to_graphs#when_is_a_function_tracing" rel="nofollow">“tracing”</a>。</p> <p data-svelte-h="svelte-1eerun3">您可能会注意到生成时间并不快。连续调用<code>xla_generate()</code>(或任何其他启用了XLA的函数)不需要再次推断计算图,只要函数的输入与最初构建计算图时的形状相匹配。对于具有固定输入形状的模态(例如图像),这不是问题,但如果您正在处理具有可变输入形状的模态(例如文本),则必须注意。</p> <p data-svelte-h="svelte-nx5jow">为了确保<code>xla_generate()</code>始终使用相同的输入形状,您可以在调用<code>tokenizer</code>时指定<code>padding</code>参数。</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, TFAutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>, padding_side=<span class="hljs-string">"left"</span>, pad_token=<span class="hljs-string">"</s>"</span>) | |
| model = TFAutoModelForCausalLM.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>) | |
| input_string = [<span class="hljs-string">"TensorFlow is"</span>] | |
| xla_generate = tf.function(model.generate, jit_compile=<span class="hljs-literal">True</span>) | |
| <span class="hljs-comment"># Here, we call the tokenizer with padding options.</span> | |
| tokenized_input = tokenizer(input_string, pad_to_multiple_of=<span class="hljs-number">8</span>, padding=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"tf"</span>) | |
| generated_tokens = xla_generate(**tokenized_input, num_beams=<span class="hljs-number">2</span>) | |
| decoded_text = tokenizer.decode(generated_tokens[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Generated -- <span class="hljs-subst">{decoded_text}</span>"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-a0s1z1">通过这种方式,您可以确保<code>xla_generate()</code>的输入始终具有它跟踪的形状,从而加速生成时间。您可以使用以下代码来验证这一点:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> time | |
| <span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, TFAutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>, padding_side=<span class="hljs-string">"left"</span>, pad_token=<span class="hljs-string">"</s>"</span>) | |
| model = TFAutoModelForCausalLM.from_pretrained(<span class="hljs-string">"openai-community/gpt2"</span>) | |
| xla_generate = tf.function(model.generate, jit_compile=<span class="hljs-literal">True</span>) | |
| <span class="hljs-keyword">for</span> input_string <span class="hljs-keyword">in</span> [<span class="hljs-string">"TensorFlow is"</span>, <span class="hljs-string">"TensorFlow is a"</span>, <span class="hljs-string">"TFLite is a"</span>]: | |
| tokenized_input = tokenizer(input_string, pad_to_multiple_of=<span class="hljs-number">8</span>, padding=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"tf"</span>) | |
| start = time.time_ns() | |
| generated_tokens = xla_generate(**tokenized_input, num_beams=<span class="hljs-number">2</span>) | |
| end = time.time_ns() | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Execution time -- <span class="hljs-subst">{(end - start) / <span class="hljs-number">1e6</span>:<span class="hljs-number">.1</span>f}</span> ms\n"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15o96gl">在Tesla T4 GPU上,您可以期望如下的输出:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Execution time -- 30819.6 ms | |
| Execution time -- 79.0 ms | |
| Execution time -- 78.9 ms<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ctf36t">第一次调用<code>xla_generate()</code>会因为<code>tracing</code>而耗时,但后续的调用会快得多。请注意,任何时候对生成选项的更改都会触发重新<code>tracing</code>,从而导致生成时间减慢。</p> <p data-svelte-h="svelte-1qb1iwp">在本文档中,我们没有涵盖🤗 Transformers提供的所有文本生成选项。我们鼓励您阅读文档以了解高级用例。</p> <h2 class="relative group"><a id="附加资源" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#附加资源"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>附加资源</span></h2> <p data-svelte-h="svelte-nhayyz">以下是一些附加资源,如果您想深入了解在🤗 Transformers和其他库下使用XLA:</p> <ul data-svelte-h="svelte-11ijxsh"><li><p><a href="https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/91_tf_xla_generate.ipynb" rel="nofollow">这个Colab Notebook</a> 提供了一个互动演示,让您可以尝试使用XLA兼容的编码器-解码器(例如<a href="https://huggingface.co/docs/transformers/model_doc/t5" rel="nofollow">T5</a>)和仅解码器(例如<a href="https://huggingface.co/docs/transformers/model_doc/gpt2" rel="nofollow">GPT2</a>)文本生成模型。</p></li> <li><p><a href="https://huggingface.co/blog/tf-xla-generate" rel="nofollow">这篇博客文章</a> 提供了XLA兼容模型的比较基准概述,以及关于在TensorFlow中使用XLA的友好介绍。</p></li> <li><p><a href="https://blog.tensorflow.org/2022/11/how-hugging-face-improved-text-generation-performance-with-xla.html" rel="nofollow">这篇博客文章</a> 讨论了我们在🤗 Transformers中为TensorFlow模型添加XLA支持的设计理念。</p></li> <li><p>推荐用于更多学习XLA和TensorFlow图的资源:</p> <ul><li><a href="https://www.tensorflow.org/xla" rel="nofollow">XLA:面向机器学习的优化编译器</a></li> <li><a href="https://www.tensorflow.org/guide/intro_to_graphs" rel="nofollow">图和tf.function简介</a></li> <li><a href="https://www.tensorflow.org/guide/function" rel="nofollow">使用tf.function获得更好的性能</a></li></ul></li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/zh/tf_xla.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_173uja2 = { | |
| assets: "/docs/transformers/main/zh", | |
| base: "/docs/transformers/main/zh", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/main/zh/_app/immutable/entry/start.a61b9c50.js"), | |
| import("/docs/transformers/main/zh/_app/immutable/entry/app.99775688.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 59], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 34.1 kB
- Xet hash:
- 75356ab2d731ed6a25ba036c15e9c82120e512379f03252f9678997a5cebbea0
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.