Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Building custom models","local":"building-custom-models","sections":[{"title":"Writing a custom configuration","local":"writing-a-custom-configuration","sections":[],"depth":2},{"title":"Writing a custom model","local":"writing-a-custom-model","sections":[],"depth":2},{"title":"Registering a model with custom code to the auto classes","local":"registering-a-model-with-custom-code-to-the-auto-classes","sections":[],"depth":2},{"title":"Sending the code to the Hub","local":"sending-the-code-to-the-hub","sections":[],"depth":2},{"title":"Using a model with custom code","local":"using-a-model-with-custom-code","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_33913/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/scheduler.25b97de1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/singletons.62a184e0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.e188933d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/paths.51881b9e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.d9030fc9.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/0.05e395f5.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/17.c25c2c53.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/Tip.baa67368.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/EditOnGithub.91d95064.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Building custom models","local":"building-custom-models","sections":[{"title":"Writing a custom configuration","local":"writing-a-custom-configuration","sections":[],"depth":2},{"title":"Writing a custom model","local":"writing-a-custom-model","sections":[],"depth":2},{"title":"Registering a model with custom code to the auto classes","local":"registering-a-model-with-custom-code-to-the-auto-classes","sections":[],"depth":2},{"title":"Sending the code to the Hub","local":"sending-the-code-to-the-hub","sections":[],"depth":2},{"title":"Using a model with custom code","local":"using-a-model-with-custom-code","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="building-custom-models" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#building-custom-models"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Building custom models</span></h1> <p data-svelte-h="svelte-1bvx2nz">The 🤗 Transformers library is designed to be easily extensible. Every model is fully coded in a given subfolder | |
| of the repository with no abstraction, so you can easily copy a modeling file and tweak it to your needs.</p> <p data-svelte-h="svelte-12q72ch">If you are writing a brand new model, it might be easier to start from scratch. In this tutorial, we will show you | |
| how to write a custom model and its configuration so it can be used inside Transformers, and how you can share it | |
| with the community (with the code it relies on) so that anyone can use it, even if it’s not present in the 🤗 | |
| Transformers library. We’ll see how to build upon transformers and extend the framework with your hooks and | |
| custom code.</p> <p data-svelte-h="svelte-d8wlob">We will illustrate all of this on a ResNet model, by wrapping the ResNet class of the | |
| <a href="https://github.com/rwightman/pytorch-image-models" rel="nofollow">timm library</a> into a <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel">PreTrainedModel</a>.</p> <h2 class="relative group"><a id="writing-a-custom-configuration" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#writing-a-custom-configuration"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Writing a custom configuration</span></h2> <p data-svelte-h="svelte-15o85x">Before we dive into the model, let’s first write its configuration. The configuration of a model is an object that | |
| will contain all the necessary information to build the model. As we will see in the next section, the model can only | |
| take a <code>config</code> to be initialized, so we really need that object to be as complete as possible.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1qp2x4n">Models in the <code>transformers</code> library itself generally follow the convention that they accept a <code>config</code> object | |
| in their <code>__init__</code> method, and then pass the whole <code>config</code> to sub-layers in the model, rather than breaking the | |
| config object into multiple arguments that are all passed individually to sub-layers. Writing your model in this | |
| style results in simpler code with a clear “source of truth” for any hyperparameters, and also makes it easier | |
| to reuse code from other models in <code>transformers</code>.</p></div> <p data-svelte-h="svelte-1u8fbq1">In our example, we will take a couple of arguments of the ResNet class that we might want to tweak. Different | |
| configurations will then give us the different types of ResNets that are possible. We then just store those arguments, | |
| after checking the validity of a few of them.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> PretrainedConfig | |
| <span class="hljs-keyword">from</span> typing <span class="hljs-keyword">import</span> <span class="hljs-type">List</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">ResnetConfig</span>(<span class="hljs-title class_ inherited__">PretrainedConfig</span>): | |
| model_type = <span class="hljs-string">"resnet"</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params"> | |
| self, | |
| block_type=<span class="hljs-string">"bottleneck"</span>, | |
| layers: <span class="hljs-type">List</span>[<span class="hljs-built_in">int</span>] = [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>, <span class="hljs-number">6</span>, <span class="hljs-number">3</span>], | |
| num_classes: <span class="hljs-built_in">int</span> = <span class="hljs-number">1000</span>, | |
| input_channels: <span class="hljs-built_in">int</span> = <span class="hljs-number">3</span>, | |
| cardinality: <span class="hljs-built_in">int</span> = <span class="hljs-number">1</span>, | |
| base_width: <span class="hljs-built_in">int</span> = <span class="hljs-number">64</span>, | |
| stem_width: <span class="hljs-built_in">int</span> = <span class="hljs-number">64</span>, | |
| stem_type: <span class="hljs-built_in">str</span> = <span class="hljs-string">""</span>, | |
| avg_down: <span class="hljs-built_in">bool</span> = <span class="hljs-literal">False</span>, | |
| **kwargs, | |
| </span>): | |
| <span class="hljs-keyword">if</span> block_type <span class="hljs-keyword">not</span> <span class="hljs-keyword">in</span> [<span class="hljs-string">"basic"</span>, <span class="hljs-string">"bottleneck"</span>]: | |
| <span class="hljs-keyword">raise</span> ValueError(<span class="hljs-string">f"`block_type` must be 'basic' or bottleneck', got <span class="hljs-subst">{block_type}</span>."</span>) | |
| <span class="hljs-keyword">if</span> stem_type <span class="hljs-keyword">not</span> <span class="hljs-keyword">in</span> [<span class="hljs-string">""</span>, <span class="hljs-string">"deep"</span>, <span class="hljs-string">"deep-tiered"</span>]: | |
| <span class="hljs-keyword">raise</span> ValueError(<span class="hljs-string">f"`stem_type` must be '', 'deep' or 'deep-tiered', got <span class="hljs-subst">{stem_type}</span>."</span>) | |
| self.block_type = block_type | |
| self.layers = layers | |
| self.num_classes = num_classes | |
| self.input_channels = input_channels | |
| self.cardinality = cardinality | |
| self.base_width = base_width | |
| self.stem_width = stem_width | |
| self.stem_type = stem_type | |
| self.avg_down = avg_down | |
| <span class="hljs-built_in">super</span>().__init__(**kwargs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-9tm0xv">The three important things to remember when writing you own configuration are the following:</p> <ul data-svelte-h="svelte-1l985bz"><li>you have to inherit from <code>PretrainedConfig</code>,</li> <li>the <code>__init__</code> of your <code>PretrainedConfig</code> must accept any kwargs,</li> <li>those <code>kwargs</code> need to be passed to the superclass <code>__init__</code>.</li></ul> <p data-svelte-h="svelte-ttxucl">The inheritance is to make sure you get all the functionality from the 🤗 Transformers library, while the two other | |
| constraints come from the fact a <code>PretrainedConfig</code> has more fields than the ones you are setting. When reloading a | |
| config with the <code>from_pretrained</code> method, those fields need to be accepted by your config and then sent to the | |
| superclass.</p> <p data-svelte-h="svelte-9uxjnq">Defining a <code>model_type</code> for your configuration (here <code>model_type="resnet"</code>) is not mandatory, unless you want to | |
| register your model with the auto classes (see last section).</p> <p data-svelte-h="svelte-1229zpx">With this done, you can easily create and save your configuration like you would do with any other model config of the | |
| library. Here is how we can create a resnet50d config and save it:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->resnet50d_config = ResnetConfig(block_type=<span class="hljs-string">"bottleneck"</span>, stem_width=<span class="hljs-number">32</span>, stem_type=<span class="hljs-string">"deep"</span>, avg_down=<span class="hljs-literal">True</span>) | |
| resnet50d_config.save_pretrained(<span class="hljs-string">"custom-resnet"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hshoid">This will save a file named <code>config.json</code> inside the folder <code>custom-resnet</code>. You can then reload your config with the | |
| <code>from_pretrained</code> method:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->resnet50d_config = ResnetConfig.from_pretrained(<span class="hljs-string">"custom-resnet"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1s5i4tr">You can also use any other method of the <a href="/docs/transformers/pr_33913/en/main_classes/configuration#transformers.PretrainedConfig">PretrainedConfig</a> class, like <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.utils.PushToHubMixin.push_to_hub">push_to_hub()</a> to | |
| directly upload your config to the Hub.</p> <h2 class="relative group"><a id="writing-a-custom-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#writing-a-custom-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Writing a custom model</span></h2> <p data-svelte-h="svelte-qqc30j">Now that we have our ResNet configuration, we can go on writing the model. We will actually write two: one that | |
| extracts the hidden features from a batch of images (like <a href="/docs/transformers/pr_33913/en/model_doc/bert#transformers.BertModel">BertModel</a>) and one that is suitable for image | |
| classification (like <a href="/docs/transformers/pr_33913/en/model_doc/bert#transformers.BertForSequenceClassification">BertForSequenceClassification</a>).</p> <p data-svelte-h="svelte-13oyc7x">As we mentioned before, we’ll only write a loose wrapper of the model to keep it simple for this example. The only | |
| thing we need to do before writing this class is a map between the block types and actual block classes. Then the | |
| model is defined from the configuration by passing everything to the <code>ResNet</code> class:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> PreTrainedModel | |
| <span class="hljs-keyword">from</span> timm.models.resnet <span class="hljs-keyword">import</span> BasicBlock, Bottleneck, ResNet | |
| <span class="hljs-keyword">from</span> .configuration_resnet <span class="hljs-keyword">import</span> ResnetConfig | |
| BLOCK_MAPPING = {<span class="hljs-string">"basic"</span>: BasicBlock, <span class="hljs-string">"bottleneck"</span>: Bottleneck} | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">ResnetModel</span>(<span class="hljs-title class_ inherited__">PreTrainedModel</span>): | |
| config_class = ResnetConfig | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config</span>): | |
| <span class="hljs-built_in">super</span>().__init__(config) | |
| block_layer = BLOCK_MAPPING[config.block_type] | |
| self.model = ResNet( | |
| block_layer, | |
| config.layers, | |
| num_classes=config.num_classes, | |
| in_chans=config.input_channels, | |
| cardinality=config.cardinality, | |
| base_width=config.base_width, | |
| stem_width=config.stem_width, | |
| stem_type=config.stem_type, | |
| avg_down=config.avg_down, | |
| ) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, tensor</span>): | |
| <span class="hljs-keyword">return</span> self.model.forward_features(tensor)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-85b6xb">For the model that will classify images, we just change the forward method:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">ResnetModelForImageClassification</span>(<span class="hljs-title class_ inherited__">PreTrainedModel</span>): | |
| config_class = ResnetConfig | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config</span>): | |
| <span class="hljs-built_in">super</span>().__init__(config) | |
| block_layer = BLOCK_MAPPING[config.block_type] | |
| self.model = ResNet( | |
| block_layer, | |
| config.layers, | |
| num_classes=config.num_classes, | |
| in_chans=config.input_channels, | |
| cardinality=config.cardinality, | |
| base_width=config.base_width, | |
| stem_width=config.stem_width, | |
| stem_type=config.stem_type, | |
| avg_down=config.avg_down, | |
| ) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, tensor, labels=<span class="hljs-literal">None</span></span>): | |
| logits = self.model(tensor) | |
| <span class="hljs-keyword">if</span> labels <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>: | |
| loss = torch.nn.functional.cross_entropy(logits, labels) | |
| <span class="hljs-keyword">return</span> {<span class="hljs-string">"loss"</span>: loss, <span class="hljs-string">"logits"</span>: logits} | |
| <span class="hljs-keyword">return</span> {<span class="hljs-string">"logits"</span>: logits}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1vktyut">In both cases, notice how we inherit from <code>PreTrainedModel</code> and call the superclass initialization with the <code>config</code> | |
| (a bit like when you write a regular <code>torch.nn.Module</code>). The line that sets the <code>config_class</code> is not mandatory, unless | |
| you want to register your model with the auto classes (see last section).</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-alg4p1">If your model is very similar to a model inside the library, you can re-use the same configuration as this model.</p></div> <p data-svelte-h="svelte-1iol4ns">You can have your model return anything you want, but returning a dictionary like we did for | |
| <code>ResnetModelForImageClassification</code>, with the loss included when labels are passed, will make your model directly | |
| usable inside the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class. Using another output format is fine as long as you are planning on using your own | |
| training loop or another library for training.</p> <p data-svelte-h="svelte-zsjn9b">Now that we have our model class, let’s create one:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->resnet50d = ResnetModelForImageClassification(resnet50d_config)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-7s7g23">Again, you can use any of the methods of <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel">PreTrainedModel</a>, like <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel.save_pretrained">save_pretrained()</a> or | |
| <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.utils.PushToHubMixin.push_to_hub">push_to_hub()</a>. We will use the second in the next section, and see how to push the model weights | |
| with the code of our model. But first, let’s load some pretrained weights inside our model.</p> <p data-svelte-h="svelte-kn8bw1">In your own use case, you will probably be training your custom model on your own data. To go fast for this tutorial, | |
| we will use the pretrained version of the resnet50d. Since our model is just a wrapper around it, it’s going to be | |
| easy to transfer those weights:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> timm | |
| pretrained_model = timm.create_model(<span class="hljs-string">"resnet50d"</span>, pretrained=<span class="hljs-literal">True</span>) | |
| resnet50d.model.load_state_dict(pretrained_model.state_dict())<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-edufaa">Now let’s see how to make sure that when we do <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel.save_pretrained">save_pretrained()</a> or <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.utils.PushToHubMixin.push_to_hub">push_to_hub()</a>, the | |
| code of the model is saved.</p> <h2 class="relative group"><a id="registering-a-model-with-custom-code-to-the-auto-classes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#registering-a-model-with-custom-code-to-the-auto-classes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Registering a model with custom code to the auto classes</span></h2> <p data-svelte-h="svelte-bccbz6">If you are writing a library that extends 🤗 Transformers, you may want to extend the auto classes to include your own | |
| model. This is different from pushing the code to the Hub in the sense that users will need to import your library to | |
| get the custom models (contrarily to automatically downloading the model code from the Hub).</p> <p data-svelte-h="svelte-uijzdg">As long as your config has a <code>model_type</code> attribute that is different from existing model types, and that your model | |
| classes have the right <code>config_class</code> attributes, you can just add them to the auto classes like this:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoConfig, AutoModel, AutoModelForImageClassification | |
| AutoConfig.register(<span class="hljs-string">"resnet"</span>, ResnetConfig) | |
| AutoModel.register(ResnetConfig, ResnetModel) | |
| AutoModelForImageClassification.register(ResnetConfig, ResnetModelForImageClassification)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1gu0byz">Note that the first argument used when registering your custom config to <a href="/docs/transformers/pr_33913/en/model_doc/auto#transformers.AutoConfig">AutoConfig</a> needs to match the <code>model_type</code> | |
| of your custom config, and the first argument used when registering your custom models to any auto model class needs | |
| to match the <code>config_class</code> of those models.</p> <h2 class="relative group"><a id="sending-the-code-to-the-hub" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sending-the-code-to-the-hub"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Sending the code to the Hub</span></h2> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-15rpg4">This API is experimental and may have some slight breaking changes in the next releases.</p></div> <p data-svelte-h="svelte-w79o0g">First, make sure your model is fully defined in a <code>.py</code> file. It can rely on relative imports to some other files as | |
| long as all the files are in the same directory (we don’t support submodules for this feature yet). For our example, | |
| we’ll define a <code>modeling_resnet.py</code> file and a <code>configuration_resnet.py</code> file in a folder of the current working | |
| directory named <code>resnet_model</code>. The configuration file contains the code for <code>ResnetConfig</code> and the modeling file | |
| contains the code of <code>ResnetModel</code> and <code>ResnetModelForImageClassification</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->. | |
| └── resnet_model | |
| ├── __init__.<span class="hljs-keyword">py</span> | |
| ├── configuration_resnet.<span class="hljs-keyword">py</span> | |
| └── modeling_resnet.<span class="hljs-keyword">py</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1orsl60">The <code>__init__.py</code> can be empty, it’s just there so that Python detects <code>resnet_model</code> can be use as a module.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1816fmc">If copying a modeling files from the library, you will need to replace all the relative imports at the top of the file | |
| to import from the <code>transformers</code> package.</p></div> <p data-svelte-h="svelte-miae0j">Note that you can re-use (or subclass) an existing configuration/model.</p> <p data-svelte-h="svelte-1kyhtnz">To share your model with the community, follow those steps: first import the ResNet model and config from the newly | |
| created files:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> resnet_model.configuration_resnet <span class="hljs-keyword">import</span> ResnetConfig | |
| <span class="hljs-keyword">from</span> resnet_model.modeling_resnet <span class="hljs-keyword">import</span> ResnetModel, ResnetModelForImageClassification<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ce24vd">Then you have to tell the library you want to copy the code files of those objects when using the <code>save_pretrained</code> | |
| method and properly register them with a given Auto class (especially for models), just run:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->ResnetConfig.register_for_auto_class() | |
| ResnetModel.register_for_auto_class(<span class="hljs-string">"AutoModel"</span>) | |
| ResnetModelForImageClassification.register_for_auto_class(<span class="hljs-string">"AutoModelForImageClassification"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1pmi7im">Note that there is no need to specify an auto class for the configuration (there is only one auto class for them, | |
| <a href="/docs/transformers/pr_33913/en/model_doc/auto#transformers.AutoConfig">AutoConfig</a>) but it’s different for models. Your custom model could be suitable for many different tasks, so you | |
| have to specify which one of the auto classes is the correct one for your model.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1rexmhi">Use <code>register_for_auto_class()</code> if you want the code files to be copied. If you instead prefer to use code on the Hub from another repo, | |
| you don’t need to call it. In cases where there’s more than one auto class, you can modify the <code>config.json</code> directly using the | |
| following structure:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-attr">"auto_map"</span><span class="hljs-punctuation">:</span> <span class="hljs-punctuation">{</span> | |
| <span class="hljs-attr">"AutoConfig"</span><span class="hljs-punctuation">:</span> <span class="hljs-string">"<your-repo-name>--<config-name>"</span><span class="hljs-punctuation">,</span> | |
| <span class="hljs-attr">"AutoModel"</span><span class="hljs-punctuation">:</span> <span class="hljs-string">"<your-repo-name>--<config-name>"</span><span class="hljs-punctuation">,</span> | |
| <span class="hljs-attr">"AutoModelFor<Task>"</span><span class="hljs-punctuation">:</span> <span class="hljs-string">"<your-repo-name>--<config-name>"</span><span class="hljs-punctuation">,</span> | |
| <span class="hljs-punctuation">}</span><span class="hljs-punctuation">,</span><!-- HTML_TAG_END --></pre></div></div> <p data-svelte-h="svelte-19txhag">Next, let’s create the config and models as we did before:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->resnet50d_config = ResnetConfig(block_type=<span class="hljs-string">"bottleneck"</span>, stem_width=<span class="hljs-number">32</span>, stem_type=<span class="hljs-string">"deep"</span>, avg_down=<span class="hljs-literal">True</span>) | |
| resnet50d = ResnetModelForImageClassification(resnet50d_config) | |
| pretrained_model = timm.create_model(<span class="hljs-string">"resnet50d"</span>, pretrained=<span class="hljs-literal">True</span>) | |
| resnet50d.model.load_state_dict(pretrained_model.state_dict())<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1w2dqml">Now to send the model to the Hub, make sure you are logged in. Either run in your terminal:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->huggingface-cli login<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-11lw65d">or from a notebook:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login | |
| notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1c2397x">You can then push to your own namespace (or an organization you are a member of) like this:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->resnet50d.push_to_hub(<span class="hljs-string">"custom-resnet50d"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-djed3h">On top of the modeling weights and the configuration in json format, this also copied the modeling and | |
| configuration <code>.py</code> files in the folder <code>custom-resnet50d</code> and uploaded the result to the Hub. You can check the result | |
| in this <a href="https://huggingface.co/sgugger/custom-resnet50d" rel="nofollow">model repo</a>.</p> <p data-svelte-h="svelte-1sdq6eg">See the <a href="model_sharing">sharing tutorial</a> for more information on the push to Hub method.</p> <h2 class="relative group"><a id="using-a-model-with-custom-code" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-a-model-with-custom-code"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using a model with custom code</span></h2> <p data-svelte-h="svelte-nao5c0">You can use any configuration, model or tokenizer with custom code files in its repository with the auto-classes and | |
| the <code>from_pretrained</code> method. All files and code uploaded to the Hub are scanned for malware (refer to the <a href="https://huggingface.co/docs/hub/security#malware-scanning" rel="nofollow">Hub security</a> documentation for more information), but you should still | |
| review the model code and author to avoid executing malicious code on your machine. Set <code>trust_remote_code=True</code> to use | |
| a model with custom code:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForImageClassification | |
| model = AutoModelForImageClassification.from_pretrained(<span class="hljs-string">"sgugger/custom-resnet50d"</span>, trust_remote_code=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1i8vto1">It is also strongly encouraged to pass a commit hash as a <code>revision</code> to make sure the author of the models did not | |
| update the code with some malicious new lines (unless you fully trust the authors of the models).</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->commit_hash = <span class="hljs-string">"ed94a7c6247d8aedce4647f00f20de6875b5b292"</span> | |
| model = AutoModelForImageClassification.from_pretrained( | |
| <span class="hljs-string">"sgugger/custom-resnet50d"</span>, trust_remote_code=<span class="hljs-literal">True</span>, revision=commit_hash | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1y2zx54">Note that when browsing the commit history of the model repo on the Hub, there is a button to easily copy the commit | |
| hash of any commit.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/custom_models.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_z647wz = { | |
| assets: "/docs/transformers/pr_33913/en", | |
| base: "/docs/transformers/pr_33913/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"), | |
| import("/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 17], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 58.6 kB
- Xet hash:
- e49600274bd87f38c4853e5d8335f8fd9757ce2ef362b0a6e1f257d1a8e6b650
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.