Buckets:

hf-doc-build/doc-dev / timm /pr_2213 /en /quickstart.html
rtrm's picture
download
raw
50.2 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Quickstart&quot;,&quot;local&quot;:&quot;quickstart&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Load a Pretrained Model&quot;,&quot;local&quot;:&quot;load-a-pretrained-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;List Models with Pretrained Weights&quot;,&quot;local&quot;:&quot;list-models-with-pretrained-weights&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-Tune a Pretrained Model&quot;,&quot;local&quot;:&quot;fine-tune-a-pretrained-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Use a Pretrained Model for Feature Extraction&quot;,&quot;local&quot;:&quot;use-a-pretrained-model-for-feature-extraction&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Image Augmentation&quot;,&quot;local&quot;:&quot;image-augmentation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using Pretrained Models for Inference&quot;,&quot;local&quot;:&quot;using-pretrained-models-for-inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/timm/pr_2213/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/entry/start.7b6e956d.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/chunks/scheduler.85c25b89.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/chunks/singletons.1da28543.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/chunks/paths.dfb1c1a6.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/entry/app.9d25b188.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/chunks/index.c9837788.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/nodes/0.62ba91e5.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/nodes/70.82ee2316.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/chunks/Tip.6bd863a0.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/chunks/CodeBlock.52fa569e.js">
<link rel="modulepreload" href="/docs/timm/pr_2213/en/_app/immutable/chunks/EditOnGithub.b65eee75.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Quickstart&quot;,&quot;local&quot;:&quot;quickstart&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Load a Pretrained Model&quot;,&quot;local&quot;:&quot;load-a-pretrained-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;List Models with Pretrained Weights&quot;,&quot;local&quot;:&quot;list-models-with-pretrained-weights&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-Tune a Pretrained Model&quot;,&quot;local&quot;:&quot;fine-tune-a-pretrained-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Use a Pretrained Model for Feature Extraction&quot;,&quot;local&quot;:&quot;use-a-pretrained-model-for-feature-extraction&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Image Augmentation&quot;,&quot;local&quot;:&quot;image-augmentation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using Pretrained Models for Inference&quot;,&quot;local&quot;:&quot;using-pretrained-models-for-inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="quickstart" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quickstart"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quickstart</span></h1> <p data-svelte-h="svelte-1mt2wpi">This quickstart is intended for developers who are ready to dive into the code and see an example of how to integrate <code>timm</code> into their model training workflow.</p> <p data-svelte-h="svelte-15vunib">First, you’ll need to install <code>timm</code>. For more information on installation, see <a href="installation">Installation</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install timm<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="load-a-pretrained-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#load-a-pretrained-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Load a Pretrained Model</span></h2> <p data-svelte-h="svelte-q915xz">Pretrained models can be loaded using <a href="/docs/timm/pr_2213/en/reference/models#timm.create_model">create_model()</a>.</p> <p data-svelte-h="svelte-11h2wfj">Here, we load the pretrained <code>mobilenetv3_large_100</code> model.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> timm
<span class="hljs-meta">&gt;&gt;&gt; </span>m = timm.create_model(<span class="hljs-string">&#x27;mobilenetv3_large_100&#x27;</span>, pretrained=<span class="hljs-literal">True</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>m.<span class="hljs-built_in">eval</span>()<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400">Note: The returned PyTorch model is set to train mode by default, so you must call .eval() on it if you plan to use it for inference.</div> <h2 class="relative group"><a id="list-models-with-pretrained-weights" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#list-models-with-pretrained-weights"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>List Models with Pretrained Weights</span></h2> <p data-svelte-h="svelte-vl2nlp">To list models packaged with <code>timm</code>, you can use <a href="/docs/timm/pr_2213/en/reference/models#timm.list_models">list_models()</a>. If you specify <code>pretrained=True</code>, this function will only return model names that have associated pretrained weights available.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> timm
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> pprint <span class="hljs-keyword">import</span> pprint
<span class="hljs-meta">&gt;&gt;&gt; </span>model_names = timm.list_models(pretrained=<span class="hljs-literal">True</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>pprint(model_names)
[
<span class="hljs-string">&#x27;adv_inception_v3&#x27;</span>,
<span class="hljs-string">&#x27;cspdarknet53&#x27;</span>,
<span class="hljs-string">&#x27;cspresnext50&#x27;</span>,
<span class="hljs-string">&#x27;densenet121&#x27;</span>,
<span class="hljs-string">&#x27;densenet161&#x27;</span>,
<span class="hljs-string">&#x27;densenet169&#x27;</span>,
<span class="hljs-string">&#x27;densenet201&#x27;</span>,
<span class="hljs-string">&#x27;densenetblur121d&#x27;</span>,
<span class="hljs-string">&#x27;dla34&#x27;</span>,
<span class="hljs-string">&#x27;dla46_c&#x27;</span>,
]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15bjls3">You can also list models with a specific pattern in their name.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> timm
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> pprint <span class="hljs-keyword">import</span> pprint
<span class="hljs-meta">&gt;&gt;&gt; </span>model_names = timm.list_models(<span class="hljs-string">&#x27;*resne*t*&#x27;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>pprint(model_names)
[
<span class="hljs-string">&#x27;cspresnet50&#x27;</span>,
<span class="hljs-string">&#x27;cspresnet50d&#x27;</span>,
<span class="hljs-string">&#x27;cspresnet50w&#x27;</span>,
<span class="hljs-string">&#x27;cspresnext50&#x27;</span>,
...
]<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="fine-tune-a-pretrained-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tune-a-pretrained-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-Tune a Pretrained Model</span></h2> <p data-svelte-h="svelte-9sr7nh">You can finetune any of the pre-trained models just by changing the classifier (the last layer).</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>model = timm.create_model(<span class="hljs-string">&#x27;mobilenetv3_large_100&#x27;</span>, pretrained=<span class="hljs-literal">True</span>, num_classes=NUM_FINETUNE_CLASSES)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17jj7hb">To fine-tune on your own dataset, you have to write a PyTorch training loop or adapt <code>timm</code>’s <a href="training_script">training script</a> to use your dataset.</p> <h2 class="relative group"><a id="use-a-pretrained-model-for-feature-extraction" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#use-a-pretrained-model-for-feature-extraction"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Use a Pretrained Model for Feature Extraction</span></h2> <p data-svelte-h="svelte-1bs7nra">Without modifying the network, one can call model.forward_features(input) on any model instead of the usual model(input). This will bypass the head classifier and global pooling for networks.</p> <p data-svelte-h="svelte-1ufypal">For a more in depth guide to using <code>timm</code> for feature extraction, see <a href="feature_extraction">Feature Extraction</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> timm
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span>x = torch.randn(<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">224</span>, <span class="hljs-number">224</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = timm.create_model(<span class="hljs-string">&#x27;mobilenetv3_large_100&#x27;</span>, pretrained=<span class="hljs-literal">True</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>features = model.forward_features(x)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(features.shape)
torch.Size([<span class="hljs-number">1</span>, <span class="hljs-number">960</span>, <span class="hljs-number">7</span>, <span class="hljs-number">7</span>])<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="image-augmentation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#image-augmentation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Image Augmentation</span></h2> <p data-svelte-h="svelte-1brjagi">To transform images into valid inputs for a model, you can use <a href="/docs/timm/pr_2213/en/reference/data#timm.data.create_transform">timm.data.create_transform()</a>, providing the desired <code>input_size</code> that the model expects.</p> <p data-svelte-h="svelte-14s5fqt">This will return a generic transform that uses reasonable defaults.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>timm.data.create_transform((<span class="hljs-number">3</span>, <span class="hljs-number">224</span>, <span class="hljs-number">224</span>))
Compose(
Resize(size=<span class="hljs-number">256</span>, interpolation=bilinear, max_size=<span class="hljs-literal">None</span>, antialias=<span class="hljs-literal">None</span>)
CenterCrop(size=(<span class="hljs-number">224</span>, <span class="hljs-number">224</span>))
ToTensor()
Normalize(mean=tensor([<span class="hljs-number">0.4850</span>, <span class="hljs-number">0.4560</span>, <span class="hljs-number">0.4060</span>]), std=tensor([<span class="hljs-number">0.2290</span>, <span class="hljs-number">0.2240</span>, <span class="hljs-number">0.2250</span>]))
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jniymz">Pretrained models have specific transforms that were applied to images fed into them while training. If you use the wrong transform on your image, the model won’t understand what it’s seeing!</p> <p data-svelte-h="svelte-1qp841i">To figure out which transformations were used for a given pretrained model, we can start by taking a look at its <code>pretrained_cfg</code></p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>model.pretrained_cfg
{<span class="hljs-string">&#x27;url&#x27;</span>: <span class="hljs-string">&#x27;https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth&#x27;</span>,
<span class="hljs-string">&#x27;num_classes&#x27;</span>: <span class="hljs-number">1000</span>,
<span class="hljs-string">&#x27;input_size&#x27;</span>: (<span class="hljs-number">3</span>, <span class="hljs-number">224</span>, <span class="hljs-number">224</span>),
<span class="hljs-string">&#x27;pool_size&#x27;</span>: (<span class="hljs-number">7</span>, <span class="hljs-number">7</span>),
<span class="hljs-string">&#x27;crop_pct&#x27;</span>: <span class="hljs-number">0.875</span>,
<span class="hljs-string">&#x27;interpolation&#x27;</span>: <span class="hljs-string">&#x27;bicubic&#x27;</span>,
<span class="hljs-string">&#x27;mean&#x27;</span>: (<span class="hljs-number">0.485</span>, <span class="hljs-number">0.456</span>, <span class="hljs-number">0.406</span>),
<span class="hljs-string">&#x27;std&#x27;</span>: (<span class="hljs-number">0.229</span>, <span class="hljs-number">0.224</span>, <span class="hljs-number">0.225</span>),
<span class="hljs-string">&#x27;first_conv&#x27;</span>: <span class="hljs-string">&#x27;conv_stem&#x27;</span>,
<span class="hljs-string">&#x27;classifier&#x27;</span>: <span class="hljs-string">&#x27;classifier&#x27;</span>,
<span class="hljs-string">&#x27;architecture&#x27;</span>: <span class="hljs-string">&#x27;mobilenetv3_large_100&#x27;</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1i34x3c">We can then resolve only the data related configuration by using <a href="/docs/timm/pr_2213/en/reference/data#timm.data.resolve_data_config">timm.data.resolve_data_config()</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>timm.data.resolve_data_config(model.pretrained_cfg)
{<span class="hljs-string">&#x27;input_size&#x27;</span>: (<span class="hljs-number">3</span>, <span class="hljs-number">224</span>, <span class="hljs-number">224</span>),
<span class="hljs-string">&#x27;interpolation&#x27;</span>: <span class="hljs-string">&#x27;bicubic&#x27;</span>,
<span class="hljs-string">&#x27;mean&#x27;</span>: (<span class="hljs-number">0.485</span>, <span class="hljs-number">0.456</span>, <span class="hljs-number">0.406</span>),
<span class="hljs-string">&#x27;std&#x27;</span>: (<span class="hljs-number">0.229</span>, <span class="hljs-number">0.224</span>, <span class="hljs-number">0.225</span>),
<span class="hljs-string">&#x27;crop_pct&#x27;</span>: <span class="hljs-number">0.875</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-hn1gnh">We can pass this data config to <a href="/docs/timm/pr_2213/en/reference/data#timm.data.create_transform">timm.data.create_transform()</a> to initialize the model’s associated transform.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
<span class="hljs-meta">&gt;&gt;&gt; </span>transform = timm.data.create_transform(**data_cfg)
<span class="hljs-meta">&gt;&gt;&gt; </span>transform
Compose(
Resize(size=<span class="hljs-number">256</span>, interpolation=bicubic, max_size=<span class="hljs-literal">None</span>, antialias=<span class="hljs-literal">None</span>)
CenterCrop(size=(<span class="hljs-number">224</span>, <span class="hljs-number">224</span>))
ToTensor()
Normalize(mean=tensor([<span class="hljs-number">0.4850</span>, <span class="hljs-number">0.4560</span>, <span class="hljs-number">0.4060</span>]), std=tensor([<span class="hljs-number">0.2290</span>, <span class="hljs-number">0.2240</span>, <span class="hljs-number">0.2250</span>]))
)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400">Note: Here, the pretrained model&#39;s config happens to be the same as the generic config we made earlier. This is not always the case. So, it&#39;s safer to use the data config to create the transform as we did here instead of using the generic transform.</div> <h2 class="relative group"><a id="using-pretrained-models-for-inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-pretrained-models-for-inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using Pretrained Models for Inference</span></h2> <p data-svelte-h="svelte-1vfr0rj">Here, we will put together the above sections and use a pretrained model for inference.</p> <p data-svelte-h="svelte-1u7llub">First we’ll need an image to do inference on. Here we load a picture of a leaf from the web:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> requests
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> io <span class="hljs-keyword">import</span> BytesIO
<span class="hljs-meta">&gt;&gt;&gt; </span>url = <span class="hljs-string">&#x27;https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg&#x27;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>image = Image.<span class="hljs-built_in">open</span>(requests.get(url, stream=<span class="hljs-literal">True</span>).raw)
<span class="hljs-meta">&gt;&gt;&gt; </span>image<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-nk223p">Here’s the image we loaded:</p> <img src="https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg" alt="An Image from a link" width="300"> <p data-svelte-h="svelte-1lr3iea">Now, we’ll create our model and transforms again. This time, we make sure to set our model in evaluation mode.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>model = timm.create_model(<span class="hljs-string">&#x27;mobilenetv3_large_100&#x27;</span>, pretrained=<span class="hljs-literal">True</span>).<span class="hljs-built_in">eval</span>()
<span class="hljs-meta">&gt;&gt;&gt; </span>transform = timm.data.create_transform(
**timm.data.resolve_data_config(model.pretrained_cfg)
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8vriuk">We can prepare this image for the model by passing it to the transform.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>image_tensor = transform(image)
<span class="hljs-meta">&gt;&gt;&gt; </span>image_tensor.shape
torch.Size([<span class="hljs-number">3</span>, <span class="hljs-number">224</span>, <span class="hljs-number">224</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ql7fkr">Now we can pass that image to the model to get the predictions. We use <code>unsqueeze(0)</code> in this case, as the model is expecting a batch dimension.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>output = model(image_tensor.unsqueeze(<span class="hljs-number">0</span>))
<span class="hljs-meta">&gt;&gt;&gt; </span>output.shape
torch.Size([<span class="hljs-number">1</span>, <span class="hljs-number">1000</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1sn7t3l">To get the predicted probabilities, we apply softmax to the output. This leaves us with a tensor of shape <code>(num_classes,)</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>probabilities = torch.nn.functional.softmax(output[<span class="hljs-number">0</span>], dim=<span class="hljs-number">0</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>probabilities.shape
torch.Size([<span class="hljs-number">1000</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1bct4ll">Now we’ll find the top 5 predicted class indexes and values using <code>torch.topk</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>values, indices = torch.topk(probabilities, <span class="hljs-number">5</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>indices
tensor([<span class="hljs-number">162</span>, <span class="hljs-number">166</span>, <span class="hljs-number">161</span>, <span class="hljs-number">164</span>, <span class="hljs-number">167</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-9q87ex">If we check the imagenet labels for the top index, we can see what the model predicted…</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>IMAGENET_1k_URL = <span class="hljs-string">&#x27;https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt&#x27;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split(<span class="hljs-string">&#x27;\n&#x27;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>[{<span class="hljs-string">&#x27;label&#x27;</span>: IMAGENET_1k_LABELS[idx], <span class="hljs-string">&#x27;value&#x27;</span>: val.item()} <span class="hljs-keyword">for</span> val, idx <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(values, indices)]
[{<span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;beagle&#x27;</span>, <span class="hljs-string">&#x27;value&#x27;</span>: <span class="hljs-number">0.8486220836639404</span>},
{<span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;Walker_hound, Walker_foxhound&#x27;</span>, <span class="hljs-string">&#x27;value&#x27;</span>: <span class="hljs-number">0.03753996267914772</span>},
{<span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;basset, basset_hound&#x27;</span>, <span class="hljs-string">&#x27;value&#x27;</span>: <span class="hljs-number">0.024628572165966034</span>},
{<span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;bluetick&#x27;</span>, <span class="hljs-string">&#x27;value&#x27;</span>: <span class="hljs-number">0.010317106731235981</span>},
{<span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;English_foxhound&#x27;</span>, <span class="hljs-string">&#x27;value&#x27;</span>: <span class="hljs-number">0.006958036217838526</span>}]<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/pytorch-image-models/blob/main/hfdocs/source/quickstart.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_15v4h5o = {
assets: "/docs/timm/pr_2213/en",
base: "/docs/timm/pr_2213/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/timm/pr_2213/en/_app/immutable/entry/start.7b6e956d.js"),
import("/docs/timm/pr_2213/en/_app/immutable/entry/app.9d25b188.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 70],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
50.2 kB
·
Xet hash:
ff3dc57486e904f0a85b12d385aa30ec2575d0178b9719bba86c1ca3ecaebe56

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.