Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"torchao","local":"torchao","sections":[{"title":"Quantization examples","local":"quantization-examples","sections":[{"title":"H100 GPU","local":"h100-gpu","sections":[],"depth":3},{"title":"A100 GPU","local":"a100-gpu","sections":[],"depth":3},{"title":"Intel XPU","local":"intel-xpu","sections":[],"depth":3},{"title":"CPU","local":"cpu","sections":[],"depth":3},{"title":"Per Module Quantization","local":"per-module-quantization","sections":[{"title":"1. Skip quantization for certain layers","local":"1-skip-quantization-for-certain-layers","sections":[],"depth":4},{"title":"2. Quantizing different layers with different quantization configs (no regex)","local":"2-quantizing-different-layers-with-different-quantization-configs-no-regex","sections":[],"depth":4},{"title":"3. Quantizing different layers with different quantization configs (with regex)","local":"3-quantizing-different-layers-with-different-quantization-configs-with-regex","sections":[],"depth":4}],"depth":3},{"title":"Autoquant","local":"autoquant","sections":[],"depth":3}],"depth":2},{"title":"Serialization","local":"serialization","sections":[],"depth":2},{"title":"Loading quantized models","local":"loading-quantized-models","sections":[],"depth":2},{"title":"⚠️ Deprecation Notice","local":"-deprecation-notice","sections":[],"depth":2},{"title":"Resources","local":"resources","sections":[],"depth":2},{"title":"Issues","local":"issues","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_33892/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/scheduler.31fdf58d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/singletons.9860629f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.252883d5.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/paths.e85c0ec8.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/preload-helper.40847a0e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.2f76fdf0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/0.ca4aafa4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/543.c362dfbb.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CopyLLMTxtMenu.ff482081.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.71f274cc.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/IconCopy.ac192424.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CodeBlock.ab12f8e1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/HfOption.fb051768.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"torchao","local":"torchao","sections":[{"title":"Quantization examples","local":"quantization-examples","sections":[{"title":"H100 GPU","local":"h100-gpu","sections":[],"depth":3},{"title":"A100 GPU","local":"a100-gpu","sections":[],"depth":3},{"title":"Intel XPU","local":"intel-xpu","sections":[],"depth":3},{"title":"CPU","local":"cpu","sections":[],"depth":3},{"title":"Per Module Quantization","local":"per-module-quantization","sections":[{"title":"1. Skip quantization for certain layers","local":"1-skip-quantization-for-certain-layers","sections":[],"depth":4},{"title":"2. Quantizing different layers with different quantization configs (no regex)","local":"2-quantizing-different-layers-with-different-quantization-configs-no-regex","sections":[],"depth":4},{"title":"3. Quantizing different layers with different quantization configs (with regex)","local":"3-quantizing-different-layers-with-different-quantization-configs-with-regex","sections":[],"depth":4}],"depth":3},{"title":"Autoquant","local":"autoquant","sections":[],"depth":3}],"depth":2},{"title":"Serialization","local":"serialization","sections":[],"depth":2},{"title":"Loading quantized models","local":"loading-quantized-models","sections":[],"depth":2},{"title":"⚠️ Deprecation Notice","local":"-deprecation-notice","sections":[],"depth":2},{"title":"Resources","local":"resources","sections":[],"depth":2},{"title":"Issues","local":"issues","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="torchao" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchao"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torchao</span></h1> <p data-svelte-h="svelte-8yu6d7"><a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/quantization/torchao.ipynb" rel="nofollow"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab: Torchao Demo"></a></p> <p data-svelte-h="svelte-hptymg"><a href="https://github.com/pytorch/ao" rel="nofollow">torchao</a> is a PyTorch architecture optimization library with support for custom high performance data types, quantization, and sparsity. It is composable with native PyTorch features such as <a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html" rel="nofollow">torch.compile</a> for even faster inference and training.</p> <p data-svelte-h="svelte-k1mb66">See the table below for additional torchao features.</p> <table data-svelte-h="svelte-1x8g49v"><thead><tr><th>Feature</th> <th>Description</th></tr></thead> <tbody><tr><td><strong>Quantization Aware Training (QAT)</strong></td> <td>Train quantized models with minimal accuracy loss (see <a href="https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md" rel="nofollow">QAT README</a>)</td></tr> <tr><td><strong>Float8 Training</strong></td> <td>High-throughput training with float8 formats (see <a href="https://github.com/pytorch/torchtitan/blob/main/docs/float8.md" rel="nofollow">torchtitan</a> and <a href="https://huggingface.co/docs/accelerate/usage_guides/low_precision_training#configuring-torchao" rel="nofollow">Accelerate</a> docs)</td></tr> <tr><td><strong>Sparsity Support</strong></td> <td>Semi-structured (2:4) sparsity for faster inference (see <a href="https://pytorch.org/blog/accelerating-neural-network-training/" rel="nofollow">Accelerating Neural Network Training with Semi-Structured (2:4) Sparsity</a> blog post)</td></tr> <tr><td><strong>Optimizer Quantization</strong></td> <td>Reduce optimizer state memory with 4 and 8-bit variants of Adam</td></tr> <tr><td><strong>KV Cache Quantization</strong></td> <td>Enables long context inference with lower memory (see <a href="https://github.com/pytorch/ao/blob/main/torchao/_models/llama/README.md" rel="nofollow">KV Cache Quantization</a>)</td></tr> <tr><td><strong>Custom Kernels Support</strong></td> <td>use your own <code>torch.compile</code> compatible ops</td></tr> <tr><td><strong>FSDP2</strong></td> <td>Composable with FSDP2 for training</td></tr></tbody></table> <blockquote class="tip" data-svelte-h="svelte-esb3t5"><p>Refer to the torchao <a href="https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization" rel="nofollow">README.md</a> for more details about the library.</p></blockquote> <p data-svelte-h="svelte-bctle4">torchao supports the <a href="https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md" rel="nofollow">quantization techniques</a> below.</p> <ul data-svelte-h="svelte-1c7onxq"><li>A16W8 Float8 Dynamic Quantization</li> <li>A16W8 Float8 WeightOnly Quantization</li> <li>A8W8 Int8 Dynamic Quantization</li> <li>A16W8 Int8 Weight Only Quantization</li> <li>A16W4 Int4 Weight Only Quantization</li> <li>A16W4 Int4 Weight Only Quantization + 2:4 Sparsity</li> <li>Autoquantization</li></ul> <p data-svelte-h="svelte-1u913mc">torchao also supports module level configuration by specifying a dictionary from fully qualified name of module and its corresponding quantization config. This allows skip quantizing certain layers and using different quantization config for different modules.</p> <p data-svelte-h="svelte-1o9cwfq">Check the table below to see if your hardware is compatible.</p> <table data-svelte-h="svelte-k3clyq"><thead><tr><th>Component</th> <th>Compatibility</th></tr></thead> <tbody><tr><td>CUDA Versions</td> <td>✅ cu118, cu126, cu128</td></tr> <tr><td>XPU Versions</td> <td>✅ pytorch2.8</td></tr> <tr><td>CPU</td> <td>✅ change <code>device_map="cpu"</code> (see examples below)</td></tr></tbody></table> <p data-svelte-h="svelte-1wnb12y">Install torchao from PyPi or the PyTorch index with the following commands.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">PyPi </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">PyTorch Index </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation</span> | |
| <span class="hljs-comment"># Stable release from Pypi which will default to CUDA 12.6</span> | |
| pip install --upgrade torchao transformers<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-1v1t5ji">If your torchao version is below 0.10.0, you need to upgrade it, please refer to the <a href="#deprecation-notice">deprecation notice</a> for more details.</p> <h2 class="relative group"><a id="quantization-examples" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quantization-examples"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quantization examples</span></h2> <p data-svelte-h="svelte-mfxtr3">TorchAO provides a variety of quantization configurations. Each configuration can be further customized with parameters such as <code>group_size</code>, <code>scheme</code>, and <code>layout</code> to optimize for specific hardware and model architectures.</p> <p data-svelte-h="svelte-1omxb57">For a complete list of available configurations, see the <a href="https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py" rel="nofollow">quantization API documentation</a>.</p> <p data-svelte-h="svelte-hxxsaz">You can manually choose the quantization types and settings or automatically select the quantization types.</p> <p data-svelte-h="svelte-vrhqds">Create a <a href="/docs/transformers/pr_33892/en/main_classes/quantization#transformers.TorchAoConfig">TorchAoConfig</a> and specify the quantization type and <code>group_size</code> of the weights to quantize (for int8 weight only and int4 weight only). Set the <code>cache_implementation</code> to <code>"static"</code> to automatically <a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html" rel="nofollow">torch.compile</a> the forward method.</p> <p data-svelte-h="svelte-yb6i3u">We’ll show examples for recommended quantization methods based on hardwares, e.g. A100 GPU, H100 GPU, CPU.</p> <h3 class="relative group"><a id="h100-gpu" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#h100-gpu"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>H100 GPU</span></h3> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">float8-dynamic-and-weight-only </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">int4-weight-only </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig | |
| quant_config = Float8DynamicActivationFloat8WeightConfig() | |
| <span class="hljs-comment"># or float8 weight only quantization</span> | |
| <span class="hljs-comment"># quant_config = Float8WeightOnlyConfig()</span> | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| <span class="hljs-comment"># Load and quantize the model</span> | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>, | |
| dtype=<span class="hljs-string">"auto"</span>, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| <span class="hljs-comment"># auto-compile the quantized model with `cache_implementation="static"` to get speed up</span> | |
| output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> </div> | |
| </hfoption> | |
| <hfoption id="int4-weight-only-24sparse"> | |
| <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig | |
| <span class="hljs-keyword">from</span> torchao.dtypes <span class="hljs-keyword">import</span> MarlinSparseLayout | |
| quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout()) | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| <span class="hljs-comment"># Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuracy loss</span> | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"RedHatAI/Sparse-Llama-3.1-8B-2of4"</span>, | |
| dtype=torch.float16, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"RedHatAI/Sparse-Llama-3.1-8B-2of4"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| <span class="hljs-comment"># auto-compile the quantized model with `cache_implementation="static"` to get speed up</span> | |
| output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> | |
| </hfoption> | |
| </hfoptions> | |
| <h3 class="relative group"><a id="a100-gpu" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#a100-gpu"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>A100 GPU</span></h3> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">int8-dynamic-and-weight-only </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">int4-weight-only </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig | |
| quant_config = Int8DynamicActivationInt8WeightConfig() | |
| <span class="hljs-comment"># or int8 weight only quantization</span> | |
| <span class="hljs-comment"># quant_config = Int8WeightOnlyConfig()</span> | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| <span class="hljs-comment"># Load and quantize the model</span> | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>, | |
| dtype=<span class="hljs-string">"auto"</span>, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| <span class="hljs-comment"># auto-compile the quantized model with `cache_implementation="static"` to get speed up</span> | |
| output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> </div> | |
| </hfoption> | |
| <hfoption id="int4-weight-only-24sparse"> | |
| <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig | |
| <span class="hljs-keyword">from</span> torchao.dtypes <span class="hljs-keyword">import</span> MarlinSparseLayout | |
| quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout()) | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| <span class="hljs-comment"># Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuracy loss</span> | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"RedHatAI/Sparse-Llama-3.1-8B-2of4"</span>, | |
| dtype=torch.float16, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"RedHatAI/Sparse-Llama-3.1-8B-2of4"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| <span class="hljs-comment"># auto-compile the quantized model with `cache_implementation="static"` to get speed up</span> | |
| output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> | |
| </hfoption> | |
| </hfoptions> | |
| <h3 class="relative group"><a id="intel-xpu" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#intel-xpu"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Intel XPU</span></h3> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">int8-dynamic-and-weight-only </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">int4-weight-only </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig | |
| quant_config = Int8DynamicActivationInt8WeightConfig() | |
| <span class="hljs-comment"># or int8 weight only quantization</span> | |
| <span class="hljs-comment"># quant_config = Int8WeightOnlyConfig()</span> | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| <span class="hljs-comment"># Load and quantize the model</span> | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>, | |
| dtype=<span class="hljs-string">"auto"</span>, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| <span class="hljs-comment"># auto-compile the quantized model with `cache_implementation="static"` to get speed up</span> | |
| output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> </div> <h3 class="relative group"><a id="cpu" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cpu"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>CPU</span></h3> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">int8-dynamic-and-weight-only </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">int4-weight-only </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig | |
| quant_config = Int8DynamicActivationInt8WeightConfig() | |
| <span class="hljs-comment"># quant_config = Int8WeightOnlyConfig()</span> | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| <span class="hljs-comment"># Load and quantize the model</span> | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>, | |
| dtype=<span class="hljs-string">"auto"</span>, | |
| device_map=<span class="hljs-string">"cpu"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>) | |
| <span class="hljs-comment"># auto-compile the quantized model with `cache_implementation="static"` to get speed up</span> | |
| output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> </div> <h3 class="relative group"><a id="per-module-quantization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#per-module-quantization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Per Module Quantization</span></h3> <h4 class="relative group"><a id="1-skip-quantization-for-certain-layers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#1-skip-quantization-for-certain-layers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>1. Skip quantization for certain layers</span></h4> <p data-svelte-h="svelte-60csp8">With <code>ModuleFqnToConfig</code> we can specify a default configuration for all layers while skipping quantization for certain layers.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | |
| model_id = <span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span> | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig, ModuleFqnToConfig | |
| config = Int4WeightOnlyConfig(group_size=<span class="hljs-number">128</span>) | |
| <span class="hljs-comment"># set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`</span> | |
| quant_config = ModuleFqnToConfig({<span class="hljs-string">"_default"</span>: config, <span class="hljs-string">"model.layers.0.self_attn.q_proj"</span>: <span class="hljs-literal">None</span>}) | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map=<span class="hljs-string">"auto"</span>, dtype=torch.bfloat16, quantization_config=quantization_config) | |
| <span class="hljs-comment"># lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized</span> | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"quantized model:"</span>, quantized_model) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| <span class="hljs-comment"># Manual Testing</span> | |
| prompt = <span class="hljs-string">"Hey, are you conscious? Can you talk to me?"</span> | |
| inputs = tokenizer(prompt, return_tensors=<span class="hljs-string">"pt"</span>).to(quantized_model.device.<span class="hljs-built_in">type</span>) | |
| generated_ids = quantized_model.generate(**inputs, max_new_tokens=<span class="hljs-number">128</span>) | |
| output_text = tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>, clean_up_tokenization_spaces=<span class="hljs-literal">False</span> | |
| ) | |
| <span class="hljs-built_in">print</span>(output_text)<!-- HTML_TAG_END --></pre></div> <h4 class="relative group"><a id="2-quantizing-different-layers-with-different-quantization-configs-no-regex" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#2-quantizing-different-layers-with-different-quantization-configs-no-regex"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>2. Quantizing different layers with different quantization configs (no regex)</span></h4> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | |
| model_id = <span class="hljs-string">"facebook/opt-125m"</span> | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig, ModuleFqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType | |
| weight_dtype = torch.int8 | |
| granularity = PerAxis(<span class="hljs-number">0</span>) | |
| mapping_type = MappingType.ASYMMETRIC | |
| embedding_config = IntxWeightOnlyConfig( | |
| weight_dtype=weight_dtype, | |
| granularity=granularity, | |
| mapping_type=mapping_type, | |
| ) | |
| linear_config = Int8DynamicActivationInt4WeightConfig(group_size=<span class="hljs-number">128</span>) | |
| quant_config = ModuleFqnToConfig({<span class="hljs-string">"_default"</span>: linear_config, <span class="hljs-string">"model.decoder.embed_tokens"</span>: embedding_config, <span class="hljs-string">"model.decoder.embed_positions"</span>: <span class="hljs-literal">None</span>}) | |
| <span class="hljs-comment"># set `include_embedding` to True in order to include embedding in quantization</span> | |
| <span class="hljs-comment"># when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well</span> | |
| quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=<span class="hljs-literal">True</span>) | |
| quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map=<span class="hljs-string">"cpu"</span>, dtype=torch.bfloat16, quantization_config=quantization_config) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"quantized model:"</span>, quantized_model) | |
| <span class="hljs-comment"># make sure embedding is quantized</span> | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"embed_tokens weight:"</span>, quantized_model.model.decoder.embed_tokens.weight) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| <span class="hljs-comment"># Manual Testing</span> | |
| prompt = <span class="hljs-string">"Hey, are you conscious? Can you talk to me?"</span> | |
| inputs = tokenizer(prompt, return_tensors=<span class="hljs-string">"pt"</span>).to(<span class="hljs-string">"cpu"</span>) | |
| generated_ids = quantized_model.generate(**inputs, max_new_tokens=<span class="hljs-number">128</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| output_text = tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>, clean_up_tokenization_spaces=<span class="hljs-literal">False</span> | |
| ) | |
| <span class="hljs-built_in">print</span>(output_text)<!-- HTML_TAG_END --></pre></div> <h4 class="relative group"><a id="3-quantizing-different-layers-with-different-quantization-configs-with-regex" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#3-quantizing-different-layers-with-different-quantization-configs-with-regex"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>3. Quantizing different layers with different quantization configs (with regex)</span></h4> <p data-svelte-h="svelte-1fu5667">We can also use regex to specify the config for all modules that has <code>module_fqn</code> that | |
| matches the regex, all regex should start with <code>re:</code>, for example <code>re:layers\..*\.gate_proj</code> will | |
| match all layers like <code>layers.0.gate_proj</code>. See <a href="https://github.com/pytorch/ao/blob/2fe0ca0899c730c528efdbec8886feaa38879f39/torchao/quantization/quant_api.py#L2392" rel="nofollow">here</a> for docs.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> logging | |
| <span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | |
| <span class="hljs-comment"># Configure logging to see warnings and debug information</span> | |
| logging.basicConfig( | |
| level=logging.INFO, <span class="hljs-built_in">format</span>=<span class="hljs-string">"%(name)s - %(levelname)s - %(message)s"</span> | |
| ) | |
| <span class="hljs-comment"># Enable specific loggers that might contain the serialization warnings</span> | |
| logging.getLogger(<span class="hljs-string">"transformers"</span>).setLevel(logging.INFO) | |
| logging.getLogger(<span class="hljs-string">"torchao"</span>).setLevel(logging.INFO) | |
| logging.getLogger(<span class="hljs-string">"safetensors"</span>).setLevel(logging.INFO) | |
| logging.getLogger(<span class="hljs-string">"huggingface_hub"</span>).setLevel(logging.INFO) | |
| model_id = <span class="hljs-string">"facebook/opt-125m"</span> | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> ( | |
| Float8DynamicActivationFloat8WeightConfig, | |
| Int4WeightOnlyConfig, | |
| IntxWeightOnlyConfig, | |
| PerRow, | |
| PerAxis, | |
| ModuleFqnToConfig, | |
| Float8Tensor, | |
| Int4TilePackedTo4dTensor, | |
| IntxUnpackedToInt8Tensor, | |
| ) | |
| float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) | |
| int4wo = Int4WeightOnlyConfig(int4_packing_format=<span class="hljs-string">"tile_packed_to_4d"</span>) | |
| intxwo = IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(<span class="hljs-number">0</span>)) | |
| qconfig_dict = { | |
| <span class="hljs-comment"># highest priority</span> | |
| <span class="hljs-string">"model.decoder.layers.3.self_attn.q_proj"</span>: int4wo, | |
| <span class="hljs-string">"model.decoder.layers.3.self_attn.k_proj"</span>: int4wo, | |
| <span class="hljs-string">"model.decoder.layers.3.self_attn.v_proj"</span>: int4wo, | |
| <span class="hljs-comment"># vllm</span> | |
| <span class="hljs-string">"model.decoder.layers.3.self_attn.qkv_proj"</span>: int4wo, | |
| <span class="hljs-string">"re:model\.decoder\.layers\..+\.self_attn\.q_proj"</span>: float8dyn, | |
| <span class="hljs-string">"re:model\.decoder\.layers\..+\.self_attn\.k_proj"</span>: float8dyn, | |
| <span class="hljs-string">"re:model\.decoder\.layers\..+\.self_attn\.v_proj"</span>: float8dyn, | |
| <span class="hljs-comment"># this should not take effect and we'll fallback to _default</span> | |
| <span class="hljs-comment"># since no full mach (missing `j` in the end)</span> | |
| <span class="hljs-string">"re:model\.decoder\.layers\..+\.self_attn\.out_pro"</span>: float8dyn, | |
| <span class="hljs-comment"># vllm</span> | |
| <span class="hljs-string">"re:model\.decoder\.layers\..+\.self_attn\.qkv_proj"</span>: float8dyn, | |
| <span class="hljs-string">"_default"</span>: intxwo, | |
| } | |
| quant_config = ModuleFqnToConfig(qconfig_dict) | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| torch_dtype=torch.bfloat16, | |
| quantization_config=quantization_config, | |
| ) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"quantized model:"</span>, quantized_model) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">12</span>): | |
| <span class="hljs-keyword">if</span> i == <span class="hljs-number">3</span>: | |
| <span class="hljs-keyword">assert</span> <span class="hljs-built_in">isinstance</span>(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor) | |
| <span class="hljs-keyword">assert</span> <span class="hljs-built_in">isinstance</span>(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor) | |
| <span class="hljs-keyword">assert</span> <span class="hljs-built_in">isinstance</span>(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor) | |
| <span class="hljs-keyword">else</span>: | |
| <span class="hljs-keyword">assert</span> <span class="hljs-built_in">isinstance</span>(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor) | |
| <span class="hljs-keyword">assert</span> <span class="hljs-built_in">isinstance</span>(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor) | |
| <span class="hljs-keyword">assert</span> <span class="hljs-built_in">isinstance</span>(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor) | |
| <span class="hljs-keyword">assert</span> <span class="hljs-built_in">isinstance</span>(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor) | |
| <span class="hljs-comment"># Manual Testing</span> | |
| prompt = <span class="hljs-string">"What are we having for dinner?"</span> | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"Prompt:"</span>, prompt) | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors=<span class="hljs-string">"pt"</span>, | |
| ).to(<span class="hljs-string">"cuda"</span>) | |
| <span class="hljs-comment"># setting temperature to 0 to make sure result deterministic</span> | |
| generated_ids = quantized_model.generate(**inputs, max_new_tokens=<span class="hljs-number">128</span>, temperature=<span class="hljs-number">0</span>) | |
| correct_output_text = tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>, clean_up_tokenization_spaces=<span class="hljs-literal">False</span> | |
| ) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"Response:"</span>, correct_output_text[<span class="hljs-number">0</span>][<span class="hljs-built_in">len</span>(prompt) :]) | |
| <span class="hljs-comment"># Load model from saved checkpoint</span> | |
| reloaded_model = AutoModelForCausalLM.from_pretrained( | |
| save_to, | |
| device_map=<span class="hljs-string">"cuda:0"</span>, | |
| torch_dtype=torch.bfloat16, | |
| <span class="hljs-comment"># quantization_config=quantization_config,</span> | |
| ) | |
| generated_ids = reloaded_model.generate(**inputs, max_new_tokens=<span class="hljs-number">128</span>, temperature=<span class="hljs-number">0</span>) | |
| output_text = tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>, clean_up_tokenization_spaces=<span class="hljs-literal">False</span> | |
| ) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"Response:"</span>, output_text[<span class="hljs-number">0</span>][<span class="hljs-built_in">len</span>(prompt) :]) | |
| <span class="hljs-keyword">assert</span>(correct_output_text == output_text)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="autoquant" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#autoquant"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Autoquant</span></h3> <p data-svelte-h="svelte-1fr0k2r">If you want to automatically choose a quantization type for quantizable layers (<code>nn.Linear</code>) you can use the <a href="https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant" rel="nofollow">autoquant</a> API.</p> <p data-svelte-h="svelte-axxx0x">The <code>autoquant</code> API automatically chooses a quantization type by micro-benchmarking on input type and shape and compiling a single linear layer.</p> <p data-svelte-h="svelte-19318ol">Note: autoquant is for GPU only right now.</p> <p data-svelte-h="svelte-19ysi5o">Create a <a href="/docs/transformers/pr_33892/en/main_classes/quantization#transformers.TorchAoConfig">TorchAoConfig</a> and set to <code>"autoquant"</code>. Set the <code>cache_implementation</code> to <code>"static"</code> to automatically <a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html" rel="nofollow">torch.compile</a> the forward method. Finally, call <code>finalize_autoquant</code> on the quantized model to finalize the quantization and log the input shapes.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| quantization_config = TorchAoConfig(<span class="hljs-string">"autoquant"</span>, min_sqnr=<span class="hljs-literal">None</span>) | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>, | |
| dtype=<span class="hljs-string">"auto"</span>, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(quantized_model.device.<span class="hljs-built_in">type</span>) | |
| <span class="hljs-comment"># auto-compile the quantized model with `cache_implementation="static"` to get speed up</span> | |
| output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| <span class="hljs-comment"># explicitly call `finalize_autoquant` (may be refactored and removed in the future)</span> | |
| quantized_model.finalize_autoquant() | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="serialization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#serialization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Serialization</span></h2> <p data-svelte-h="svelte-1g3oift">torchao implements <a href="https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor" rel="nofollow">torch.Tensor subclasses</a> for maximum flexibility in supporting new quantized torch.Tensor formats. <a href="https://huggingface.co/docs/safetensors/en/index" rel="nofollow">Safetensors</a> serialization and deserialization does not work with torchao.</p> <p data-svelte-h="svelte-5ma9bd">To avoid arbitrary user code execution, torchao sets <code>weights_only=True</code> in <a href="https://pytorch.org/docs/stable/generated/torch.load.html" rel="nofollow">torch.load</a> to ensure only tensors are loaded. Any known user functions can be whitelisted with <a href="https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals" rel="nofollow">add_safe_globals</a>.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">save-locally </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">push-to-huggingface-hub </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># don't serialize model with Safetensors</span> | |
| output_dir = <span class="hljs-string">"llama3-8b-int4wo-128"</span> | |
| quantized_model.save_pretrained(<span class="hljs-string">"llama3-8b-int4wo-128"</span>, safe_serialization=<span class="hljs-literal">False</span>)<!-- HTML_TAG_END --></pre></div> </div> <h2 class="relative group"><a id="loading-quantized-models" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#loading-quantized-models"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Loading quantized models</span></h2> <p data-svelte-h="svelte-9pklwu">Loading a quantized model depends on the quantization scheme. For quantization schemes, like int8 and float8, you can quantize the model on any device and also load it on any device. The example below demonstrates quantizing a model on the CPU and then loading it on CUDA or XPU.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int8WeightOnlyConfig | |
| quant_config = Int8WeightOnlyConfig(group_size=<span class="hljs-number">128</span>) | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| <span class="hljs-comment"># Load and quantize the model</span> | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>, | |
| dtype=<span class="hljs-string">"auto"</span>, | |
| device_map=<span class="hljs-string">"cpu"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| <span class="hljs-comment"># save the quantized model</span> | |
| output_dir = <span class="hljs-string">"llama-3.1-8b-torchao-int8"</span> | |
| quantized_model.save_pretrained(output_dir, safe_serialization=<span class="hljs-literal">False</span>) | |
| <span class="hljs-comment"># reload the quantized model</span> | |
| reloaded_model = AutoModelForCausalLM.from_pretrained( | |
| output_dir, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| dtype=torch.bfloat16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>).to(reloaded_model.device.<span class="hljs-built_in">type</span>) | |
| output = reloaded_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>)) | |
| <!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1oj3qy1">For int4, the model can only be loaded on the same device it was quantized on because the layout is specific to the device. The example below demonstrates quantizing and loading a model on the CPU.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig | |
| <span class="hljs-keyword">from</span> torchao.dtypes <span class="hljs-keyword">import</span> Int4CPULayout | |
| quant_config = Int4WeightOnlyConfig(group_size=<span class="hljs-number">128</span>, layout=Int4CPULayout()) | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| <span class="hljs-comment"># Load and quantize the model</span> | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>, | |
| dtype=<span class="hljs-string">"auto"</span>, | |
| device_map=<span class="hljs-string">"cpu"</span>, | |
| quantization_config=quantization_config | |
| ) | |
| <span class="hljs-comment"># save the quantized model</span> | |
| output_dir = <span class="hljs-string">"llama-3.1-8b-torchao-int4-cpu"</span> | |
| quantized_model.save_pretrained(output_dir, safe_serialization=<span class="hljs-literal">False</span>) | |
| <span class="hljs-comment"># reload the quantized model</span> | |
| reloaded_model = AutoModelForCausalLM.from_pretrained( | |
| output_dir, | |
| device_map=<span class="hljs-string">"cpu"</span>, | |
| dtype=torch.bfloat16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B-Instruct"</span>) | |
| input_text = <span class="hljs-string">"What are we having for dinner?"</span> | |
| input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">"pt"</span>) | |
| output = reloaded_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>)) | |
| <!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="-deprecation-notice" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#-deprecation-notice"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>⚠️ Deprecation Notice</span></h2> <blockquote><p data-svelte-h="svelte-jmxllg">Starting with version 0.10.0, the string-based API for quantization configuration (e.g., <code>TorchAoConfig("int4_weight_only", group_size=128)</code>) is <strong>deprecated</strong> and will be removed in a future release.</p> <p data-svelte-h="svelte-1qak5au">Please use the new <code>AOBaseConfig</code>-based approach instead:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Old way (deprecated)</span> | |
| quantization_config = TorchAoConfig(<span class="hljs-string">"int4_weight_only"</span>, group_size=<span class="hljs-number">128</span>) | |
| <span class="hljs-comment"># New way (recommended)</span> | |
| <span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig | |
| quant_config = Int4WeightOnlyConfig(group_size=<span class="hljs-number">128</span>) | |
| quantization_config = TorchAoConfig(quant_type=quant_config)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8tiw44">The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.</p> <p data-svelte-h="svelte-1wt1icj"><a href="#migration-guide">Migration Guide</a></p> <p data-svelte-h="svelte-1kyext4">Here’s how to migrate from common string identifiers to their <code>AOBaseConfig</code> equivalents:</p> <table data-svelte-h="svelte-1w9xui3"><thead><tr><th>Old String API</th> <th>New <code>AOBaseConfig</code> API</th></tr></thead> <tbody><tr><td><code>"int4_weight_only"</code></td> <td><code>Int4WeightOnlyConfig()</code></td></tr> <tr><td><code>"int8_weight_only"</code></td> <td><code>Int8WeightOnlyConfig()</code></td></tr> <tr><td><code>"int8_dynamic_activation_int8_weight"</code></td> <td><code>Int8DynamicActivationInt8WeightConfig()</code></td></tr></tbody></table> <p data-svelte-h="svelte-9z9ctj">All configuration objects accept parameters for customization (e.g., <code>group_size</code>, <code>scheme</code>, <code>layout</code>).</p></blockquote> <h2 class="relative group"><a id="resources" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#resources"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Resources</span></h2> <p data-svelte-h="svelte-1cgj2bs">For a better sense of expected performance, view the <a href="https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks" rel="nofollow">benchmarks</a> for various models with CUDA and XPU backends. You can also run the code below to benchmark a model yourself.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch._inductor.utils <span class="hljs-keyword">import</span> do_bench_using_profiling | |
| <span class="hljs-keyword">from</span> typing <span class="hljs-keyword">import</span> <span class="hljs-type">Callable</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">benchmark_fn</span>(<span class="hljs-params">func: <span class="hljs-type">Callable</span>, *args, **kwargs</span>) -> <span class="hljs-built_in">float</span>: | |
| <span class="hljs-string">"""Thin wrapper around do_bench_using_profiling"""</span> | |
| no_args = <span class="hljs-keyword">lambda</span>: func(*args, **kwargs) | |
| time = do_bench_using_profiling(no_args) | |
| <span class="hljs-keyword">return</span> time * <span class="hljs-number">1e3</span> | |
| MAX_NEW_TOKENS = <span class="hljs-number">1000</span> | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"int4wo-128 model:"</span>, benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation=<span class="hljs-string">"static"</span>)) | |
| bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=<span class="hljs-string">"auto"</span>, dtype=torch.bfloat16) | |
| output = bf16_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>, cache_implementation=<span class="hljs-string">"static"</span>) <span class="hljs-comment"># auto-compile</span> | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"bf16 model:"</span>, benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation=<span class="hljs-string">"static"</span>))<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-vhjtng"><p>For best performance, you can use recommended settings by calling <code>torchao.quantization.utils.recommended_inductor_config_setter()</code></p></blockquote> <p data-svelte-h="svelte-fj0t1q">Refer to <a href="https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques" rel="nofollow">Other Available Quantization Techniques</a> for more examples and documentation.</p> <h2 class="relative group"><a id="issues" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#issues"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Issues</span></h2> <p data-svelte-h="svelte-1auvp72">If you encounter any issues with the Transformers integration, please open an issue on the <a href="https://github.com/huggingface/transformers/issues" rel="nofollow">Transformers</a> repository. For issues directly related to torchao, please open an issue on the <a href="https://github.com/pytorch/ao/issues" rel="nofollow">torchao</a> repository.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/quantization/torchao.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_16tnnm8 = { | |
| assets: "/docs/transformers/pr_33892/en", | |
| base: "/docs/transformers/pr_33892/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"), | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 543], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 93.1 kB
- Xet hash:
- 2e7de432d2a91c94cf6c22b2702693fc5acd719cc1753af68e8d7ef103ded179
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.