Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"torchao","local":"torchao","sections":[{"title":"Available Quantization Schemes","local":"available-quantization-schemes","sections":[],"depth":2},{"title":"Migration Guide","local":"migration-guide","sections":[],"depth":2},{"title":"Serialization","local":"serialization","sections":[],"depth":2},{"title":"Resources","local":"resources","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_36839/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/scheduler.01eeda35.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/singletons.177df05e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.4862150a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/paths.517376d1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.6dd51b66.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/0.8897c14d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/447.3395dd78.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/Tip.de9bae2b.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/CodeBlock.864da1b0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/EditOnGithub.7faefd25.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/HfOption.f7f04550.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/stores.318eade7.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"torchao","local":"torchao","sections":[{"title":"Available Quantization Schemes","local":"available-quantization-schemes","sections":[],"depth":2},{"title":"Migration Guide","local":"migration-guide","sections":[],"depth":2},{"title":"Serialization","local":"serialization","sections":[],"depth":2},{"title":"Resources","local":"resources","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="torchao" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchao"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torchao</span></h1> <p data-svelte-h="svelte-hptymg"><a href="https://github.com/pytorch/ao" rel="nofollow">torchao</a> is a PyTorch architecture optimization library with support for custom high performance data types, quantization, and sparsity. It is composable with native PyTorch features such as <a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html" rel="nofollow">torch.compile</a> for even faster inference and training.</p> <p data-svelte-h="svelte-q2xfio">Install torchao with the following command.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation</span> | |
| pip install --upgrade torch torchao transformers<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ka4w7">torchao supports many quantization types for different data types (int4, float8, weight only, etc.). | |
| Starting with version 0.10.0, torchao provides enhanced flexibility through the <code>AOBaseConfig</code> API, allowing for more customized quantization configurations. | |
| And full access to the techniques offered in the torchao library.</p> <p data-svelte-h="svelte-hxxsaz">You can manually choose the quantization types and settings or automatically select the quantization types.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">manual </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">automatic </div></div> <div class="language-select"><p data-svelte-h="svelte-yenbf5">Create a <a href="/docs/transformers/pr_36839/en/main_classes/quantization#transformers.TorchAoConfig">TorchAoConfig</a> and specify the quantization type and <code>group_size</code> of the weights to quantize. Set the <code>cache_implementation</code> to <code>"static"</code> to automatically <a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html" rel="nofollow">torch.compile</a> the forward method.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-b8nyn5">Run the quantized model on a CPU by changing <code>device_map</code> to <code>"cpu"</code> and <code>layout</code> to <code>Int4CPULayout()</code>. This is only available in torchao 0.8.0+.</p></div> <p data-svelte-h="svelte-141idma">In torchao 0.10.0+, you can use the more flexible <code>AOBaseConfig</code> approach instead of string identifiers:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig | |
| <span class="hljs-comment"># Using AOBaseConfig instance (torchao >= 0.10.0)</span> | |
| quant_config = Int4WeightOnlyConfig(group_size=<span class="hljs-number">128</span>) | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| <span class="hljs-comment"># Load and quantize the model</span> | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Meta-Llama-3-8B"</span>, | |
| torch_dtype=<span class="hljs-string">"auto"</span>, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Meta-Llama-3-8B"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(<span class="hljs-string">"cuda"</span>) | |
| <span class="hljs-comment"># auto-compile the quantized model with `cache_implementation="static"` to get speed up</span> | |
| output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="available-quantization-schemes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#available-quantization-schemes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Available Quantization Schemes</span></h2> <p data-svelte-h="svelte-vq290e">TorchAO provides a variety of quantization configurations:</p> <ul data-svelte-h="svelte-1uraxyr"><li><code>Int4WeightOnlyConfig</code></li> <li><code>Int8WeightOnlyConfig</code></li> <li><code>Int8DynamicActivationInt8WeightConfig</code></li> <li><code>Float8WeightOnlyConfig</code></li></ul> <p data-svelte-h="svelte-fedjfn">Each configuration can be further customized with parameters such as <code>group_size</code>, <code>scheme</code>, and <code>layout</code> to optimize for specific hardware and model architectures.</p> <p data-svelte-h="svelte-15ibe5k">For a complete list of available configurations, see our <a href="https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py" rel="nofollow">quantization API documentation</a>.</p> <blockquote><p data-svelte-h="svelte-11pv5zk"><strong>⚠️ DEPRECATION WARNING</strong></p> <p data-svelte-h="svelte-jmxllg">Starting with version 0.10.0, the string-based API for quantization configuration (e.g., <code>TorchAoConfig("int4_weight_only", group_size=128)</code>) is <strong>deprecated</strong> and will be removed in a future release.</p> <p data-svelte-h="svelte-1qak5au">Please use the new <code>AOBaseConfig</code>-based approach instead:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Old way (deprecated)</span> | |
| quantization_config = TorchAoConfig(<span class="hljs-string">"int4_weight_only"</span>, group_size=<span class="hljs-number">128</span>) | |
| <span class="hljs-comment"># New way (recommended)</span> | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig | |
| quant_config = Int4WeightOnlyConfig(group_size=<span class="hljs-number">128</span>) | |
| quantization_config = TorchAoConfig(quant_type=quant_config)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8tiw44">The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.</p> <h2 class="relative group"><a id="migration-guide" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#migration-guide"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Migration Guide</span></h2> <p data-svelte-h="svelte-1kyext4">Here’s how to migrate from common string identifiers to their <code>AOBaseConfig</code> equivalents:</p> <table data-svelte-h="svelte-1w9xui3"><thead><tr><th>Old String API</th> <th>New <code>AOBaseConfig</code> API</th></tr></thead> <tbody><tr><td><code>"int4_weight_only"</code></td> <td><code>Int4WeightOnlyConfig()</code></td></tr> <tr><td><code>"int8_weight_only"</code></td> <td><code>Int8WeightOnlyConfig()</code></td></tr> <tr><td><code>"int8_dynamic_activation_int8_weight"</code></td> <td><code>Int8DynamicActivationInt8WeightConfig()</code></td></tr></tbody></table> <p data-svelte-h="svelte-9z9ctj">All configuration objects accept parameters for customization (e.g., <code>group_size</code>, <code>scheme</code>, <code>layout</code>).</p></blockquote> <p data-svelte-h="svelte-1gt4wva">Below is the API for for torchao < <code>0.9.0</code></p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| quantization_config = TorchAoConfig(<span class="hljs-string">"int4_weight_only"</span>, group_size=<span class="hljs-number">128</span>) | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Meta-Llama-3-8B"</span>, | |
| torch_dtype=<span class="hljs-string">"auto"</span>, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Meta-Llama-3-8B"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(<span class="hljs-string">"cuda"</span>) | |
| <span class="hljs-comment"># auto-compile the quantized model with `cache_implementation="static"` to get speed up</span> | |
| output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-thutbg">Run the code below to benchmark the quantized models performance.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch._inductor.utils <span class="hljs-keyword">import</span> do_bench_using_profiling | |
| <span class="hljs-keyword">from</span> typing <span class="hljs-keyword">import</span> <span class="hljs-type">Callable</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">benchmark_fn</span>(<span class="hljs-params">func: <span class="hljs-type">Callable</span>, *args, **kwargs</span>) -> <span class="hljs-built_in">float</span>: | |
| <span class="hljs-string">"""Thin wrapper around do_bench_using_profiling"""</span> | |
| no_args = <span class="hljs-keyword">lambda</span>: func(*args, **kwargs) | |
| time = do_bench_using_profiling(no_args) | |
| <span class="hljs-keyword">return</span> time * <span class="hljs-number">1e3</span> | |
| MAX_NEW_TOKENS = <span class="hljs-number">1000</span> | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"int4wo-128 model:"</span>, benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation=<span class="hljs-string">"static"</span>)) | |
| bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=<span class="hljs-string">"auto"</span>, torch_dtype=torch.bfloat16) | |
| output = bf16_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) <span class="hljs-comment"># auto-compile</span> | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"bf16 model:"</span>, benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation=<span class="hljs-string">"static"</span>))<!-- HTML_TAG_END --></pre></div> </div> <h2 class="relative group"><a id="serialization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#serialization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Serialization</span></h2> <p data-svelte-h="svelte-1g3oift">torchao implements <a href="https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor" rel="nofollow">torch.Tensor subclasses</a> for maximum flexibility in supporting new quantized torch.Tensor formats. <a href="https://huggingface.co/docs/safetensors/en/index" rel="nofollow">Safetensors</a> serialization and deserialization does not work with torchao.</p> <p data-svelte-h="svelte-5ma9bd">To avoid arbitrary user code execution, torchao sets <code>weights_only=True</code> in <a href="https://pytorch.org/docs/stable/generated/torch.load.html" rel="nofollow">torch.load</a> to ensure only tensors are loaded. Any known user functions can be whitelisted with <a href="https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals" rel="nofollow">add_safe_globals</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># don't serialize model with Safetensors</span> | |
| output_dir = <span class="hljs-string">"llama3-8b-int4wo-128"</span> | |
| quantized_model.save_pretrained(<span class="hljs-string">"llama3-8b-int4wo-128"</span>, safe_serialization=<span class="hljs-literal">False</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="resources" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#resources"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Resources</span></h2> <p data-svelte-h="svelte-1b589wj">For a better sense of expected performance, view the <a href="https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks" rel="nofollow">benchmarks</a> for various models with CUDA and XPU backends.</p> <p data-svelte-h="svelte-fj0t1q">Refer to <a href="https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques" rel="nofollow">Other Available Quantization Techniques</a> for more examples and documentation.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/quantization/torchao.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1bm5psi = { | |
| assets: "/docs/transformers/pr_36839/en", | |
| base: "/docs/transformers/pr_36839/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js"), | |
| import("/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 447], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 29.5 kB
- Xet hash:
- 4753ea45ef9f44c16e4c9b6e6c840494c8de0b2867db65d06aece277aafc4c22
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.