Buckets:

rtrm's picture
download
raw
15.6 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;TorchAO&quot;,&quot;local&quot;:&quot;torchao&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Serialization and Deserialization&quot;,&quot;local&quot;:&quot;serialization-and-deserialization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/pr_33913/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/scheduler.25b97de1.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/singletons.62a184e0.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.e188933d.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/paths.51881b9e.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.d9030fc9.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/0.05e395f5.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/409.687e5631.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/EditOnGithub.91d95064.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;TorchAO&quot;,&quot;local&quot;:&quot;torchao&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Serialization and Deserialization&quot;,&quot;local&quot;:&quot;serialization-and-deserialization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="torchao" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchao"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>TorchAO</span></h1> <p data-svelte-h="svelte-1w5ww4g"><a href="https://github.com/pytorch/ao" rel="nofollow">TorchAO</a> is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like <code>torch.compile</code>, FSDP etc.. Some benchmark numbers can be found <a href="https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks" rel="nofollow">here</a>.</p> <p data-svelte-h="svelte-1wrsrbm">Before you begin, make sure the following libraries are installed with their latest version:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install --upgrade torch torchao<!-- HTML_TAG_END --></pre></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
model_name = <span class="hljs-string">&quot;meta-llama/Meta-Llama-3-8B&quot;</span>
<span class="hljs-comment"># We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight</span>
<span class="hljs-comment"># More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques</span>
quantization_config = TorchAoConfig(<span class="hljs-string">&quot;int4_weight_only&quot;</span>, group_size=<span class="hljs-number">128</span>)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=<span class="hljs-string">&quot;auto&quot;</span>, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = <span class="hljs-string">&quot;What are we having for dinner?&quot;</span>
input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(<span class="hljs-string">&quot;cuda&quot;</span>)
<span class="hljs-comment"># compile the quantized model to get speedup</span>
<span class="hljs-keyword">import</span> torchao
torchao.quantization.utils.recommended_inductor_config_setter()
quantized_model = torch.<span class="hljs-built_in">compile</span>(quantized_model, mode=<span class="hljs-string">&quot;max-autotune&quot;</span>)
output = quantized_model.generate(**input_ids, max_new_tokens=<span class="hljs-number">10</span>)
<span class="hljs-built_in">print</span>(tokenizer.decode(output[<span class="hljs-number">0</span>], skip_special_tokens=<span class="hljs-literal">True</span>))
<span class="hljs-comment"># benchmark the performance</span>
<span class="hljs-keyword">import</span> torch.utils.benchmark <span class="hljs-keyword">as</span> benchmark
<span class="hljs-keyword">def</span> <span class="hljs-title function_">benchmark_fn</span>(<span class="hljs-params">f, *args, **kwargs</span>):
<span class="hljs-comment"># Manual warmup</span>
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">5</span>):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt=<span class="hljs-string">&quot;f(*args, **kwargs)&quot;</span>,
<span class="hljs-built_in">globals</span>={<span class="hljs-string">&quot;args&quot;</span>: args, <span class="hljs-string">&quot;kwargs&quot;</span>: kwargs, <span class="hljs-string">&quot;f&quot;</span>: f},
num_threads=torch.get_num_threads(),
)
<span class="hljs-keyword">return</span> <span class="hljs-string">f&quot;<span class="hljs-subst">{(t0.blocked_autorange().mean):<span class="hljs-number">.3</span>f}</span>&quot;</span>
MAX_NEW_TOKENS = <span class="hljs-number">1000</span>
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;int4wo-128 model:&quot;</span>, benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=<span class="hljs-string">&quot;cuda&quot;</span>, torch_dtype=torch.bfloat16)
bf16_model = torch.<span class="hljs-built_in">compile</span>(bf16_model, mode=<span class="hljs-string">&quot;max-autotune&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;bf16 model:&quot;</span>, benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="serialization-and-deserialization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#serialization-and-deserialization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Serialization and Deserialization</span></h2> <p data-svelte-h="svelte-14dwjgd">torchao quantization is implemented with <a href="https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor" rel="nofollow">tensor subclasses</a>, it only work with huggingface non-safetensor serialization and deserialization. It relies on <code>torch.load(..., weights_only=True)</code> to avoid arbitrary user code execution during load time and use <a href="https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals" rel="nofollow">add_safe_globals</a> to allowlist some known user functions.</p> <p data-svelte-h="svelte-10wmnl9">The reason why it does not support safe tensor serialization is that wrapper tensor subclass allows maximum flexibility so we want to make sure the effort of supporting new format of quantized Tensor is low, while safe tensor optimizes for maximum safety (no user code execution), it also means we have to make sure to manually support new quantization format.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># save quantized model locally</span>
output_dir = <span class="hljs-string">&quot;llama3-8b-int4wo-128&quot;</span>
quantized_model.save_pretrained(output_dir, safe_serialization=<span class="hljs-literal">False</span>)
<span class="hljs-comment"># push to huggingface hub</span>
<span class="hljs-comment"># save_to = &quot;{user_id}/llama3-8b-int4wo-128&quot;</span>
<span class="hljs-comment"># quantized_model.push_to_hub(save_to, safe_serialization=False)</span>
<span class="hljs-comment"># load quantized model</span>
ckpt_id = <span class="hljs-string">&quot;llama3-8b-int4wo-128&quot;</span> <span class="hljs-comment"># or huggingface hub model id</span>
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map=<span class="hljs-string">&quot;cuda&quot;</span>)
<span class="hljs-comment"># confirm the speedup</span>
loaded_quantized_model = torch.<span class="hljs-built_in">compile</span>(loaded_quantized_model, mode=<span class="hljs-string">&quot;max-autotune&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;loaded int4wo-128 model:&quot;</span>, benchmark_fn(loaded_quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/quantization/torchao.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_z647wz = {
assets: "/docs/transformers/pr_33913/en",
base: "/docs/transformers/pr_33913/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"),
import("/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 409],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
15.6 kB
·
Xet hash:
8e556ed8fe27b4381bb0533639d410d63c57585eb20d69b26843bf1d2153d2b6

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.