Buckets:

hf-doc-build/doc-dev / transformers /main /en /tasks /visual_question_answering.html
rtrm's picture
download
raw
73 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Visual Question Answering&quot;,&quot;local&quot;:&quot;visual-question-answering&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Fine-tuning ViLT&quot;,&quot;local&quot;:&quot;fine-tuning-vilt&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Load the data&quot;,&quot;local&quot;:&quot;load-the-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Preprocessing data&quot;,&quot;local&quot;:&quot;preprocessing-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Train the model&quot;,&quot;local&quot;:&quot;train-the-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Zero-shot VQA&quot;,&quot;local&quot;:&quot;zero-shot-vqa&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/scheduler.25b97de1.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/singletons.0f2b7d5f.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.e188933d.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/paths.3d04d2c6.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.d9030fc9.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/0.026d2fdd.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/425.1b4c0a8d.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/DocNotebookDropdown.5ea6cb78.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/globals.7f7f1b26.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/EditOnGithub.91d95064.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Visual Question Answering&quot;,&quot;local&quot;:&quot;visual-question-answering&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Fine-tuning ViLT&quot;,&quot;local&quot;:&quot;fine-tuning-vilt&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Load the data&quot;,&quot;local&quot;:&quot;load-the-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Preprocessing data&quot;,&quot;local&quot;:&quot;preprocessing-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Train the model&quot;,&quot;local&quot;:&quot;train-the-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Zero-shot VQA&quot;,&quot;local&quot;:&quot;zero-shot-vqa&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="visual-question-answering" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#visual-question-answering"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Visual Question Answering</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"> </button> </div> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"> </button> </div></div> <p data-svelte-h="svelte-vljp1h">Visual Question Answering (VQA) is the task of answering open-ended questions based on an image.
The input to models supporting this task is typically a combination of an image and a question, and the output is an
answer expressed in natural language.</p> <p data-svelte-h="svelte-1r09kty">Some noteworthy use case examples for VQA include:</p> <ul data-svelte-h="svelte-geftd8"><li>Accessibility applications for visually impaired individuals.</li> <li>Education: posing questions about visual materials presented in lectures or textbooks. VQA can also be utilized in interactive museum exhibits or historical sites.</li> <li>Customer service and e-commerce: VQA can enhance user experience by letting users ask questions about products.</li> <li>Image retrieval: VQA models can be used to retrieve images with specific characteristics. For example, the user can ask “Is there a dog?” to find all images with dogs from a set of images.</li></ul> <p data-svelte-h="svelte-jr2b5g">In this guide you’ll learn how to:</p> <ul data-svelte-h="svelte-168wjvr"><li>Fine-tune a classification VQA model, specifically <a href="../model_doc/vilt">ViLT</a>, on the <a href="https://huggingface.co/datasets/Graphcore/vqa" rel="nofollow"><code>Graphcore/vqa</code> dataset</a>.</li> <li>Use your fine-tuned ViLT for inference.</li> <li>Run zero-shot VQA inference with a generative model, like BLIP-2.</li></ul> <h2 class="relative group"><a id="fine-tuning-vilt" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tuning-vilt"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tuning ViLT</span></h2> <p data-svelte-h="svelte-132kvya">ViLT model incorporates text embeddings into a Vision Transformer (ViT), allowing it to have a minimal design for
Vision-and-Language Pre-training (VLP). This model can be used for several downstream tasks. For the VQA task, a classifier
head is placed on top (a linear layer on top of the final hidden state of the <code>[CLS]</code> token) and randomly initialized.
Visual Question Answering is thus treated as a <strong>classification problem</strong>.</p> <p data-svelte-h="svelte-x085ix">More recent models, such as BLIP, BLIP-2, and InstructBLIP, treat VQA as a generative task. Later in this guide we
illustrate how to use them for zero-shot VQA inference.</p> <p data-svelte-h="svelte-qn4ey1">Before you begin, make sure you have all the necessary libraries installed.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install -q transformers datasets<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1yqpblu">We encourage you to share your model with the community. Log in to your Hugging Face account to upload it to the 🤗 Hub.
When prompted, enter your token to log in:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login
<span class="hljs-meta">&gt;&gt;&gt; </span>notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-hhk89w">Let’s define the model checkpoint as a global variable.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>model_checkpoint = <span class="hljs-string">&quot;dandelin/vilt-b32-mlm&quot;</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="load-the-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#load-the-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Load the data</span></h2> <p data-svelte-h="svelte-n8sxxt">For illustration purposes, in this guide we use a very small sample of the annotated visual question answering <code>Graphcore/vqa</code> dataset.
You can find the full dataset on <a href="https://huggingface.co/datasets/Graphcore/vqa" rel="nofollow">🤗 Hub</a>.</p> <p data-svelte-h="svelte-b7b4c0">As an alternative to the <a href="https://huggingface.co/datasets/Graphcore/vqa" rel="nofollow"><code>Graphcore/vqa</code> dataset</a>, you can download the
same data manually from the official <a href="https://visualqa.org/download.html" rel="nofollow">VQA dataset page</a>. If you prefer to follow the
tutorial with your custom data, check out how to <a href="https://huggingface.co/docs/datasets/image_dataset#loading-script" rel="nofollow">Create an image dataset</a>
guide in the 🤗 Datasets documentation.</p> <p data-svelte-h="svelte-16i329z">Let’s load the first 200 examples from the validation split and explore the dataset’s features:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset = load_dataset(<span class="hljs-string">&quot;Graphcore/vqa&quot;</span>, split=<span class="hljs-string">&quot;validation[:200]&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset
Dataset({
features: [<span class="hljs-string">&#x27;question&#x27;</span>, <span class="hljs-string">&#x27;question_type&#x27;</span>, <span class="hljs-string">&#x27;question_id&#x27;</span>, <span class="hljs-string">&#x27;image_id&#x27;</span>, <span class="hljs-string">&#x27;answer_type&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>],
num_rows: <span class="hljs-number">200</span>
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10yp249">Let’s take a look at an example to understand the dataset’s features:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>dataset[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;question&#x27;</span>: <span class="hljs-string">&#x27;Where is he looking?&#x27;</span>,
<span class="hljs-string">&#x27;question_type&#x27;</span>: <span class="hljs-string">&#x27;none of the above&#x27;</span>,
<span class="hljs-string">&#x27;question_id&#x27;</span>: <span class="hljs-number">262148000</span>,
<span class="hljs-string">&#x27;image_id&#x27;</span>: <span class="hljs-string">&#x27;/root/.cache/huggingface/datasets/downloads/extracted/ca733e0e000fb2d7a09fbcc94dbfe7b5a30750681d0e965f8e0a23b1c2f98c75/val2014/COCO_val2014_000000262148.jpg&#x27;</span>,
<span class="hljs-string">&#x27;answer_type&#x27;</span>: <span class="hljs-string">&#x27;other&#x27;</span>,
<span class="hljs-string">&#x27;label&#x27;</span>: {<span class="hljs-string">&#x27;ids&#x27;</span>: [<span class="hljs-string">&#x27;at table&#x27;</span>, <span class="hljs-string">&#x27;down&#x27;</span>, <span class="hljs-string">&#x27;skateboard&#x27;</span>, <span class="hljs-string">&#x27;table&#x27;</span>],
<span class="hljs-string">&#x27;weights&#x27;</span>: [<span class="hljs-number">0.30000001192092896</span>,
<span class="hljs-number">1.0</span>,
<span class="hljs-number">0.30000001192092896</span>,
<span class="hljs-number">0.30000001192092896</span>]}}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1g3cog6">The features relevant to the task include:</p> <ul data-svelte-h="svelte-o5tqko"><li><code>question</code>: the question to be answered from the image</li> <li><code>image_id</code>: the path to the image the question refers to</li> <li><code>label</code>: the annotations</li></ul> <p data-svelte-h="svelte-73ooet">We can remove the rest of the features as they won’t be necessary:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>dataset = dataset.remove_columns([<span class="hljs-string">&#x27;question_type&#x27;</span>, <span class="hljs-string">&#x27;question_id&#x27;</span>, <span class="hljs-string">&#x27;answer_type&#x27;</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ozbkx3">As you can see, the <code>label</code> feature contains several answers to the same question (called <code>ids</code> here) collected by different human annotators.
This is because the answer to a question can be subjective. In this case, the question is “where is he looking?“. Some people
annotated this with “down”, others with “at table”, another one with “skateboard”, etc.</p> <p data-svelte-h="svelte-1iqjyc4">Take a look at the image and consider which answer would you give:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image
<span class="hljs-meta">&gt;&gt;&gt; </span>image = Image.<span class="hljs-built_in">open</span>(dataset[<span class="hljs-number">0</span>][<span class="hljs-string">&#x27;image_id&#x27;</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>image<!-- HTML_TAG_END --></pre></div> <div class="flex justify-center" data-svelte-h="svelte-1tjg4st"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/vqa-example.png" alt="VQA Image Example"></div> <p data-svelte-h="svelte-7dyn3e">Due to the questions’ and answers’ ambiguity, datasets like this are treated as a multi-label classification problem (as
multiple answers are possibly valid). Moreover, rather than just creating a one-hot encoded vector, one creates a
soft encoding, based on the number of times a certain answer appeared in the annotations.</p> <p data-svelte-h="svelte-1g9udhi">For instance, in the example above, because the answer “down” is selected way more often than other answers, it has a
score (called <code>weight</code> in the dataset) of 1.0, and the rest of the answers have scores &lt; 1.0.</p> <p data-svelte-h="svelte-11t4fgj">To later instantiate the model with an appropriate classification head, let’s create two dictionaries: one that maps
the label name to an integer and vice versa:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> itertools
<span class="hljs-meta">&gt;&gt;&gt; </span>labels = [item[<span class="hljs-string">&#x27;ids&#x27;</span>] <span class="hljs-keyword">for</span> item <span class="hljs-keyword">in</span> dataset[<span class="hljs-string">&#x27;label&#x27;</span>]]
<span class="hljs-meta">&gt;&gt;&gt; </span>flattened_labels = <span class="hljs-built_in">list</span>(itertools.chain(*labels))
<span class="hljs-meta">&gt;&gt;&gt; </span>unique_labels = <span class="hljs-built_in">list</span>(<span class="hljs-built_in">set</span>(flattened_labels))
<span class="hljs-meta">&gt;&gt;&gt; </span>label2id = {label: idx <span class="hljs-keyword">for</span> idx, label <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(unique_labels)}
<span class="hljs-meta">&gt;&gt;&gt; </span>id2label = {idx: label <span class="hljs-keyword">for</span> label, idx <span class="hljs-keyword">in</span> label2id.items()} <!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1p4inua">Now that we have the mappings, we can replace the string answers with their ids, and flatten the dataset for a more convenient further preprocessing.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">replace_ids</span>(<span class="hljs-params">inputs</span>):
<span class="hljs-meta">... </span> inputs[<span class="hljs-string">&quot;label&quot;</span>][<span class="hljs-string">&quot;ids&quot;</span>] = [label2id[x] <span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> inputs[<span class="hljs-string">&quot;label&quot;</span>][<span class="hljs-string">&quot;ids&quot;</span>]]
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> inputs
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset = dataset.<span class="hljs-built_in">map</span>(replace_ids)
<span class="hljs-meta">&gt;&gt;&gt; </span>flat_dataset = dataset.flatten()
<span class="hljs-meta">&gt;&gt;&gt; </span>flat_dataset.features
{<span class="hljs-string">&#x27;question&#x27;</span>: Value(dtype=<span class="hljs-string">&#x27;string&#x27;</span>, <span class="hljs-built_in">id</span>=<span class="hljs-literal">None</span>),
<span class="hljs-string">&#x27;image_id&#x27;</span>: Value(dtype=<span class="hljs-string">&#x27;string&#x27;</span>, <span class="hljs-built_in">id</span>=<span class="hljs-literal">None</span>),
<span class="hljs-string">&#x27;label.ids&#x27;</span>: <span class="hljs-type">Sequence</span>(feature=Value(dtype=<span class="hljs-string">&#x27;int64&#x27;</span>, <span class="hljs-built_in">id</span>=<span class="hljs-literal">None</span>), length=-<span class="hljs-number">1</span>, <span class="hljs-built_in">id</span>=<span class="hljs-literal">None</span>),
<span class="hljs-string">&#x27;label.weights&#x27;</span>: <span class="hljs-type">Sequence</span>(feature=Value(dtype=<span class="hljs-string">&#x27;float64&#x27;</span>, <span class="hljs-built_in">id</span>=<span class="hljs-literal">None</span>), length=-<span class="hljs-number">1</span>, <span class="hljs-built_in">id</span>=<span class="hljs-literal">None</span>)}<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="preprocessing-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preprocessing-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preprocessing data</span></h2> <p data-svelte-h="svelte-md8uyj">The next step is to load a ViLT processor to prepare the image and text data for the model.
<a href="/docs/transformers/main/en/model_doc/vilt#transformers.ViltProcessor">ViltProcessor</a> wraps a BERT tokenizer and ViLT image processor into a convenient single processor:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> ViltProcessor
<span class="hljs-meta">&gt;&gt;&gt; </span>processor = ViltProcessor.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-vdqbed">To preprocess the data we need to encode the images and questions using the <a href="/docs/transformers/main/en/model_doc/vilt#transformers.ViltProcessor">ViltProcessor</a>. The processor will use
the <a href="/docs/transformers/main/en/model_doc/bert#transformers.BertTokenizerFast">BertTokenizerFast</a> to tokenize the text and create <code>input_ids</code>, <code>attention_mask</code> and <code>token_type_ids</code> for the text data.
As for images, the processor will leverage <a href="/docs/transformers/main/en/model_doc/vilt#transformers.ViltImageProcessor">ViltImageProcessor</a> to resize and normalize the image, and create <code>pixel_values</code> and <code>pixel_mask</code>.</p> <p data-svelte-h="svelte-kuuiuf">All these preprocessing steps are done under the hood, we only need to call the <code>processor</code>. However, we still need to
prepare the target labels. In this representation, each element corresponds to a possible answer (label). For correct answers, the element holds
their respective score (weight), while the remaining elements are set to zero.</p> <p data-svelte-h="svelte-cgn1al">The following function applies the <code>processor</code> to the images and questions and formats the labels as described above:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_data</span>(<span class="hljs-params">examples</span>):
<span class="hljs-meta">... </span> image_paths = examples[<span class="hljs-string">&#x27;image_id&#x27;</span>]
<span class="hljs-meta">... </span> images = [Image.<span class="hljs-built_in">open</span>(image_path) <span class="hljs-keyword">for</span> image_path <span class="hljs-keyword">in</span> image_paths]
<span class="hljs-meta">... </span> texts = examples[<span class="hljs-string">&#x27;question&#x27;</span>]
<span class="hljs-meta">... </span> encoding = processor(images, texts, padding=<span class="hljs-string">&quot;max_length&quot;</span>, truncation=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>)
<span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> encoding.items():
<span class="hljs-meta">... </span> encoding[k] = v.squeeze()
<span class="hljs-meta">... </span> targets = []
<span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> labels, scores <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(examples[<span class="hljs-string">&#x27;label.ids&#x27;</span>], examples[<span class="hljs-string">&#x27;label.weights&#x27;</span>]):
<span class="hljs-meta">... </span> target = torch.zeros(<span class="hljs-built_in">len</span>(id2label))
<span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> label, score <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(labels, scores):
<span class="hljs-meta">... </span> target[label] = score
<span class="hljs-meta">... </span> targets.append(target)
<span class="hljs-meta">... </span> encoding[<span class="hljs-string">&quot;labels&quot;</span>] = targets
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> encoding<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1og8o1g">To apply the preprocessing function over the entire dataset, use 🤗 Datasets <code>map</code> function. You can speed up <code>map</code> by
setting <code>batched=True</code> to process multiple elements of the dataset at once. At this point, feel free to remove the columns you don’t need.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>processed_dataset = flat_dataset.<span class="hljs-built_in">map</span>(preprocess_data, batched=<span class="hljs-literal">True</span>, remove_columns=[<span class="hljs-string">&#x27;question&#x27;</span>,<span class="hljs-string">&#x27;question_type&#x27;</span>, <span class="hljs-string">&#x27;question_id&#x27;</span>, <span class="hljs-string">&#x27;image_id&#x27;</span>, <span class="hljs-string">&#x27;answer_type&#x27;</span>, <span class="hljs-string">&#x27;label.ids&#x27;</span>, <span class="hljs-string">&#x27;label.weights&#x27;</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>processed_dataset
Dataset({
features: [<span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;token_type_ids&#x27;</span>, <span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;pixel_values&#x27;</span>, <span class="hljs-string">&#x27;pixel_mask&#x27;</span>, <span class="hljs-string">&#x27;labels&#x27;</span>],
num_rows: <span class="hljs-number">200</span>
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1jw8zeo">As a final step, create a batch of examples using <a href="/docs/transformers/main/en/main_classes/data_collator#transformers.DefaultDataCollator">DefaultDataCollator</a>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> DefaultDataCollator
<span class="hljs-meta">&gt;&gt;&gt; </span>data_collator = DefaultDataCollator()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="train-the-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#train-the-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Train the model</span></h2> <p data-svelte-h="svelte-1iw20rl">You’re ready to start training your model now! Load ViLT with <a href="/docs/transformers/main/en/model_doc/vilt#transformers.ViltForQuestionAnswering">ViltForQuestionAnswering</a>. Specify the number of labels
along with the label mappings:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> ViltForQuestionAnswering
<span class="hljs-meta">&gt;&gt;&gt; </span>model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=<span class="hljs-built_in">len</span>(id2label), id2label=id2label, label2id=label2id)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-l42k0i">At this point, only three steps remain:</p> <ol data-svelte-h="svelte-8gsikg"><li>Define your training hyperparameters in <a href="/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>:</li></ol> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
<span class="hljs-meta">&gt;&gt;&gt; </span>repo_id = <span class="hljs-string">&quot;MariaK/vilt_finetuned_200&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>training_args = TrainingArguments(
<span class="hljs-meta">... </span> output_dir=repo_id,
<span class="hljs-meta">... </span> per_device_train_batch_size=<span class="hljs-number">4</span>,
<span class="hljs-meta">... </span> num_train_epochs=<span class="hljs-number">20</span>,
<span class="hljs-meta">... </span> save_steps=<span class="hljs-number">200</span>,
<span class="hljs-meta">... </span> logging_steps=<span class="hljs-number">50</span>,
<span class="hljs-meta">... </span> learning_rate=<span class="hljs-number">5e-5</span>,
<span class="hljs-meta">... </span> save_total_limit=<span class="hljs-number">2</span>,
<span class="hljs-meta">... </span> remove_unused_columns=<span class="hljs-literal">False</span>,
<span class="hljs-meta">... </span> push_to_hub=<span class="hljs-literal">True</span>,
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <ol start="2" data-svelte-h="svelte-6ix8zl"><li>Pass the training arguments to <a href="/docs/transformers/main/en/main_classes/trainer#transformers.Trainer">Trainer</a> along with the model, dataset, processor, and data collator.</li></ol> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
<span class="hljs-meta">&gt;&gt;&gt; </span>trainer = Trainer(
<span class="hljs-meta">... </span> model=model,
<span class="hljs-meta">... </span> args=training_args,
<span class="hljs-meta">... </span> data_collator=data_collator,
<span class="hljs-meta">... </span> train_dataset=processed_dataset,
<span class="hljs-meta">... </span> tokenizer=processor,
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <ol start="3" data-svelte-h="svelte-18cisw2"><li>Call <a href="/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.train">train()</a> to finetune your model.</li></ol> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>trainer.train() <!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-13v7w4m">Once training is completed, share your model to the Hub with the <a href="/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.push_to_hub">push_to_hub()</a> method to share your final model on the 🤗 Hub:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>trainer.push_to_hub()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Inference</span></h2> <p data-svelte-h="svelte-y390us">Now that you have fine-tuned a ViLT model, and uploaded it to the 🤗 Hub, you can use it for inference. The simplest
way to try out your fine-tuned model for inference is to use it in a <a href="/docs/transformers/main/en/main_classes/pipelines#transformers.Pipeline">Pipeline</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> pipeline
<span class="hljs-meta">&gt;&gt;&gt; </span>pipe = pipeline(<span class="hljs-string">&quot;visual-question-answering&quot;</span>, model=<span class="hljs-string">&quot;MariaK/vilt_finetuned_200&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1lvulzt">The model in this guide has only been trained on 200 examples, so don’t expect a lot from it. Let’s see if it at least
learned something from the data and take the first example from the dataset to illustrate inference:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>example = dataset[<span class="hljs-number">0</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>image = Image.<span class="hljs-built_in">open</span>(example[<span class="hljs-string">&#x27;image_id&#x27;</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>question = example[<span class="hljs-string">&#x27;question&#x27;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(question)
<span class="hljs-meta">&gt;&gt;&gt; </span>pipe(image, question, top_k=<span class="hljs-number">1</span>)
<span class="hljs-string">&quot;Where is he looking?&quot;</span>
[{<span class="hljs-string">&#x27;score&#x27;</span>: <span class="hljs-number">0.5498199462890625</span>, <span class="hljs-string">&#x27;answer&#x27;</span>: <span class="hljs-string">&#x27;down&#x27;</span>}]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mc0cvg">Even though not very confident, the model indeed has learned something. With more examples and longer training, you’ll get far better results!</p> <p data-svelte-h="svelte-o6117l">You can also manually replicate the results of the pipeline if you’d like:</p> <ol data-svelte-h="svelte-346f3f"><li>Take an image and a question, prepare them for the model using the processor from your model.</li> <li>Forward the result or preprocessing through the model.</li> <li>From the logits, get the most likely answer’s id, and find the actual answer in the <code>id2label</code>.</li></ol> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>processor = ViltProcessor.from_pretrained(<span class="hljs-string">&quot;MariaK/vilt_finetuned_200&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>image = Image.<span class="hljs-built_in">open</span>(example[<span class="hljs-string">&#x27;image_id&#x27;</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>question = example[<span class="hljs-string">&#x27;question&#x27;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># prepare inputs</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = processor(image, question, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = ViltForQuestionAnswering.from_pretrained(<span class="hljs-string">&quot;MariaK/vilt_finetuned_200&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># forward pass</span>
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">with</span> torch.no_grad():
<span class="hljs-meta">... </span> outputs = model(**inputs)
<span class="hljs-meta">&gt;&gt;&gt; </span>logits = outputs.logits
<span class="hljs-meta">&gt;&gt;&gt; </span>idx = logits.argmax(-<span class="hljs-number">1</span>).item()
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Predicted answer:&quot;</span>, model.config.id2label[idx])
Predicted answer: down<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="zero-shot-vqa" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#zero-shot-vqa"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Zero-shot VQA</span></h2> <p data-svelte-h="svelte-4nhby3">The previous model treated VQA as a classification task. Some recent models, such as BLIP, BLIP-2, and InstructBLIP approach
VQA as a generative task. Let’s take <a href="../model_doc/blip-2">BLIP-2</a> as an example. It introduced a new visual-language pre-training
paradigm in which any combination of pre-trained vision encoder and LLM can be used (learn more in the <a href="https://huggingface.co/blog/blip-2" rel="nofollow">BLIP-2 blog post</a>).
This enables achieving state-of-the-art results on multiple visual-language tasks including visual question answering.</p> <p data-svelte-h="svelte-eq4cdk">Let’s illustrate how you can use this model for VQA. First, let’s load the model. Here we’ll explicitly send the model to a
GPU, if available, which we didn’t need to do earlier when training, as <a href="/docs/transformers/main/en/main_classes/trainer#transformers.Trainer">Trainer</a> handles this automatically:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoProcessor, Blip2ForConditionalGeneration
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span>processor = AutoProcessor.from_pretrained(<span class="hljs-string">&quot;Salesforce/blip2-opt-2.7b&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = Blip2ForConditionalGeneration.from_pretrained(<span class="hljs-string">&quot;Salesforce/blip2-opt-2.7b&quot;</span>, torch_dtype=torch.float16)
<span class="hljs-meta">&gt;&gt;&gt; </span>device = <span class="hljs-string">&quot;cuda&quot;</span> <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> <span class="hljs-string">&quot;cpu&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>model.to(device)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-w113ee">The model takes image and text as input, so let’s use the exact same image/question pair from the first example in the VQA dataset:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>example = dataset[<span class="hljs-number">0</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>image = Image.<span class="hljs-built_in">open</span>(example[<span class="hljs-string">&#x27;image_id&#x27;</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>question = example[<span class="hljs-string">&#x27;question&#x27;</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-n2zykh">To use BLIP-2 for visual question answering task, the textual prompt has to follow a specific format: <code>Question: {} Answer:</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>prompt = <span class="hljs-string">f&quot;Question: <span class="hljs-subst">{question}</span> Answer:&quot;</span> <!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10l7d2y">Now we need to preprocess the image/prompt with the model’s processor, pass the processed input through the model, and decode the output:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>inputs = processor(image, text=prompt, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(device, torch.float16)
<span class="hljs-meta">&gt;&gt;&gt; </span>generated_ids = model.generate(**inputs, max_new_tokens=<span class="hljs-number">10</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>generated_text = processor.batch_decode(generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>].strip()
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(generated_text)
<span class="hljs-string">&quot;He is looking at the crowd&quot;</span> <!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1wjg6co">As you can see, the model recognized the crowd, and the direction of the face (looking down), however, it seems to miss
the fact the crowd is behind the skater. Still, in cases where acquiring human-annotated datasets is not feasible, this
approach can quickly produce useful results.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/tasks/visual_question_answering.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1xexzbk = {
assets: "/docs/transformers/main/en",
base: "/docs/transformers/main/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js"),
import("/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 425],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
73 kB
·
Xet hash:
0b3db472268a8476bd94d2f149de120347e9b96f1a9eb215e0979163231e09f5

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.