deepspeed / transformers /docs /source /ko /tasks /video_classification.md
xingzhikb's picture
init
002bd9b
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
โš ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# ์˜์ƒ ๋ถ„๋ฅ˜ [[video-classification]]
[[open-in-colab]]
์˜์ƒ ๋ถ„๋ฅ˜๋Š” ์˜์ƒ ์ „์ฒด์— ๋ ˆ์ด๋ธ” ๋˜๋Š” ํด๋ž˜์Šค๋ฅผ ์ง€์ •ํ•˜๋Š” ์ž‘์—…์ž…๋‹ˆ๋‹ค. ๊ฐ ์˜์ƒ์—๋Š” ํ•˜๋‚˜์˜ ํด๋ž˜์Šค๊ฐ€ ์žˆ์„ ๊ฒƒ์œผ๋กœ ์˜ˆ์ƒ๋ฉ๋‹ˆ๋‹ค. ์˜์ƒ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์€ ์˜์ƒ์„ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„ ์–ด๋А ํด๋ž˜์Šค์— ์†ํ•˜๋Š”์ง€์— ๋Œ€ํ•œ ์˜ˆ์ธก์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์€ ์˜์ƒ์ด ์–ด๋–ค ๋‚ด์šฉ์ธ์ง€ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜์ƒ ๋ถ„๋ฅ˜์˜ ์‹ค์ œ ์‘์šฉ ์˜ˆ๋Š” ํ”ผํŠธ๋‹ˆ์Šค ์•ฑ์—์„œ ์œ ์šฉํ•œ ๋™์ž‘ / ์šด๋™ ์ธ์‹ ์„œ๋น„์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๋˜ํ•œ ์‹œ๊ฐ ์žฅ์• ์ธ์ด ์ด๋™ํ•  ๋•Œ ๋ณด์กฐํ•˜๋Š”๋ฐ ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋‹ค์Œ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:
1. [UCF101](https://www.crcv.ucf.edu/data/UCF101.php) ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํ•˜์œ„ ์ง‘ํ•ฉ์„ ํ†ตํ•ด [VideoMAE](https://huggingface.co/docs/transformers/main/en/model_doc/videomae) ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ.
2. ๋ฏธ์„ธ ์กฐ์ •ํ•œ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•˜๊ธฐ.
<Tip>
์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์„ค๋ช…ํ•˜๋Š” ์ž‘์—…์€ ๋‹ค์Œ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์—์„œ ์ง€์›๋ฉ๋‹ˆ๋‹ค:
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[TimeSformer](../model_doc/timesformer), [VideoMAE](../model_doc/videomae)
<!--End of the generated tip-->
</Tip>
์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•„์š”ํ•œ ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:
```bash
pip install -q pytorchvideo transformers evaluate
```
์˜์ƒ์„ ์ฒ˜๋ฆฌํ•˜๊ณ  ์ค€๋น„ํ•˜๊ธฐ ์œ„ํ•ด [PyTorchVideo](https://pytorchvideo.org/)(์ดํ•˜ `pytorchvideo`)๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
์ปค๋ฎค๋‹ˆํ‹ฐ์— ๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๊ณ  ๊ณต์œ ํ•  ์ˆ˜ ์žˆ๋„๋ก Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. ํ”„๋กฌํ”„ํŠธ๊ฐ€ ๋‚˜ํƒ€๋‚˜๋ฉด ํ† ํฐ์„ ์ž…๋ ฅํ•˜์—ฌ ๋กœ๊ทธ์ธํ•˜์„ธ์š”:
```py
>>> from huggingface_hub import notebook_login
>>> notebook_login()
```
## UCF101 ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ [[load-ufc101-dataset]]
[UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํ•˜์œ„ ์ง‘ํ•ฉ(subset)์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ํ•™์Šตํ•˜๋Š”๋ฐ ๋” ๋งŽ์€ ์‹œ๊ฐ„์„ ํ• ์• ํ•˜๊ธฐ ์ „์— ๋ฐ์ดํ„ฐ์˜ ํ•˜์œ„ ์ง‘ํ•ฉ์„ ๋ถˆ๋Ÿฌ์™€ ๋ชจ๋“  ๊ฒƒ์ด ์ž˜ ์ž‘๋™ํ•˜๋Š”์ง€ ์‹คํ—˜ํ•˜๊ณ  ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
```py
>>> from huggingface_hub import hf_hub_download
>>> hf_dataset_identifier = "sayakpaul/ucf101-subset"
>>> filename = "UCF101_subset.tar.gz"
>>> file_path = hf_hub_download(repo_id=hf_dataset_identifier, filename=filename, repo_type="dataset")
```
๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํ•˜์œ„ ์ง‘ํ•ฉ์ด ๋‹ค์šด๋กœ๋“œ ๋˜๋ฉด, ์••์ถ•๋œ ํŒŒ์ผ์˜ ์••์ถ•์„ ํ•ด์ œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:
```py
>>> import tarfile
>>> with tarfile.open(file_path) as t:
... t.extractall(".")
```
์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
```bash
UCF101_subset/
train/
BandMarching/
video_1.mp4
video_2.mp4
...
Archery
video_1.mp4
video_2.mp4
...
...
val/
BandMarching/
video_1.mp4
video_2.mp4
...
Archery
video_1.mp4
video_2.mp4
...
...
test/
BandMarching/
video_1.mp4
video_2.mp4
...
Archery
video_1.mp4
video_2.mp4
...
...
```
์ •๋ ฌ๋œ ์˜์ƒ์˜ ๊ฒฝ๋กœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:
```bash
...
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c04.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c06.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c02.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c06.avi'
...
```
๋™์ผํ•œ ๊ทธ๋ฃน/์žฅ๋ฉด์— ์†ํ•˜๋Š” ์˜์ƒ ํด๋ฆฝ์€ ํŒŒ์ผ ๊ฒฝ๋กœ์—์„œ `g`๋กœ ํ‘œ์‹œ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด, `v_ApplyEyeMakeup_g07_c04.avi`์™€ `v_ApplyEyeMakeup_g07_c06.avi` ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋‘˜์€ ๊ฐ™์€ ๊ทธ๋ฃน์ž…๋‹ˆ๋‹ค.
๊ฒ€์ฆ ๋ฐ ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ ๋ถ„ํ• ์„ ํ•  ๋•Œ, [๋ฐ์ดํ„ฐ ๋ˆ„์ถœ(data leakage)](https://www.kaggle.com/code/alexisbcook/data-leakage)์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ๋™์ผํ•œ ๊ทธ๋ฃน / ์žฅ๋ฉด์˜ ์˜์ƒ ํด๋ฆฝ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์‚ฌ์šฉํ•˜๋Š” ํ•˜์œ„ ์ง‘ํ•ฉ์€ ์ด๋Ÿฌํ•œ ์ •๋ณด๋ฅผ ๊ณ ๋ คํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
๊ทธ ๋‹ค์Œ์œผ๋กœ, ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ์กด์žฌํ•˜๋Š” ๋ผ๋ฒจ์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ, ๋ชจ๋ธ์„ ์ดˆ๊ธฐํ™”ํ•  ๋•Œ ๋„์›€์ด ๋  ๋”•์…”๋„ˆ๋ฆฌ(dictionary data type)๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
* `label2id`: ํด๋ž˜์Šค ์ด๋ฆ„์„ ์ •์ˆ˜์— ๋งคํ•‘ํ•ฉ๋‹ˆ๋‹ค.
* `id2label`: ์ •์ˆ˜๋ฅผ ํด๋ž˜์Šค ์ด๋ฆ„์— ๋งคํ•‘ํ•ฉ๋‹ˆ๋‹ค.
```py
>>> class_labels = sorted({str(path).split("/")[2] for path in all_video_file_paths})
>>> label2id = {label: i for i, label in enumerate(class_labels)}
>>> id2label = {i: label for label, i in label2id.items()}
>>> print(f"Unique classes: {list(label2id.keys())}.")
# Unique classes: ['ApplyEyeMakeup', 'ApplyLipstick', 'Archery', 'BabyCrawling', 'BalanceBeam', 'BandMarching', 'BaseballPitch', 'Basketball', 'BasketballDunk', 'BenchPress'].
```
์ด ๋ฐ์ดํ„ฐ ์„ธํŠธ์—๋Š” ์ด 10๊ฐœ์˜ ๊ณ ์œ ํ•œ ํด๋ž˜์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ ํด๋ž˜์Šค๋งˆ๋‹ค 30๊ฐœ์˜ ์˜์ƒ์ด ํ›ˆ๋ จ ์„ธํŠธ์— ์žˆ์Šต๋‹ˆ๋‹ค
## ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ [[load-a-model-to-fine-tune]]
์‚ฌ์ „ ํ›ˆ๋ จ๋œ ์ฒดํฌํฌ์ธํŠธ์™€ ์ฒดํฌํฌ์ธํŠธ์— ์—ฐ๊ด€๋œ ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜์ƒ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ์ธ์Šคํ„ด์Šคํ™”ํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ์˜ ์ธ์ฝ”๋”์—๋Š” ๋ฏธ๋ฆฌ ํ•™์Šต๋œ ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ์ œ๊ณต๋˜๋ฉฐ, ๋ถ„๋ฅ˜ ํ—ค๋“œ(๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด)๋Š” ๋ฌด์ž‘์œ„๋กœ ์ดˆ๊ธฐํ™”๋ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ž‘์„ฑํ•  ๋•Œ๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๊ฐ€ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.
```py
>>> from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
>>> model_ckpt = "MCG-NJU/videomae-base"
>>> image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt)
>>> model = VideoMAEForVideoClassification.from_pretrained(
... model_ckpt,
... label2id=label2id,
... id2label=id2label,
... ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
... )
```
๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ค๋Š” ๋™์•ˆ, ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฒฝ๊ณ ๋ฅผ ๋งˆ์ฃผ์น  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:
```bash
Some weights of the model checkpoint at MCG-NJU/videomae-base were not used when initializing VideoMAEForVideoClassification: [..., 'decoder.decoder_layers.1.attention.output.dense.bias', 'decoder.decoder_layers.2.attention.attention.key.weight']
- This IS expected if you are initializing VideoMAEForVideoClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VideoMAEForVideoClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VideoMAEForVideoClassification were not initialized from the model checkpoint at MCG-NJU/videomae-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
์œ„ ๊ฒฝ๊ณ ๋Š” ์šฐ๋ฆฌ๊ฐ€ ์ผ๋ถ€ ๊ฐ€์ค‘์น˜(์˜ˆ: `classifier` ์ธต์˜ ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ)๋ฅผ ๋ฒ„๋ฆฌ๊ณ  ์ƒˆ๋กœ์šด `classifier` ์ธต์˜ ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ์„ ๋ฌด์ž‘์œ„๋กœ ์ดˆ๊ธฐํ™”ํ•˜๊ณ  ์žˆ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ๋ ค์ค๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ์—๋Š” ๋ฏธ๋ฆฌ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๊ฐ€ ์—†๋Š” ์ƒˆ๋กœ์šด ํ—ค๋“œ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ณ  ์žˆ์œผ๋ฏ€๋กœ, ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•˜๊ธฐ ์ „์— ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋ผ๊ณ  ๊ฒฝ๊ณ ๋ฅผ ๋ณด๋‚ด๋Š” ๊ฒƒ์€ ๋‹น์—ฐํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ด์ œ ์šฐ๋ฆฌ๋Š” ์ด ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•  ์˜ˆ์ •์ž…๋‹ˆ๋‹ค.
**์ฐธ๊ณ ** ์ด [์ฒดํฌํฌ์ธํŠธ](https://huggingface.co/MCG-NJU/videomae-base-finetuned-kinetics)๋Š” ๋„๋ฉ”์ธ์ด ๋งŽ์ด ์ค‘์ฒฉ๋œ ์œ ์‚ฌํ•œ ๋‹ค์šด์ŠคํŠธ๋ฆผ ์ž‘์—…์— ๋Œ€ํ•ด ๋ฏธ์„ธ ์กฐ์ •ํ•˜์—ฌ ์–ป์€ ์ฒดํฌํฌ์ธํŠธ์ด๋ฏ€๋กœ ์ด ์ž‘์—…์—์„œ ๋” ๋‚˜์€ ์„ฑ๋Šฅ์„ ๋ณด์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. `MCG-NJU/videomae-base-finetuned-kinetics` ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋ฏธ์„ธ ์กฐ์ •ํ•˜์—ฌ ์–ป์€ [์ฒดํฌํฌ์ธํŠธ](https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset)๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
## ํ›ˆ๋ จ์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ ์ค€๋น„ํ•˜๊ธฐ[[prepare-the-datasets-for-training]]
์˜์ƒ ์ „์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด [PyTorchVideo ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ](https://pytorchvideo.org/)๋ฅผ ํ™œ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํ•„์š”ํ•œ ์ข…์†์„ฑ์„ ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•˜์„ธ์š”.
```py
>>> import pytorchvideo.data
>>> from pytorchvideo.transforms import (
... ApplyTransformToKey,
... Normalize,
... RandomShortSideScale,
... RemoveKey,
... ShortSideScale,
... UniformTemporalSubsample,
... )
>>> from torchvision.transforms import (
... Compose,
... Lambda,
... RandomCrop,
... RandomHorizontalFlip,
... Resize,
... )
```
ํ•™์Šต ๋ฐ์ดํ„ฐ ์„ธํŠธ ๋ณ€ํ™˜์—๋Š” '๊ท ์ผํ•œ ์‹œ๊ฐ„ ์ƒ˜ํ”Œ๋ง(uniform temporal subsampling)', 'ํ”ฝ์…€ ์ •๊ทœํ™”(pixel normalization)', '๋žœ๋ค ์ž˜๋ผ๋‚ด๊ธฐ(random cropping)' ๋ฐ '๋žœ๋ค ์ˆ˜ํ‰ ๋’ค์ง‘๊ธฐ(random horizontal flipping)'์˜ ์กฐํ•ฉ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ๊ฒ€์ฆ ๋ฐ ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ ์„ธํŠธ ๋ณ€ํ™˜์—๋Š” '๋žœ๋ค ์ž˜๋ผ๋‚ด๊ธฐ'์™€ '๋žœ๋ค ๋’ค์ง‘๊ธฐ'๋ฅผ ์ œ์™ธํ•œ ๋™์ผํ•œ ๋ณ€ํ™˜ ์ฒด์ธ์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ณ€ํ™˜์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด [PyTorchVideo ๊ณต์‹ ๋ฌธ์„œ](https://pytorchvideo.org)๋ฅผ ํ™•์ธํ•˜์„ธ์š”.
์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ๊ณผ ๊ด€๋ จ๋œ ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์Œ ์ •๋ณด๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:
* ์˜์ƒ ํ”„๋ ˆ์ž„ ํ”ฝ์…€์„ ์ •๊ทœํ™”ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” ์ด๋ฏธ์ง€ ํ‰๊ท ๊ณผ ํ‘œ์ค€ ํŽธ์ฐจ
* ์˜์ƒ ํ”„๋ ˆ์ž„์ด ์กฐ์ •๋  ๊ณต๊ฐ„ ํ•ด์ƒ๋„
๋จผ์ €, ๋ช‡ ๊ฐ€์ง€ ์ƒ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
```py
>>> mean = image_processor.image_mean
>>> std = image_processor.image_std
>>> if "shortest_edge" in image_processor.size:
... height = width = image_processor.size["shortest_edge"]
>>> else:
... height = image_processor.size["height"]
... width = image_processor.size["width"]
>>> resize_to = (height, width)
>>> num_frames_to_sample = model.config.num_frames
>>> sample_rate = 4
>>> fps = 30
>>> clip_duration = num_frames_to_sample * sample_rate / fps
```
์ด์ œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ํŠนํ™”๋œ ์ „์ฒ˜๋ฆฌ(transform)๊ณผ ๋ฐ์ดํ„ฐ ์„ธํŠธ ์ž์ฒด๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ๋จผ์ € ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋กœ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค:
```py
>>> train_transform = Compose(
... [
... ApplyTransformToKey(
... key="video",
... transform=Compose(
... [
... UniformTemporalSubsample(num_frames_to_sample),
... Lambda(lambda x: x / 255.0),
... Normalize(mean, std),
... RandomShortSideScale(min_size=256, max_size=320),
... RandomCrop(resize_to),
... RandomHorizontalFlip(p=0.5),
... ]
... ),
... ),
... ]
... )
>>> train_dataset = pytorchvideo.data.Ucf101(
... data_path=os.path.join(dataset_root_path, "train"),
... clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
... decode_audio=False,
... transform=train_transform,
... )
```
๊ฐ™์€ ๋ฐฉ์‹์˜ ์ž‘์—… ํ๋ฆ„์„ ๊ฒ€์ฆ๊ณผ ํ‰๊ฐ€ ์„ธํŠธ์—๋„ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
```py
>>> val_transform = Compose(
... [
... ApplyTransformToKey(
... key="video",
... transform=Compose(
... [
... UniformTemporalSubsample(num_frames_to_sample),
... Lambda(lambda x: x / 255.0),
... Normalize(mean, std),
... Resize(resize_to),
... ]
... ),
... ),
... ]
... )
>>> val_dataset = pytorchvideo.data.Ucf101(
... data_path=os.path.join(dataset_root_path, "val"),
... clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
... decode_audio=False,
... transform=val_transform,
... )
>>> test_dataset = pytorchvideo.data.Ucf101(
... data_path=os.path.join(dataset_root_path, "test"),
... clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
... decode_audio=False,
... transform=val_transform,
... )
```
**์ฐธ๊ณ **: ์œ„์˜ ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํŒŒ์ดํ”„๋ผ์ธ์€ [๊ณต์‹ ํŒŒ์ดํ† ์น˜ ์˜ˆ์ œ](https://pytorchvideo.org/docs/tutorial_classification#dataset)์—์„œ ๊ฐ€์ ธ์˜จ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” UCF-101 ๋ฐ์ดํ„ฐ์…‹์— ๋งž๊ฒŒ [`pytorchvideo.data.Ucf101()`](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html#pytorchvideo.data.Ucf101) ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚ด๋ถ€์ ์œผ๋กœ ์ด ํ•จ์ˆ˜๋Š” [`pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset`](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html#pytorchvideo.data.LabeledVideoDataset) ๊ฐ์ฒด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. `LabeledVideoDataset` ํด๋ž˜์Šค๋Š” PyTorchVideo ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋ชจ๋“  ์˜์ƒ ๊ด€๋ จ ์ž‘์—…์˜ ๊ธฐ๋ณธ ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ PyTorchVideo์—์„œ ๋ฏธ๋ฆฌ ์ œ๊ณตํ•˜์ง€ ์•Š๋Š” ์‚ฌ์šฉ์ž ์ง€์ • ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด, ์ด ํด๋ž˜์Šค๋ฅผ ์ ์ ˆํ•˜๊ฒŒ ํ™•์žฅํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. ๋” ์ž์„ธํ•œ ์‚ฌํ•ญ์ด ์•Œ๊ณ  ์‹ถ๋‹ค๋ฉด `data` API [๋ฌธ์„œ](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html) ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”. ๋˜ํ•œ ์œ„์˜ ์˜ˆ์‹œ์™€ ์œ ์‚ฌํ•œ ๊ตฌ์กฐ๋ฅผ ๊ฐ–๋Š” ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋‹ค๋ฉด, `pytorchvideo.data.Ucf101()` ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐ ๋ฌธ์ œ๊ฐ€ ์—†์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.
๋ฐ์ดํ„ฐ ์„ธํŠธ์— ์˜์ƒ์˜ ๊ฐœ์ˆ˜๋ฅผ ์•Œ๊ธฐ ์œ„ํ•ด `num_videos` ์ธ์ˆ˜์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
```py
>>> print(train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos)
# (300, 30, 75)
```
## ๋” ๋‚˜์€ ๋””๋ฒ„๊น…์„ ์œ„ํ•ด ์ „์ฒ˜๋ฆฌ ์˜์ƒ ์‹œ๊ฐํ™”ํ•˜๊ธฐ[[visualize-the-preprocessed-video-for-better-debugging]]
```py
>>> import imageio
>>> import numpy as np
>>> from IPython.display import Image
>>> def unnormalize_img(img):
... """Un-normalizes the image pixels."""
... img = (img * std) + mean
... img = (img * 255).astype("uint8")
... return img.clip(0, 255)
>>> def create_gif(video_tensor, filename="sample.gif"):
... """Prepares a GIF from a video tensor.
...
... The video tensor is expected to have the following shape:
... (num_frames, num_channels, height, width).
... """
... frames = []
... for video_frame in video_tensor:
... frame_unnormalized = unnormalize_img(video_frame.permute(1, 2, 0).numpy())
... frames.append(frame_unnormalized)
... kargs = {"duration": 0.25}
... imageio.mimsave(filename, frames, "GIF", **kargs)
... return filename
>>> def display_gif(video_tensor, gif_name="sample.gif"):
... """Prepares and displays a GIF from a video tensor."""
... video_tensor = video_tensor.permute(1, 0, 2, 3)
... gif_filename = create_gif(video_tensor, gif_name)
... return Image(filename=gif_filename)
>>> sample_video = next(iter(train_dataset))
>>> video_tensor = sample_video["video"]
>>> display_gif(video_tensor)
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_gif.gif" alt="Person playing basketball"/>
</div>
## ๋ชจ๋ธ ํ›ˆ๋ จํ•˜๊ธฐ[[train-the-model]]
๐Ÿค— Transformers์˜ [`Trainer`](https://huggingface.co/docs/transformers/main_classes/trainer)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ›ˆ๋ จ์‹œ์ผœ๋ณด์„ธ์š”. `Trainer`๋ฅผ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๋ ค๋ฉด ํ›ˆ๋ จ ์„ค์ •๊ณผ ํ‰๊ฐ€ ์ง€ํ‘œ๋ฅผ ์ •์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๊ฒƒ์€ [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments)์ž…๋‹ˆ๋‹ค. ์ด ํด๋ž˜์Šค๋Š” ํ›ˆ๋ จ์„ ๊ตฌ์„ฑํ•˜๋Š” ๋ชจ๋“  ์†์„ฑ์„ ํฌํ•จํ•˜๋ฉฐ, ํ›ˆ๋ จ ์ค‘ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•  ์ถœ๋ ฅ ํด๋” ์ด๋ฆ„์„ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ๐Ÿค— Hub์˜ ๋ชจ๋ธ ์ €์žฅ์†Œ์˜ ๋ชจ๋“  ์ •๋ณด๋ฅผ ๋™๊ธฐํ™”ํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค.
๋Œ€๋ถ€๋ถ„์˜ ํ›ˆ๋ จ ์ธ์ˆ˜๋Š” ๋”ฐ๋กœ ์„ค๋ช…ํ•  ํ•„์š”๋Š” ์—†์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์—ฌ๊ธฐ์—์„œ ์ค‘์š”ํ•œ ์ธ์ˆ˜๋Š” `remove_unused_columns=False` ์ž…๋‹ˆ๋‹ค. ์ด ์ธ์ž๋Š” ๋ชจ๋ธ์˜ ํ˜ธ์ถœ ํ•จ์ˆ˜์—์„œ ์‚ฌ์šฉ๋˜์ง€ ์•Š๋Š” ๋ชจ๋“  ์†์„ฑ ์—ด(columns)์„ ์‚ญ์ œํ•ฉ๋‹ˆ๋‹ค. ๊ธฐ๋ณธ๊ฐ’์€ ์ผ๋ฐ˜์ ์œผ๋กœ True์ž…๋‹ˆ๋‹ค. ์ด๋Š” ์‚ฌ์šฉ๋˜์ง€ ์•Š๋Š” ๊ธฐ๋Šฅ ์—ด์„ ์‚ญ์ œํ•˜๋Š” ๊ฒƒ์ด ์ด์ƒ์ ์ด๋ฉฐ, ์ž…๋ ฅ์„ ๋ชจ๋ธ์˜ ํ˜ธ์ถœ ํ•จ์ˆ˜๋กœ ํ’€๊ธฐ(unpack)๊ฐ€ ์‰ฌ์›Œ์ง€๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด ๊ฒฝ์šฐ์—๋Š” `pixel_values`(๋ชจ๋ธ์˜ ์ž…๋ ฅ์œผ๋กœ ํ•„์ˆ˜์ ์ธ ํ‚ค)๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋˜์ง€ ์•Š๋Š” ๊ธฐ๋Šฅ('video'๊ฐ€ ํŠนํžˆ ๊ทธ๋ ‡์Šต๋‹ˆ๋‹ค)์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ remove_unused_columns์„ False๋กœ ์„ค์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
```py
>>> from transformers import TrainingArguments, Trainer
>>> model_name = model_ckpt.split("/")[-1]
>>> new_model_name = f"{model_name}-finetuned-ucf101-subset"
>>> num_epochs = 4
>>> args = TrainingArguments(
... new_model_name,
... remove_unused_columns=False,
... evaluation_strategy="epoch",
... save_strategy="epoch",
... learning_rate=5e-5,
... per_device_train_batch_size=batch_size,
... per_device_eval_batch_size=batch_size,
... warmup_ratio=0.1,
... logging_steps=10,
... load_best_model_at_end=True,
... metric_for_best_model="accuracy",
... push_to_hub=True,
... max_steps=(train_dataset.num_videos // batch_size) * num_epochs,
... )
```
`pytorchvideo.data.Ucf101()` ํ•จ์ˆ˜๋กœ ๋ฐ˜ํ™˜๋˜๋Š” ๋ฐ์ดํ„ฐ ์„ธํŠธ๋Š” `__len__` ๋ฉ”์†Œ๋“œ๊ฐ€ ์ด์‹๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, `TrainingArguments`๋ฅผ ์ธ์Šคํ„ด์Šคํ™”ํ•  ๋•Œ `max_steps`๋ฅผ ์ •์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
๋‹ค์Œ์œผ๋กœ, ํ‰๊ฐ€์ง€ํ‘œ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ , ์˜ˆ์ธก๊ฐ’์—์„œ ํ‰๊ฐ€์ง€ํ‘œ๋ฅผ ๊ณ„์‚ฐํ•  ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ํ•„์š”ํ•œ ์ „์ฒ˜๋ฆฌ ์ž‘์—…์€ ์˜ˆ์ธก๋œ ๋กœ์ง“(logits)์— argmax ๊ฐ’์„ ์ทจํ•˜๋Š” ๊ฒƒ๋ฟ์ž…๋‹ˆ๋‹ค:
```py
import evaluate
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
```
**ํ‰๊ฐ€์— ๋Œ€ํ•œ ์ฐธ๊ณ ์‚ฌํ•ญ**:
[VideoMAE ๋…ผ๋ฌธ](https://arxiv.org/abs/2203.12602)์—์„œ ์ €์ž๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํ‰๊ฐ€ ์ „๋žต์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ํ…Œ์ŠคํŠธ ์˜์ƒ์—์„œ ์—ฌ๋Ÿฌ ํด๋ฆฝ์„ ์„ ํƒํ•˜๊ณ  ๊ทธ ํด๋ฆฝ์— ๋‹ค์–‘ํ•œ ํฌ๋กญ์„ ์ ์šฉํ•˜์—ฌ ์ง‘๊ณ„ ์ ์ˆ˜๋ฅผ ๋ณด๊ณ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด๋ฒˆ ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ๊ฐ„๋‹จํ•จ๊ณผ ๊ฐ„๊ฒฐํ•จ์„ ์œ„ํ•ด ํ•ด๋‹น ์ „๋žต์„ ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
๋˜ํ•œ, ์˜ˆ์ œ๋ฅผ ๋ฌถ์–ด์„œ ๋ฐฐ์น˜๋ฅผ ํ˜•์„ฑํ•˜๋Š” `collate_fn`์„ ์ •์˜ํ•ด์•ผํ•ฉ๋‹ˆ๋‹ค. ๊ฐ ๋ฐฐ์น˜๋Š” `pixel_values`์™€ `labels`๋ผ๋Š” 2๊ฐœ์˜ ํ‚ค๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.
```py
>>> def collate_fn(examples):
... # permute to (num_frames, num_channels, height, width)
... pixel_values = torch.stack(
... [example["video"].permute(1, 0, 2, 3) for example in examples]
... )
... labels = torch.tensor([example["label"] for example in examples])
... return {"pixel_values": pixel_values, "labels": labels}
```
๊ทธ๋Ÿฐ ๋‹ค์Œ ์ด ๋ชจ๋“  ๊ฒƒ์„ ๋ฐ์ดํ„ฐ ์„ธํŠธ์™€ ํ•จ๊ป˜ `Trainer`์— ์ „๋‹ฌํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค:
```py
>>> trainer = Trainer(
... model,
... args,
... train_dataset=train_dataset,
... eval_dataset=val_dataset,
... tokenizer=image_processor,
... compute_metrics=compute_metrics,
... data_collator=collate_fn,
... )
```
๋ฐ์ดํ„ฐ๋ฅผ ์ด๋ฏธ ์ฒ˜๋ฆฌํ–ˆ๋Š”๋ฐ๋„ ๋ถˆ๊ตฌํ•˜๊ณ  `image_processor`๋ฅผ ํ† ํฌ๋‚˜์ด์ € ์ธ์ˆ˜๋กœ ๋„ฃ์€ ์ด์œ ๋Š” JSON์œผ๋กœ ์ €์žฅ๋˜๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ ๊ตฌ์„ฑ ํŒŒ์ผ์ด Hub์˜ ์ €์žฅ์†Œ์— ์—…๋กœ๋“œ๋˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•จ์ž…๋‹ˆ๋‹ค.
`train` ๋ฉ”์†Œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜์„ธ์š”:
```py
>>> train_results = trainer.train()
```
ํ•™์Šต์ด ์™„๋ฃŒ๋˜๋ฉด, ๋ชจ๋ธ์„ [`~transformers.Trainer.push_to_hub`] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ—ˆ๋ธŒ์— ๊ณต์œ ํ•˜์—ฌ ๋ˆ„๊ตฌ๋‚˜ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค:
```py
>>> trainer.push_to_hub()
```
## ์ถ”๋ก ํ•˜๊ธฐ[[inference]]
์ข‹์Šต๋‹ˆ๋‹ค. ์ด์ œ ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์ถ”๋ก ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ถ”๋ก ์— ์‚ฌ์šฉํ•  ์˜์ƒ์„ ๋ถˆ๋Ÿฌ์˜ค์„ธ์š”:
```py
>>> sample_test_video = next(iter(test_dataset))
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_gif_two.gif" alt="Teams playing basketball"/>
</div>
๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•˜๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ [`pipeline`](https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.VideoClassificationPipeline)์—์„œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ชจ๋ธ๋กœ ์˜์ƒ ๋ถ„๋ฅ˜๋ฅผ ํ•˜๊ธฐ ์œ„ํ•ด `pipeline`์„ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๊ณ  ์˜์ƒ์„ ์ „๋‹ฌํ•˜์„ธ์š”:
```py
>>> from transformers import pipeline
>>> video_cls = pipeline(model="my_awesome_video_cls_model")
>>> video_cls("https://huggingface.co/datasets/sayakpaul/ucf101-subset/resolve/main/v_BasketballDunk_g14_c06.avi")
[{'score': 0.9272987842559814, 'label': 'BasketballDunk'},
{'score': 0.017777055501937866, 'label': 'BabyCrawling'},
{'score': 0.01663011871278286, 'label': 'BalanceBeam'},
{'score': 0.009560945443809032, 'label': 'BandMarching'},
{'score': 0.0068979403004050255, 'label': 'BaseballPitch'}]
```
๋งŒ์•ฝ ์›ํ•œ๋‹ค๋ฉด ์ˆ˜๋™์œผ๋กœ `pipeline`์˜ ๊ฒฐ๊ณผ๋ฅผ ์žฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:
```py
>>> def run_inference(model, video):
... # (num_frames, num_channels, height, width)
... perumuted_sample_test_video = video.permute(1, 0, 2, 3)
... inputs = {
... "pixel_values": perumuted_sample_test_video.unsqueeze(0),
... "labels": torch.tensor(
... [sample_test_video["label"]]
... ), # this can be skipped if you don't have labels available.
... }
... device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
... inputs = {k: v.to(device) for k, v in inputs.items()}
... model = model.to(device)
... # forward pass
... with torch.no_grad():
... outputs = model(**inputs)
... logits = outputs.logits
... return logits
```
๋ชจ๋ธ์— ์ž…๋ ฅ๊ฐ’์„ ๋„ฃ๊ณ  `logits`์„ ๋ฐ˜ํ™˜๋ฐ›์œผ์„ธ์š”:
```py
>>> logits = run_inference(trained_model, sample_test_video["video"])
```
`logits`์„ ๋””์ฝ”๋”ฉํ•˜๋ฉด, ์šฐ๋ฆฌ๋Š” ๋‹ค์Œ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:
```py
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
# Predicted class: BasketballDunk
```