Buckets:

hf-doc-build/doc-dev / transformers /main /en /tasks /video_classification.html
rtrm's picture
download
raw
88.7 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Video classification&quot;,&quot;local&quot;:&quot;video-classification&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Load UCF101 dataset&quot;,&quot;local&quot;:&quot;load-ucf101-dataset&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Load a model to fine-tune&quot;,&quot;local&quot;:&quot;load-a-model-to-fine-tune&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Prepare the datasets for training&quot;,&quot;local&quot;:&quot;prepare-the-datasets-for-training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Visualize the preprocessed video for better debugging&quot;,&quot;local&quot;:&quot;visualize-the-preprocessed-video-for-better-debugging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Train the model&quot;,&quot;local&quot;:&quot;train-the-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/scheduler.25b97de1.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/singletons.0f2b7d5f.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.e188933d.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/paths.3d04d2c6.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.d9030fc9.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/0.026d2fdd.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/423.d64aacee.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/Tip.baa67368.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/DocNotebookDropdown.5ea6cb78.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/globals.7f7f1b26.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/EditOnGithub.91d95064.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Video classification&quot;,&quot;local&quot;:&quot;video-classification&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Load UCF101 dataset&quot;,&quot;local&quot;:&quot;load-ucf101-dataset&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Load a model to fine-tune&quot;,&quot;local&quot;:&quot;load-a-model-to-fine-tune&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Prepare the datasets for training&quot;,&quot;local&quot;:&quot;prepare-the-datasets-for-training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Visualize the preprocessed video for better debugging&quot;,&quot;local&quot;:&quot;visualize-the-preprocessed-video-for-better-debugging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Train the model&quot;,&quot;local&quot;:&quot;train-the-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="video-classification" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#video-classification"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Video classification</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"> </button> </div> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"> </button> </div></div> <p data-svelte-h="svelte-1uw5n59">Video classification is the task of assigning a label or class to an entire video. Videos are expected to have only one class for each video. Video classification models take a video as input and return a prediction about which class the video belongs to. These models can be used to categorize what a video is all about. A real-world application of video classification is action / activity recognition, which is useful for fitness applications. It is also helpful for vision-impaired individuals, especially when they are commuting.</p> <p data-svelte-h="svelte-1aff4p7">This guide will show you how to:</p> <ol data-svelte-h="svelte-1qfvs25"><li>Fine-tune <a href="https://huggingface.co/docs/transformers/main/en/model_doc/videomae" rel="nofollow">VideoMAE</a> on a subset of the <a href="https://www.crcv.ucf.edu/data/UCF101.php" rel="nofollow">UCF101</a> dataset.</li> <li>Use your fine-tuned model for inference.</li></ol> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1cybrxm">To see all architectures and checkpoints compatible with this task, we recommend checking the <a href="https://huggingface.co/tasks/video-classification" rel="nofollow">task-page</a>.</p></div> <p data-svelte-h="svelte-1c9nexd">Before you begin, make sure you have all the necessary libraries installed:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install -q pytorchvideo transformers evaluate<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cnicg0">You will use <a href="https://pytorchvideo.org/" rel="nofollow">PyTorchVideo</a> (dubbed <code>pytorchvideo</code>) to process and prepare the videos.</p> <p data-svelte-h="svelte-27hn0u">We encourage you to log in to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to log in:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login
<span class="hljs-meta">&gt;&gt;&gt; </span>notebook_login()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="load-ucf101-dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#load-ucf101-dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Load UCF101 dataset</span></h2> <p data-svelte-h="svelte-zo4tpc">Start by loading a subset of the <a href="https://www.crcv.ucf.edu/data/UCF101.php" rel="nofollow">UCF-101 dataset</a>. This will give you a chance to experiment and make sure everything works before spending more time training on the full dataset.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> hf_hub_download
<span class="hljs-meta">&gt;&gt;&gt; </span>hf_dataset_identifier = <span class="hljs-string">&quot;sayakpaul/ucf101-subset&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>filename = <span class="hljs-string">&quot;UCF101_subset.tar.gz&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>file_path = hf_hub_download(repo_id=hf_dataset_identifier, filename=filename, repo_type=<span class="hljs-string">&quot;dataset&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-yyrc6l">After the subset has been downloaded, you need to extract the compressed archive:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> tarfile
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">with</span> tarfile.<span class="hljs-built_in">open</span>(file_path) <span class="hljs-keyword">as</span> t:
<span class="hljs-meta">... </span> t.extractall(<span class="hljs-string">&quot;.&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-vxwx6z">At a high level, the dataset is organized like so:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->UCF101_subset/
train/
BandMarching/
video_1.mp4
video_2.mp4
...
Archery
video_1.mp4
video_2.mp4
...
...
val/
BandMarching/
video_1.mp4
video_2.mp4
...
Archery
video_1.mp4
video_2.mp4
...
...
<span class="hljs-built_in">test</span>/
BandMarching/
video_1.mp4
video_2.mp4
...
Archery
video_1.mp4
video_2.mp4
...
...<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1q9jgqq">You can then count the number of total videos.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> pathlib
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset_root_path = <span class="hljs-string">&quot;UCF101_subset&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset_root_path = pathlib.Path(dataset_root_path)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>video_count_train = <span class="hljs-built_in">len</span>(<span class="hljs-built_in">list</span>(dataset_root_path.glob(<span class="hljs-string">&quot;train/*/*.avi&quot;</span>)))
<span class="hljs-meta">&gt;&gt;&gt; </span>video_count_val = <span class="hljs-built_in">len</span>(<span class="hljs-built_in">list</span>(dataset_root_path.glob(<span class="hljs-string">&quot;val/*/*.avi&quot;</span>)))
<span class="hljs-meta">&gt;&gt;&gt; </span>video_count_test = <span class="hljs-built_in">len</span>(<span class="hljs-built_in">list</span>(dataset_root_path.glob(<span class="hljs-string">&quot;test/*/*.avi&quot;</span>)))
<span class="hljs-meta">&gt;&gt;&gt; </span>video_total = video_count_train + video_count_val + video_count_test
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Total videos: <span class="hljs-subst">{video_total}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>all_video_file_paths = (
<span class="hljs-meta">... </span> <span class="hljs-built_in">list</span>(dataset_root_path.glob(<span class="hljs-string">&quot;train/*/*.avi&quot;</span>))
<span class="hljs-meta">... </span> + <span class="hljs-built_in">list</span>(dataset_root_path.glob(<span class="hljs-string">&quot;val/*/*.avi&quot;</span>))
<span class="hljs-meta">... </span> + <span class="hljs-built_in">list</span>(dataset_root_path.glob(<span class="hljs-string">&quot;test/*/*.avi&quot;</span>))
<span class="hljs-meta">... </span> )
<span class="hljs-meta">&gt;&gt;&gt; </span>all_video_file_paths[:<span class="hljs-number">5</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xl39ap">The (<code>sorted</code>) video paths appear like so:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->...
<span class="hljs-string">&#x27;UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c04.avi&#x27;</span>,
<span class="hljs-string">&#x27;UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c06.avi&#x27;</span>,
<span class="hljs-string">&#x27;UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01.avi&#x27;</span>,
<span class="hljs-string">&#x27;UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c02.avi&#x27;</span>,
<span class="hljs-string">&#x27;UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c06.avi&#x27;</span>
...<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1t7u230">You will notice that there are video clips belonging to the same group / scene where group is denoted by <code>g</code> in the video file paths. <code>v_ApplyEyeMakeup_g07_c04.avi</code> and <code>v_ApplyEyeMakeup_g07_c06.avi</code>, for example.</p> <p data-svelte-h="svelte-igo46q">For the validation and evaluation splits, you wouldn’t want to have video clips from the same group / scene to prevent <a href="https://www.kaggle.com/code/alexisbcook/data-leakage" rel="nofollow">data leakage</a>. The subset that you are using in this tutorial takes this information into account.</p> <p data-svelte-h="svelte-4ll1ff">Next up, you will derive the set of labels present in the dataset. Also, create two dictionaries that’ll be helpful when initializing the model:</p> <ul data-svelte-h="svelte-1y0n38a"><li><code>label2id</code>: maps the class names to integers.</li> <li><code>id2label</code>: maps the integers to class names.</li></ul> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>class_labels = <span class="hljs-built_in">sorted</span>({<span class="hljs-built_in">str</span>(path).split(<span class="hljs-string">&quot;/&quot;</span>)[<span class="hljs-number">2</span>] <span class="hljs-keyword">for</span> path <span class="hljs-keyword">in</span> all_video_file_paths})
<span class="hljs-meta">&gt;&gt;&gt; </span>label2id = {label: i <span class="hljs-keyword">for</span> i, label <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(class_labels)}
<span class="hljs-meta">&gt;&gt;&gt; </span>id2label = {i: label <span class="hljs-keyword">for</span> label, i <span class="hljs-keyword">in</span> label2id.items()}
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Unique classes: <span class="hljs-subst">{<span class="hljs-built_in">list</span>(label2id.keys())}</span>.&quot;</span>)
<span class="hljs-comment"># Unique classes: [&#x27;ApplyEyeMakeup&#x27;, &#x27;ApplyLipstick&#x27;, &#x27;Archery&#x27;, &#x27;BabyCrawling&#x27;, &#x27;BalanceBeam&#x27;, &#x27;BandMarching&#x27;, &#x27;BaseballPitch&#x27;, &#x27;Basketball&#x27;, &#x27;BasketballDunk&#x27;, &#x27;BenchPress&#x27;].</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1z0r2k5">There are 10 unique classes. For each class, there are 30 videos in the training set.</p> <h2 class="relative group"><a id="load-a-model-to-fine-tune" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#load-a-model-to-fine-tune"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Load a model to fine-tune</span></h2> <p data-svelte-h="svelte-14088fx">Instantiate a video classification model from a pretrained checkpoint and its associated image processor. The model’s encoder comes with pre-trained parameters, and the classification head is randomly initialized. The image processor will come in handy when writing the preprocessing pipeline for our dataset.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> VideoMAEImageProcessor, VideoMAEForVideoClassification
<span class="hljs-meta">&gt;&gt;&gt; </span>model_ckpt = <span class="hljs-string">&quot;MCG-NJU/videomae-base&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = VideoMAEForVideoClassification.from_pretrained(
<span class="hljs-meta">... </span> model_ckpt,
<span class="hljs-meta">... </span> label2id=label2id,
<span class="hljs-meta">... </span> id2label=id2label,
<span class="hljs-meta">... </span> ignore_mismatched_sizes=<span class="hljs-literal">True</span>, <span class="hljs-comment"># provide this in case you&#x27;re planning to fine-tune an already fine-tuned checkpoint</span>
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1rauzal">While the model is loading, you might notice the following warning:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Some weights of the model checkpoint at MCG-NJU/videomae-base were not used when initializing VideoMAEForVideoClassification: [..., <span class="hljs-string">&#x27;decoder.decoder_layers.1.attention.output.dense.bias&#x27;</span>, <span class="hljs-string">&#x27;decoder.decoder_layers.2.attention.attention.key.weight&#x27;</span>]
- This IS expected <span class="hljs-keyword">if</span> you are initializing VideoMAEForVideoClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected <span class="hljs-keyword">if</span> you are initializing VideoMAEForVideoClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VideoMAEForVideoClassification were not initialized from the model checkpoint at MCG-NJU/videomae-base and are newly initialized: [<span class="hljs-string">&#x27;classifier.bias&#x27;</span>, <span class="hljs-string">&#x27;classifier.weight&#x27;</span>]
You should probably TRAIN this model on a down-stream task to be able to use it <span class="hljs-keyword">for</span> predictions and inference.<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-19dkvgp">The warning is telling us we are throwing away some weights (e.g. the weights and bias of the <code>classifier</code> layer) and randomly initializing some others (the weights and bias of a new <code>classifier</code> layer). This is expected in this case, because we are adding a new head for which we don’t have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.</p> <p data-svelte-h="svelte-4nczcs"><strong>Note</strong> that <a href="https://huggingface.co/MCG-NJU/videomae-base-finetuned-kinetics" rel="nofollow">this checkpoint</a> leads to better performance on this task as the checkpoint was obtained by fine-tuning on a similar downstream task having considerable domain overlap. You can check out <a href="https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset" rel="nofollow">this checkpoint</a> which was obtained by fine-tuning <code>MCG-NJU/videomae-base-finetuned-kinetics</code>.</p> <h2 class="relative group"><a id="prepare-the-datasets-for-training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#prepare-the-datasets-for-training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Prepare the datasets for training</span></h2> <p data-svelte-h="svelte-islfqt">For preprocessing the videos, you will leverage the <a href="https://pytorchvideo.org/" rel="nofollow">PyTorchVideo library</a>. Start by importing the dependencies we need.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> pytorchvideo.data
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> pytorchvideo.transforms <span class="hljs-keyword">import</span> (
<span class="hljs-meta">... </span> ApplyTransformToKey,
<span class="hljs-meta">... </span> Normalize,
<span class="hljs-meta">... </span> RandomShortSideScale,
<span class="hljs-meta">... </span> RemoveKey,
<span class="hljs-meta">... </span> ShortSideScale,
<span class="hljs-meta">... </span> UniformTemporalSubsample,
<span class="hljs-meta">... </span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> torchvision.transforms <span class="hljs-keyword">import</span> (
<span class="hljs-meta">... </span> Compose,
<span class="hljs-meta">... </span> Lambda,
<span class="hljs-meta">... </span> RandomCrop,
<span class="hljs-meta">... </span> RandomHorizontalFlip,
<span class="hljs-meta">... </span> Resize,
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-5m4gle">For the training dataset transformations, use a combination of uniform temporal subsampling, pixel normalization, random cropping, and random horizontal flipping. For the validation and evaluation dataset transformations, keep the same transformation chain except for random cropping and horizontal flipping. To learn more about the details of these transformations check out the <a href="https://pytorchvideo.org" rel="nofollow">official documentation of PyTorchVideo</a>.</p> <p data-svelte-h="svelte-8w7a7w">Use the <code>image_processor</code> associated with the pre-trained model to obtain the following information:</p> <ul data-svelte-h="svelte-u2neln"><li>Image mean and standard deviation with which the video frame pixels will be normalized.</li> <li>Spatial resolution to which the video frames will be resized.</li></ul> <p data-svelte-h="svelte-llv4fi">Start by defining some constants.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>mean = image_processor.image_mean
<span class="hljs-meta">&gt;&gt;&gt; </span>std = image_processor.image_std
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">if</span> <span class="hljs-string">&quot;shortest_edge&quot;</span> <span class="hljs-keyword">in</span> image_processor.size:
<span class="hljs-meta">... </span> height = width = image_processor.size[<span class="hljs-string">&quot;shortest_edge&quot;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">else</span>:
<span class="hljs-meta">... </span> height = image_processor.size[<span class="hljs-string">&quot;height&quot;</span>]
<span class="hljs-meta">... </span> width = image_processor.size[<span class="hljs-string">&quot;width&quot;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>resize_to = (height, width)
<span class="hljs-meta">&gt;&gt;&gt; </span>num_frames_to_sample = model.config.num_frames
<span class="hljs-meta">&gt;&gt;&gt; </span>sample_rate = <span class="hljs-number">4</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>fps = <span class="hljs-number">30</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>clip_duration = num_frames_to_sample * sample_rate / fps<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1lr33l4">Now, define the dataset-specific transformations and the datasets respectively. Starting with the training set:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>train_transform = Compose(
<span class="hljs-meta">... </span> [
<span class="hljs-meta">... </span> ApplyTransformToKey(
<span class="hljs-meta">... </span> key=<span class="hljs-string">&quot;video&quot;</span>,
<span class="hljs-meta">... </span> transform=Compose(
<span class="hljs-meta">... </span> [
<span class="hljs-meta">... </span> UniformTemporalSubsample(num_frames_to_sample),
<span class="hljs-meta">... </span> Lambda(<span class="hljs-keyword">lambda</span> x: x / <span class="hljs-number">255.0</span>),
<span class="hljs-meta">... </span> Normalize(mean, std),
<span class="hljs-meta">... </span> RandomShortSideScale(min_size=<span class="hljs-number">256</span>, max_size=<span class="hljs-number">320</span>),
<span class="hljs-meta">... </span> RandomCrop(resize_to),
<span class="hljs-meta">... </span> RandomHorizontalFlip(p=<span class="hljs-number">0.5</span>),
<span class="hljs-meta">... </span> ]
<span class="hljs-meta">... </span> ),
<span class="hljs-meta">... </span> ),
<span class="hljs-meta">... </span> ]
<span class="hljs-meta">... </span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>train_dataset = pytorchvideo.data.Ucf101(
<span class="hljs-meta">... </span> data_path=os.path.join(dataset_root_path, <span class="hljs-string">&quot;train&quot;</span>),
<span class="hljs-meta">... </span> clip_sampler=pytorchvideo.data.make_clip_sampler(<span class="hljs-string">&quot;random&quot;</span>, clip_duration),
<span class="hljs-meta">... </span> decode_audio=<span class="hljs-literal">False</span>,
<span class="hljs-meta">... </span> transform=train_transform,
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-92qkhv">The same sequence of workflow can be applied to the validation and evaluation sets:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>val_transform = Compose(
<span class="hljs-meta">... </span> [
<span class="hljs-meta">... </span> ApplyTransformToKey(
<span class="hljs-meta">... </span> key=<span class="hljs-string">&quot;video&quot;</span>,
<span class="hljs-meta">... </span> transform=Compose(
<span class="hljs-meta">... </span> [
<span class="hljs-meta">... </span> UniformTemporalSubsample(num_frames_to_sample),
<span class="hljs-meta">... </span> Lambda(<span class="hljs-keyword">lambda</span> x: x / <span class="hljs-number">255.0</span>),
<span class="hljs-meta">... </span> Normalize(mean, std),
<span class="hljs-meta">... </span> Resize(resize_to),
<span class="hljs-meta">... </span> ]
<span class="hljs-meta">... </span> ),
<span class="hljs-meta">... </span> ),
<span class="hljs-meta">... </span> ]
<span class="hljs-meta">... </span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>val_dataset = pytorchvideo.data.Ucf101(
<span class="hljs-meta">... </span> data_path=os.path.join(dataset_root_path, <span class="hljs-string">&quot;val&quot;</span>),
<span class="hljs-meta">... </span> clip_sampler=pytorchvideo.data.make_clip_sampler(<span class="hljs-string">&quot;uniform&quot;</span>, clip_duration),
<span class="hljs-meta">... </span> decode_audio=<span class="hljs-literal">False</span>,
<span class="hljs-meta">... </span> transform=val_transform,
<span class="hljs-meta">... </span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>test_dataset = pytorchvideo.data.Ucf101(
<span class="hljs-meta">... </span> data_path=os.path.join(dataset_root_path, <span class="hljs-string">&quot;test&quot;</span>),
<span class="hljs-meta">... </span> clip_sampler=pytorchvideo.data.make_clip_sampler(<span class="hljs-string">&quot;uniform&quot;</span>, clip_duration),
<span class="hljs-meta">... </span> decode_audio=<span class="hljs-literal">False</span>,
<span class="hljs-meta">... </span> transform=val_transform,
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lyltai"><strong>Note</strong>: The above dataset pipelines are taken from the <a href="https://pytorchvideo.org/docs/tutorial_classification#dataset" rel="nofollow">official PyTorchVideo example</a>. We’re using the <a href="https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html#pytorchvideo.data.Ucf101" rel="nofollow"><code>pytorchvideo.data.Ucf101()</code></a> function because it’s tailored for the UCF-101 dataset. Under the hood, it returns a <a href="https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html#pytorchvideo.data.LabeledVideoDataset" rel="nofollow"><code>pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset</code></a> object. <code>LabeledVideoDataset</code> class is the base class for all things video in the PyTorchVideo dataset. So, if you want to use a custom dataset not supported off-the-shelf by PyTorchVideo, you can extend the <code>LabeledVideoDataset</code> class accordingly. Refer to the <code>data</code> API <a href="https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html" rel="nofollow">documentation to</a> learn more. Also, if your dataset follows a similar structure (as shown above), then using the <code>pytorchvideo.data.Ucf101()</code> should work just fine.</p> <p data-svelte-h="svelte-1vli4t8">You can access the <code>num_videos</code> argument to know the number of videos in the dataset.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos)
<span class="hljs-comment"># (300, 30, 75)</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="visualize-the-preprocessed-video-for-better-debugging" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#visualize-the-preprocessed-video-for-better-debugging"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Visualize the preprocessed video for better debugging</span></h2> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> imageio
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> IPython.display <span class="hljs-keyword">import</span> Image
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">unnormalize_img</span>(<span class="hljs-params">img</span>):
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;&quot;&quot;Un-normalizes the image pixels.&quot;&quot;&quot;</span>
<span class="hljs-meta">... </span> img = (img * std) + mean
<span class="hljs-meta">... </span> img = (img * <span class="hljs-number">255</span>).astype(<span class="hljs-string">&quot;uint8&quot;</span>)
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> img.clip(<span class="hljs-number">0</span>, <span class="hljs-number">255</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">create_gif</span>(<span class="hljs-params">video_tensor, filename=<span class="hljs-string">&quot;sample.gif&quot;</span></span>):
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;&quot;&quot;Prepares a GIF from a video tensor.
<span class="hljs-meta">... </span>
<span class="hljs-meta">... </span> The video tensor is expected to have the following shape:
<span class="hljs-meta">... </span> (num_frames, num_channels, height, width).
<span class="hljs-meta">... </span> &quot;&quot;&quot;</span>
<span class="hljs-meta">... </span> frames = []
<span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> video_frame <span class="hljs-keyword">in</span> video_tensor:
<span class="hljs-meta">... </span> frame_unnormalized = unnormalize_img(video_frame.permute(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">0</span>).numpy())
<span class="hljs-meta">... </span> frames.append(frame_unnormalized)
<span class="hljs-meta">... </span> kargs = {<span class="hljs-string">&quot;duration&quot;</span>: <span class="hljs-number">0.25</span>}
<span class="hljs-meta">... </span> imageio.mimsave(filename, frames, <span class="hljs-string">&quot;GIF&quot;</span>, **kargs)
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> filename
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">display_gif</span>(<span class="hljs-params">video_tensor, gif_name=<span class="hljs-string">&quot;sample.gif&quot;</span></span>):
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;&quot;&quot;Prepares and displays a GIF from a video tensor.&quot;&quot;&quot;</span>
<span class="hljs-meta">... </span> video_tensor = video_tensor.permute(<span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>)
<span class="hljs-meta">... </span> gif_filename = create_gif(video_tensor, gif_name)
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> Image(filename=gif_filename)
<span class="hljs-meta">&gt;&gt;&gt; </span>sample_video = <span class="hljs-built_in">next</span>(<span class="hljs-built_in">iter</span>(train_dataset))
<span class="hljs-meta">&gt;&gt;&gt; </span>video_tensor = sample_video[<span class="hljs-string">&quot;video&quot;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>display_gif(video_tensor)<!-- HTML_TAG_END --></pre></div> <div class="flex justify-center" data-svelte-h="svelte-1mxsghh"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_gif.gif" alt="Person playing basketball"></div> <h2 class="relative group"><a id="train-the-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#train-the-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Train the model</span></h2> <p data-svelte-h="svelte-rirkpj">Leverage <a href="https://huggingface.co/docs/transformers/main_classes/trainer" rel="nofollow"><code>Trainer</code></a> from 🤗 Transformers for training the model. To instantiate a <code>Trainer</code>, you need to define the training configuration and an evaluation metric. The most important is the <a href="https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments" rel="nofollow"><code>TrainingArguments</code></a>, which is a class that contains all the attributes to configure the training. It requires an output folder name, which will be used to save the checkpoints of the model. It also helps sync all the information in the model repository on 🤗 Hub.</p> <p data-svelte-h="svelte-ixq0kp">Most of the training arguments are self-explanatory, but one that is quite important here is <code>remove_unused_columns=False</code>. This one will drop any features not used by the model’s call function. By default it’s <code>True</code> because usually it’s ideal to drop unused feature columns, making it easier to unpack inputs into the model’s call function. But, in this case, you need the unused features (‘video’ in particular) in order to create <code>pixel_values</code> (which is a mandatory key our model expects in its inputs).</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, Trainer
<span class="hljs-meta">&gt;&gt;&gt; </span>model_name = model_ckpt.split(<span class="hljs-string">&quot;/&quot;</span>)[-<span class="hljs-number">1</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>new_model_name = <span class="hljs-string">f&quot;<span class="hljs-subst">{model_name}</span>-finetuned-ucf101-subset&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>num_epochs = <span class="hljs-number">4</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>args = TrainingArguments(
<span class="hljs-meta">... </span> new_model_name,
<span class="hljs-meta">... </span> remove_unused_columns=<span class="hljs-literal">False</span>,
<span class="hljs-meta">... </span> eval_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
<span class="hljs-meta">... </span> save_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
<span class="hljs-meta">... </span> learning_rate=<span class="hljs-number">5e-5</span>,
<span class="hljs-meta">... </span> per_device_train_batch_size=batch_size,
<span class="hljs-meta">... </span> per_device_eval_batch_size=batch_size,
<span class="hljs-meta">... </span> warmup_ratio=<span class="hljs-number">0.1</span>,
<span class="hljs-meta">... </span> logging_steps=<span class="hljs-number">10</span>,
<span class="hljs-meta">... </span> load_best_model_at_end=<span class="hljs-literal">True</span>,
<span class="hljs-meta">... </span> metric_for_best_model=<span class="hljs-string">&quot;accuracy&quot;</span>,
<span class="hljs-meta">... </span> push_to_hub=<span class="hljs-literal">True</span>,
<span class="hljs-meta">... </span> max_steps=(train_dataset.num_videos // batch_size) * num_epochs,
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1v50sum">The dataset returned by <code>pytorchvideo.data.Ucf101()</code> doesn’t implement the <code>__len__</code> method. As such, we must define <code>max_steps</code> when instantiating <code>TrainingArguments</code>.</p> <p data-svelte-h="svelte-1kbaooa">Next, you need to define a function to compute the metrics from the predictions, which will use the <code>metric</code> you’ll load now. The only preprocessing you have to do is to take the argmax of our predicted logits:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> evaluate
metric = evaluate.load(<span class="hljs-string">&quot;accuracy&quot;</span>)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_pred</span>):
predictions = np.argmax(eval_pred.predictions, axis=<span class="hljs-number">1</span>)
<span class="hljs-keyword">return</span> metric.compute(predictions=predictions, references=eval_pred.label_ids)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1yc7v5f"><strong>A note on evaluation</strong>:</p> <p data-svelte-h="svelte-9bk5j6">In the <a href="https://arxiv.org/abs/2203.12602" rel="nofollow">VideoMAE paper</a>, the authors use the following evaluation strategy. They evaluate the model on several clips from test videos and apply different crops to those clips and report the aggregate score. However, in the interest of simplicity and brevity, we don’t consider that in this tutorial.</p> <p data-svelte-h="svelte-1csqroh">Also, define a <code>collate_fn</code>, which will be used to batch examples together. Each batch consists of 2 keys, namely <code>pixel_values</code> and <code>labels</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">collate_fn</span>(<span class="hljs-params">examples</span>):
<span class="hljs-meta">... </span> <span class="hljs-comment"># permute to (num_frames, num_channels, height, width)</span>
<span class="hljs-meta">... </span> pixel_values = torch.stack(
<span class="hljs-meta">... </span> [example[<span class="hljs-string">&quot;video&quot;</span>].permute(<span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>) <span class="hljs-keyword">for</span> example <span class="hljs-keyword">in</span> examples]
<span class="hljs-meta">... </span> )
<span class="hljs-meta">... </span> labels = torch.tensor([example[<span class="hljs-string">&quot;label&quot;</span>] <span class="hljs-keyword">for</span> example <span class="hljs-keyword">in</span> examples])
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;pixel_values&quot;</span>: pixel_values, <span class="hljs-string">&quot;labels&quot;</span>: labels}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-16wl6hd">Then you just pass all of this along with the datasets to <code>Trainer</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>trainer = Trainer(
<span class="hljs-meta">... </span> model,
<span class="hljs-meta">... </span> args,
<span class="hljs-meta">... </span> train_dataset=train_dataset,
<span class="hljs-meta">... </span> eval_dataset=val_dataset,
<span class="hljs-meta">... </span> tokenizer=image_processor,
<span class="hljs-meta">... </span> compute_metrics=compute_metrics,
<span class="hljs-meta">... </span> data_collator=collate_fn,
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-g6wg27">You might wonder why you passed along the <code>image_processor</code> as a tokenizer when you preprocessed the data already. This is only to make sure the image processor configuration file (stored as JSON) will also be uploaded to the repo on the Hub.</p> <p data-svelte-h="svelte-umxz0w">Now fine-tune our model by calling the <code>train</code> method:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>train_results = trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1v13hlo">Once training is completed, share your model to the Hub with the <a href="/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.push_to_hub">push_to_hub()</a> method so everyone can use your model:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>trainer.push_to_hub()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Inference</span></h2> <p data-svelte-h="svelte-riodtu">Great, now that you have fine-tuned a model, you can use it for inference!</p> <p data-svelte-h="svelte-w1spga">Load a video for inference:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>sample_test_video = <span class="hljs-built_in">next</span>(<span class="hljs-built_in">iter</span>(test_dataset))<!-- HTML_TAG_END --></pre></div> <div class="flex justify-center" data-svelte-h="svelte-556htt"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_gif_two.gif" alt="Teams playing basketball"></div> <p data-svelte-h="svelte-e8z0ag">The simplest way to try out your fine-tuned model for inference is to use it in a <a href="https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.VideoClassificationPipeline" rel="nofollow"><code>pipeline</code></a>. Instantiate a <code>pipeline</code> for video classification with your model, and pass your video to it:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> pipeline
<span class="hljs-meta">&gt;&gt;&gt; </span>video_cls = pipeline(model=<span class="hljs-string">&quot;my_awesome_video_cls_model&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>video_cls(<span class="hljs-string">&quot;https://huggingface.co/datasets/sayakpaul/ucf101-subset/resolve/main/v_BasketballDunk_g14_c06.avi&quot;</span>)
[{<span class="hljs-string">&#x27;score&#x27;</span>: <span class="hljs-number">0.9272987842559814</span>, <span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;BasketballDunk&#x27;</span>},
{<span class="hljs-string">&#x27;score&#x27;</span>: <span class="hljs-number">0.017777055501937866</span>, <span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;BabyCrawling&#x27;</span>},
{<span class="hljs-string">&#x27;score&#x27;</span>: <span class="hljs-number">0.01663011871278286</span>, <span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;BalanceBeam&#x27;</span>},
{<span class="hljs-string">&#x27;score&#x27;</span>: <span class="hljs-number">0.009560945443809032</span>, <span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;BandMarching&#x27;</span>},
{<span class="hljs-string">&#x27;score&#x27;</span>: <span class="hljs-number">0.0068979403004050255</span>, <span class="hljs-string">&#x27;label&#x27;</span>: <span class="hljs-string">&#x27;BaseballPitch&#x27;</span>}]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1j33lbi">You can also manually replicate the results of the <code>pipeline</code> if you’d like.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">run_inference</span>(<span class="hljs-params">model, video</span>):
<span class="hljs-meta">... </span> <span class="hljs-comment"># (num_frames, num_channels, height, width)</span>
<span class="hljs-meta">... </span> perumuted_sample_test_video = video.permute(<span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>)
<span class="hljs-meta">... </span> inputs = {
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;pixel_values&quot;</span>: perumuted_sample_test_video.unsqueeze(<span class="hljs-number">0</span>),
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;labels&quot;</span>: torch.tensor(
<span class="hljs-meta">... </span> [sample_test_video[<span class="hljs-string">&quot;label&quot;</span>]]
<span class="hljs-meta">... </span> ), <span class="hljs-comment"># this can be skipped if you don&#x27;t have labels available.</span>
<span class="hljs-meta">... </span> }
<span class="hljs-meta">... </span> device = torch.device(<span class="hljs-string">&quot;cuda&quot;</span> <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> <span class="hljs-string">&quot;cpu&quot;</span>)
<span class="hljs-meta">... </span> inputs = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> inputs.items()}
<span class="hljs-meta">... </span> model = model.to(device)
<span class="hljs-meta">... </span> <span class="hljs-comment"># forward pass</span>
<span class="hljs-meta">... </span> <span class="hljs-keyword">with</span> torch.no_grad():
<span class="hljs-meta">... </span> outputs = model(**inputs)
<span class="hljs-meta">... </span> logits = outputs.logits
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> logits<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-12olihs">Now, pass your input to the model and return the <code>logits</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>logits = run_inference(trained_model, sample_test_video[<span class="hljs-string">&quot;video&quot;</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1v8qszj">Decoding the <code>logits</code>, we get:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>predicted_class_idx = logits.argmax(-<span class="hljs-number">1</span>).item()
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Predicted class:&quot;</span>, model.config.id2label[predicted_class_idx])
<span class="hljs-comment"># Predicted class: BasketballDunk</span><!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/tasks/video_classification.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1xexzbk = {
assets: "/docs/transformers/main/en",
base: "/docs/transformers/main/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js"),
import("/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 423],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
88.7 kB
·
Xet hash:
f2517beedf0364bb4d6f541018d714d182ae1355793ac643d7f3178894d352bc

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.