Buckets:

rtrm's picture
download
raw
53.2 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Efficiently run SetFit Models with Optimum&quot;,&quot;local&quot;:&quot;efficiently-run-setfit-models-with-optimum&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;1. Setup development environment&quot;,&quot;local&quot;:&quot;1-setup-development-environment&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;2. Create a performance benchmark&quot;,&quot;local&quot;:&quot;2-create-a-performance-benchmark&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;3. Train/evaluate bge-small SetFit models&quot;,&quot;local&quot;:&quot;3-trainevaluate-bge-small-setfit-models&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;4. Compressing with Optimum ONNX and CUDAExecutionProvider&quot;,&quot;local&quot;:&quot;4-compressing-with-optimum-onnx-and-cudaexecutionprovider&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/setfit/pr_618/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/entry/start.bb9a9e95.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/scheduler.c59d9fbb.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/singletons.9ab168b2.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/paths.f1826cfc.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/entry/app.5270c2f1.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/index.a47918e3.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/nodes/0.9ae94413.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/nodes/21.e8d53e40.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/CodeBlock.f26209eb.js">
<link rel="modulepreload" href="/docs/setfit/pr_618/en/_app/immutable/chunks/getInferenceSnippets.7ed99fe5.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Efficiently run SetFit Models with Optimum&quot;,&quot;local&quot;:&quot;efficiently-run-setfit-models-with-optimum&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;1. Setup development environment&quot;,&quot;local&quot;:&quot;1-setup-development-environment&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;2. Create a performance benchmark&quot;,&quot;local&quot;:&quot;2-create-a-performance-benchmark&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;3. Train/evaluate bge-small SetFit models&quot;,&quot;local&quot;:&quot;3-trainevaluate-bge-small-setfit-models&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;4. Compressing with Optimum ONNX and CUDAExecutionProvider&quot;,&quot;local&quot;:&quot;4-compressing-with-optimum-onnx-and-cudaexecutionprovider&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="efficiently-run-setfit-models-with-optimum" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#efficiently-run-setfit-models-with-optimum"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Efficiently run SetFit Models with Optimum</span></h1> <p data-svelte-h="svelte-n86os7"><a href="https://github.com/huggingface/setfit" rel="nofollow">SetFit</a> is a technique for few-shot text classification that uses contrastive learning to fine-tune Sentence Transformers in domains where little to no labeled data is available. It achieves comparable performance to existing state-of-the-art methods based on large language models, yet requires no prompts and is efficient to train (typically a few seconds on a GPU to minutes on a CPU).</p> <p data-svelte-h="svelte-1bd8w5">In this notebook you’ll learn how to further compress SetFit models for faster inference &amp; deployment on GPU using Optimum Onnx.</p> <h2 class="relative group"><a id="1-setup-development-environment" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#1-setup-development-environment"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>1. Setup development environment</span></h2> <p data-svelte-h="svelte-yduaqs">Our first step is to install SetFit. Running the following cell will install all the required packages for us.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!pip <span class="hljs-keyword">install</span> setfit accelerate -qqq<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="2-create-a-performance-benchmark" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#2-create-a-performance-benchmark"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>2. Create a performance benchmark</span></h2> <p data-svelte-h="svelte-13l2fse">Before we train and optimize any models, let’s define a performance benchmark that we can use to compare our models. In general, deploying ML models in production environments involves a tradeoff among several constraints:</p> <ul data-svelte-h="svelte-29gjki"><li>Model performance: how well does the model perform on a well crafted test set?</li> <li>Latency: how fast can our model deliver predictions?</li> <li>Memory: on what cloud instance or device can we store and load our model?</li></ul> <p data-svelte-h="svelte-z4wtvv">The class below defines a simple benchmark that measure each quantity for a given SetFit model and test dataset:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> pathlib <span class="hljs-keyword">import</span> Path
<span class="hljs-keyword">from</span> time <span class="hljs-keyword">import</span> perf_counter
<span class="hljs-keyword">import</span> evaluate
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
metric = evaluate.load(<span class="hljs-string">&quot;accuracy&quot;</span>)
<span class="hljs-keyword">class</span> <span class="hljs-title class_">PerformanceBenchmark</span>:
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, model, dataset, optim_type</span>):
self.model = model
self.dataset = dataset
self.optim_type = optim_type
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_accuracy</span>(<span class="hljs-params">self</span>):
preds = self.model.predict(self.dataset[<span class="hljs-string">&quot;text&quot;</span>])
labels = self.dataset[<span class="hljs-string">&quot;label&quot;</span>]
accuracy = metric.compute(predictions=preds, references=labels)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Accuracy on test set - <span class="hljs-subst">{accuracy[<span class="hljs-string">&#x27;accuracy&#x27;</span>]:<span class="hljs-number">.3</span>f}</span>&quot;</span>)
<span class="hljs-keyword">return</span> accuracy
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_size</span>(<span class="hljs-params">self</span>):
state_dict = self.model.model_body.state_dict()
tmp_path = Path(<span class="hljs-string">&quot;model.pt&quot;</span>)
torch.save(state_dict, tmp_path)
<span class="hljs-comment"># Calculate size in megabytes</span>
size_mb = Path(tmp_path).stat().st_size / (<span class="hljs-number">1024</span> * <span class="hljs-number">1024</span>)
<span class="hljs-comment"># Delete temporary file</span>
tmp_path.unlink()
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Model size (MB) - <span class="hljs-subst">{size_mb:<span class="hljs-number">.2</span>f}</span>&quot;</span>)
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;size_mb&quot;</span>: size_mb}
<span class="hljs-keyword">def</span> <span class="hljs-title function_">time_model</span>(<span class="hljs-params">self, query=<span class="hljs-string">&quot;that loves its characters and communicates something rather beautiful about human nature&quot;</span></span>):
latencies = []
<span class="hljs-comment"># Warmup</span>
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">10</span>):
_ = self.model([query])
<span class="hljs-comment"># Timed run</span>
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">100</span>):
start_time = perf_counter()
_ = self.model([query])
latency = perf_counter() - start_time
latencies.append(latency)
<span class="hljs-comment"># Compute run statistics</span>
time_avg_ms = <span class="hljs-number">1000</span> * np.mean(latencies)
time_std_ms = <span class="hljs-number">1000</span> * np.std(latencies)
<span class="hljs-built_in">print</span>(<span class="hljs-string">rf&quot;Average latency (ms) - <span class="hljs-subst">{time_avg_ms:<span class="hljs-number">.2</span>f}</span> +\- <span class="hljs-subst">{time_std_ms:<span class="hljs-number">.2</span>f}</span>&quot;</span>)
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;time_avg_ms&quot;</span>: time_avg_ms, <span class="hljs-string">&quot;time_std_ms&quot;</span>: time_std_ms}
<span class="hljs-keyword">def</span> <span class="hljs-title function_">run_benchmark</span>(<span class="hljs-params">self</span>):
metrics = {}
metrics[self.optim_type] = self.compute_size()
metrics[self.optim_type].update(self.compute_accuracy())
metrics[self.optim_type].update(self.time_model())
<span class="hljs-keyword">return</span> metrics<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6tflkv">Beyond that, we’ll create a simple function to plot the performances reported by this benchmark.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt
<span class="hljs-keyword">import</span> pandas <span class="hljs-keyword">as</span> pd
<span class="hljs-keyword">def</span> <span class="hljs-title function_">plot_metrics</span>(<span class="hljs-params">perf_metrics</span>):
df = pd.DataFrame.from_dict(perf_metrics, orient=<span class="hljs-string">&quot;index&quot;</span>)
<span class="hljs-keyword">for</span> idx <span class="hljs-keyword">in</span> df.index:
df_opt = df.loc[idx]
plt.errorbar(
df_opt[<span class="hljs-string">&quot;time_avg_ms&quot;</span>],
df_opt[<span class="hljs-string">&quot;accuracy&quot;</span>] * <span class="hljs-number">100</span>,
xerr=df_opt[<span class="hljs-string">&quot;time_std_ms&quot;</span>],
fmt=<span class="hljs-string">&quot;o&quot;</span>,
alpha=<span class="hljs-number">0.5</span>,
ms=df_opt[<span class="hljs-string">&quot;size_mb&quot;</span>] / <span class="hljs-number">15</span>,
label=idx,
capsize=<span class="hljs-number">5</span>,
capthick=<span class="hljs-number">1</span>,
)
legend = plt.legend(loc=<span class="hljs-string">&quot;lower right&quot;</span>)
plt.ylim(<span class="hljs-number">63</span>, <span class="hljs-number">95</span>)
<span class="hljs-comment"># Use the slowest model to define the x-axis range</span>
xlim = <span class="hljs-built_in">max</span>([metrics[<span class="hljs-string">&quot;time_avg_ms&quot;</span>] <span class="hljs-keyword">for</span> metrics <span class="hljs-keyword">in</span> perf_metrics.values()]) * <span class="hljs-number">1.2</span>
plt.xlim(<span class="hljs-number">0</span>, xlim)
plt.ylabel(<span class="hljs-string">&quot;Accuracy (%)&quot;</span>)
plt.xlabel(<span class="hljs-string">&quot;Average latency with batch_size=1 (ms)&quot;</span>)
plt.show()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="3-trainevaluate-bge-small-setfit-models" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#3-trainevaluate-bge-small-setfit-models"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>3. Train/evaluate bge-small SetFit models</span></h2> <p data-svelte-h="svelte-15nd2qq">Before we optimize any models, let’s train a few baselines as a point of reference. We’ll use the <a href="https://huggingface.co/datasets/SetFit/sst2" rel="nofollow">sst-2</a> dataset, which is a collection of sentiment text catagorized into 2 classes: positive, negative</p> <p data-svelte-h="svelte-saerdd">Let’s start by loading the dataset from the Hub:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->from datasets import load_dataset
<span class="hljs-attribute">dataset</span> <span class="hljs-operator">=</span> load_dataset(<span class="hljs-string">&quot;SetFit/sst2&quot;</span>)
dataset<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({
train: <span class="hljs-built_in">Dataset</span>({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>, <span class="hljs-string">&#x27;label_text&#x27;</span>],
num_rows: <span class="hljs-number">6920</span>
})
validation: <span class="hljs-built_in">Dataset</span>({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>, <span class="hljs-string">&#x27;label_text&#x27;</span>],
num_rows: <span class="hljs-number">872</span>
})
test: <span class="hljs-built_in">Dataset</span>({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>, <span class="hljs-string">&#x27;label_text&#x27;</span>],
num_rows: <span class="hljs-number">1821</span>
})
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cffqw7">We train a SetFit model with the full dataset. Recall that SetFit excels with few-shot scenario, but this time we are interested to achieve maximum accuracy.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->train_dataset = dataset[<span class="hljs-string">&quot;train&quot;</span>]
test_dataset = dataset[<span class="hljs-string">&quot;validation&quot;</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15w8xvk">Use the following line code to download the <a href="https://huggingface.co/moshew/bge-small-en-v1.5_setfit-sst2-english" rel="nofollow">already finetuned model</a> and evaluate. Alternatively, uncomment the code below it to fine-tune the base model from scratch.</p> <p data-svelte-h="svelte-1vdmtnq">Note that we perform the evaluations on Google Colab using the free T4 GPU.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Evaluate the uploaded model!</span>
<span class="hljs-keyword">from</span> setfit <span class="hljs-keyword">import</span> SetFitModel
small_model = SetFitModel.from_pretrained(<span class="hljs-string">&quot;moshew/bge-small-en-v1.5_setfit-sst2-english&quot;</span>)
pb = PerformanceBenchmark(model=small_model, dataset=test_dataset, optim_type=<span class="hljs-string">&quot;bge-small (PyTorch)&quot;</span>)
perf_metrics = pb.run_benchmark()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-attribute">Model</span> size (MB) - <span class="hljs-number">127</span>.<span class="hljs-number">33</span>
<span class="hljs-attribute">Accuracy</span> <span class="hljs-literal">on</span> test set - <span class="hljs-number">0</span>.<span class="hljs-number">906</span>
<span class="hljs-attribute">Average</span> latency (ms) - <span class="hljs-number">17</span>.<span class="hljs-number">42</span> +\- <span class="hljs-number">4</span>.<span class="hljs-number">47</span><!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># # Fine-tune the base model and Evaluate!</span>
<span class="hljs-comment"># from setfit import SetFitModel, Trainer, TrainingArguments</span>
<span class="hljs-comment"># # Load pretrained model from the Hub</span>
<span class="hljs-comment"># small_model = SetFitModel.from_pretrained(</span>
<span class="hljs-comment"># &quot;BAAI/bge-small-en-v1.5&quot;</span>
<span class="hljs-comment"># )</span>
<span class="hljs-comment"># args = TrainingArguments(num_iterations=20)</span>
<span class="hljs-comment"># # Create trainer</span>
<span class="hljs-comment"># small_trainer = Trainer(</span>
<span class="hljs-comment"># model=small_model, args=args, train_dataset=train_dataset</span>
<span class="hljs-comment"># )</span>
<span class="hljs-comment"># # Train!</span>
<span class="hljs-comment"># small_trainer.train()</span>
<span class="hljs-comment"># # Evaluate!</span>
<span class="hljs-comment"># pb = PerformanceBenchmark(</span>
<span class="hljs-comment"># model=small_trainer.model, dataset=test_dataset, optim_type=&quot;bge-small (base)&quot;</span>
<span class="hljs-comment"># )</span>
<span class="hljs-comment"># perf_metrics = pb.run_benchmark()</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1r87202">Let’s plot the results to visualise the performance:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-function"><span class="hljs-title">plot_metrics</span><span class="hljs-params">(perf_metrics)</span></span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3yqnh2"><img src="https://github.com/huggingface/setfit/assets/37621491/4786eee6-88c8-46ca-95be-801514697a9d" alt="setfit_torch"></p> <h2 class="relative group"><a id="4-compressing-with-optimum-onnx-and-cudaexecutionprovider" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#4-compressing-with-optimum-onnx-and-cudaexecutionprovider"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>4. Compressing with Optimum ONNX and CUDAExecutionProvider</span></h2> <p data-svelte-h="svelte-1tl60qh">We’ll be using Optimum’s ONNX Runtime support with <code>CUDAExecutionProvider</code> <a href="https://github.com/huggingface/optimum-benchmark/tree/main/examples/fast-mteb#notes" rel="nofollow">because it’s fast while also supporting dynamic shapes</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!pip <span class="hljs-keyword">install</span> optimum[onnxruntime-gpu] -qqq<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-4oyakg"><a href="https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization#optimizing-a-model-during-the-onnx-export" rel="nofollow"><code>optimum-cli</code></a> makes it extremely easy to export a model to ONNX and apply SOTA graph optimizations / kernel fusions.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!optimum-cli export onnx \
--model moshew/bge-small-en-v1<span class="hljs-number">.5</span>_setfit-sst2-english \
--task feature-extraction \
--optimize O4 \
--device cuda \
bge_auto_opt_O4<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-hhd52e">We may see some warnings, but these are not ones to be concerned about. We’ll see later that it does not affect the model performance.</p> <p data-svelte-h="svelte-2m011h">First of all, we’ll create a subclass of our performance benchmark to also allow benchmarking ONNX models.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">OnnxPerformanceBenchmark</span>(<span class="hljs-title class_ inherited__">PerformanceBenchmark</span>):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, *args, model_path, **kwargs</span>):
<span class="hljs-built_in">super</span>().__init__(*args, **kwargs)
self.model_path = model_path
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_size</span>(<span class="hljs-params">self</span>):
size_mb = Path(self.model_path).stat().st_size / (<span class="hljs-number">1024</span> * <span class="hljs-number">1024</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Model size (MB) - <span class="hljs-subst">{size_mb:<span class="hljs-number">.2</span>f}</span>&quot;</span>)
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;size_mb&quot;</span>: size_mb}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1n59tl1">Then, we can load the converted SentenceTransformer model with the <code>&quot;CUDAExecutionProvider&quot;</code> provider. Feel free to also experiment with other providers, such as <code>&quot;TensorrtExecutionProvider&quot;</code> and <code>&quot;CPUExecutionProvider&quot;</code>. The former may be even faster than <code>&quot;CUDAExecutionProvider&quot;</code>, but requires more installation.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer
<span class="hljs-keyword">from</span> optimum.onnxruntime <span class="hljs-keyword">import</span> ORTModelForFeatureExtraction
<span class="hljs-comment"># Load model from HuggingFace Hub</span>
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&#x27;bge_auto_opt_O4&#x27;</span>, model_max_length=<span class="hljs-number">512</span>)
ort_model = ORTModelForFeatureExtraction.from_pretrained(<span class="hljs-string">&#x27;bge_auto_opt_O4&#x27;</span>, provider=<span class="hljs-string">&quot;CUDAExecutionProvider&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1fs19fo">And let’s make a class that uses the tokenizer, ONNX Runtime (ORT) model and a SetFit model head.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> setfit.exporters.utils <span class="hljs-keyword">import</span> mean_pooling
<span class="hljs-keyword">class</span> <span class="hljs-title class_">OnnxSetFitModel</span>:
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, ort_model, tokenizer, model_head</span>):
self.ort_model = ort_model
self.tokenizer = tokenizer
self.model_head = model_head
<span class="hljs-keyword">def</span> <span class="hljs-title function_">predict</span>(<span class="hljs-params">self, inputs</span>):
encoded_inputs = self.tokenizer(
inputs, padding=<span class="hljs-literal">True</span>, truncation=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>
).to(self.ort_model.device)
outputs = self.ort_model(**encoded_inputs)
embeddings = mean_pooling(
outputs[<span class="hljs-string">&quot;last_hidden_state&quot;</span>], encoded_inputs[<span class="hljs-string">&quot;attention_mask&quot;</span>]
)
<span class="hljs-keyword">return</span> self.model_head.predict(embeddings.cpu())
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__call__</span>(<span class="hljs-params">self, inputs</span>):
<span class="hljs-keyword">return</span> self.predict(inputs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-v502qs">We can initialize this model like so:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = SetFitModel.from_pretrained(<span class="hljs-string">&quot;moshew/bge-small-en-v1.5_setfit-sst2-english&quot;</span>)
onnx_setfit_model = OnnxSetFitModel(ort_model, tokenizer, model.model_head)
<span class="hljs-comment"># Perform inference</span>
onnx_setfit_model(test_dataset[<span class="hljs-string">&quot;text&quot;</span>][:<span class="hljs-number">2</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-function"><span class="hljs-title">array</span><span class="hljs-params">([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>])</span></span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1tr4hv1">Time to benchmark this ONNX model.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pb = OnnxPerformanceBenchmark(
onnx_setfit_model,
test_dataset,
<span class="hljs-string">&quot;bge-small (optimum ONNX)&quot;</span>,
model_path=<span class="hljs-string">&quot;bge_auto_opt_O4/model.onnx&quot;</span>,
)
perf_metrics.update(pb.run_benchmark())<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->plot_metrics(perf_metrics)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1yk7uwi"><img src="https://github.com/huggingface/setfit/assets/37621491/9907ec1d-d4c6-431d-8695-1adc4247a576" alt="setfit_onnx"></p> <p data-svelte-h="svelte-ph4797">By applying ONNX, we were able to improve the latency from 13.43ms per sample to 2.19ms per sample, for a speedup of 6.13x!</p> <p data-svelte-h="svelte-rbx6kg">For further improvements, we recommend increasing the inference batch size, as this may also heavily improve the throughput. For example, setting the batch size to 128 reduces the latency further down to 0.3ms, and down to 0.2ms at a batch size of 2048.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/setfit/blob/main/docs/source/en/tutorials/onnx.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_j4bv9w = {
assets: "/docs/setfit/pr_618/en",
base: "/docs/setfit/pr_618/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/setfit/pr_618/en/_app/immutable/entry/start.bb9a9e95.js"),
import("/docs/setfit/pr_618/en/_app/immutable/entry/app.5270c2f1.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 21],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
53.2 kB
·
Xet hash:
7b711366e0f5672edd43ee34d2d52aac184748ec3899c27da1f2486cd21f8ed4

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.