Buckets:

rtrm's picture
download
raw
29.5 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;torchao&quot;,&quot;local&quot;:&quot;torchao&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Available Quantization Schemes&quot;,&quot;local&quot;:&quot;available-quantization-schemes&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Migration Guide&quot;,&quot;local&quot;:&quot;migration-guide&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Serialization&quot;,&quot;local&quot;:&quot;serialization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Resources&quot;,&quot;local&quot;:&quot;resources&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/pr_36839/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/scheduler.01eeda35.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/singletons.177df05e.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.4862150a.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/paths.517376d1.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.6dd51b66.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/0.8897c14d.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/447.3395dd78.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/Tip.de9bae2b.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/CodeBlock.864da1b0.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/EditOnGithub.7faefd25.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/HfOption.f7f04550.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/stores.318eade7.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;torchao&quot;,&quot;local&quot;:&quot;torchao&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Available Quantization Schemes&quot;,&quot;local&quot;:&quot;available-quantization-schemes&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Migration Guide&quot;,&quot;local&quot;:&quot;migration-guide&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Serialization&quot;,&quot;local&quot;:&quot;serialization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Resources&quot;,&quot;local&quot;:&quot;resources&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="torchao" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchao"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torchao</span></h1> <p data-svelte-h="svelte-hptymg"><a href="https://github.com/pytorch/ao" rel="nofollow">torchao</a> is a PyTorch architecture optimization library with support for custom high performance data types, quantization, and sparsity. It is composable with native PyTorch features such as <a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html" rel="nofollow">torch.compile</a> for even faster inference and training.</p> <p data-svelte-h="svelte-q2xfio">Install torchao with the following command.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation</span>
pip install --upgrade torch torchao transformers<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ka4w7">torchao supports many quantization types for different data types (int4, float8, weight only, etc.).
Starting with version 0.10.0, torchao provides enhanced flexibility through the <code>AOBaseConfig</code> API, allowing for more customized quantization configurations.
And full access to the techniques offered in the torchao library.</p> <p data-svelte-h="svelte-hxxsaz">You can manually choose the quantization types and settings or automatically select the quantization types.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">manual </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">automatic </div></div> <div class="language-select"><p data-svelte-h="svelte-yenbf5">Create a <a href="/docs/transformers/pr_36839/en/main_classes/quantization#transformers.TorchAoConfig">TorchAoConfig</a> and specify the quantization type and <code>group_size</code> of the weights to quantize. Set the <code>cache_implementation</code> to <code>&quot;static&quot;</code> to automatically <a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html" rel="nofollow">torch.compile</a> the forward method.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-b8nyn5">Run the quantized model on a CPU by changing <code>device_map</code> to <code>&quot;cpu&quot;</code> and <code>layout</code> to <code>Int4CPULayout()</code>. This is only available in torchao 0.8.0+.</p></div> <p data-svelte-h="svelte-141idma">In torchao 0.10.0+, you can use the more flexible <code>AOBaseConfig</code> approach instead of string identifiers:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
<span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig
<span class="hljs-comment"># Using AOBaseConfig instance (torchao &gt;= 0.10.0)</span>
quant_config = Int4WeightOnlyConfig(group_size=<span class="hljs-number">128</span>)
quantization_config = TorchAoConfig(quant_type=quant_config)
<span class="hljs-comment"># Load and quantize the model</span>
quantized_model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;meta-llama/Meta-Llama-3-8B&quot;</span>,
torch_dtype=<span class="hljs-string">&quot;auto&quot;</span>,
device_map=<span class="hljs-string">&quot;auto&quot;</span>,
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;meta-llama/Meta-Llama-3-8B&quot;</span>)
input_text = <span class="hljs-string">&quot;What are we having for dinner?&quot;</span>
input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(<span class="hljs-string">&quot;cuda&quot;</span>)
<span class="hljs-comment"># auto-compile the quantized model with `cache_implementation=&quot;static&quot;` to get speed up</span>
output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">&quot;static&quot;</span>)
<span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="available-quantization-schemes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#available-quantization-schemes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Available Quantization Schemes</span></h2> <p data-svelte-h="svelte-vq290e">TorchAO provides a variety of quantization configurations:</p> <ul data-svelte-h="svelte-1uraxyr"><li><code>Int4WeightOnlyConfig</code></li> <li><code>Int8WeightOnlyConfig</code></li> <li><code>Int8DynamicActivationInt8WeightConfig</code></li> <li><code>Float8WeightOnlyConfig</code></li></ul> <p data-svelte-h="svelte-fedjfn">Each configuration can be further customized with parameters such as <code>group_size</code>, <code>scheme</code>, and <code>layout</code> to optimize for specific hardware and model architectures.</p> <p data-svelte-h="svelte-15ibe5k">For a complete list of available configurations, see our <a href="https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py" rel="nofollow">quantization API documentation</a>.</p> <blockquote><p data-svelte-h="svelte-11pv5zk"><strong>⚠️ DEPRECATION WARNING</strong></p> <p data-svelte-h="svelte-jmxllg">Starting with version 0.10.0, the string-based API for quantization configuration (e.g., <code>TorchAoConfig(&quot;int4_weight_only&quot;, group_size=128)</code>) is <strong>deprecated</strong> and will be removed in a future release.</p> <p data-svelte-h="svelte-1qak5au">Please use the new <code>AOBaseConfig</code>-based approach instead:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Old way (deprecated)</span>
quantization_config = TorchAoConfig(<span class="hljs-string">&quot;int4_weight_only&quot;</span>, group_size=<span class="hljs-number">128</span>)
<span class="hljs-comment"># New way (recommended)</span>
<span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig
quant_config = Int4WeightOnlyConfig(group_size=<span class="hljs-number">128</span>)
quantization_config = TorchAoConfig(quant_type=quant_config)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8tiw44">The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.</p> <h2 class="relative group"><a id="migration-guide" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#migration-guide"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Migration Guide</span></h2> <p data-svelte-h="svelte-1kyext4">Here’s how to migrate from common string identifiers to their <code>AOBaseConfig</code> equivalents:</p> <table data-svelte-h="svelte-1w9xui3"><thead><tr><th>Old String API</th> <th>New <code>AOBaseConfig</code> API</th></tr></thead> <tbody><tr><td><code>&quot;int4_weight_only&quot;</code></td> <td><code>Int4WeightOnlyConfig()</code></td></tr> <tr><td><code>&quot;int8_weight_only&quot;</code></td> <td><code>Int8WeightOnlyConfig()</code></td></tr> <tr><td><code>&quot;int8_dynamic_activation_int8_weight&quot;</code></td> <td><code>Int8DynamicActivationInt8WeightConfig()</code></td></tr></tbody></table> <p data-svelte-h="svelte-9z9ctj">All configuration objects accept parameters for customization (e.g., <code>group_size</code>, <code>scheme</code>, <code>layout</code>).</p></blockquote> <p data-svelte-h="svelte-1gt4wva">Below is the API for for torchao &lt; <code>0.9.0</code></p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
quantization_config = TorchAoConfig(<span class="hljs-string">&quot;int4_weight_only&quot;</span>, group_size=<span class="hljs-number">128</span>)
quantized_model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;meta-llama/Meta-Llama-3-8B&quot;</span>,
torch_dtype=<span class="hljs-string">&quot;auto&quot;</span>,
device_map=<span class="hljs-string">&quot;auto&quot;</span>,
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;meta-llama/Meta-Llama-3-8B&quot;</span>)
input_text = <span class="hljs-string">&quot;What are we having for dinner?&quot;</span>
input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(<span class="hljs-string">&quot;cuda&quot;</span>)
<span class="hljs-comment"># auto-compile the quantized model with `cache_implementation=&quot;static&quot;` to get speed up</span>
output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">&quot;static&quot;</span>)
<span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-thutbg">Run the code below to benchmark the quantized models performance.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch._inductor.utils <span class="hljs-keyword">import</span> do_bench_using_profiling
<span class="hljs-keyword">from</span> typing <span class="hljs-keyword">import</span> <span class="hljs-type">Callable</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">benchmark_fn</span>(<span class="hljs-params">func: <span class="hljs-type">Callable</span>, *args, **kwargs</span>) -&gt; <span class="hljs-built_in">float</span>:
<span class="hljs-string">&quot;&quot;&quot;Thin wrapper around do_bench_using_profiling&quot;&quot;&quot;</span>
no_args = <span class="hljs-keyword">lambda</span>: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
<span class="hljs-keyword">return</span> time * <span class="hljs-number">1e3</span>
MAX_NEW_TOKENS = <span class="hljs-number">1000</span>
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;int4wo-128 model:&quot;</span>, benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation=<span class="hljs-string">&quot;static&quot;</span>))
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=<span class="hljs-string">&quot;auto&quot;</span>, torch_dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">&quot;static&quot;</span>) <span class="hljs-comment"># auto-compile</span>
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;bf16 model:&quot;</span>, benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation=<span class="hljs-string">&quot;static&quot;</span>))<!-- HTML_TAG_END --></pre></div> </div> <h2 class="relative group"><a id="serialization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#serialization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Serialization</span></h2> <p data-svelte-h="svelte-1g3oift">torchao implements <a href="https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor" rel="nofollow">torch.Tensor subclasses</a> for maximum flexibility in supporting new quantized torch.Tensor formats. <a href="https://huggingface.co/docs/safetensors/en/index" rel="nofollow">Safetensors</a> serialization and deserialization does not work with torchao.</p> <p data-svelte-h="svelte-5ma9bd">To avoid arbitrary user code execution, torchao sets <code>weights_only=True</code> in <a href="https://pytorch.org/docs/stable/generated/torch.load.html" rel="nofollow">torch.load</a> to ensure only tensors are loaded. Any known user functions can be whitelisted with <a href="https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals" rel="nofollow">add_safe_globals</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># don&#x27;t serialize model with Safetensors</span>
output_dir = <span class="hljs-string">&quot;llama3-8b-int4wo-128&quot;</span>
quantized_model.save_pretrained(<span class="hljs-string">&quot;llama3-8b-int4wo-128&quot;</span>, safe_serialization=<span class="hljs-literal">False</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="resources" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#resources"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Resources</span></h2> <p data-svelte-h="svelte-1b589wj">For a better sense of expected performance, view the <a href="https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks" rel="nofollow">benchmarks</a> for various models with CUDA and XPU backends.</p> <p data-svelte-h="svelte-fj0t1q">Refer to <a href="https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques" rel="nofollow">Other Available Quantization Techniques</a> for more examples and documentation.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/quantization/torchao.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1bm5psi = {
assets: "/docs/transformers/pr_36839/en",
base: "/docs/transformers/pr_36839/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js"),
import("/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 447],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
29.5 kB
·
Xet hash:
4753ea45ef9f44c16e4c9b6e6c840494c8de0b2867db65d06aece277aafc4c22

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.