Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Object detection","local":"object-detection","sections":[{"title":"Load the CPPE-5 dataset","local":"load-the-cppe-5-dataset","sections":[],"depth":2},{"title":"Preprocess the data","local":"preprocess-the-data","sections":[],"depth":2},{"title":"Preparing function to compute mAP","local":"preparing-function-to-compute-map","sections":[],"depth":2},{"title":"Training the detection model","local":"training-the-detection-model","sections":[],"depth":2},{"title":"Evaluate","local":"evaluate","sections":[],"depth":2},{"title":"Inference","local":"inference","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_33913/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/scheduler.25b97de1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/singletons.62a184e0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.e188933d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/paths.51881b9e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.d9030fc9.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/0.05e395f5.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/431.9e97c17b.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/Tip.baa67368.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/DocNotebookDropdown.5ea6cb78.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/globals.7f7f1b26.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/EditOnGithub.91d95064.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Object detection","local":"object-detection","sections":[{"title":"Load the CPPE-5 dataset","local":"load-the-cppe-5-dataset","sections":[],"depth":2},{"title":"Preprocess the data","local":"preprocess-the-data","sections":[],"depth":2},{"title":"Preparing function to compute mAP","local":"preparing-function-to-compute-map","sections":[],"depth":2},{"title":"Training the detection model","local":"training-the-detection-model","sections":[],"depth":2},{"title":"Evaluate","local":"evaluate","sections":[],"depth":2},{"title":"Inference","local":"inference","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="object-detection" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#object-detection"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Object detection</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"> </button> </div> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"> </button> </div></div> <p data-svelte-h="svelte-18pt6j0">Object detection is the computer vision task of detecting instances (such as humans, buildings, or cars) in an image. Object detection models receive an image as input and output | |
| coordinates of the bounding boxes and associated labels of the detected objects. An image can contain multiple objects, | |
| each with its own bounding box and a label (e.g. it can have a car and a building), and each object can | |
| be present in different parts of an image (e.g. the image can have several cars). | |
| This task is commonly used in autonomous driving for detecting things like pedestrians, road signs, and traffic lights. | |
| Other applications include counting objects in images, image search, and more.</p> <p data-svelte-h="svelte-1xy9go1">In this guide, you will learn how to:</p> <ol data-svelte-h="svelte-6qcuz8"><li>Finetune <a href="https://huggingface.co/docs/transformers/model_doc/detr" rel="nofollow">DETR</a>, a model that combines a convolutional | |
| backbone with an encoder-decoder Transformer, on the <a href="https://huggingface.co/datasets/cppe-5" rel="nofollow">CPPE-5</a> | |
| dataset.</li> <li>Use your finetuned model for inference.</li></ol> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-5wyiet">To see all architectures and checkpoints compatible with this task, we recommend checking the <a href="https://huggingface.co/tasks/object-detection" rel="nofollow">task-page</a></p></div> <p data-svelte-h="svelte-1c9nexd">Before you begin, make sure you have all the necessary libraries installed:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install -q datasets transformers accelerate timm | |
| pip install -q -U albumentations>=1.4.5 torchmetrics pycocotools<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hao5fk">You’ll use 🤗 Datasets to load a dataset from the Hugging Face Hub, 🤗 Transformers to train your model, | |
| and <code>albumentations</code> to augment the data.</p> <p data-svelte-h="svelte-1oee7b1">We encourage you to share your model with the community. Log in to your Hugging Face account to upload it to the Hub. | |
| When prompted, enter your token to log in:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login | |
| <span class="hljs-meta">>>> </span>notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ip202a">To get started, we’ll define global constants, namely the model name and image size. For this tutorial, we’ll use the conditional DETR model due to its faster convergence. Feel free to select any object detection model available in the <code>transformers</code> library.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>MODEL_NAME = <span class="hljs-string">"microsoft/conditional-detr-resnet-50"</span> <span class="hljs-comment"># or "facebook/detr-resnet-50"</span> | |
| <span class="hljs-meta">>>> </span>IMAGE_SIZE = <span class="hljs-number">480</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="load-the-cppe-5-dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#load-the-cppe-5-dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Load the CPPE-5 dataset</span></h2> <p data-svelte-h="svelte-9es7uf">The <a href="https://huggingface.co/datasets/cppe-5" rel="nofollow">CPPE-5 dataset</a> contains images with | |
| annotations identifying medical personal protective equipment (PPE) in the context of the COVID-19 pandemic.</p> <p data-svelte-h="svelte-11vhiu">Start by loading the dataset and creating a <code>validation</code> split from <code>train</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-meta">>>> </span>cppe5 = load_dataset(<span class="hljs-string">"cppe-5"</span>) | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">if</span> <span class="hljs-string">"validation"</span> <span class="hljs-keyword">not</span> <span class="hljs-keyword">in</span> cppe5: | |
| <span class="hljs-meta">... </span> split = cppe5[<span class="hljs-string">"train"</span>].train_test_split(<span class="hljs-number">0.15</span>, seed=<span class="hljs-number">1337</span>) | |
| <span class="hljs-meta">... </span> cppe5[<span class="hljs-string">"train"</span>] = split[<span class="hljs-string">"train"</span>] | |
| <span class="hljs-meta">... </span> cppe5[<span class="hljs-string">"validation"</span>] = split[<span class="hljs-string">"test"</span>] | |
| <span class="hljs-meta">>>> </span>cppe5 | |
| DatasetDict({ | |
| train: Dataset({ | |
| features: [<span class="hljs-string">'image_id'</span>, <span class="hljs-string">'image'</span>, <span class="hljs-string">'width'</span>, <span class="hljs-string">'height'</span>, <span class="hljs-string">'objects'</span>], | |
| num_rows: <span class="hljs-number">850</span> | |
| }) | |
| test: Dataset({ | |
| features: [<span class="hljs-string">'image_id'</span>, <span class="hljs-string">'image'</span>, <span class="hljs-string">'width'</span>, <span class="hljs-string">'height'</span>, <span class="hljs-string">'objects'</span>], | |
| num_rows: <span class="hljs-number">29</span> | |
| }) | |
| validation: Dataset({ | |
| features: [<span class="hljs-string">'image_id'</span>, <span class="hljs-string">'image'</span>, <span class="hljs-string">'width'</span>, <span class="hljs-string">'height'</span>, <span class="hljs-string">'objects'</span>], | |
| num_rows: <span class="hljs-number">150</span> | |
| }) | |
| })<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-rvn8y2">You’ll see that this dataset has 1000 images for train and validation sets and a test set with 29 images.</p> <p data-svelte-h="svelte-4bevpw">To get familiar with the data, explore what the examples look like.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>cppe5[<span class="hljs-string">"train"</span>][<span class="hljs-number">0</span>] | |
| { | |
| <span class="hljs-string">'image_id'</span>: <span class="hljs-number">366</span>, | |
| <span class="hljs-string">'image'</span>: <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=500x290>, | |
| <span class="hljs-string">'width'</span>: <span class="hljs-number">500</span>, | |
| <span class="hljs-string">'height'</span>: <span class="hljs-number">500</span>, | |
| <span class="hljs-string">'objects'</span>: { | |
| <span class="hljs-string">'id'</span>: [<span class="hljs-number">1932</span>, <span class="hljs-number">1933</span>, <span class="hljs-number">1934</span>], | |
| <span class="hljs-string">'area'</span>: [<span class="hljs-number">27063</span>, <span class="hljs-number">34200</span>, <span class="hljs-number">32431</span>], | |
| <span class="hljs-string">'bbox'</span>: [[<span class="hljs-number">29.0</span>, <span class="hljs-number">11.0</span>, <span class="hljs-number">97.0</span>, <span class="hljs-number">279.0</span>], | |
| [<span class="hljs-number">201.0</span>, <span class="hljs-number">1.0</span>, <span class="hljs-number">120.0</span>, <span class="hljs-number">285.0</span>], | |
| [<span class="hljs-number">382.0</span>, <span class="hljs-number">0.0</span>, <span class="hljs-number">113.0</span>, <span class="hljs-number">287.0</span>]], | |
| <span class="hljs-string">'category'</span>: [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>] | |
| } | |
| }<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-m0t76z">The examples in the dataset have the following fields:</p> <ul data-svelte-h="svelte-1tn0avh"><li><code>image_id</code>: the example image id</li> <li><code>image</code>: a <code>PIL.Image.Image</code> object containing the image</li> <li><code>width</code>: width of the image</li> <li><code>height</code>: height of the image</li> <li><code>objects</code>: a dictionary containing bounding box metadata for the objects in the image:<ul><li><code>id</code>: the annotation id</li> <li><code>area</code>: the area of the bounding box</li> <li><code>bbox</code>: the object’s bounding box (in the <a href="https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco" rel="nofollow">COCO format</a> )</li> <li><code>category</code>: the object’s category, with possible values including <code>Coverall (0)</code>, <code>Face_Shield (1)</code>, <code>Gloves (2)</code>, <code>Goggles (3)</code> and <code>Mask (4)</code></li></ul></li></ul> <p data-svelte-h="svelte-edp0uk">You may notice that the <code>bbox</code> field follows the COCO format, which is the format that the DETR model expects. | |
| However, the grouping of the fields inside <code>objects</code> differs from the annotation format DETR requires. You will | |
| need to apply some preprocessing transformations before using this data for training.</p> <p data-svelte-h="svelte-1o4zzv7">To get an even better understanding of the data, visualize an example in the dataset.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> os | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image, ImageDraw | |
| <span class="hljs-meta">>>> </span>image = cppe5[<span class="hljs-string">"train"</span>][<span class="hljs-number">2</span>][<span class="hljs-string">"image"</span>] | |
| <span class="hljs-meta">>>> </span>annotations = cppe5[<span class="hljs-string">"train"</span>][<span class="hljs-number">2</span>][<span class="hljs-string">"objects"</span>] | |
| <span class="hljs-meta">>>> </span>draw = ImageDraw.Draw(image) | |
| <span class="hljs-meta">>>> </span>categories = cppe5[<span class="hljs-string">"train"</span>].features[<span class="hljs-string">"objects"</span>].feature[<span class="hljs-string">"category"</span>].names | |
| <span class="hljs-meta">>>> </span>id2label = {index: x <span class="hljs-keyword">for</span> index, x <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(categories, start=<span class="hljs-number">0</span>)} | |
| <span class="hljs-meta">>>> </span>label2id = {v: k <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> id2label.items()} | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(annotations[<span class="hljs-string">"id"</span>])): | |
| <span class="hljs-meta">... </span> box = annotations[<span class="hljs-string">"bbox"</span>][i] | |
| <span class="hljs-meta">... </span> class_idx = annotations[<span class="hljs-string">"category"</span>][i] | |
| <span class="hljs-meta">... </span> x, y, w, h = <span class="hljs-built_in">tuple</span>(box) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># Check if coordinates are normalized or not</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> <span class="hljs-built_in">max</span>(box) > <span class="hljs-number">1.0</span>: | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># Coordinates are un-normalized, no need to re-scale them</span> | |
| <span class="hljs-meta">... </span> x1, y1 = <span class="hljs-built_in">int</span>(x), <span class="hljs-built_in">int</span>(y) | |
| <span class="hljs-meta">... </span> x2, y2 = <span class="hljs-built_in">int</span>(x + w), <span class="hljs-built_in">int</span>(y + h) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">else</span>: | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># Coordinates are normalized, re-scale them</span> | |
| <span class="hljs-meta">... </span> x1 = <span class="hljs-built_in">int</span>(x * width) | |
| <span class="hljs-meta">... </span> y1 = <span class="hljs-built_in">int</span>(y * height) | |
| <span class="hljs-meta">... </span> x2 = <span class="hljs-built_in">int</span>((x + w) * width) | |
| <span class="hljs-meta">... </span> y2 = <span class="hljs-built_in">int</span>((y + h) * height) | |
| <span class="hljs-meta">... </span> draw.rectangle((x, y, x + w, y + h), outline=<span class="hljs-string">"red"</span>, width=<span class="hljs-number">1</span>) | |
| <span class="hljs-meta">... </span> draw.text((x, y), id2label[class_idx], fill=<span class="hljs-string">"white"</span>) | |
| <span class="hljs-meta">>>> </span>image<!-- HTML_TAG_END --></pre></div> <div class="flex justify-center" data-svelte-h="svelte-7wscra"><img src="https://i.imgur.com/oVQb9SF.png" alt="CPPE-5 Image Example"></div> <p data-svelte-h="svelte-s62j5i">To visualize the bounding boxes with associated labels, you can get the labels from the dataset’s metadata, specifically | |
| the <code>category</code> field. | |
| You’ll also want to create dictionaries that map a label id to a label class (<code>id2label</code>) and the other way around (<code>label2id</code>). | |
| You can use them later when setting up the model. Including these maps will make your model reusable by others if you share | |
| it on the Hugging Face Hub. Please note that, the part of above code that draws the bounding boxes assume that it is in <code>COCO</code> format <code>(x_min, y_min, width, height)</code>. It has to be adjusted to work for other formats like <code>(x_min, y_min, x_max, y_max)</code>.</p> <p data-svelte-h="svelte-1p1sy04">As a final step of getting familiar with the data, explore it for potential issues. One common problem with datasets for | |
| object detection is bounding boxes that “stretch” beyond the edge of the image. Such “runaway” bounding boxes can raise | |
| errors during training and should be addressed. There are a few examples with this issue in this dataset. | |
| To keep things simple in this guide, we will set <code>clip=True</code> for <code>BboxParams</code> in transformations below.</p> <h2 class="relative group"><a id="preprocess-the-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preprocess-the-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preprocess the data</span></h2> <p data-svelte-h="svelte-1rs0ois">To finetune a model, you must preprocess the data you plan to use to match precisely the approach used for the pre-trained model. | |
| <a href="/docs/transformers/pr_33913/en/model_doc/auto#transformers.AutoImageProcessor">AutoImageProcessor</a> takes care of processing image data to create <code>pixel_values</code>, <code>pixel_mask</code>, and | |
| <code>labels</code> that a DETR model can train with. The image processor has some attributes that you won’t have to worry about:</p> <ul data-svelte-h="svelte-9xz2l6"><li><code>image_mean = [0.485, 0.456, 0.406 ]</code></li> <li><code>image_std = [0.229, 0.224, 0.225]</code></li></ul> <p data-svelte-h="svelte-1uiy3io">These are the mean and standard deviation used to normalize images during the model pre-training. These values are crucial | |
| to replicate when doing inference or finetuning a pre-trained image model.</p> <p data-svelte-h="svelte-1ipxopl">Instantiate the image processor from the same checkpoint as the model you want to finetune.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoImageProcessor | |
| <span class="hljs-meta">>>> </span>MAX_SIZE = IMAGE_SIZE | |
| <span class="hljs-meta">>>> </span>image_processor = AutoImageProcessor.from_pretrained( | |
| <span class="hljs-meta">... </span> MODEL_NAME, | |
| <span class="hljs-meta">... </span> do_resize=<span class="hljs-literal">True</span>, | |
| <span class="hljs-meta">... </span> size={<span class="hljs-string">"max_height"</span>: MAX_SIZE, <span class="hljs-string">"max_width"</span>: MAX_SIZE}, | |
| <span class="hljs-meta">... </span> do_pad=<span class="hljs-literal">True</span>, | |
| <span class="hljs-meta">... </span> pad_size={<span class="hljs-string">"height"</span>: MAX_SIZE, <span class="hljs-string">"width"</span>: MAX_SIZE}, | |
| <span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-16qnm9y">Before passing the images to the <code>image_processor</code>, apply two preprocessing transformations to the dataset:</p> <ul data-svelte-h="svelte-pqfhzt"><li>Augmenting images</li> <li>Reformatting annotations to meet DETR expectations</li></ul> <p data-svelte-h="svelte-1ekhp6u">First, to make sure the model does not overfit on the training data, you can apply image augmentation with any data augmentation library. Here we use <a href="https://albumentations.ai/docs/" rel="nofollow">Albumentations</a>. | |
| This library ensures that transformations affect the image and update the bounding boxes accordingly. | |
| The 🤗 Datasets library documentation has a detailed <a href="https://huggingface.co/docs/datasets/object_detection" rel="nofollow">guide on how to augment images for object detection</a>, | |
| and it uses the exact same dataset as an example. Apply some geometric and color transformations to the image. For additional augmentation options, explore the <a href="https://huggingface.co/spaces/qubvel-hf/albumentations-demo" rel="nofollow">Albumentations Demo Space</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> albumentations <span class="hljs-keyword">as</span> A | |
| <span class="hljs-meta">>>> </span>train_augment_and_transform = A.Compose( | |
| <span class="hljs-meta">... </span> [ | |
| <span class="hljs-meta">... </span> A.Perspective(p=<span class="hljs-number">0.1</span>), | |
| <span class="hljs-meta">... </span> A.HorizontalFlip(p=<span class="hljs-number">0.5</span>), | |
| <span class="hljs-meta">... </span> A.RandomBrightnessContrast(p=<span class="hljs-number">0.5</span>), | |
| <span class="hljs-meta">... </span> A.HueSaturationValue(p=<span class="hljs-number">0.1</span>), | |
| <span class="hljs-meta">... </span> ], | |
| <span class="hljs-meta">... </span> bbox_params=A.BboxParams(<span class="hljs-built_in">format</span>=<span class="hljs-string">"coco"</span>, label_fields=[<span class="hljs-string">"category"</span>], clip=<span class="hljs-literal">True</span>, min_area=<span class="hljs-number">25</span>), | |
| <span class="hljs-meta">... </span>) | |
| <span class="hljs-meta">>>> </span>validation_transform = A.Compose( | |
| <span class="hljs-meta">... </span> [A.NoOp()], | |
| <span class="hljs-meta">... </span> bbox_params=A.BboxParams(<span class="hljs-built_in">format</span>=<span class="hljs-string">"coco"</span>, label_fields=[<span class="hljs-string">"category"</span>], clip=<span class="hljs-literal">True</span>), | |
| <span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-bsjzql">The <code>image_processor</code> expects the annotations to be in the following format: <code>{'image_id': int, 'annotations': List[Dict]}</code>, | |
| where each dictionary is a COCO object annotation. Let’s add a function to reformat annotations for a single example:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">format_image_annotations_as_coco</span>(<span class="hljs-params">image_id, categories, areas, bboxes</span>): | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"""Format one set of image annotations to the COCO format | |
| <span class="hljs-meta">... </span> Args: | |
| <span class="hljs-meta">... </span> image_id (str): image id. e.g. "0001" | |
| <span class="hljs-meta">... </span> categories (List[int]): list of categories/class labels corresponding to provided bounding boxes | |
| <span class="hljs-meta">... </span> areas (List[float]): list of corresponding areas to provided bounding boxes | |
| <span class="hljs-meta">... </span> bboxes (List[Tuple[float]]): list of bounding boxes provided in COCO format | |
| <span class="hljs-meta">... </span> ([center_x, center_y, width, height] in absolute coordinates) | |
| <span class="hljs-meta">... </span> Returns: | |
| <span class="hljs-meta">... </span> dict: { | |
| <span class="hljs-meta">... </span> "image_id": image id, | |
| <span class="hljs-meta">... </span> "annotations": list of formatted annotations | |
| <span class="hljs-meta">... </span> } | |
| <span class="hljs-meta">... </span> """</span> | |
| <span class="hljs-meta">... </span> annotations = [] | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> category, area, bbox <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(categories, areas, bboxes): | |
| <span class="hljs-meta">... </span> formatted_annotation = { | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"image_id"</span>: image_id, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"category_id"</span>: category, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"iscrowd"</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"area"</span>: area, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"bbox"</span>: <span class="hljs-built_in">list</span>(bbox), | |
| <span class="hljs-meta">... </span> } | |
| <span class="hljs-meta">... </span> annotations.append(formatted_annotation) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> { | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"image_id"</span>: image_id, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"annotations"</span>: annotations, | |
| <span class="hljs-meta">... </span> } | |
| <!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-16yruro">Now you can combine the image and annotation transformations to use on a batch of examples:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">augment_and_transform_batch</span>(<span class="hljs-params">examples, transform, image_processor, return_pixel_mask=<span class="hljs-literal">False</span></span>): | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"""Apply augmentations and format annotations in COCO format for object detection task"""</span> | |
| <span class="hljs-meta">... </span> images = [] | |
| <span class="hljs-meta">... </span> annotations = [] | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> image_id, image, objects <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(examples[<span class="hljs-string">"image_id"</span>], examples[<span class="hljs-string">"image"</span>], examples[<span class="hljs-string">"objects"</span>]): | |
| <span class="hljs-meta">... </span> image = np.array(image.convert(<span class="hljs-string">"RGB"</span>)) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># apply augmentations</span> | |
| <span class="hljs-meta">... </span> output = transform(image=image, bboxes=objects[<span class="hljs-string">"bbox"</span>], category=objects[<span class="hljs-string">"category"</span>]) | |
| <span class="hljs-meta">... </span> images.append(output[<span class="hljs-string">"image"</span>]) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># format annotations in COCO format</span> | |
| <span class="hljs-meta">... </span> formatted_annotations = format_image_annotations_as_coco( | |
| <span class="hljs-meta">... </span> image_id, output[<span class="hljs-string">"category"</span>], objects[<span class="hljs-string">"area"</span>], output[<span class="hljs-string">"bboxes"</span>] | |
| <span class="hljs-meta">... </span> ) | |
| <span class="hljs-meta">... </span> annotations.append(formatted_annotations) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># Apply the image processor transformations: resizing, rescaling, normalization</span> | |
| <span class="hljs-meta">... </span> result = image_processor(images=images, annotations=annotations, return_tensors=<span class="hljs-string">"pt"</span>) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> return_pixel_mask: | |
| <span class="hljs-meta">... </span> result.pop(<span class="hljs-string">"pixel_mask"</span>, <span class="hljs-literal">None</span>) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> result<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-gv4k4k">Apply this preprocessing function to the entire dataset using 🤗 Datasets <a href="https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.with_transform" rel="nofollow">with_transform</a> method. This method applies | |
| transformations on the fly when you load an element of the dataset.</p> <p data-svelte-h="svelte-1o4lbgk">At this point, you can check what an example from the dataset looks like after the transformations. You should see a tensor | |
| with <code>pixel_values</code>, a tensor with <code>pixel_mask</code>, and <code>labels</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> functools <span class="hljs-keyword">import</span> partial | |
| <span class="hljs-meta">>>> </span><span class="hljs-comment"># Make transform functions for batch and apply for dataset splits</span> | |
| <span class="hljs-meta">>>> </span>train_transform_batch = partial( | |
| <span class="hljs-meta">... </span> augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor | |
| <span class="hljs-meta">... </span>) | |
| <span class="hljs-meta">>>> </span>validation_transform_batch = partial( | |
| <span class="hljs-meta">... </span> augment_and_transform_batch, transform=validation_transform, image_processor=image_processor | |
| <span class="hljs-meta">... </span>) | |
| <span class="hljs-meta">>>> </span>cppe5[<span class="hljs-string">"train"</span>] = cppe5[<span class="hljs-string">"train"</span>].with_transform(train_transform_batch) | |
| <span class="hljs-meta">>>> </span>cppe5[<span class="hljs-string">"validation"</span>] = cppe5[<span class="hljs-string">"validation"</span>].with_transform(validation_transform_batch) | |
| <span class="hljs-meta">>>> </span>cppe5[<span class="hljs-string">"test"</span>] = cppe5[<span class="hljs-string">"test"</span>].with_transform(validation_transform_batch) | |
| <span class="hljs-meta">>>> </span>cppe5[<span class="hljs-string">"train"</span>][<span class="hljs-number">15</span>] | |
| {<span class="hljs-string">'pixel_values'</span>: tensor([[[ <span class="hljs-number">1.9235</span>, <span class="hljs-number">1.9407</span>, <span class="hljs-number">1.9749</span>, ..., -<span class="hljs-number">0.7822</span>, -<span class="hljs-number">0.7479</span>, -<span class="hljs-number">0.6965</span>], | |
| [ <span class="hljs-number">1.9578</span>, <span class="hljs-number">1.9749</span>, <span class="hljs-number">1.9920</span>, ..., -<span class="hljs-number">0.7993</span>, -<span class="hljs-number">0.7650</span>, -<span class="hljs-number">0.7308</span>], | |
| [ <span class="hljs-number">2.0092</span>, <span class="hljs-number">2.0092</span>, <span class="hljs-number">2.0263</span>, ..., -<span class="hljs-number">0.8507</span>, -<span class="hljs-number">0.8164</span>, -<span class="hljs-number">0.7822</span>], | |
| ..., | |
| [ <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>, ..., <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>], | |
| [ <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>, ..., <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>], | |
| [ <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>, ..., <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>, <span class="hljs-number">0.0741</span>]], | |
| [[ <span class="hljs-number">1.6232</span>, <span class="hljs-number">1.6408</span>, <span class="hljs-number">1.6583</span>, ..., <span class="hljs-number">0.8704</span>, <span class="hljs-number">1.0105</span>, <span class="hljs-number">1.1331</span>], | |
| [ <span class="hljs-number">1.6408</span>, <span class="hljs-number">1.6583</span>, <span class="hljs-number">1.6758</span>, ..., <span class="hljs-number">0.8529</span>, <span class="hljs-number">0.9930</span>, <span class="hljs-number">1.0980</span>], | |
| [ <span class="hljs-number">1.6933</span>, <span class="hljs-number">1.6933</span>, <span class="hljs-number">1.7108</span>, ..., <span class="hljs-number">0.8179</span>, <span class="hljs-number">0.9580</span>, <span class="hljs-number">1.0630</span>], | |
| ..., | |
| [ <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>, ..., <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>], | |
| [ <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>, ..., <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>], | |
| [ <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>, ..., <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>, <span class="hljs-number">0.2052</span>]], | |
| [[ <span class="hljs-number">1.8905</span>, <span class="hljs-number">1.9080</span>, <span class="hljs-number">1.9428</span>, ..., -<span class="hljs-number">0.1487</span>, -<span class="hljs-number">0.0964</span>, -<span class="hljs-number">0.0615</span>], | |
| [ <span class="hljs-number">1.9254</span>, <span class="hljs-number">1.9428</span>, <span class="hljs-number">1.9603</span>, ..., -<span class="hljs-number">0.1661</span>, -<span class="hljs-number">0.1138</span>, -<span class="hljs-number">0.0790</span>], | |
| [ <span class="hljs-number">1.9777</span>, <span class="hljs-number">1.9777</span>, <span class="hljs-number">1.9951</span>, ..., -<span class="hljs-number">0.2010</span>, -<span class="hljs-number">0.1138</span>, -<span class="hljs-number">0.0790</span>], | |
| ..., | |
| [ <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>, ..., <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>], | |
| [ <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>, ..., <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>], | |
| [ <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>, ..., <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>, <span class="hljs-number">0.4265</span>]]]), | |
| <span class="hljs-string">'labels'</span>: {<span class="hljs-string">'image_id'</span>: tensor([<span class="hljs-number">688</span>]), <span class="hljs-string">'class_labels'</span>: tensor([<span class="hljs-number">3</span>, <span class="hljs-number">4</span>, <span class="hljs-number">2</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]), <span class="hljs-string">'boxes'</span>: tensor([[<span class="hljs-number">0.4700</span>, <span class="hljs-number">0.1933</span>, <span class="hljs-number">0.1467</span>, <span class="hljs-number">0.0767</span>], | |
| [<span class="hljs-number">0.4858</span>, <span class="hljs-number">0.2600</span>, <span class="hljs-number">0.1150</span>, <span class="hljs-number">0.1000</span>], | |
| [<span class="hljs-number">0.4042</span>, <span class="hljs-number">0.4517</span>, <span class="hljs-number">0.1217</span>, <span class="hljs-number">0.1300</span>], | |
| [<span class="hljs-number">0.4242</span>, <span class="hljs-number">0.3217</span>, <span class="hljs-number">0.3617</span>, <span class="hljs-number">0.5567</span>], | |
| [<span class="hljs-number">0.6617</span>, <span class="hljs-number">0.4033</span>, <span class="hljs-number">0.5400</span>, <span class="hljs-number">0.4533</span>]]), <span class="hljs-string">'area'</span>: tensor([ <span class="hljs-number">4048.</span>, <span class="hljs-number">4140.</span>, <span class="hljs-number">5694.</span>, <span class="hljs-number">72478.</span>, <span class="hljs-number">88128.</span>]), <span class="hljs-string">'iscrowd'</span>: tensor([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]), <span class="hljs-string">'orig_size'</span>: tensor([<span class="hljs-number">480</span>, <span class="hljs-number">480</span>])}}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ghsv74">You have successfully augmented the individual images and prepared their annotations. However, preprocessing isn’t | |
| complete yet. In the final step, create a custom <code>collate_fn</code> to batch images together. | |
| Pad images (which are now <code>pixel_values</code>) to the largest image in a batch, and create a corresponding <code>pixel_mask</code> | |
| to indicate which pixels are real (1) and which are padding (0).</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">collate_fn</span>(<span class="hljs-params">batch</span>): | |
| <span class="hljs-meta">... </span> data = {} | |
| <span class="hljs-meta">... </span> data[<span class="hljs-string">"pixel_values"</span>] = torch.stack([x[<span class="hljs-string">"pixel_values"</span>] <span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> batch]) | |
| <span class="hljs-meta">... </span> data[<span class="hljs-string">"labels"</span>] = [x[<span class="hljs-string">"labels"</span>] <span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> batch] | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> <span class="hljs-string">"pixel_mask"</span> <span class="hljs-keyword">in</span> batch[<span class="hljs-number">0</span>]: | |
| <span class="hljs-meta">... </span> data[<span class="hljs-string">"pixel_mask"</span>] = torch.stack([x[<span class="hljs-string">"pixel_mask"</span>] <span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> batch]) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> data | |
| <!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="preparing-function-to-compute-map" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preparing-function-to-compute-map"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preparing function to compute mAP</span></h2> <p data-svelte-h="svelte-1vweb7y">Object detection models are commonly evaluated with a set of <a href="https://cocodataset.org/#detection-eval">COCO-style metrics</a>. We are going to use <code>torchmetrics</code> to compute <code>mAP</code> (mean average precision) and <code>mAR</code> (mean average recall) metrics and will wrap it to <code>compute_metrics</code> function in order to use in <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> for evaluation.</p> <p data-svelte-h="svelte-ceh07m">Intermediate format of boxes used for training is <code>YOLO</code> (normalized) but we will compute metrics for boxes in <code>Pascal VOC</code> (absolute) format in order to correctly handle box areas. Let’s define a function that converts bounding boxes to <code>Pascal VOC</code> format:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers.image_transforms <span class="hljs-keyword">import</span> center_to_corners_format | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">convert_bbox_yolo_to_pascal</span>(<span class="hljs-params">boxes, image_size</span>): | |
| <span class="hljs-meta">... </span> <span class="hljs-string">""" | |
| <span class="hljs-meta">... </span> Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1] | |
| <span class="hljs-meta">... </span> to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates. | |
| <span class="hljs-meta">... </span> Args: | |
| <span class="hljs-meta">... </span> boxes (torch.Tensor): Bounding boxes in YOLO format | |
| <span class="hljs-meta">... </span> image_size (Tuple[int, int]): Image size in format (height, width) | |
| <span class="hljs-meta">... </span> Returns: | |
| <span class="hljs-meta">... </span> torch.Tensor: Bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max) | |
| <span class="hljs-meta">... </span> """</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># convert center to corners format</span> | |
| <span class="hljs-meta">... </span> boxes = center_to_corners_format(boxes) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># convert to absolute coordinates</span> | |
| <span class="hljs-meta">... </span> height, width = image_size | |
| <span class="hljs-meta">... </span> boxes = boxes * torch.tensor([[width, height, width, height]]) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> boxes<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-izy56o">Then, in <code>compute_metrics</code> function we collect <code>predicted</code> and <code>target</code> bounding boxes, scores and labels from evaluation loop results and pass it to the scoring function.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> dataclasses <span class="hljs-keyword">import</span> dataclass | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> torchmetrics.detection.mean_ap <span class="hljs-keyword">import</span> MeanAveragePrecision | |
| <span class="hljs-meta">>>> </span>@dataclass | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">class</span> <span class="hljs-title class_">ModelOutput</span>: | |
| <span class="hljs-meta">... </span> logits: torch.Tensor | |
| <span class="hljs-meta">... </span> pred_boxes: torch.Tensor | |
| <span class="hljs-meta">>>> </span>@torch.no_grad() | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">evaluation_results, image_processor, threshold=<span class="hljs-number">0.0</span>, id2label=<span class="hljs-literal">None</span></span>): | |
| <span class="hljs-meta">... </span> <span class="hljs-string">""" | |
| <span class="hljs-meta">... </span> Compute mean average mAP, mAR and their variants for the object detection task. | |
| <span class="hljs-meta">... </span> Args: | |
| <span class="hljs-meta">... </span> evaluation_results (EvalPrediction): Predictions and targets from evaluation. | |
| <span class="hljs-meta">... </span> threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.0. | |
| <span class="hljs-meta">... </span> id2label (Optional[dict], optional): Mapping from class id to class name. Defaults to None. | |
| <span class="hljs-meta">... </span> Returns: | |
| <span class="hljs-meta">... </span> Mapping[str, float]: Metrics in a form of dictionary {<metric_name>: <metric_value>} | |
| <span class="hljs-meta">... </span> """</span> | |
| <span class="hljs-meta">... </span> predictions, targets = evaluation_results.predictions, evaluation_results.label_ids | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># For metric computation we need to provide:</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># - targets in a form of list of dictionaries with keys "boxes", "labels"</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># - predictions in a form of list of dictionaries with keys "boxes", "scores", "labels"</span> | |
| <span class="hljs-meta">... </span> image_sizes = [] | |
| <span class="hljs-meta">... </span> post_processed_targets = [] | |
| <span class="hljs-meta">... </span> post_processed_predictions = [] | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># Collect targets in the required format for metric computation</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> targets: | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># collect image sizes, we will need them for predictions post processing</span> | |
| <span class="hljs-meta">... </span> batch_image_sizes = torch.tensor(np.array([x[<span class="hljs-string">"orig_size"</span>] <span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> batch])) | |
| <span class="hljs-meta">... </span> image_sizes.append(batch_image_sizes) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># collect targets in the required format for metric computation</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># boxes were converted to YOLO format needed for model training</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># here we will convert them to Pascal VOC format (x_min, y_min, x_max, y_max)</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> image_target <span class="hljs-keyword">in</span> batch: | |
| <span class="hljs-meta">... </span> boxes = torch.tensor(image_target[<span class="hljs-string">"boxes"</span>]) | |
| <span class="hljs-meta">... </span> boxes = convert_bbox_yolo_to_pascal(boxes, image_target[<span class="hljs-string">"orig_size"</span>]) | |
| <span class="hljs-meta">... </span> labels = torch.tensor(image_target[<span class="hljs-string">"class_labels"</span>]) | |
| <span class="hljs-meta">... </span> post_processed_targets.append({<span class="hljs-string">"boxes"</span>: boxes, <span class="hljs-string">"labels"</span>: labels}) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># Collect predictions in the required format for metric computation,</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># model produce boxes in YOLO format, then image_processor convert them to Pascal VOC format</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> batch, target_sizes <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(predictions, image_sizes): | |
| <span class="hljs-meta">... </span> batch_logits, batch_boxes = batch[<span class="hljs-number">1</span>], batch[<span class="hljs-number">2</span>] | |
| <span class="hljs-meta">... </span> output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes)) | |
| <span class="hljs-meta">... </span> post_processed_output = image_processor.post_process_object_detection( | |
| <span class="hljs-meta">... </span> output, threshold=threshold, target_sizes=target_sizes | |
| <span class="hljs-meta">... </span> ) | |
| <span class="hljs-meta">... </span> post_processed_predictions.extend(post_processed_output) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># Compute metrics</span> | |
| <span class="hljs-meta">... </span> metric = MeanAveragePrecision(box_format=<span class="hljs-string">"xyxy"</span>, class_metrics=<span class="hljs-literal">True</span>) | |
| <span class="hljs-meta">... </span> metric.update(post_processed_predictions, post_processed_targets) | |
| <span class="hljs-meta">... </span> metrics = metric.compute() | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># Replace list of per class metrics with separate metric for each class</span> | |
| <span class="hljs-meta">... </span> classes = metrics.pop(<span class="hljs-string">"classes"</span>) | |
| <span class="hljs-meta">... </span> map_per_class = metrics.pop(<span class="hljs-string">"map_per_class"</span>) | |
| <span class="hljs-meta">... </span> mar_100_per_class = metrics.pop(<span class="hljs-string">"mar_100_per_class"</span>) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> class_id, class_map, class_mar <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(classes, map_per_class, mar_100_per_class): | |
| <span class="hljs-meta">... </span> class_name = id2label[class_id.item()] <span class="hljs-keyword">if</span> id2label <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">else</span> class_id.item() | |
| <span class="hljs-meta">... </span> metrics[<span class="hljs-string">f"map_<span class="hljs-subst">{class_name}</span>"</span>] = class_map | |
| <span class="hljs-meta">... </span> metrics[<span class="hljs-string">f"mar_100_<span class="hljs-subst">{class_name}</span>"</span>] = class_mar | |
| <span class="hljs-meta">... </span> metrics = {k: <span class="hljs-built_in">round</span>(v.item(), <span class="hljs-number">4</span>) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> metrics.items()} | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> metrics | |
| <span class="hljs-meta">>>> </span>eval_compute_metrics_fn = partial( | |
| <span class="hljs-meta">... </span> compute_metrics, image_processor=image_processor, id2label=id2label, threshold=<span class="hljs-number">0.0</span> | |
| <span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="training-the-detection-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training-the-detection-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training the detection model</span></h2> <p data-svelte-h="svelte-1970dhn">You have done most of the heavy lifting in the previous sections, so now you are ready to train your model! | |
| The images in this dataset are still quite large, even after resizing. This means that finetuning this model will | |
| require at least one GPU.</p> <p data-svelte-h="svelte-qp7n2l">Training involves the following steps:</p> <ol data-svelte-h="svelte-15jsdwr"><li>Load the model with <a href="/docs/transformers/pr_33913/en/model_doc/auto#transformers.AutoModelForObjectDetection">AutoModelForObjectDetection</a> using the same checkpoint as in the preprocessing.</li> <li>Define your training hyperparameters in <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>.</li> <li>Pass the training arguments to <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> along with the model, dataset, image processor, and data collator.</li> <li>Call <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.train">train()</a> to finetune your model.</li></ol> <p data-svelte-h="svelte-3tgt16">When loading the model from the same checkpoint that you used for the preprocessing, remember to pass the <code>label2id</code> | |
| and <code>id2label</code> maps that you created earlier from the dataset’s metadata. Additionally, we specify <code>ignore_mismatched_sizes=True</code> to replace the existing classification head with a new one.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForObjectDetection | |
| <span class="hljs-meta">>>> </span>model = AutoModelForObjectDetection.from_pretrained( | |
| <span class="hljs-meta">... </span> MODEL_NAME, | |
| <span class="hljs-meta">... </span> id2label=id2label, | |
| <span class="hljs-meta">... </span> label2id=label2id, | |
| <span class="hljs-meta">... </span> ignore_mismatched_sizes=<span class="hljs-literal">True</span>, | |
| <span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-tw8vr8">In the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> use <code>output_dir</code> to specify where to save your model, then configure hyperparameters as you see fit. For <code>num_train_epochs=30</code> training will take about 35 minutes in Google Colab T4 GPU, increase the number of epoch to get better results.</p> <p data-svelte-h="svelte-1z0bxll">Important notes:</p> <ul data-svelte-h="svelte-18syxls"><li>Do not remove unused columns because this will drop the image column. Without the image column, you | |
| can’t create <code>pixel_values</code>. For this reason, set <code>remove_unused_columns</code> to <code>False</code>.</li> <li>Set <code>eval_do_concat_batches=False</code> to get proper evaluation results. Images have different number of target boxes, if batches are concatenated we will not be able to determine which boxes belongs to particular image.</li></ul> <p data-svelte-h="svelte-m8o45s">If you wish to share your model by pushing to the Hub, set <code>push_to_hub</code> to <code>True</code> (you must be signed in to Hugging | |
| Face to upload your model).</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| <span class="hljs-meta">>>> </span>training_args = TrainingArguments( | |
| <span class="hljs-meta">... </span> output_dir=<span class="hljs-string">"detr_finetuned_cppe5"</span>, | |
| <span class="hljs-meta">... </span> num_train_epochs=<span class="hljs-number">30</span>, | |
| <span class="hljs-meta">... </span> fp16=<span class="hljs-literal">False</span>, | |
| <span class="hljs-meta">... </span> per_device_train_batch_size=<span class="hljs-number">8</span>, | |
| <span class="hljs-meta">... </span> dataloader_num_workers=<span class="hljs-number">4</span>, | |
| <span class="hljs-meta">... </span> learning_rate=<span class="hljs-number">5e-5</span>, | |
| <span class="hljs-meta">... </span> lr_scheduler_type=<span class="hljs-string">"cosine"</span>, | |
| <span class="hljs-meta">... </span> weight_decay=<span class="hljs-number">1e-4</span>, | |
| <span class="hljs-meta">... </span> max_grad_norm=<span class="hljs-number">0.01</span>, | |
| <span class="hljs-meta">... </span> metric_for_best_model=<span class="hljs-string">"eval_map"</span>, | |
| <span class="hljs-meta">... </span> greater_is_better=<span class="hljs-literal">True</span>, | |
| <span class="hljs-meta">... </span> load_best_model_at_end=<span class="hljs-literal">True</span>, | |
| <span class="hljs-meta">... </span> eval_strategy=<span class="hljs-string">"epoch"</span>, | |
| <span class="hljs-meta">... </span> save_strategy=<span class="hljs-string">"epoch"</span>, | |
| <span class="hljs-meta">... </span> save_total_limit=<span class="hljs-number">2</span>, | |
| <span class="hljs-meta">... </span> remove_unused_columns=<span class="hljs-literal">False</span>, | |
| <span class="hljs-meta">... </span> eval_do_concat_batches=<span class="hljs-literal">False</span>, | |
| <span class="hljs-meta">... </span> push_to_hub=<span class="hljs-literal">True</span>, | |
| <span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1rvg1we">Finally, bring everything together, and call <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.train">train()</a>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer | |
| <span class="hljs-meta">>>> </span>trainer = Trainer( | |
| <span class="hljs-meta">... </span> model=model, | |
| <span class="hljs-meta">... </span> args=training_args, | |
| <span class="hljs-meta">... </span> train_dataset=cppe5[<span class="hljs-string">"train"</span>], | |
| <span class="hljs-meta">... </span> eval_dataset=cppe5[<span class="hljs-string">"validation"</span>], | |
| <span class="hljs-meta">... </span> processing_class=image_processor, | |
| <span class="hljs-meta">... </span> data_collator=collate_fn, | |
| <span class="hljs-meta">... </span> compute_metrics=eval_compute_metrics_fn, | |
| <span class="hljs-meta">... </span>) | |
| <span class="hljs-meta">>>> </span>trainer.train()<!-- HTML_TAG_END --></pre></div> <div data-svelte-h="svelte-z5zay"><progress value="3210" max="3210" style="width:300px; height:20px; vertical-align: middle;"></progress> | |
| [3210/3210 26:07, Epoch 30/30]</div> <table border="1" class="dataframe" data-svelte-h="svelte-186gcvq"><thead><tr style="text-align: left;"><th>Epoch</th> <th>Training Loss</th> <th>Validation Loss</th> <th>Map</th> <th>Map 50</th> <th>Map 75</th> <th>Map Small</th> <th>Map Medium</th> <th>Map Large</th> <th>Mar 1</th> <th>Mar 10</th> <th>Mar 100</th> <th>Mar Small</th> <th>Mar Medium</th> <th>Mar Large</th> <th>Map Coverall</th> <th>Mar 100 Coverall</th> <th>Map Face Shield</th> <th>Mar 100 Face Shield</th> <th>Map Gloves</th> <th>Mar 100 Gloves</th> <th>Map Goggles</th> <th>Mar 100 Goggles</th> <th>Map Mask</th> <th>Mar 100 Mask</th></tr></thead> <tbody><tr><td>1</td> <td>No log</td> <td>2.629903</td> <td>0.008900</td> <td>0.023200</td> <td>0.006500</td> <td>0.001300</td> <td>0.002800</td> <td>0.020500</td> <td>0.021500</td> <td>0.070400</td> <td>0.101400</td> <td>0.007600</td> <td>0.106200</td> <td>0.096100</td> <td>0.036700</td> <td>0.232000</td> <td>0.000300</td> <td>0.019000</td> <td>0.003900</td> <td>0.125400</td> <td>0.000100</td> <td>0.003100</td> <td>0.003500</td> <td>0.127600</td></tr> <tr><td>2</td> <td>No log</td> <td>3.479864</td> <td>0.014800</td> <td>0.034600</td> <td>0.010800</td> <td>0.008600</td> <td>0.011700</td> <td>0.012500</td> <td>0.041100</td> <td>0.098700</td> <td>0.130000</td> <td>0.056000</td> <td>0.062200</td> <td>0.111900</td> <td>0.053500</td> <td>0.447300</td> <td>0.010600</td> <td>0.100000</td> <td>0.000200</td> <td>0.022800</td> <td>0.000100</td> <td>0.015400</td> <td>0.009700</td> <td>0.064400</td></tr> <tr><td>3</td> <td>No log</td> <td>2.107622</td> <td>0.041700</td> <td>0.094000</td> <td>0.034300</td> <td>0.024100</td> <td>0.026400</td> <td>0.047400</td> <td>0.091500</td> <td>0.182800</td> <td>0.225800</td> <td>0.087200</td> <td>0.199400</td> <td>0.210600</td> <td>0.150900</td> <td>0.571200</td> <td>0.017300</td> <td>0.101300</td> <td>0.007300</td> <td>0.180400</td> <td>0.002100</td> <td>0.026200</td> <td>0.031000</td> <td>0.250200</td></tr> <tr><td>4</td> <td>No log</td> <td>2.031242</td> <td>0.055900</td> <td>0.120600</td> <td>0.046900</td> <td>0.013800</td> <td>0.038100</td> <td>0.090300</td> <td>0.105900</td> <td>0.225600</td> <td>0.266100</td> <td>0.130200</td> <td>0.228100</td> <td>0.330000</td> <td>0.191000</td> <td>0.572100</td> <td>0.010600</td> <td>0.157000</td> <td>0.014600</td> <td>0.235300</td> <td>0.001700</td> <td>0.052300</td> <td>0.061800</td> <td>0.313800</td></tr> <tr><td>5</td> <td>3.889400</td> <td>1.883433</td> <td>0.089700</td> <td>0.201800</td> <td>0.067300</td> <td>0.022800</td> <td>0.065300</td> <td>0.129500</td> <td>0.136000</td> <td>0.272200</td> <td>0.303700</td> <td>0.112900</td> <td>0.312500</td> <td>0.424600</td> <td>0.300200</td> <td>0.585100</td> <td>0.032700</td> <td>0.202500</td> <td>0.031300</td> <td>0.271000</td> <td>0.008700</td> <td>0.126200</td> <td>0.075500</td> <td>0.333800</td></tr> <tr><td>6</td> <td>3.889400</td> <td>1.807503</td> <td>0.118500</td> <td>0.270900</td> <td>0.090200</td> <td>0.034900</td> <td>0.076700</td> <td>0.152500</td> <td>0.146100</td> <td>0.297800</td> <td>0.325400</td> <td>0.171700</td> <td>0.283700</td> <td>0.545900</td> <td>0.396900</td> <td>0.554500</td> <td>0.043000</td> <td>0.262000</td> <td>0.054500</td> <td>0.271900</td> <td>0.020300</td> <td>0.230800</td> <td>0.077600</td> <td>0.308000</td></tr> <tr><td>7</td> <td>3.889400</td> <td>1.716169</td> <td>0.143500</td> <td>0.307700</td> <td>0.123200</td> <td>0.045800</td> <td>0.097800</td> <td>0.258300</td> <td>0.165300</td> <td>0.327700</td> <td>0.352600</td> <td>0.140900</td> <td>0.336700</td> <td>0.599400</td> <td>0.442900</td> <td>0.620700</td> <td>0.069400</td> <td>0.301300</td> <td>0.081600</td> <td>0.292000</td> <td>0.011000</td> <td>0.230800</td> <td>0.112700</td> <td>0.318200</td></tr> <tr><td>8</td> <td>3.889400</td> <td>1.679014</td> <td>0.153000</td> <td>0.355800</td> <td>0.127900</td> <td>0.038700</td> <td>0.115600</td> <td>0.291600</td> <td>0.176000</td> <td>0.322500</td> <td>0.349700</td> <td>0.135600</td> <td>0.326100</td> <td>0.643700</td> <td>0.431700</td> <td>0.582900</td> <td>0.069800</td> <td>0.265800</td> <td>0.088600</td> <td>0.274600</td> <td>0.028300</td> <td>0.280000</td> <td>0.146700</td> <td>0.345300</td></tr> <tr><td>9</td> <td>3.889400</td> <td>1.618239</td> <td>0.172100</td> <td>0.375300</td> <td>0.137600</td> <td>0.046100</td> <td>0.141700</td> <td>0.308500</td> <td>0.194000</td> <td>0.356200</td> <td>0.386200</td> <td>0.162400</td> <td>0.359200</td> <td>0.677700</td> <td>0.469800</td> <td>0.623900</td> <td>0.102100</td> <td>0.317700</td> <td>0.099100</td> <td>0.290200</td> <td>0.029300</td> <td>0.335400</td> <td>0.160200</td> <td>0.364000</td></tr> <tr><td>10</td> <td>1.599700</td> <td>1.572512</td> <td>0.179500</td> <td>0.400400</td> <td>0.147200</td> <td>0.056500</td> <td>0.141700</td> <td>0.316700</td> <td>0.213100</td> <td>0.357600</td> <td>0.381300</td> <td>0.197900</td> <td>0.344300</td> <td>0.638500</td> <td>0.466900</td> <td>0.623900</td> <td>0.101300</td> <td>0.311400</td> <td>0.104700</td> <td>0.279500</td> <td>0.051600</td> <td>0.338500</td> <td>0.173000</td> <td>0.353300</td></tr> <tr><td>11</td> <td>1.599700</td> <td>1.528889</td> <td>0.192200</td> <td>0.415000</td> <td>0.160800</td> <td>0.053700</td> <td>0.150500</td> <td>0.378000</td> <td>0.211500</td> <td>0.371700</td> <td>0.397800</td> <td>0.204900</td> <td>0.374600</td> <td>0.684800</td> <td>0.491900</td> <td>0.632400</td> <td>0.131200</td> <td>0.346800</td> <td>0.122000</td> <td>0.300900</td> <td>0.038400</td> <td>0.344600</td> <td>0.177500</td> <td>0.364400</td></tr> <tr><td>12</td> <td>1.599700</td> <td>1.517532</td> <td>0.198300</td> <td>0.429800</td> <td>0.159800</td> <td>0.066400</td> <td>0.162900</td> <td>0.383300</td> <td>0.220700</td> <td>0.382100</td> <td>0.405400</td> <td>0.214800</td> <td>0.383200</td> <td>0.672900</td> <td>0.469000</td> <td>0.610400</td> <td>0.167800</td> <td>0.379700</td> <td>0.119700</td> <td>0.307100</td> <td>0.038100</td> <td>0.335400</td> <td>0.196800</td> <td>0.394200</td></tr> <tr><td>13</td> <td>1.599700</td> <td>1.488849</td> <td>0.209800</td> <td>0.452300</td> <td>0.172300</td> <td>0.094900</td> <td>0.171100</td> <td>0.437800</td> <td>0.222000</td> <td>0.379800</td> <td>0.411500</td> <td>0.203800</td> <td>0.397300</td> <td>0.707500</td> <td>0.470700</td> <td>0.620700</td> <td>0.186900</td> <td>0.407600</td> <td>0.124200</td> <td>0.306700</td> <td>0.059300</td> <td>0.355400</td> <td>0.207700</td> <td>0.367100</td></tr> <tr><td>14</td> <td>1.599700</td> <td>1.482210</td> <td>0.228900</td> <td>0.482600</td> <td>0.187800</td> <td>0.083600</td> <td>0.191800</td> <td>0.444100</td> <td>0.225900</td> <td>0.376900</td> <td>0.407400</td> <td>0.182500</td> <td>0.384800</td> <td>0.700600</td> <td>0.512100</td> <td>0.640100</td> <td>0.175000</td> <td>0.363300</td> <td>0.144300</td> <td>0.300000</td> <td>0.083100</td> <td>0.363100</td> <td>0.229900</td> <td>0.370700</td></tr> <tr><td>15</td> <td>1.326800</td> <td>1.475198</td> <td>0.216300</td> <td>0.455600</td> <td>0.174900</td> <td>0.088500</td> <td>0.183500</td> <td>0.424400</td> <td>0.226900</td> <td>0.373400</td> <td>0.404300</td> <td>0.199200</td> <td>0.396400</td> <td>0.677800</td> <td>0.496300</td> <td>0.633800</td> <td>0.166300</td> <td>0.392400</td> <td>0.128900</td> <td>0.312900</td> <td>0.085200</td> <td>0.312300</td> <td>0.205000</td> <td>0.370200</td></tr> <tr><td>16</td> <td>1.326800</td> <td>1.459697</td> <td>0.233200</td> <td>0.504200</td> <td>0.192200</td> <td>0.096000</td> <td>0.202000</td> <td>0.430800</td> <td>0.239100</td> <td>0.382400</td> <td>0.412600</td> <td>0.219500</td> <td>0.403100</td> <td>0.670400</td> <td>0.485200</td> <td>0.625200</td> <td>0.196500</td> <td>0.410100</td> <td>0.135700</td> <td>0.299600</td> <td>0.123100</td> <td>0.356900</td> <td>0.225300</td> <td>0.371100</td></tr> <tr><td>17</td> <td>1.326800</td> <td>1.407340</td> <td>0.243400</td> <td>0.511900</td> <td>0.204500</td> <td>0.121000</td> <td>0.215700</td> <td>0.468000</td> <td>0.246200</td> <td>0.394600</td> <td>0.424200</td> <td>0.225900</td> <td>0.416100</td> <td>0.705200</td> <td>0.494900</td> <td>0.638300</td> <td>0.224900</td> <td>0.430400</td> <td>0.157200</td> <td>0.317900</td> <td>0.115700</td> <td>0.369200</td> <td>0.224200</td> <td>0.365300</td></tr> <tr><td>18</td> <td>1.326800</td> <td>1.419522</td> <td>0.245100</td> <td>0.521500</td> <td>0.210000</td> <td>0.116100</td> <td>0.211500</td> <td>0.489900</td> <td>0.255400</td> <td>0.391600</td> <td>0.419700</td> <td>0.198800</td> <td>0.421200</td> <td>0.701400</td> <td>0.501800</td> <td>0.634200</td> <td>0.226700</td> <td>0.410100</td> <td>0.154400</td> <td>0.321400</td> <td>0.105900</td> <td>0.352300</td> <td>0.236700</td> <td>0.380400</td></tr> <tr><td>19</td> <td>1.158600</td> <td>1.398764</td> <td>0.253600</td> <td>0.519200</td> <td>0.213600</td> <td>0.135200</td> <td>0.207700</td> <td>0.491900</td> <td>0.257300</td> <td>0.397300</td> <td>0.428000</td> <td>0.241400</td> <td>0.401800</td> <td>0.703500</td> <td>0.509700</td> <td>0.631100</td> <td>0.236700</td> <td>0.441800</td> <td>0.155900</td> <td>0.330800</td> <td>0.128100</td> <td>0.352300</td> <td>0.237500</td> <td>0.384000</td></tr> <tr><td>20</td> <td>1.158600</td> <td>1.390591</td> <td>0.248800</td> <td>0.520200</td> <td>0.216600</td> <td>0.127500</td> <td>0.211400</td> <td>0.471900</td> <td>0.258300</td> <td>0.407000</td> <td>0.429100</td> <td>0.240300</td> <td>0.407600</td> <td>0.708500</td> <td>0.505800</td> <td>0.623400</td> <td>0.235500</td> <td>0.431600</td> <td>0.150000</td> <td>0.325000</td> <td>0.125700</td> <td>0.375400</td> <td>0.227200</td> <td>0.390200</td></tr> <tr><td>21</td> <td>1.158600</td> <td>1.360608</td> <td>0.262700</td> <td>0.544800</td> <td>0.222100</td> <td>0.134700</td> <td>0.230000</td> <td>0.487500</td> <td>0.269500</td> <td>0.413300</td> <td>0.436300</td> <td>0.236200</td> <td>0.419100</td> <td>0.709300</td> <td>0.514100</td> <td>0.637400</td> <td>0.257200</td> <td>0.450600</td> <td>0.165100</td> <td>0.338400</td> <td>0.139400</td> <td>0.372300</td> <td>0.237700</td> <td>0.382700</td></tr> <tr><td>22</td> <td>1.158600</td> <td>1.368296</td> <td>0.262800</td> <td>0.542400</td> <td>0.236400</td> <td>0.137400</td> <td>0.228100</td> <td>0.498500</td> <td>0.266500</td> <td>0.409000</td> <td>0.433000</td> <td>0.239900</td> <td>0.418500</td> <td>0.697500</td> <td>0.520500</td> <td>0.641000</td> <td>0.257500</td> <td>0.455700</td> <td>0.162600</td> <td>0.334800</td> <td>0.140200</td> <td>0.353800</td> <td>0.233200</td> <td>0.379600</td></tr> <tr><td>23</td> <td>1.158600</td> <td>1.368176</td> <td>0.264800</td> <td>0.541100</td> <td>0.233100</td> <td>0.138200</td> <td>0.223900</td> <td>0.498700</td> <td>0.272300</td> <td>0.407400</td> <td>0.434400</td> <td>0.233100</td> <td>0.418300</td> <td>0.702000</td> <td>0.524400</td> <td>0.642300</td> <td>0.262300</td> <td>0.444300</td> <td>0.159700</td> <td>0.335300</td> <td>0.140500</td> <td>0.366200</td> <td>0.236900</td> <td>0.384000</td></tr> <tr><td>24</td> <td>1.049700</td> <td>1.355271</td> <td>0.269700</td> <td>0.549200</td> <td>0.239100</td> <td>0.134700</td> <td>0.229900</td> <td>0.519200</td> <td>0.274800</td> <td>0.412700</td> <td>0.437600</td> <td>0.245400</td> <td>0.417200</td> <td>0.711200</td> <td>0.523200</td> <td>0.644100</td> <td>0.272100</td> <td>0.440500</td> <td>0.166700</td> <td>0.341500</td> <td>0.137700</td> <td>0.373800</td> <td>0.249000</td> <td>0.388000</td></tr> <tr><td>25</td> <td>1.049700</td> <td>1.355180</td> <td>0.272500</td> <td>0.547900</td> <td>0.243800</td> <td>0.149700</td> <td>0.229900</td> <td>0.523100</td> <td>0.272500</td> <td>0.415700</td> <td>0.442200</td> <td>0.256200</td> <td>0.420200</td> <td>0.705800</td> <td>0.523900</td> <td>0.639600</td> <td>0.271700</td> <td>0.451900</td> <td>0.166300</td> <td>0.346900</td> <td>0.153700</td> <td>0.383100</td> <td>0.247000</td> <td>0.389300</td></tr> <tr><td>26</td> <td>1.049700</td> <td>1.349337</td> <td>0.275600</td> <td>0.556300</td> <td>0.246400</td> <td>0.146700</td> <td>0.234800</td> <td>0.516300</td> <td>0.274200</td> <td>0.418300</td> <td>0.440900</td> <td>0.248700</td> <td>0.418900</td> <td>0.705800</td> <td>0.523200</td> <td>0.636500</td> <td>0.274700</td> <td>0.440500</td> <td>0.172400</td> <td>0.349100</td> <td>0.155600</td> <td>0.384600</td> <td>0.252300</td> <td>0.393800</td></tr> <tr><td>27</td> <td>1.049700</td> <td>1.350782</td> <td>0.275200</td> <td>0.548700</td> <td>0.246800</td> <td>0.147300</td> <td>0.236400</td> <td>0.527200</td> <td>0.280100</td> <td>0.416200</td> <td>0.442600</td> <td>0.253400</td> <td>0.424000</td> <td>0.710300</td> <td>0.526600</td> <td>0.640100</td> <td>0.273200</td> <td>0.445600</td> <td>0.167000</td> <td>0.346900</td> <td>0.160100</td> <td>0.387700</td> <td>0.249200</td> <td>0.392900</td></tr> <tr><td>28</td> <td>1.049700</td> <td>1.346533</td> <td>0.277000</td> <td>0.552800</td> <td>0.252900</td> <td>0.147400</td> <td>0.240000</td> <td>0.527600</td> <td>0.280900</td> <td>0.420900</td> <td>0.444100</td> <td>0.255500</td> <td>0.424500</td> <td>0.711200</td> <td>0.530200</td> <td>0.646800</td> <td>0.277400</td> <td>0.441800</td> <td>0.170900</td> <td>0.346900</td> <td>0.156600</td> <td>0.389200</td> <td>0.249600</td> <td>0.396000</td></tr> <tr><td>29</td> <td>0.993700</td> <td>1.346575</td> <td>0.277100</td> <td>0.554800</td> <td>0.252900</td> <td>0.148400</td> <td>0.239700</td> <td>0.523600</td> <td>0.278400</td> <td>0.420000</td> <td>0.443300</td> <td>0.256300</td> <td>0.424000</td> <td>0.705600</td> <td>0.529600</td> <td>0.647300</td> <td>0.273900</td> <td>0.439200</td> <td>0.174300</td> <td>0.348700</td> <td>0.157600</td> <td>0.386200</td> <td>0.250100</td> <td>0.395100</td></tr> <tr><td>30</td> <td>0.993700</td> <td>1.346446</td> <td>0.277400</td> <td>0.554700</td> <td>0.252700</td> <td>0.147900</td> <td>0.240800</td> <td>0.523600</td> <td>0.278800</td> <td>0.420400</td> <td>0.443300</td> <td>0.256100</td> <td>0.424200</td> <td>0.705500</td> <td>0.530100</td> <td>0.646800</td> <td>0.275600</td> <td>0.440500</td> <td>0.174500</td> <td>0.348700</td> <td>0.157300</td> <td>0.386200</td> <td>0.249200</td> <td>0.394200</td></tr></tbody> </table><p data-svelte-h="svelte-b5f4e5"></p><p data-svelte-h="svelte-19mv6nr">If you have set <code>push_to_hub</code> to <code>True</code> in the <code>training_args</code>, the training checkpoints are pushed to the | |
| Hugging Face Hub. Upon training completion, push the final model to the Hub as well by calling the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.push_to_hub">push_to_hub()</a> method.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>trainer.push_to_hub()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="evaluate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#evaluate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Evaluate</span></h2> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> pprint <span class="hljs-keyword">import</span> pprint | |
| <span class="hljs-meta">>>> </span>metrics = trainer.evaluate(eval_dataset=cppe5[<span class="hljs-string">"test"</span>], metric_key_prefix=<span class="hljs-string">"test"</span>) | |
| <span class="hljs-meta">>>> </span>pprint(metrics) | |
| {<span class="hljs-string">'epoch'</span>: <span class="hljs-number">30.0</span>, | |
| <span class="hljs-string">'test_loss'</span>: <span class="hljs-number">1.0877351760864258</span>, | |
| <span class="hljs-string">'test_map'</span>: <span class="hljs-number">0.4116</span>, | |
| <span class="hljs-string">'test_map_50'</span>: <span class="hljs-number">0.741</span>, | |
| <span class="hljs-string">'test_map_75'</span>: <span class="hljs-number">0.3663</span>, | |
| <span class="hljs-string">'test_map_Coverall'</span>: <span class="hljs-number">0.5937</span>, | |
| <span class="hljs-string">'test_map_Face_Shield'</span>: <span class="hljs-number">0.5863</span>, | |
| <span class="hljs-string">'test_map_Gloves'</span>: <span class="hljs-number">0.3416</span>, | |
| <span class="hljs-string">'test_map_Goggles'</span>: <span class="hljs-number">0.1468</span>, | |
| <span class="hljs-string">'test_map_Mask'</span>: <span class="hljs-number">0.3894</span>, | |
| <span class="hljs-string">'test_map_large'</span>: <span class="hljs-number">0.5637</span>, | |
| <span class="hljs-string">'test_map_medium'</span>: <span class="hljs-number">0.3257</span>, | |
| <span class="hljs-string">'test_map_small'</span>: <span class="hljs-number">0.3589</span>, | |
| <span class="hljs-string">'test_mar_1'</span>: <span class="hljs-number">0.323</span>, | |
| <span class="hljs-string">'test_mar_10'</span>: <span class="hljs-number">0.5237</span>, | |
| <span class="hljs-string">'test_mar_100'</span>: <span class="hljs-number">0.5587</span>, | |
| <span class="hljs-string">'test_mar_100_Coverall'</span>: <span class="hljs-number">0.6756</span>, | |
| <span class="hljs-string">'test_mar_100_Face_Shield'</span>: <span class="hljs-number">0.7294</span>, | |
| <span class="hljs-string">'test_mar_100_Gloves'</span>: <span class="hljs-number">0.4721</span>, | |
| <span class="hljs-string">'test_mar_100_Goggles'</span>: <span class="hljs-number">0.4125</span>, | |
| <span class="hljs-string">'test_mar_100_Mask'</span>: <span class="hljs-number">0.5038</span>, | |
| <span class="hljs-string">'test_mar_large'</span>: <span class="hljs-number">0.7283</span>, | |
| <span class="hljs-string">'test_mar_medium'</span>: <span class="hljs-number">0.4901</span>, | |
| <span class="hljs-string">'test_mar_small'</span>: <span class="hljs-number">0.4469</span>, | |
| <span class="hljs-string">'test_runtime'</span>: <span class="hljs-number">1.6526</span>, | |
| <span class="hljs-string">'test_samples_per_second'</span>: <span class="hljs-number">17.548</span>, | |
| <span class="hljs-string">'test_steps_per_second'</span>: <span class="hljs-number">2.42</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1usu20f">These results can be further improved by adjusting the hyperparameters in <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>. Give it a go!</p> <h2 class="relative group"><a id="inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Inference</span></h2> <p data-svelte-h="svelte-1awi77u">Now that you have finetuned a model, evaluated it, and uploaded it to the Hugging Face Hub, you can use it for inference.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> requests | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image, ImageDraw | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoImageProcessor, AutoModelForObjectDetection | |
| <span class="hljs-meta">>>> </span>url = <span class="hljs-string">"https://images.pexels.com/photos/8413299/pexels-photo-8413299.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2"</span> | |
| <span class="hljs-meta">>>> </span>image = Image.<span class="hljs-built_in">open</span>(requests.get(url, stream=<span class="hljs-literal">True</span>).raw)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-f44t9b">Load model and image processor from the Hugging Face Hub (skip to use already trained in this session):</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> accelerate.test_utils.testing <span class="hljs-keyword">import</span> get_backend | |
| <span class="hljs-comment"># automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)</span> | |
| <span class="hljs-meta">>>> </span>device, _, _ = get_backend() | |
| <span class="hljs-meta">>>> </span>model_repo = <span class="hljs-string">"qubvel-hf/detr_finetuned_cppe5"</span> | |
| <span class="hljs-meta">>>> </span>image_processor = AutoImageProcessor.from_pretrained(model_repo) | |
| <span class="hljs-meta">>>> </span>model = AutoModelForObjectDetection.from_pretrained(model_repo) | |
| <span class="hljs-meta">>>> </span>model = model.to(device)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-m3ccz3">And detect bounding boxes:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">with</span> torch.no_grad(): | |
| <span class="hljs-meta">... </span> inputs = image_processor(images=[image], return_tensors=<span class="hljs-string">"pt"</span>) | |
| <span class="hljs-meta">... </span> outputs = model(**inputs.to(device)) | |
| <span class="hljs-meta">... </span> target_sizes = torch.tensor([[image.size[<span class="hljs-number">1</span>], image.size[<span class="hljs-number">0</span>]]]) | |
| <span class="hljs-meta">... </span> results = image_processor.post_process_object_detection(outputs, threshold=<span class="hljs-number">0.3</span>, target_sizes=target_sizes)[<span class="hljs-number">0</span>] | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">for</span> score, label, box <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(results[<span class="hljs-string">"scores"</span>], results[<span class="hljs-string">"labels"</span>], results[<span class="hljs-string">"boxes"</span>]): | |
| <span class="hljs-meta">... </span> box = [<span class="hljs-built_in">round</span>(i, <span class="hljs-number">2</span>) <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> box.tolist()] | |
| <span class="hljs-meta">... </span> <span class="hljs-built_in">print</span>( | |
| <span class="hljs-meta">... </span> <span class="hljs-string">f"Detected <span class="hljs-subst">{model.config.id2label[label.item()]}</span> with confidence "</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-string">f"<span class="hljs-subst">{<span class="hljs-built_in">round</span>(score.item(), <span class="hljs-number">3</span>)}</span> at location <span class="hljs-subst">{box}</span>"</span> | |
| <span class="hljs-meta">... </span> ) | |
| Detected Gloves <span class="hljs-keyword">with</span> confidence <span class="hljs-number">0.683</span> at location [<span class="hljs-number">244.58</span>, <span class="hljs-number">124.33</span>, <span class="hljs-number">300.35</span>, <span class="hljs-number">185.13</span>] | |
| Detected Mask <span class="hljs-keyword">with</span> confidence <span class="hljs-number">0.517</span> at location [<span class="hljs-number">143.73</span>, <span class="hljs-number">64.58</span>, <span class="hljs-number">219.57</span>, <span class="hljs-number">125.89</span>] | |
| Detected Gloves <span class="hljs-keyword">with</span> confidence <span class="hljs-number">0.425</span> at location [<span class="hljs-number">179.15</span>, <span class="hljs-number">155.57</span>, <span class="hljs-number">262.4</span>, <span class="hljs-number">226.35</span>] | |
| Detected Coverall <span class="hljs-keyword">with</span> confidence <span class="hljs-number">0.407</span> at location [<span class="hljs-number">307.13</span>, -<span class="hljs-number">1.18</span>, <span class="hljs-number">477.82</span>, <span class="hljs-number">318.06</span>] | |
| Detected Coverall <span class="hljs-keyword">with</span> confidence <span class="hljs-number">0.391</span> at location [<span class="hljs-number">68.61</span>, <span class="hljs-number">126.66</span>, <span class="hljs-number">309.03</span>, <span class="hljs-number">318.89</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-7zeucu">Let’s plot the result:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>draw = ImageDraw.Draw(image) | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">for</span> score, label, box <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(results[<span class="hljs-string">"scores"</span>], results[<span class="hljs-string">"labels"</span>], results[<span class="hljs-string">"boxes"</span>]): | |
| <span class="hljs-meta">... </span> box = [<span class="hljs-built_in">round</span>(i, <span class="hljs-number">2</span>) <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> box.tolist()] | |
| <span class="hljs-meta">... </span> x, y, x2, y2 = <span class="hljs-built_in">tuple</span>(box) | |
| <span class="hljs-meta">... </span> draw.rectangle((x, y, x2, y2), outline=<span class="hljs-string">"red"</span>, width=<span class="hljs-number">1</span>) | |
| <span class="hljs-meta">... </span> draw.text((x, y), model.config.id2label[label.item()], fill=<span class="hljs-string">"white"</span>) | |
| <span class="hljs-meta">>>> </span>image<!-- HTML_TAG_END --></pre></div> <div class="flex justify-center" data-svelte-h="svelte-bozw7r"><img src="https://i.imgur.com/oDUqD0K.png" alt="Object detection result on a new image"></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/tasks/object_detection.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_z647wz = { | |
| assets: "/docs/transformers/pr_33913/en", | |
| base: "/docs/transformers/pr_33913/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"), | |
| import("/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 431], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 120 kB
- Xet hash:
- 270f3fbdd81de7869f75da5edc105538e96268178323aa2ebe11cb816b285150
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.