Buckets:

hf-doc-build/doc-dev / setfit /pr_618 /en /quickstart.html
rtrm's picture
download
raw
50.6 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Quickstart&quot;,&quot;local&quot;:&quot;quickstart&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;SetFit&quot;,&quot;local&quot;:&quot;setfit&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Training&quot;,&quot;local&quot;:&quot;training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Saving a 🤗 SetFit model&quot;,&quot;local&quot;:&quot;saving-a--setfit-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Loading a 🤗 SetFit model&quot;,&quot;local&quot;:&quot;loading-a--setfit-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;What’s next?&quot;,&quot;local&quot;:&quot;whats-next&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;End-to-end&quot;,&quot;local&quot;:&quot;end-to-end&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/setfit/pr_618/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/entry/start.bb9a9e95.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/scheduler.c59d9fbb.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/singletons.9ab168b2.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/paths.f1826cfc.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/entry/app.5270c2f1.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/index.a47918e3.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/nodes/0.9ae94413.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/nodes/17.a556be73.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/CodeBlock.f26209eb.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/DocNotebookDropdown.537a75c5.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/globals.7f7f1b26.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/getInferenceSnippets.7ed99fe5.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Quickstart&quot;,&quot;local&quot;:&quot;quickstart&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;SetFit&quot;,&quot;local&quot;:&quot;setfit&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Training&quot;,&quot;local&quot;:&quot;training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Saving a 🤗 SetFit model&quot;,&quot;local&quot;:&quot;saving-a--setfit-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Loading a 🤗 SetFit model&quot;,&quot;local&quot;:&quot;loading-a--setfit-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;What’s next?&quot;,&quot;local&quot;:&quot;whats-next&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;End-to-end&quot;,&quot;local&quot;:&quot;end-to-end&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="quickstart" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quickstart"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quickstart</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"> </button> </div> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"> </button> </div></div> <p data-svelte-h="svelte-yf6m1a">This quickstart is intended for developers who are ready to dive into the code and see an example of how to train and use 🤗 SetFit models. We recommend starting with this quickstart, and then proceeding to the <a href="./tutorials/overview">tutorials</a> or <a href="./how_to/overview">how-to guides</a> for additional material. Additionally, the <a href="./conceptual_guides/setfit">conceptual guides</a> help explain exactly how SetFit works.</p> <p data-svelte-h="svelte-14v6c2i">Start by installing 🤗 SetFit:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install setfit<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zdbrrm">If you have a CUDA-capable graphics card, then it is recommended to <a href="https://pytorch.org/get-started/locally/" rel="nofollow">install <code>torch</code> with CUDA support</a> to train and performing inference much more quickly:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install torch --index-url https://download.pytorch.org/whl/cu118<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="setfit" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#setfit"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>SetFit</span></h2> <p data-svelte-h="svelte-yv05pn">SetFit is an efficient framework to train low-latency text classification models using little training data. In this Quickstart, you’ll learn how to train a SetFit model, how to perform inference with it, and how to save it to the Hugging Face Hub.</p> <h3 class="relative group"><a id="training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training</span></h3> <p data-svelte-h="svelte-55wyuo">In this section, you’ll load a <a href="https://huggingface.co/models?library=sentence-transformers" rel="nofollow">Sentence Transformer model</a> and further finetune it for classifying movie reviews as positive or negative. To train a model, we will need to prepare the following three: 1) a <strong>model</strong>, 2) a <strong>dataset</strong>, and 3) <strong>training arguments</strong>.</p> <p data-svelte-h="svelte-b5pnqy"><strong>1</strong>. Initialize a SetFit model using a Sentence Transformer model of our choice. Consider using the <a href="https://huggingface.co/spaces/mteb/leaderboard" rel="nofollow">MTEB Leaderboard</a> to guide your decision on which Sentence Transformer model to choose. We will use <a href="https://huggingface.co/BAAI/bge-small-en-v1.5" rel="nofollow">BAAI/bge-small-en-v1.5</a>, a small but performant model.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> setfit <span class="hljs-keyword">import</span> SetFitModel
<span class="hljs-meta">&gt;&gt;&gt; </span>model = SetFitModel.from_pretrained(<span class="hljs-string">&quot;BAAI/bge-small-en-v1.5&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1dmjuby"><strong>2a</strong>. Next, load both the “train” and “test” splits of the <a href="https://huggingface.co/datasets/sst2" rel="nofollow">SetFit/sst2</a> dataset. Note that the dataset has <code>&quot;text&quot;</code> and <code>&quot;label&quot;</code> columns: this is exactly the format that 🤗 SetFit expects. If your dataset has different columns, then you can use the column_mapping argument of the <a href="/docs/setfit/pr_618/en/reference/trainer#setfit.Trainer">Trainer</a> in step 4 to map the column names to <code>&quot;text&quot;</code> and <code>&quot;label&quot;</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset = load_dataset(<span class="hljs-string">&quot;SetFit/sst2&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset
DatasetDict({
train: Dataset({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>, <span class="hljs-string">&#x27;label_text&#x27;</span>],
num_rows: <span class="hljs-number">6920</span>
})
test: Dataset({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>, <span class="hljs-string">&#x27;label_text&#x27;</span>],
num_rows: <span class="hljs-number">1821</span>
})
validation: Dataset({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>, <span class="hljs-string">&#x27;label_text&#x27;</span>],
num_rows: <span class="hljs-number">872</span>
})
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10w3ory"><strong>2b</strong>. In real world scenarios it is very uncommon to have ~7.000 high quality labeled training samples, so we will heavily shrink the training dataset to give a better idea of how 🤗 SetFit would work in real settings. To be specific, the <code>sample_dataset</code> function will sample only 8 samples for each class. The testing set is left unaffected for better evaluation.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> setfit <span class="hljs-keyword">import</span> sample_dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>train_dataset = sample_dataset(dataset[<span class="hljs-string">&quot;train&quot;</span>], label_column=<span class="hljs-string">&quot;label&quot;</span>, num_samples=<span class="hljs-number">8</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>train_dataset
Dataset({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>, <span class="hljs-string">&#x27;label_text&#x27;</span>],
num_rows: <span class="hljs-number">16</span>
})<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>test_dataset = dataset[<span class="hljs-string">&quot;test&quot;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>test_dataset
Dataset({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>, <span class="hljs-string">&#x27;label_text&#x27;</span>],
num_rows: <span class="hljs-number">1821</span>
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1x016ab"><strong>2c</strong>. We can apply the labels from the dataset on the model, so the predictions output readable classes. You can also provide the labels directly to <code>SetFitModel.from_pretrained()</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>model.labels = [<span class="hljs-string">&quot;negative&quot;</span>, <span class="hljs-string">&quot;positive&quot;</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1z44i4"><strong>3</strong>. Prepare the <a href="/docs/setfit/pr_618/en/reference/trainer#setfit.TrainingArguments">TrainingArguments</a> for training. Note that training with 🤗 SetFit consists of two phases behind the scenes: <strong>finetuning embeddings</strong> and <strong>training a classification head</strong>. As a result, some of the training arguments can be tuples, where the two values are used for each of the two phases, respectively.</p> <p data-svelte-h="svelte-1z01zd6">The <code>num_epochs</code> and <code>max_steps</code> arguments are frequently used to increase and decrease the number of total training steps. Consider that with SetFit, better performance is reached with <strong>more data, not more training</strong>! Don’t be afraid to train for less than 1 epoch if you have a lot of data.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> setfit <span class="hljs-keyword">import</span> TrainingArguments
<span class="hljs-meta">&gt;&gt;&gt; </span>args = TrainingArguments(
<span class="hljs-meta">... </span> batch_size=<span class="hljs-number">32</span>,
<span class="hljs-meta">... </span> num_epochs=<span class="hljs-number">10</span>,
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18fs6ca"><strong>4</strong>. Initialize the <a href="/docs/setfit/pr_618/en/reference/trainer#setfit.Trainer">Trainer</a> and perform training.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> setfit <span class="hljs-keyword">import</span> Trainer
<span class="hljs-meta">&gt;&gt;&gt; </span>trainer = Trainer(
<span class="hljs-meta">... </span> model=model,
<span class="hljs-meta">... </span> args=args,
<span class="hljs-meta">... </span> train_dataset=train_dataset,
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>trainer.train()
***** Running training *****
Num examples = <span class="hljs-number">5</span>
Num epochs = <span class="hljs-number">10</span>
Total optimization steps = <span class="hljs-number">50</span>
Total train batch size = <span class="hljs-number">32</span>
{<span class="hljs-string">&#x27;embedding_loss&#x27;</span>: <span class="hljs-number">0.2077</span>, <span class="hljs-string">&#x27;learning_rate&#x27;</span>: <span class="hljs-number">4.000000000000001e-06</span>, <span class="hljs-string">&#x27;epoch&#x27;</span>: <span class="hljs-number">0.2</span>}
{<span class="hljs-string">&#x27;embedding_loss&#x27;</span>: <span class="hljs-number">0.0097</span>, <span class="hljs-string">&#x27;learning_rate&#x27;</span>: <span class="hljs-number">0.0</span>, <span class="hljs-string">&#x27;epoch&#x27;</span>: <span class="hljs-number">10.0</span>}
{<span class="hljs-string">&#x27;train_runtime&#x27;</span>: <span class="hljs-number">14.705</span>, <span class="hljs-string">&#x27;train_samples_per_second&#x27;</span>: <span class="hljs-number">108.807</span>, <span class="hljs-string">&#x27;train_steps_per_second&#x27;</span>: <span class="hljs-number">3.4</span>, <span class="hljs-string">&#x27;epoch&#x27;</span>: <span class="hljs-number">10.0</span>}
<span class="hljs-number">100</span>%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| <span class="hljs-number">50</span>/<span class="hljs-number">50</span> [<span class="hljs-number">00</span>:08&lt;<span class="hljs-number">00</span>:<span class="hljs-number">00</span>, <span class="hljs-number">5.70</span>it/s]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-4ybstn"><strong>5</strong>. Perform evaluation using the provided testing dataset.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>trainer.evaluate(test_dataset)
***** Running evaluation *****
{<span class="hljs-string">&#x27;accuracy&#x27;</span>: <span class="hljs-number">0.8511806699615596</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1svsu67">Feel free to experiment with increasing the number of samples per class to observe the improvements in accuracy. As a challenge, you can play with the samples per class, learning rate, number of epochs, maximum number of steps, and the base Sentence Transformer model to try and improve the accuracy over 90% using very little data.</p> <h3 class="relative group"><a id="saving-a--setfit-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#saving-a--setfit-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Saving a 🤗 SetFit model</span></h3> <p data-svelte-h="svelte-1jhyq77">After training, you can save a 🤗 SetFit model to your local filesystem or to the Hugging Face Hub. Save a model to a local directory using <code>SetFitModel.save_pretrained()</code> by providing a <code>save_directory</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>model.save_pretrained(<span class="hljs-string">&quot;setfit-bge-small-v1.5-sst2-8-shot&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hpf3g">Alternatively, push a model to the Hugging Face Hub using <code>SetFitModel.push_to_hub()</code> by providing a <code>repo_id</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>model.push_to_hub(<span class="hljs-string">&quot;tomaarsen/setfit-bge-small-v1.5-sst2-8-shot&quot;</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="loading-a--setfit-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#loading-a--setfit-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Loading a 🤗 SetFit model</span></h3> <p data-svelte-h="svelte-1di90ms">A 🤗 SetFit model can be loaded using <code>SetFitModel.from_pretrained()</code> by providing 1) a <code>repo_id</code> from the Hugging Face Hub or 2) a path to a local directory:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>model = SetFitModel.from_pretrained(<span class="hljs-string">&quot;tomaarsen/setfit-bge-small-v1.5-sst2-8-shot&quot;</span>) <span class="hljs-comment"># Load from the Hugging Face Hub</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>model = SetFitModel.from_pretrained(<span class="hljs-string">&quot;setfit-bge-small-v1.5-sst2-8-shot&quot;</span>) <span class="hljs-comment"># Load from a local directory</span><!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Inference</span></h3> <p data-svelte-h="svelte-sb7ks6">Once a 🤗 SetFit model has been trained, then it can be used for inference to classify reviews using <a href="/docs/setfit/pr_618/en/reference/main#setfit.SetFitModel.predict">SetFitModel.predict()</a> or <a href="/docs/setfit/pr_618/en/reference/main#setfit.SetFitModel.__call__">SetFitModel.<strong>call</strong>()</a>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>preds = model.predict([
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;It&#x27;s a charming and often affecting journey.&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;It&#x27;s slow -- very, very slow.&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;A sometimes tedious film.&quot;</span>,
<span class="hljs-meta">... </span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>preds
[<span class="hljs-string">&#x27;positive&#x27;</span> <span class="hljs-string">&#x27;negative&#x27;</span> <span class="hljs-string">&#x27;negative&#x27;</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-yxp1nk">These predictions rely on the <code>model.labels</code>. If not set, it will return predictions in the format that was used during training, e.g. <code>tensor([1, 0, 0])</code>.</p> <h2 class="relative group"><a id="whats-next" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#whats-next"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>What’s next?</span></h2> <p data-svelte-h="svelte-153smf1">You’ve completed the 🤗 SetFit quickstart! You can train, save, load and perform inference with 🤗 SetFit models!</p> <p data-svelte-h="svelte-1h8ndzm">For your next steps, take a look at our <a href="./how_to/overview">How-to guides</a> and learn how to do more specific things like hyperparameter search, knowledge distillation, or zero-shot text classification. If you’re interested in learning more about how 🤗 SetFit works, grab a cup of coffee and read our <a href="./conceptual_guides/setfit">Conceptual Guides</a>!</p> <h2 class="relative group"><a id="end-to-end" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#end-to-end"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>End-to-end</span></h2> <p data-svelte-h="svelte-1lco45x">This snippet shows the entire quickstart in an end-to-end example:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> setfit <span class="hljs-keyword">import</span> SetFitModel, Trainer, TrainingArguments, sample_dataset
<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-comment"># Initializing a new SetFit model</span>
model = SetFitModel.from_pretrained(<span class="hljs-string">&quot;BAAI/bge-small-en-v1.5&quot;</span>, labels=[<span class="hljs-string">&quot;negative&quot;</span>, <span class="hljs-string">&quot;positive&quot;</span>])
<span class="hljs-comment"># Preparing the dataset</span>
dataset = load_dataset(<span class="hljs-string">&quot;SetFit/sst2&quot;</span>)
train_dataset = sample_dataset(dataset[<span class="hljs-string">&quot;train&quot;</span>], label_column=<span class="hljs-string">&quot;label&quot;</span>, num_samples=<span class="hljs-number">8</span>)
test_dataset = dataset[<span class="hljs-string">&quot;test&quot;</span>]
<span class="hljs-comment"># Preparing the training arguments</span>
args = TrainingArguments(
batch_size=<span class="hljs-number">32</span>,
num_epochs=<span class="hljs-number">10</span>,
)
<span class="hljs-comment"># Preparing the trainer</span>
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
)
trainer.train()
<span class="hljs-comment"># Evaluating</span>
metrics = trainer.evaluate(test_dataset)
<span class="hljs-built_in">print</span>(metrics)
<span class="hljs-comment"># =&gt; {&#x27;accuracy&#x27;: 0.8511806699615596}</span>
<span class="hljs-comment"># Saving the trained model</span>
model.save_pretrained(<span class="hljs-string">&quot;setfit-bge-small-v1.5-sst2-8-shot&quot;</span>)
<span class="hljs-comment"># or</span>
model.push_to_hub(<span class="hljs-string">&quot;tomaarsen/setfit-bge-small-v1.5-sst2-8-shot&quot;</span>)
<span class="hljs-comment"># Loading a trained model</span>
model = SetFitModel.from_pretrained(<span class="hljs-string">&quot;tomaarsen/setfit-bge-small-v1.5-sst2-8-shot&quot;</span>) <span class="hljs-comment"># Load from the Hugging Face Hub</span>
<span class="hljs-comment"># or</span>
model = SetFitModel.from_pretrained(<span class="hljs-string">&quot;setfit-bge-small-v1.5-sst2-8-shot&quot;</span>) <span class="hljs-comment"># Load from a local directory</span>
<span class="hljs-comment"># Performing inference</span>
preds = model.predict([
<span class="hljs-string">&quot;It&#x27;s a charming and often affecting journey.&quot;</span>,
<span class="hljs-string">&quot;It&#x27;s slow -- very, very slow.&quot;</span>,
<span class="hljs-string">&quot;A sometimes tedious film.&quot;</span>,
])
<span class="hljs-built_in">print</span>(preds)
<span class="hljs-comment"># =&gt; [&quot;positive&quot;, &quot;negative&quot;, &quot;negative&quot;]</span><!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/setfit/blob/main/docs/source/en/quickstart.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_j4bv9w = {
assets: "/docs/setfit/pr_618/en",
base: "/docs/setfit/pr_618/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/setfit/pr_618/en/_app/immutable/entry/start.bb9a9e95.js"),
import("/docs/setfit/pr_618/en/_app/immutable/entry/app.5270c2f1.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 17],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
50.6 kB
·
Xet hash:
65871ef29ddf7b44d3cbcf1730d99b7c1ab8ff21472c940beb60c926d59ca923

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.