Buckets:

hf-doc-build/doc-dev / transformers /pr_33913 /en /how_to_hack_models.html
rtrm's picture
download
raw
36.1 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;How to Hack Any Transformers Model&quot;,&quot;local&quot;:&quot;how-to-hack-any-transformers-model&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM)&quot;,&quot;local&quot;:&quot;example-modifying-the-attention-mechanism-in-the-segment-anything-model-sam&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Motivation&quot;,&quot;local&quot;:&quot;motivation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Implementation&quot;,&quot;local&quot;:&quot;implementation&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Step 1: Create a Custom Attention Class&quot;,&quot;local&quot;:&quot;step-1-create-a-custom-attention-class&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;Step 2: Replace the Original Attention Class&quot;,&quot;local&quot;:&quot;step-2-replace-the-original-attention-class&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;Step 3: Apply LoRA to Specific Projections&quot;,&quot;local&quot;:&quot;step-3-apply-lora-to-specific-projections&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;Step 4: Verify the Number of Trainable Parameters&quot;,&quot;local&quot;:&quot;step-4-verify-the-number-of-trainable-parameters&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Contributing Your Own Hacks&quot;,&quot;local&quot;:&quot;contributing-your-own-hacks&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/pr_33913/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/scheduler.25b97de1.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/singletons.62a184e0.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.e188933d.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/paths.51881b9e.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.d9030fc9.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/0.05e395f5.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/25.295145f9.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/EditOnGithub.91d95064.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;How to Hack Any Transformers Model&quot;,&quot;local&quot;:&quot;how-to-hack-any-transformers-model&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM)&quot;,&quot;local&quot;:&quot;example-modifying-the-attention-mechanism-in-the-segment-anything-model-sam&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Motivation&quot;,&quot;local&quot;:&quot;motivation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Implementation&quot;,&quot;local&quot;:&quot;implementation&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Step 1: Create a Custom Attention Class&quot;,&quot;local&quot;:&quot;step-1-create-a-custom-attention-class&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;Step 2: Replace the Original Attention Class&quot;,&quot;local&quot;:&quot;step-2-replace-the-original-attention-class&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;Step 3: Apply LoRA to Specific Projections&quot;,&quot;local&quot;:&quot;step-3-apply-lora-to-specific-projections&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;Step 4: Verify the Number of Trainable Parameters&quot;,&quot;local&quot;:&quot;step-4-verify-the-number-of-trainable-parameters&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Contributing Your Own Hacks&quot;,&quot;local&quot;:&quot;contributing-your-own-hacks&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="how-to-hack-any-transformers-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-to-hack-any-transformers-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How to Hack Any Transformers Model</span></h1> <p data-svelte-h="svelte-4p0w22">The <a href="https://github.com/huggingface/transformers" rel="nofollow">🤗 Transformers</a> library offers a collection of pre-trained models and tools for natural language processing, vision, and beyond. While these models cover a wide range of applications, you might encounter use cases that aren’t supported out of the box. Customizing models can unlock new possibilities, such as adding new layers, altering architectures, or optimizing attention mechanisms. This guide will show you how to modify existing Transformers models to fit your specific needs. The great thing is, you don’t have to step away from the Transformers framework to make these changes. You can actually modify models directly in Transformers and still take advantage of features like the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer" rel="nofollow">Trainer API</a>, <a href="https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel" rel="nofollow">PreTrainedModel</a>, and efficient fine-tuning with tools like <a href="https://huggingface.co/docs/peft/index" rel="nofollow">PEFT</a>.</p> <p data-svelte-h="svelte-1kaw87e">In this guide, we’ll walk you through how to customize existing Transformers models to meet your requirements—without losing the benefits of the ecosystem.</p> <p data-svelte-h="svelte-16olrhx">You’ll learn how to:</p> <ul data-svelte-h="svelte-1219jh3"><li>Modify a model’s architecture by changing its attention mechanism.</li> <li>Apply techniques like Low-Rank Adaptation (LoRA) to specific model components.</li></ul> <p data-svelte-h="svelte-1fl0iuo">We encourage you to contribute your own hacks and share them here with the community1</p> <h2 class="relative group"><a id="example-modifying-the-attention-mechanism-in-the-segment-anything-model-sam" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#example-modifying-the-attention-mechanism-in-the-segment-anything-model-sam"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM)</span></h2> <p data-svelte-h="svelte-pqt08">The <strong>Segment Anything Model (SAM)</strong> is a state-of-the-art model for image segmentation. In its default implementation, SAM uses a combined query-key-value (<code>qkv</code>) projection in its attention mechanism. However, you might want to fine-tune only specific components of the attention mechanism, such as the query (<code>q</code>) and value (<code>v</code>) projections, to reduce the number of trainable parameters and computational resources required.</p> <h3 class="relative group"><a id="motivation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#motivation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Motivation</span></h3> <p data-svelte-h="svelte-1fee5zh">By splitting the combined <code>qkv</code> projection into separate <code>q</code>, <code>k</code>, and <code>v</code> projections, you can apply techniques like <strong>LoRA</strong> (Low-Rank Adaptation) to only the <code>q</code> and <code>v</code> projections. This approach allows you to:</p> <ul data-svelte-h="svelte-ajxk4v"><li>Fine-tune fewer parameters, reducing computational overhead.</li> <li>Potentially achieve better performance by focusing on specific components.</li> <li>Experiment with different adaptation strategies in the attention mechanism.</li></ul> <h3 class="relative group"><a id="implementation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#implementation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Implementation</span></h3> <h4 class="relative group"><a id="step-1-create-a-custom-attention-class" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#step-1-create-a-custom-attention-class"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Step 1: Create a Custom Attention Class</span></h4> <p data-svelte-h="svelte-183kvjr">Next, subclass the original <code>SamVisionAttention</code> class and modify it to have separate <code>q</code>, <code>k</code>, and <code>v</code> projections.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> torch.nn <span class="hljs-keyword">as</span> nn
<span class="hljs-keyword">from</span> transformers.models.sam.modeling_sam <span class="hljs-keyword">import</span> SamVisionAttention
<span class="hljs-keyword">class</span> <span class="hljs-title class_">SamVisionAttentionSplit</span>(SamVisionAttention, nn.Module):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config, window_size</span>):
<span class="hljs-built_in">super</span>().__init__(config, window_size)
<span class="hljs-keyword">del</span> self.qkv
<span class="hljs-comment"># Separate q, k, v projections</span>
self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">split_q_k_v_load_hook</span>(<span class="hljs-params">self, state_dict, prefix, *args</span>):
keys_to_delete = []
<span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> <span class="hljs-built_in">list</span>(state_dict.keys()):
<span class="hljs-keyword">if</span> <span class="hljs-string">&quot;qkv.&quot;</span> <span class="hljs-keyword">in</span> key:
<span class="hljs-comment"># Split q, k, v from the combined projection</span>
q, k, v = state_dict[key].chunk(<span class="hljs-number">3</span>, dim=<span class="hljs-number">0</span>)
<span class="hljs-comment"># Replace with individual q, k, v projections</span>
state_dict[key.replace(<span class="hljs-string">&quot;qkv.&quot;</span>, <span class="hljs-string">&quot;q.&quot;</span>)] = q
state_dict[key.replace(<span class="hljs-string">&quot;qkv.&quot;</span>, <span class="hljs-string">&quot;k.&quot;</span>)] = k
state_dict[key.replace(<span class="hljs-string">&quot;qkv.&quot;</span>, <span class="hljs-string">&quot;v.&quot;</span>)] = v
<span class="hljs-comment"># Mark the old qkv key for deletion</span>
keys_to_delete.append(key)
<span class="hljs-comment"># Remove old qkv keys</span>
<span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> keys_to_delete:
<span class="hljs-keyword">del</span> state_dict[key]
<span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, hidden_states: torch.Tensor, output_attentions=<span class="hljs-literal">False</span></span>) -&gt; torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
qkv_shapes = (batch_size * self.num_attention_heads, height * width, -<span class="hljs-number">1</span>)
query = self.q(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -<span class="hljs-number">1</span>)).permute(<span class="hljs-number">0</span>,<span class="hljs-number">2</span>,<span class="hljs-number">1</span>,<span class="hljs-number">3</span>).reshape(qkv_shapes)
key = self.k(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -<span class="hljs-number">1</span>)).permute(<span class="hljs-number">0</span>,<span class="hljs-number">2</span>,<span class="hljs-number">1</span>,<span class="hljs-number">3</span>).reshape(qkv_shapes)
value = self.v(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -<span class="hljs-number">1</span>)).permute(<span class="hljs-number">0</span>,<span class="hljs-number">2</span>,<span class="hljs-number">1</span>,<span class="hljs-number">3</span>).reshape(qkv_shapes)
attn_weights = (query * self.scale) @ key.transpose(-<span class="hljs-number">2</span>, -<span class="hljs-number">1</span>)
<span class="hljs-keyword">if</span> self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-<span class="hljs-number">1</span>).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -<span class="hljs-number">1</span>)
attn_output = attn_output.permute(<span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-number">1</span>, <span class="hljs-number">4</span>).reshape(batch_size, height, width, -<span class="hljs-number">1</span>)
attn_output = self.proj(attn_output)
<span class="hljs-keyword">if</span> output_attentions:
outputs = (attn_output, attn_weights)
<span class="hljs-keyword">else</span>:
outputs = (attn_output, <span class="hljs-literal">None</span>)
<span class="hljs-keyword">return</span> outputs<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-qcx91s"><strong>Explanation:</strong></p> <ul data-svelte-h="svelte-17qbtt1"><li><strong>Separate Projections:</strong> The combined <code>qkv</code> projection is removed, and separate <code>q</code>, <code>k</code>, and <code>v</code> linear layers are created.</li> <li><strong>Weight Loading Hook:</strong> The <code>_split_qkv_load_hook</code> method splits the pre-trained <code>qkv</code> weights into separate <code>q</code>, <code>k</code>, and <code>v</code> weights when loading the model. This ensures compatibility with any pre-trained model.</li> <li><strong>Forward Pass:</strong> Queries, keys, and values are computed separately, and the attention mechanism proceeds as usual.</li></ul> <h4 class="relative group"><a id="step-2-replace-the-original-attention-class" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#step-2-replace-the-original-attention-class"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Step 2: Replace the Original Attention Class</span></h4> <p data-svelte-h="svelte-sp0tsy">Replace the original <code>SamVisionAttention</code> class with your custom class so that the model uses the modified attention mechanism.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> SamModel
<span class="hljs-keyword">from</span> transformers.models.sam <span class="hljs-keyword">import</span> modeling_sam
<span class="hljs-comment"># Replace the attention class in the modeling_sam module</span>
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
<span class="hljs-comment"># Load the pre-trained SAM model</span>
model = SamModel.from_pretrained(<span class="hljs-string">&quot;facebook/sam-vit-base&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-qcx91s"><strong>Explanation:</strong></p> <ul data-svelte-h="svelte-1qee5e9"><li><strong>Class Replacement:</strong> By assigning your custom class to <code>modeling_sam.SamVisionAttention</code>, any instances of <code>SamVisionAttention</code> in the model will use the modified version. Thus when you call <code>SamModel</code>, it will use the newly defined <code>SamVisionAttentionSplit</code>.</li> <li><strong>Model Loading:</strong> The model is loaded using <code>from_pretrained</code>, and the custom attention mechanism is integrated.</li></ul> <h4 class="relative group"><a id="step-3-apply-lora-to-specific-projections" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#step-3-apply-lora-to-specific-projections"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Step 3: Apply LoRA to Specific Projections</span></h4> <p data-svelte-h="svelte-14bh2vd">With separate <code>q</code>, <code>k</code>, and <code>v</code> projections, you can now apply LoRA to specific components, such as the <code>q</code> and <code>v</code> projections.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> LoraConfig, get_peft_model
config = LoraConfig(
r=<span class="hljs-number">16</span>,
lora_alpha=<span class="hljs-number">32</span>,
target_modules=[<span class="hljs-string">&quot;q&quot;</span>, <span class="hljs-string">&quot;v&quot;</span>], <span class="hljs-comment"># Apply LoRA to q and v projections</span>
lora_dropout=<span class="hljs-number">0.1</span>,
task_type=<span class="hljs-string">&quot;mask-generation&quot;</span>
)
<span class="hljs-comment"># Apply LoRA to the model</span>
model = get_peft_model(model, config)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-qcx91s"><strong>Explanation:</strong></p> <ul data-svelte-h="svelte-ampkeh"><li><strong>LoRA Configuration:</strong> The <code>LoraConfig</code> specifies the rank <code>r</code>, scaling factor <code>lora_alpha</code>, target modules (<code>&quot;q&quot;</code> and <code>&quot;v&quot;</code>), dropout, and task type.</li> <li><strong>Applying LoRA:</strong> The <code>get_peft_model</code> function applies LoRA to the specified modules in the model.</li> <li><strong>Parameter Reduction:</strong> By focusing on <code>q</code> and <code>v</code>, you reduce the number of trainable parameters, leading to faster training and lower memory usage.</li></ul> <h4 class="relative group"><a id="step-4-verify-the-number-of-trainable-parameters" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#step-4-verify-the-number-of-trainable-parameters"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Step 4: Verify the Number of Trainable Parameters</span></h4> <p data-svelte-h="svelte-a867g8">It’s simple to verify the number of trainable parameters and see what impact your modification had.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model.print_trainable_parameters()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cry7xq"><strong>Expected Output:</strong></p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-attribute">trainable</span> params: <span class="hljs-number">608</span>,<span class="hljs-number">256</span> || <span class="hljs-literal">all</span> params: <span class="hljs-number">94</span>,<span class="hljs-number">343</span>,<span class="hljs-number">728</span> || trainable%: <span class="hljs-number">0</span>.<span class="hljs-number">6447</span>
<span class="hljs-attribute">trainable</span> params: <span class="hljs-number">912</span>,<span class="hljs-number">384</span> || <span class="hljs-literal">all</span> params: <span class="hljs-number">94</span>,<span class="hljs-number">647</span>,<span class="hljs-number">856</span> || trainable%: <span class="hljs-number">0</span>.<span class="hljs-number">9640</span> # with k <!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="contributing-your-own-hacks" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#contributing-your-own-hacks"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Contributing Your Own Hacks</span></h2> <p data-svelte-h="svelte-1rmqfnq">Modifying pre-trained models can open up new avenues for research and application. By understanding and adjusting the internal mechanisms of models like SAM, you can tailor them to your specific needs, optimize performance, and experiment with new ideas.</p> <p data-svelte-h="svelte-1vdc3k7">If you’ve developed your own hacks for Transformers models and would like to share them, consider contributing to this doc.</p> <ul data-svelte-h="svelte-1934g57"><li><strong>Open a Pull Request:</strong> Share your code changes and improvements directly in the repository.</li> <li><strong>Write Documentation:</strong> Provide clear explanations and examples of your modifications.</li> <li><strong>Engage with the Community:</strong> Discuss your ideas and get feedback from other developers and researchers by opening an issue.</li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/how_to_hack_models.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_z647wz = {
assets: "/docs/transformers/pr_33913/en",
base: "/docs/transformers/pr_33913/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"),
import("/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 25],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
36.1 kB
·
Xet hash:
3cc4a777f77c426a12889c889423c425a6b5f4df9955c7b846204ab23cce4749

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.