Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Contributing a new model to Transformers","local":"contributing-a-new-model-to-transformers","sections":[{"title":"Motivation","local":"motivation","sections":[],"depth":2},{"title":"Create a modeling.py file","local":"create-a-modelingpy-file","sections":[{"title":"BERT and RoBERTa","local":"bert-and-roberta","sections":[],"depth":3}],"depth":2},{"title":"Implementing a modular file","local":"implementing-a-modular-file","sections":[{"title":"Config","local":"config","sections":[],"depth":3},{"title":"Norm","local":"norm","sections":[],"depth":3},{"title":"Attention","local":"attention","sections":[],"depth":3},{"title":"DecoderLayer","local":"decoderlayer","sections":[],"depth":3},{"title":"Model","local":"model","sections":[],"depth":3},{"title":"Model head","local":"model-head","sections":[],"depth":3},{"title":"Other classes","local":"other-classes","sections":[],"depth":3}],"depth":2},{"title":"Removing attributes","local":"removing-attributes","sections":[],"depth":2},{"title":"Calling parent methods without unravelling their definition","local":"calling-parent-methods-without-unravelling-their-definition","sections":[],"depth":2},{"title":"Deleting unused methods","local":"deleting-unused-methods","sections":[],"depth":2},{"title":"Defining new functions","local":"defining-new-functions","sections":[],"depth":2},{"title":"super_kwargs","local":"superkwargs","sections":[],"depth":2},{"title":"Docstring variables","local":"docstring-variables","sections":[],"depth":2},{"title":"Special naming","local":"special-naming","sections":[],"depth":2},{"title":"Config docstrings","local":"config-docstrings","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_33892/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/scheduler.31fdf58d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/singletons.9860629f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.252883d5.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/paths.e85c0ec8.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/preload-helper.40847a0e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.2f76fdf0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/0.ca4aafa4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/497.c425ac0c.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CopyLLMTxtMenu.ff482081.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.71f274cc.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/IconCopy.ac192424.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CodeBlock.ab12f8e1.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Contributing a new model to Transformers","local":"contributing-a-new-model-to-transformers","sections":[{"title":"Motivation","local":"motivation","sections":[],"depth":2},{"title":"Create a modeling.py file","local":"create-a-modelingpy-file","sections":[{"title":"BERT and RoBERTa","local":"bert-and-roberta","sections":[],"depth":3}],"depth":2},{"title":"Implementing a modular file","local":"implementing-a-modular-file","sections":[{"title":"Config","local":"config","sections":[],"depth":3},{"title":"Norm","local":"norm","sections":[],"depth":3},{"title":"Attention","local":"attention","sections":[],"depth":3},{"title":"DecoderLayer","local":"decoderlayer","sections":[],"depth":3},{"title":"Model","local":"model","sections":[],"depth":3},{"title":"Model head","local":"model-head","sections":[],"depth":3},{"title":"Other classes","local":"other-classes","sections":[],"depth":3}],"depth":2},{"title":"Removing attributes","local":"removing-attributes","sections":[],"depth":2},{"title":"Calling parent methods without unravelling their definition","local":"calling-parent-methods-without-unravelling-their-definition","sections":[],"depth":2},{"title":"Deleting unused methods","local":"deleting-unused-methods","sections":[],"depth":2},{"title":"Defining new functions","local":"defining-new-functions","sections":[],"depth":2},{"title":"super_kwargs","local":"superkwargs","sections":[],"depth":2},{"title":"Docstring variables","local":"docstring-variables","sections":[],"depth":2},{"title":"Special naming","local":"special-naming","sections":[],"depth":2},{"title":"Config docstrings","local":"config-docstrings","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="contributing-a-new-model-to-transformers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#contributing-a-new-model-to-transformers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Contributing a new model to Transformers</span></h1> <p data-svelte-h="svelte-lj81hr">Modular Transformers lowers the bar for contributing models and significantly reduces the code required to add a model by allowing imports and inheritance.</p> <p data-svelte-h="svelte-1fibcpl">One of Transformers’ core design feature is the <a href="https://huggingface.co/blog/transformers-design-philosophy" rel="nofollow">single model, single file</a> policy. Model components - such as attention layers - are repeated across many files and any independent implementations tend to diverge as fixes and changes are applied to specific parts of the code.</p> <p data-svelte-h="svelte-12jtllm">The <a href="./pr_checks#check-copies"><code># Copied from</code></a> statements prevents the code from diverging, and it is enforced by our continuous integration tests and local commands. The downside is that this approach is tedious and adds significantly more lines of code, most of which is boilerplate.</p> <h2 class="relative group"><a id="motivation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#motivation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Motivation</span></h2> <p data-svelte-h="svelte-wcqh4p">Modular Transformers addresses these issues by adding a <em>modular</em> file to a model folder. The modular file can import code from other models and inherit code from other classes unlike traditional modeling and processing files.</p> <blockquote class="tip" data-svelte-h="svelte-4vbqq6"><p>Modular Transformers isn’t meant to replace the modeling code, and if your model isn’t based on an existing model, you’ll need to add a <code>modeling.py</code> file manually. Likewise, if a configuration, tokenization or processing file can’t easily inherit from a similar file, you can add that file directly.</p></blockquote> <p data-svelte-h="svelte-sz29kf">A modular file contains model, processor, and configuration class code that would otherwise be in separate files under the single model, single file policy.</p> <p data-svelte-h="svelte-1n4f4or">Model users still import and use the single-file interface they’ve grown familiar with. In doing so, we hope to enable simpler contributions while sticking to our philosophy.</p> <h2 class="relative group"><a id="create-a-modelingpy-file" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#create-a-modelingpy-file"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Create a modeling.py file</span></h2> <p data-svelte-h="svelte-12wgxa8">A linter “unravels” the modular file into a <code>modeling.py</code> file to preserve the single model, single file directory structure (modeling, processor, etc.). Inheritance is flattened to only a <strong>single</strong> level.</p> <p data-svelte-h="svelte-jegoy5">Run the command below to automatically generate a <code>modeling.py</code> file from a modular file (assuming the snake lowercase name of the model you want to convert is <code>your_model</code>).</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->python utils/modular_model_converter.py your_model<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1gkqha7">For example:</p> <ul data-svelte-h="svelte-2znzcv"><li>If a configuration class inherits from another class, but adds and deletes an argument, the generated file directly references it if an argument is added or completely removes it if an argument is deleted.</li> <li>If a class inherits from another, like <code>GemmaModel(LlamaModel)</code>, the dependencies are automatically inferred. All submodules are also automatically inferred from the superclass.</li> <li>If a new function is defined in the modular file and used inside classes, the linter automatically infers these as well.</li></ul> <p data-svelte-h="svelte-1q8tpr3">You should be able to write everything (tokenizer, image processor, model, config, etc.) in a modular and their corresponding single-files are generated.</p> <p data-svelte-h="svelte-fj797c">The example below demonstrates how a model can be added with significantly fewer lines of code with Modular Transformers.</p> <h3 class="relative group"><a id="bert-and-roberta" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#bert-and-roberta"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>BERT and RoBERTa</span></h3> <p data-svelte-h="svelte-1tajku">BERT and RoBERTa, two very similar models, differ solely in how the embedding layer is implemented.</p> <p data-svelte-h="svelte-1izmokq">Instead of redefining the model entirely, consider the <code>modular_roberta.py</code> file shown below for the modeling and configuration classes (the tokenizer isn’t shown in this example).</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> nn | |
| <span class="hljs-keyword">from</span> ..bert.configuration_bert <span class="hljs-keyword">import</span> BertConfig | |
| <span class="hljs-keyword">from</span> ..bert.modeling_bert <span class="hljs-keyword">import</span> ( | |
| BertModel, | |
| BertEmbeddings, | |
| BertForMaskedLM | |
| ) | |
| <span class="hljs-comment"># RoBERTa and BERT config is identical</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">RobertaConfig</span>(<span class="hljs-title class_ inherited__">BertConfig</span>): | |
| model_type = <span class="hljs-string">'roberta'</span> | |
| <span class="hljs-comment"># Redefine the embeddings to highlight the padding id difference, and redefine the position embeddings</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">RobertaEmbeddings</span>(<span class="hljs-title class_ inherited__">BertEmbeddings</span>): | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config</span>): | |
| <span class="hljs-built_in">super</span>().__init__(config()) | |
| self.padding_idx = config.pad_token_id | |
| self.position_embeddings = nn.Embedding( | |
| config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx | |
| ) | |
| <span class="hljs-comment"># RoBERTa and BERT model is identical except for the embedding layer, which is defined above, so no need for additional changes here</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">RobertaModel</span>(<span class="hljs-title class_ inherited__">BertModel</span>): | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config</span>): | |
| <span class="hljs-built_in">super</span>().__init__(config) | |
| self.embeddings = RobertaEmbeddings(config) | |
| <span class="hljs-comment"># The model heads now only need to redefine the model inside to `RobertaModel`</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">RobertaForMaskedLM</span>(<span class="hljs-title class_ inherited__">BertForMaskedLM</span>): | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config</span>): | |
| <span class="hljs-built_in">super</span>().__init__(config) | |
| self.model = RobertaModel(config)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hj80v9">If you don’t use the defined dependency, you’ll receive the following error.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->ValueError: You defined `RobertaEmbeddings` in the modular_roberta.py, it should be used when you define `BertModel`, as it is one of it's direct dependencies. Make sure you use it in the `__init__` function.<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="implementing-a-modular-file" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#implementing-a-modular-file"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Implementing a modular file</span></h2> <p data-svelte-h="svelte-jkrdf9">The easiest way to start is by browsing Transformers for a model similar to yours in order to inherit from it. Some good starting points are <a href="./model_doc/mistral">Mistral</a>, <a href="./model_doc/qwen2">Qwen2</a>, <a href="./model_doc/cohere">Cohere</a> and <a href="./model_doc/cohere2">Cohere2</a>, and <a href="./model_doc/llama">Llama</a>. Refer to the table below for components your model might be using and where you can inherit from.</p> <table data-svelte-h="svelte-smzi7r"><thead><tr><th>Component</th> <th>Model</th></tr></thead> <tbody><tr><td>Mixture of expert</td> <td>SwitchTransformers or Mixtral</td></tr> <tr><td>Interleaved (and/or partial) rotary embedding</td> <td>GLM, Phi</td></tr> <tr><td>State space models</td> <td>Jamba, Bamba, Zamba, Mamba2</td></tr> <tr><td>Recurrent hidden states</td> <td>Gemma2</td></tr> <tr><td>Sliding window attention/full attention patterns per layer</td> <td>Gemma2, Cohere2</td></tr> <tr><td>QKV clipping</td> <td>Olmo</td></tr> <tr><td>QK normalization</td> <td>Olmo2, Cohere</td></tr> <tr><td>Fused QKV (not recommended)</td> <td>Phi3</td></tr></tbody></table> <p data-svelte-h="svelte-vpi0ku">This section will walk you through how to implement <a href="./model_doc/olmo2">Olmo2</a> from <a href="./model_doc/olmo">Olmo</a> with modular Transformers (you can refer to the original <a href="https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modular_olmo2.py" rel="nofollow">modeling.py</a> file).</p> <h3 class="relative group"><a id="config" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#config"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Config</span></h3> <p data-svelte-h="svelte-1yjd5s7">The modular <code>Olmo2Config</code> is shown below.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> ..olmo.configuration_olmo <span class="hljs-keyword">import</span> OlmoConfig | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">Olmo2Config</span>(<span class="hljs-title class_ inherited__">OlmoConfig</span>): | |
| <span class="hljs-string">r""" | |
| This is the configuration class to store the configuration of a [Olmo2Model](/docs/transformers/main/en/model_doc/olmo2#transformers.Olmo2Model). | |
| """</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params"> | |
| self, | |
| vocab_size=<span class="hljs-number">50304</span>, | |
| hidden_size=<span class="hljs-number">4096</span>, | |
| intermediate_size=<span class="hljs-number">11008</span>, | |
| num_hidden_layers=<span class="hljs-number">32</span>, | |
| num_attention_heads=<span class="hljs-number">32</span>, | |
| num_key_value_heads=<span class="hljs-literal">None</span>, | |
| hidden_act=<span class="hljs-string">"silu"</span>, | |
| max_position_embeddings=<span class="hljs-number">2048</span>, | |
| initializer_range=<span class="hljs-number">0.02</span>, | |
| use_cache=<span class="hljs-literal">True</span>, | |
| pad_token_id=<span class="hljs-number">1</span>, | |
| bos_token_id=<span class="hljs-literal">None</span>, | |
| eos_token_id=<span class="hljs-number">50279</span>, | |
| tie_word_embeddings=<span class="hljs-literal">False</span>, | |
| rope_theta=<span class="hljs-number">10000.0</span>, | |
| rope_scaling=<span class="hljs-literal">None</span>, | |
| attention_bias=<span class="hljs-literal">False</span>, | |
| attention_dropout=<span class="hljs-number">0.0</span>, | |
| rms_norm_eps=<span class="hljs-number">1e-5</span>, | |
| **kwargs, | |
| </span>): | |
| <span class="hljs-built_in">super</span>().__init__( | |
| vocab_size=vocab_size, | |
| hidden_size=hidden_size, | |
| intermediate_size=intermediate_size, | |
| num_hidden_layers=num_hidden_layers, | |
| num_attention_heads=num_attention_heads, | |
| num_key_value_heads=num_key_value_heads, | |
| hidden_act=hidden_act, | |
| max_position_embeddings=max_position_embeddings, | |
| initializer_range=initializer_range, | |
| use_cache=use_cache, | |
| pad_token_id=pad_token_id, | |
| bos_token_id=bos_token_id, | |
| eos_token_id=eos_token_id, | |
| tie_word_embeddings=tie_word_embeddings, | |
| rope_theta=rope_theta, | |
| rope_scaling=rope_scaling, | |
| attention_bias=attention_bias, | |
| attention_dropout=attention_dropout, | |
| **kwargs, | |
| ) | |
| self.rms_norm_eps = rms_norm_eps | |
| <span class="hljs-keyword">del</span> self.clip_qkv<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-q0b1mo">There are three points where the <code>Olmo2Config</code> is different from the original <code>OlmoConfig</code>.</p> <ol data-svelte-h="svelte-139or2s"><li>The default value of most arguments have changed.</li> <li>There is a new argument, <code>rms_norm_eps</code>.</li> <li>The <code>clip_qkv</code> argument isn’t used anymore.</li></ol> <p data-svelte-h="svelte-3ks7rn">For the new default values and argument, overwrite the <code>__init__</code> function with the new default values and add <code>rms_norm_eps</code>. Assign <code>rms_norm_eps</code> to <code>self</code> in the body of <code>__init__</code>. For the <code>clip_qkv</code> argument, use <code>del self.clip_qkv</code> to remove the assignment of this attribute in the unraveled code (post-linter conversion).</p> <p data-svelte-h="svelte-11wms29">Notice how the <code>super().__init__(...)</code> is used. Typically, it calls the parent <code>__init__</code>.</p> <p data-svelte-h="svelte-yznf21">But in modular Transformers, if there is a call like <code>super().my_function(...)</code>, the linter takes the body of <code>my_function</code> in the parent and unravels it where the call to <code>super().my_function(...)</code> occurred. The <code>del self.clip_qkv</code> statement removes the reference to <code>self.clip_qkv</code> in the unraveled body.</p> <p data-svelte-h="svelte-1i7o0i6"><code>del self.</code> and <code>super().my_function(..)</code> work together, and it should always be placed after <code>super().my_function(...)</code>. You can add whatever you want <em>before</em> calling <code>super()</code>, and it is placed before the parents body.</p> <h3 class="relative group"><a id="norm" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#norm"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Norm</span></h3> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> ..llama.modeling_llama <span class="hljs-keyword">import</span> LlamaRMSNorm | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">Olmo2RMSNorm</span>(<span class="hljs-title class_ inherited__">LlamaRMSNorm</span>): | |
| <span class="hljs-keyword">pass</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-12zqvql">Nothing needs to be modified in <code>LlamaRMSNorm</code>. The linter unravels the exact content of <code>LlamaRMSNorm</code> into <code>Olmo2RMSNorm</code>. References to Llama in the docstrings, type hints, and comments are also changed to Olmo2.</p> <h3 class="relative group"><a id="attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Attention</span></h3> <p data-svelte-h="svelte-suxygn">The modular <code>Olmo2Attention</code> is shown below.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> ..llama.modeling_llama <span class="hljs-keyword">import</span> eager_attention_forward | |
| <span class="hljs-keyword">from</span> ..olmo.modeling_olmo <span class="hljs-keyword">import</span> OlmoAttention, apply_rotary_pos_emb | |
| <span class="hljs-comment"># Olmo2 attention is identical to OLMo attention except:</span> | |
| <span class="hljs-comment"># - Norm is applied to attention queries and keys.</span> | |
| <span class="hljs-comment"># - No qkv clipping.</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">Olmo2Attention</span>(<span class="hljs-title class_ inherited__">OlmoAttention</span>): | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config: Olmo2Config, layer_idx: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">int</span>] = <span class="hljs-literal">None</span></span>): | |
| <span class="hljs-built_in">super</span>().__init__(config, layer_idx=layer_idx) | |
| self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) | |
| self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params"> | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: <span class="hljs-built_in">tuple</span>[torch.Tensor, torch.Tensor], | |
| attention_mask: <span class="hljs-type">Optional</span>[torch.Tensor], | |
| past_key_values: <span class="hljs-type">Optional</span>[Cache] = <span class="hljs-literal">None</span>, | |
| cache_position: <span class="hljs-type">Optional</span>[torch.LongTensor] = <span class="hljs-literal">None</span>, | |
| **kwargs, | |
| </span>) -> <span class="hljs-built_in">tuple</span>[torch.Tensor, <span class="hljs-type">Optional</span>[torch.Tensor], <span class="hljs-type">Optional</span>[<span class="hljs-built_in">tuple</span>[torch.Tensor]]]: | |
| input_shape = hidden_states.shape[:-<span class="hljs-number">1</span>] | |
| hidden_shape = (*input_shape, -<span class="hljs-number">1</span>, self.head_dim) | |
| query_states = self.q_norm(self.q_proj(hidden_states)) | |
| key_states = self.k_norm(self.k_proj(hidden_states)) | |
| value_states = self.v_proj(hidden_states) | |
| query_states = query_states.view(hidden_shape).transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>) | |
| key_states = key_states.view(hidden_shape).transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>) | |
| value_states = value_states.view(hidden_shape).transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>) | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
| <span class="hljs-keyword">if</span> past_key_values <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>: | |
| <span class="hljs-comment"># sin and cos are specific to RoPE models; cache_position needed for the static cache</span> | |
| cache_kwargs = {<span class="hljs-string">"sin"</span>: sin, <span class="hljs-string">"cos"</span>: cos, <span class="hljs-string">"cache_position"</span>: cache_position} | |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
| attention_interface: <span class="hljs-type">Callable</span> = eager_attention_forward | |
| <span class="hljs-keyword">if</span> self.config._attn_implementation != <span class="hljs-string">"eager"</span>: | |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | |
| attn_output, attn_weights = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| dropout=<span class="hljs-number">0.0</span> <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> self.training <span class="hljs-keyword">else</span> self.attention_dropout, | |
| scaling=self.scaling, | |
| **kwargs, | |
| ) | |
| attn_output = attn_output.reshape(*input_shape, -<span class="hljs-number">1</span>).contiguous() | |
| attn_output = self.o_proj(attn_output) | |
| <span class="hljs-keyword">return</span> attn_output, attn_weights<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1k0gkac">The <code>super().__init__(...)</code> copies the parent definition and adds 2 new layers from <code>Olmo2RMSNorm</code>. The forward pass needs to be overwritten to use these 2 new layers. A pass with the norm layers is added before projecting with <code>q_proj</code> and <code>k_proj</code>. To make it easier, the <code>eager_attention_forward</code> function is directly imported from Llama and the <code>apply_rotary_pos_emb</code> is imported from Olmo.</p> <p data-svelte-h="svelte-10o1ggb">The linter automatically adds these imported functions in the final <code>modeling_olmo2.py</code> file by copying their definitions from the source files. The <code>rotate_half</code> and <code>repeat_kv</code> functions are also added because they are used inside <code>apply_rotary_pos_emb</code> and <code>eager_attention_forward</code>.</p> <p data-svelte-h="svelte-1y0dfl">The <code>Attention</code> class had to be redefined because there weren’t any existing models with an <code>Attention</code> layer that included a <code>RMSNorm</code> layer.</p> <h3 class="relative group"><a id="decoderlayer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#decoderlayer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>DecoderLayer</span></h3> <p data-svelte-h="svelte-jueczp">The modular <code>DecoderLayer</code> is shown below.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> ..olmo.modeling_olmo <span class="hljs-keyword">import</span> OlmoDecoderLayer | |
| <span class="hljs-comment"># The OLMo2 layers are identical to those of the OLMo model except:</span> | |
| <span class="hljs-comment"># - RMSNorm is used instead of standard layer norm.</span> | |
| <span class="hljs-comment"># - Norm is applied after attention/feedforward rather than before.</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">Olmo2DecoderLayer</span>(<span class="hljs-title class_ inherited__">OlmoDecoderLayer</span>): | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config: Olmo2Config, layer_idx: <span class="hljs-built_in">int</span></span>): | |
| <span class="hljs-built_in">super</span>().__init__(config, layer_idx=layer_idx) | |
| self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) | |
| <span class="hljs-keyword">del</span> self.input_layernorm | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params"> | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: <span class="hljs-type">Optional</span>[torch.Tensor] = <span class="hljs-literal">None</span>, | |
| position_ids: <span class="hljs-type">Optional</span>[torch.LongTensor] = <span class="hljs-literal">None</span>, | |
| past_key_values: <span class="hljs-type">Optional</span>[Cache] = <span class="hljs-literal">None</span>, | |
| output_attentions: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">False</span>, | |
| use_cache: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">False</span>, | |
| cache_position: <span class="hljs-type">Optional</span>[torch.LongTensor] = <span class="hljs-literal">None</span>, | |
| position_embeddings: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">tuple</span>[torch.Tensor, torch.Tensor]] = <span class="hljs-literal">None</span>, | |
| **kwargs, | |
| </span>) -> <span class="hljs-built_in">tuple</span>[torch.FloatTensor, <span class="hljs-type">Optional</span>[<span class="hljs-built_in">tuple</span>[torch.FloatTensor, torch.FloatTensor]]]: | |
| residual = hidden_states | |
| <span class="hljs-comment"># Self Attention</span> | |
| hidden_states, self_attn_weights = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| position_embeddings=position_embeddings, | |
| **kwargs, | |
| ) | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = residual + hidden_states | |
| <span class="hljs-comment"># Fully Connected</span> | |
| residual = hidden_states | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = self.post_feedforward_layernorm(hidden_states) | |
| hidden_states = residual + hidden_states | |
| outputs = (hidden_states,) | |
| <span class="hljs-keyword">if</span> output_attentions: | |
| outputs += (self_attn_weights,) | |
| <span class="hljs-keyword">return</span> outputs<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1grcsyb">The norm type is switched in <code>__init__</code> by overwriting <code>self.post_attention_layernorm</code> after the call to <code>super().__init__(...)</code>. Delete the <code>self.input_layernorm</code> attributed and replace it with <code>self.post_feedforward_layernorm</code> because it is applied after in Olmo2. The forward method is overwritten to reflect this change.</p> <p data-svelte-h="svelte-krborp">If you only switched <code>self.post_feedforward_layernorm</code> and <code>self.input_layernorm</code> from <code>LayerNorm</code> to <code>RMSNorm</code> without also changing the name and logic of <code>self.input_layernorm</code>, then you wouldn’t have to rewrite the forward method.</p> <h3 class="relative group"><a id="model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model</span></h3> <p data-svelte-h="svelte-12qws02">The modular <code>Olmo2Model</code> class is shown below.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> ..olmo.modeling_olmo <span class="hljs-keyword">import</span> OlmoModel | |
| <span class="hljs-comment"># The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of</span> | |
| <span class="hljs-comment"># standard layer norm for the output norm.</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">Olmo2Model</span>(<span class="hljs-title class_ inherited__">OlmoModel</span>): | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config: Olmo2Config</span>): | |
| <span class="hljs-built_in">super</span>().__init__(config) | |
| self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.layers = nn.ModuleList( | |
| [Olmo2DecoderLayer(config, layer_idx) <span class="hljs-keyword">for</span> layer_idx <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(config.num_hidden_layers)] | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-2efkpk">You only need to change the <em>type</em> of the <code>self.norm</code> attribute to use <code>RMSNorm</code> instead of <code>LayerNorm</code>. This change doesn’t affect the logic in the forward method (layer name and usage is identical to the parent class), so you don’t need to overwrite it. The linter automatically unravels it.</p> <h3 class="relative group"><a id="model-head" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model-head"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model head</span></h3> <p data-svelte-h="svelte-111k955">The modular causal modeling head is shown below.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> ..olmo.modeling_olmo <span class="hljs-keyword">import</span> OlmoForCausalLM | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">Olmo2ForCausalLM</span>(<span class="hljs-title class_ inherited__">OlmoForCausalLM</span>): | |
| <span class="hljs-keyword">pass</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ue8kve">The logic is identical to <code>OlmoForCausalLM</code> which means you don’t need to make any changes here.</p> <h3 class="relative group"><a id="other-classes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#other-classes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Other classes</span></h3> <p data-svelte-h="svelte-j99j4d">The <a href="https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py" rel="nofollow">modeling_olmo2.py</a> generated by the linter also contains some classes (<code>Olmo2MLP</code>, <code>Olmo2RotaryEmbedding</code>, <code>Olmo2PreTrainedModel</code>) that weren’t explicitly defined in <code>modular_olmo2.py</code>.</p> <p data-svelte-h="svelte-1mv6aiq">Classes that are a dependency of an inherited class but aren’t explicitly defined are automatically added as a part of dependency tracing. This is similar to how some functions were added to the <code>Attention</code> class without directly importing them.</p> <p data-svelte-h="svelte-li8sd3">For example, <code>OlmoDecoderLayer</code> has an attribute defined as <code>self.mlp = OlmoMLP(config)</code>. This class was never explicitly redefined in <code>Olmo2MLP</code>, so the linter automatically created a <code>Olmo2MLP</code> class similar to <code>OlmoMLP</code>. It is identical to the code below if it was explicitly written in <code>modular_olmo2.py</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> ..olmo.modeling_olmo <span class="hljs-keyword">import</span> OlmoMLP | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">Olmo2MLP</span>(<span class="hljs-title class_ inherited__">OlmoMLP</span>): | |
| <span class="hljs-keyword">pass</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1qoxxsw">However, it was necessary to rewrite <code>Olmo2RMSNorm</code> because the layer norm needed to be redefined in the <code>Attention</code> and <code>DecoderLayer</code> classes. Similarly, this is why you didn’t need to create the <code>Olmo2PreTrainedModel</code> and <code>Olmo2RotaryEmbedding</code> classes.</p> <p data-svelte-h="svelte-16zifu2">Classes that aren’t rewritten are copied from the file where the inherited module first uses them. This means if you wanted <code>Olmo2MLP</code> to inherit from <code>MistralMLP</code> instead, you would need to be more explicit as shown below.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># switch to mistral definition</span> | |
| <span class="hljs-keyword">from</span> ..mistral.modeling_mistral <span class="hljs-keyword">import</span> MistralMLP | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">Olmo2MLP</span>(<span class="hljs-title class_ inherited__">MistralMLP</span>): | |
| <span class="hljs-keyword">pass</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="removing-attributes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#removing-attributes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Removing attributes</span></h2> <p data-svelte-h="svelte-1reud8e">You can <code>del</code> to remove attributes defined in the parent after using <code>super().__init__()</code>. However, this doesn’t work if the attribute is also used somewhere else as shown below. It only suppresses the assignment. The <code>self.attribute = config.attribute</code> line is removed, but the <code>if</code> statement remains and references the attribute.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">DummyModel</span>(nn.Module): | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config: DummyConfig</span>): | |
| <span class="hljs-built_in">super</span>().__init__() | |
| self.attribute = config.attribute | |
| <span class="hljs-keyword">if</span> self.attribute: | |
| <span class="hljs-comment"># do more stuff with `self.attribute` here</span> | |
| ... | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">MyNewDummyModel</span>(<span class="hljs-title class_ inherited__">DummyModel</span>): | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config: MyNewDummyConfig</span>): | |
| <span class="hljs-built_in">super</span>().__init__(config) | |
| <span class="hljs-keyword">del</span> self.attribute<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="calling-parent-methods-without-unravelling-their-definition" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#calling-parent-methods-without-unravelling-their-definition"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Calling parent methods without unravelling their definition</span></h2> <p data-svelte-h="svelte-1nslgmc">If you want to inherit from a module <code>DummyModule</code> and want to call <code>super()</code> WITHOUT unravelling the parent’s code (that is, you want to call <code>super()</code> on the <em>generated</em> class parent), be explicit about which class’ <code>super()</code> you’re calling. The example below shows how to call the <code>super()</code> of <code>nn.Module</code> (unraveled code shown on the right). In this example, as <code>DummyModule</code> is itself a <code>nn.Module</code>, it makes sense to call <code>nn.Module.__init__(self)</code> as it’s what was the initial intention. It’s then unravelled as <code>super()</code> in <code>MyNewDummyModule</code> to follow Python’s best-practices.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">MyNewDummyModule</span>(<span class="hljs-title class_ inherited__">DummyModule</span>): | <span class="hljs-keyword">class</span> <span class="hljs-title class_">MyNewDummyModule</span>(nn.Module): | |
| | | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self</span>): | <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self</span>): | |
| nn.Module.__init__(self) | <span class="hljs-built_in">super</span>().__init__() | |
| self.foo = config.foo | self.foo = config.foo | |
| ... | ...<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="deleting-unused-methods" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#deleting-unused-methods"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Deleting unused methods</span></h2> <p data-svelte-h="svelte-ho137b">Remove an attribute by overwriting it with a <code>raise AttributeError("")</code> statement to mimic the behavior you want when you remove a parent function in Python. The example below removes the methods in the unraveled code.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">GemmaTokenizer</span>(<span class="hljs-title class_ inherited__">LlamaTokenizer</span>): | |
| ... | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">get_spm_processor</span>(<span class="hljs-params">self</span>): | |
| <span class="hljs-keyword">raise</span> AttributeError(<span class="hljs-string">"Not needed for Gemma"</span>) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">unk_token_length</span>(<span class="hljs-params">self</span>): | |
| <span class="hljs-keyword">raise</span> AttributeError(<span class="hljs-string">"Not needed for Gemma"</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="defining-new-functions" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#defining-new-functions"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Defining new functions</span></h2> <p data-svelte-h="svelte-mwxhil">By default, if you inherit from a class and override a method with one or more decorators in the parent method, the decorators are also added to the unraveled code <em>only if you don’t add any yourself</em>. Otherwise, the redefined decorator is used.</p> <p data-svelte-h="svelte-8g3zpp">For example, if you had a parent class shown below and you overwrite it, the parent decorator is kept.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">DummyModel</span>(nn.Module): | |
| ... | |
| <span class="hljs-meta"> @decorator(<span class="hljs-params">...</span>)</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">...</span>) | |
| <span class="hljs-comment"># do stuff here</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-hkv2ct">Modular code is shown on the left, and the unraveled code is shown on the right.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">NewModel</span>(<span class="hljs-title class_ inherited__">DummyModel</span>): | <span class="hljs-keyword">class</span> <span class="hljs-title class_">NewModel</span>(nn.Module): | |
| ... | ... | |
| | | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">...</span>): | @decorator(...) | |
| ... | <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">...</span>): | |
| | ...<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1b9k8jx">But if you add a new decorator, your new decorator is used instead.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">NewModel</span>(<span class="hljs-title class_ inherited__">DummyModel</span>): | <span class="hljs-keyword">class</span> <span class="hljs-title class_">NewModel</span>(nn.Module): | |
| ... | ... | |
| | | |
| <span class="hljs-meta"> @my_new_decorator(<span class="hljs-params">...</span>) | @my_new_decorator(<span class="hljs-params">...</span>)</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">...</span>): | <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">...</span>): | |
| ... | ...<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="superkwargs" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#superkwargs"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>super_kwargs</span></h2> <p data-svelte-h="svelte-ce3ins">In scenarios where a forward method is really long and you want to switch decorators, you don’t need to redefine everything and copy/paste the function. You can use <code>super().forward(...)</code> to unravel the parent body. When there are a lot of arguments in the function signature, use the special <code>**super_kwargs</code> syntax in the overwritten signature.</p> <p data-svelte-h="svelte-3ynt33">This syntax indicates to the linter to unravel all the parent signature arguments here. An example signature in a <a href="/docs/transformers/pr_33892/en/model_doc/auto#transformers.AutoModelForCausalLM">AutoModelForCausalLM</a> model is shown below, with lots of arguments.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">LlamaForCausalLM</span>(nn.Module): | |
| ... | |
| <span class="hljs-meta"> @add_start_docstrings_to_model_forward(<span class="hljs-params">LLAMA_INPUTS_DOCSTRING</span>)</span> | |
| <span class="hljs-meta"> @replace_return_docstrings(<span class="hljs-params">output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC</span>)</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params"> | |
| self, | |
| input_ids: torch.LongTensor = <span class="hljs-literal">None</span>, | |
| attention_mask: <span class="hljs-type">Optional</span>[torch.Tensor] = <span class="hljs-literal">None</span>, | |
| position_ids: <span class="hljs-type">Optional</span>[torch.LongTensor] = <span class="hljs-literal">None</span>, | |
| past_key_values: <span class="hljs-type">Optional</span>[Cache] = <span class="hljs-literal">None</span>, | |
| inputs_embeds: <span class="hljs-type">Optional</span>[torch.FloatTensor] = <span class="hljs-literal">None</span>, | |
| labels: <span class="hljs-type">Optional</span>[torch.LongTensor] = <span class="hljs-literal">None</span>, | |
| use_cache: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">None</span>, | |
| output_attentions: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">None</span>, | |
| output_hidden_states: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">None</span>, | |
| return_dict: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">None</span>, | |
| cache_position: <span class="hljs-type">Optional</span>[torch.LongTensor] = <span class="hljs-literal">None</span>, | |
| num_logits_to_keep: <span class="hljs-built_in">int</span> = <span class="hljs-number">0</span>, | |
| **kwargs: Unpack[KwargsForCausalLM], | |
| </span>) -> <span class="hljs-type">Union</span>[<span class="hljs-type">Tuple</span>, CausalLMOutputWithPast]: | |
| ...<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-161qf9e">Instead of rewriting and copying/pasting all of those arguments, use the <code>super().forward(**super_kwargs)</code> statement (modular code shown on the left, unraveled code on the right).</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">NewModelForCausalLM</span>(<span class="hljs-title class_ inherited__">LlamaForCausalLM</span>): | <span class="hljs-keyword">class</span> <span class="hljs-title class_">LlamaForCausalLM</span>(nn.Module): | |
| ... | ... | |
| | | |
| <span class="hljs-meta"> @my_new_decorator | @my_new_decorator</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, **super_kwargs</span>): | <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params"> | |
| <span class="hljs-built_in">super</span>(<span class="hljs-params"></span>).forward(<span class="hljs-params">**super_kwargs</span>) | self, | |
| | input_ids: torch.LongTensor = <span class="hljs-literal">None</span>, | |
| | attention_mask: <span class="hljs-type">Optional</span>[torch.Tensor] = <span class="hljs-literal">None</span>, | |
| | position_ids: <span class="hljs-type">Optional</span>[torch.LongTensor] = <span class="hljs-literal">None</span>, | |
| | past_key_values: <span class="hljs-type">Optional</span>[Cache] = |<span class="hljs-literal">None</span>, | |
| | inputs_embeds: <span class="hljs-type">Optional</span>[torch.FloatTensor] = <span class="hljs-literal">None</span>, | |
| | labels: <span class="hljs-type">Optional</span>[torch.LongTensor] = <span class="hljs-literal">None</span>, | |
| | use_cache: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">None</span>, | |
| | output_attentions: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">None</span>, | |
| | output_hidden_states: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">None</span>, | |
| | return_dict: <span class="hljs-type">Optional</span>[<span class="hljs-built_in">bool</span>] = <span class="hljs-literal">None</span>, | |
| | cache_position: <span class="hljs-type">Optional</span>[torch.LongTensor] = <span class="hljs-literal">None</span>, | |
| | num_logits_to_keep: <span class="hljs-built_in">int</span> = <span class="hljs-number">0</span>, | |
| | **kwargs: Unpack[KwargsForCausalLM], | |
| | </span>) -> <span class="hljs-type">Union</span>[<span class="hljs-type">Tuple</span>, CausalLMOutputWithPast]: | |
| | ...<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1w35umg">This makes it very easy to switch decorators and makes it explicit that the only change you want to apply is the decorator.</p> <p data-svelte-h="svelte-1x7kbin"><code>**super_kwargs</code> should not be used to avoid being explicit when redefining methods though. If you overwrite a method, you should explicitly write the signature as you normally would. The <code>**super_kwargs</code> syntax is a shortcut for switching decorators and a few other niche cases.</p> <h2 class="relative group"><a id="docstring-variables" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#docstring-variables"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Docstring variables</span></h2> <blockquote class="tip" data-svelte-h="svelte-1hpj5n3"><p>Refer to the <a href="./auto_docstring">Documeting a model</a> guide for more information about how you can use the <code>@auto_docstring</code> decorator to help automatically generate consistent docstring arguments.</p></blockquote> <p data-svelte-h="svelte-dnuz94">If an object defined in both the modular and modeling file from which it inherits, the modular definition has precedence unless for assignments containing the pattern <code>DOCSTRING</code>. These variables are typically used in <code>MODEL_START_DOCSTRING</code> and <code>MODEL_INPUT_DOCSTRING</code> in the modeling files. They are big blocks of docstrings and the linter rewrites the names everywhere. For this reason, assignments containing the <code>DOCSTRING</code> variable can use the definition found in the source file without copying the whole docstring, by simply setting the variable to <code>None</code> in the modular file.</p> <p data-svelte-h="svelte-1a2wp3o">This is very useful if you need the variable reference somewhere but you don’t want to clutter the modular file with docstrings which are always the same. The example code below allows you to automatically use the same docstrings from <a href="./model_doc/mistral">Mistral</a> in <a href="./model_doc/starcoder2">Starcoder2</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->STARCODER2_INPUTS_DOCSTRING = <span class="hljs-literal">None</span> <span class="hljs-comment"># will be automatically redefined</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">Starcoder2Model</span>(<span class="hljs-title class_ inherited__">MistralModel</span>): | |
| ... | |
| <span class="hljs-meta"> @add_start_docstrings_to_model_forward(<span class="hljs-params">STARCODER2_INPUTS_DOCSTRING</span>)</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">...</span>) | |
| ...<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zwpgdq">Setting the variable to anything other than <code>None</code> will override the docstring, so that you can customize the docstrings if needed.</p> <h2 class="relative group"><a id="special-naming" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#special-naming"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Special naming</span></h2> <p data-svelte-h="svelte-1r96udu">The linter automatically renames everything when inheriting from a class. For consistency, you should always use the same class name prefix when inheriting from different classes from the same file.</p> <p data-svelte-h="svelte-1w8gdvq">The example below is not recommended. It breaks standards in the library, <code>MyModelIncredibleMLP</code> instead of <code>LlamaMLP</code>, and because the linter doesn’t know how to rename potential higher-order dependencies (<code>MyModelIncredible</code> or just <code>MyModel</code>).</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">MyModelIncredibleMLP</span>(<span class="hljs-title class_ inherited__">LlamaMLP</span>): | |
| ... | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">MyModelDecoderLayer</span>(<span class="hljs-title class_ inherited__">LlamaDecoderLayer</span>): | |
| ...<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ce2erv">However, if there aren’t any <a href="#other-classes">implicit dependencies</a>, then you can locally rename a single class. Make sure you still explicitly redefine every other mention of the class with the new name pattern though. For example, all mentions of <code>LlamaMLP</code> should be renamed to <code>MyModelIncredibleMLP</code> otherwise the linter may add a new and unwanted <code>MyModelMLP</code> class.</p> <p data-svelte-h="svelte-1n8wdp0">The linter raises a warning if an ambiguous case is detected. It explains what is happening and which prefix is used by default for getting the dependencies. These warning and renaming pattern complications usually only come up when defining multimodal models. For example, adding <code>Text</code> to class names in a multimodal model to make it clear which modality it refers to.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->We detected multiple prefix names when inheriting <span class="hljs-keyword">from</span> transformers.models.llama.modeling_llama: (<span class="hljs-string">'Emu3Text'</span>, <span class="hljs-string">'Emu3'</span>). We will only use the most used <span class="hljs-string">'Emu3'</span> prefix when grabbing args <span class="hljs-keyword">and</span> dependencies. Make sure to subclass the intermediate classes <span class="hljs-keyword">with</span> the prefix you want (<span class="hljs-keyword">if</span> different <span class="hljs-keyword">from</span> <span class="hljs-string">'Emu3'</span>) <span class="hljs-keyword">or</span> use a single prefix <span class="hljs-keyword">in</span> <span class="hljs-built_in">all</span> the modular (best).<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3dj83y">If there are automatic dependencies with a prefix, but you want another one, explicitly rename the classes locally with a <code>pass</code> class as shown in the following.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">Emu3TextMLP</span>(<span class="hljs-title class_ inherited__">LlamaMLP</span>): | |
| <span class="hljs-keyword">pass</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="config-docstrings" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#config-docstrings"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Config docstrings</span></h2> <p data-svelte-h="svelte-msrgaf">When inheriting a <code>Config</code> class or adding and deleting attributes, you may want to only redefine the new attributes in the docstring. However, the linter doesn’t support this yet. You need to directly add the while docstring directly in the modular file under the class definition.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/modular_transformers.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_16tnnm8 = { | |
| assets: "/docs/transformers/pr_33892/en", | |
| base: "/docs/transformers/pr_33892/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"), | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 497], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 108 kB
- Xet hash:
- 9ee50758066e83ba9a9247cd9ffd6eedffd4120e7b8bf08691ae2ad62d594b03
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.