transformers / docs /source /ko /tasks /visual_question_answering.md
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
โš ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# ์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต (Visual Question Answering)
[[open-in-colab]]
์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต(VQA)์€ ์ด๋ฏธ์ง€๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ฐœ๋ฐฉํ˜• ์งˆ๋ฌธ์— ๋Œ€์‘ํ•˜๋Š” ์ž‘์—…์ž…๋‹ˆ๋‹ค. ์ด ์ž‘์—…์„ ์ง€์›ํ•˜๋Š” ๋ชจ๋ธ์˜ ์ž…๋ ฅ์€ ๋Œ€๋ถ€๋ถ„ ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์˜ ์กฐํ•ฉ์ด๋ฉฐ, ์ถœ๋ ฅ์€ ์ž์—ฐ์–ด๋กœ ๋œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค.
VQA์˜ ์ฃผ์š” ์‚ฌ์šฉ ์‚ฌ๋ก€๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:
* ์‹œ๊ฐ ์žฅ์• ์ธ์„ ์œ„ํ•œ ์ ‘๊ทผ์„ฑ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ๊ตฌ์ถ•ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
* ๊ต์œก: ๊ฐ•์˜๋‚˜ ๊ต๊ณผ์„œ์— ๋‚˜์˜จ ์‹œ๊ฐ ์ž๋ฃŒ์— ๋Œ€ํ•œ ์งˆ๋ฌธ์— ๋‹ตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ์ฒดํ—˜ํ˜• ์ „์‹œ์™€ ์œ ์  ๋“ฑ์—์„œ๋„ VQA๋ฅผ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
* ๊ณ ๊ฐ ์„œ๋น„์Šค ๋ฐ ์ „์ž์ƒ๊ฑฐ๋ž˜: VQA๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ์ œํ’ˆ์— ๋Œ€ํ•ด ์งˆ๋ฌธํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•จ์œผ๋กœ์จ ์‚ฌ์šฉ์ž ๊ฒฝํ—˜์„ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
* ์ด๋ฏธ์ง€ ๊ฒ€์ƒ‰: VQA ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์›ํ•˜๋Š” ํŠน์„ฑ์„ ๊ฐ€์ง„ ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ์‚ฌ์šฉ์ž๋Š” "๊ฐ•์•„์ง€๊ฐ€ ์žˆ์–ด?"๋ผ๊ณ  ๋ฌผ์–ด๋ด์„œ ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ๋ฌถ์Œ์—์„œ ๊ฐ•์•„์ง€๊ฐ€ ์žˆ๋Š” ๋ชจ๋“  ์ด๋ฏธ์ง€๋ฅผ ๋ฐ›์•„๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ด ๊ฐ€์ด๋“œ์—์„œ ํ•™์Šตํ•  ๋‚ด์šฉ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:
- VQA ๋ชจ๋ธ ์ค‘ ํ•˜๋‚˜์ธ [ViLT](../../en/model_doc/vilt)๋ฅผ [`Graphcore/vqa` ๋ฐ์ดํ„ฐ์…‹](https://huggingface.co/datasets/Graphcore/vqa) ์—์„œ ๋ฏธ์„ธ์กฐ์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•
- ๋ฏธ์„ธ์กฐ์ •๋œ ViLT ๋ชจ๋ธ๋กœ ์ถ”๋ก ํ•˜๋Š” ๋ฐฉ๋ฒ•
- BLIP-2 ๊ฐ™์€ ์ƒ์„ฑ ๋ชจ๋ธ๋กœ ์ œ๋กœ์ƒท VQA ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•
## ViLT ๋ฏธ์„ธ ์กฐ์ • [[finetuning-vilt]]
ViLT๋Š” Vision Transformer (ViT) ๋‚ด์— ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์„ ํฌํ•จํ•˜์—ฌ ๋น„์ „/์ž์—ฐ์–ด ์‚ฌ์ „ํ›ˆ๋ จ(VLP; Vision-and-Language Pretraining)์„ ์œ„ํ•œ ๊ธฐ๋ณธ ๋””์ž์ธ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
ViLT ๋ชจ๋ธ์€ ๋น„์ „ ํŠธ๋žœ์Šคํฌ๋จธ(ViT)์— ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์„ ๋„ฃ์–ด ๋น„์ „/์–ธ์–ด ์‚ฌ์ „ํ›ˆ๋ จ(VLP; Vision-and-Language Pre-training)์„ ์œ„ํ•œ ๊ธฐ๋ณธ์ ์ธ ๋””์ž์ธ์„ ๊ฐ–์ท„์Šต๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ์—ฌ๋Ÿฌ ๋‹ค์šด์ŠคํŠธ๋ฆผ ์ž‘์—…์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. VQA ํƒœ์Šคํฌ์—์„œ๋Š” (`[CLS]` ํ† ํฐ์˜ ์ตœ์ข… ์€๋‹‰ ์ƒํƒœ ์œ„์— ์„ ํ˜• ๋ ˆ์ด์–ด์ธ) ๋ถ„๋ฅ˜ ํ—ค๋”๊ฐ€ ์žˆ์œผ๋ฉฐ ๋ฌด์ž‘์œ„๋กœ ์ดˆ๊ธฐํ™”๋ฉ๋‹ˆ๋‹ค.
๋”ฐ๋ผ์„œ ์—ฌ๊ธฐ์—์„œ ์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต์€ **๋ถ„๋ฅ˜ ๋ฌธ์ œ**๋กœ ์ทจ๊ธ‰๋ฉ๋‹ˆ๋‹ค.
์ตœ๊ทผ์˜ BLIP, BLIP-2, InstructBLIP์™€ ๊ฐ™์€ ๋ชจ๋ธ๋“ค์€ VQA๋ฅผ ์ƒ์„ฑํ˜• ์ž‘์—…์œผ๋กœ ๊ฐ„์ฃผํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€์ด๋“œ์˜ ํ›„๋ฐ˜๋ถ€์—์„œ๋Š” ์ด๋Ÿฐ ๋ชจ๋ธ๋“ค์„ ์‚ฌ์šฉํ•˜์—ฌ ์ œ๋กœ์ƒท VQA ์ถ”๋ก ์„ ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์„ค๋ช…ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
์‹œ์ž‘ํ•˜๊ธฐ ์ „ ํ•„์š”ํ•œ ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ–ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.
```bash
pip install -q transformers datasets
```
์ปค๋ฎค๋‹ˆํ‹ฐ์— ๋ชจ๋ธ์„ ๊ณต์œ ํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅ ๋“œ๋ฆฝ๋‹ˆ๋‹ค. Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธํ•˜์—ฌ ๐Ÿค— Hub์— ์—…๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
๋ฉ”์‹œ์ง€๊ฐ€ ๋‚˜ํƒ€๋‚˜๋ฉด ๋กœ๊ทธ์ธํ•  ํ† ํฐ์„ ์ž…๋ ฅํ•˜์„ธ์š”:
```py
>>> from huggingface_hub import notebook_login
>>> notebook_login()
```
๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ „์—ญ ๋ณ€์ˆ˜๋กœ ์„ ์–ธํ•˜์„ธ์š”.
```py
>>> model_checkpoint = "dandelin/vilt-b32-mlm"
```
## ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ [[load-the-data]]
์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” `Graphcore/vqa` ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ์ž‘์€ ์ƒ˜ํ”Œ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ „์ฒด ๋ฐ์ดํ„ฐ์„ธํŠธ๋Š” [๐Ÿค— Hub](https://huggingface.co/datasets/Graphcore/vqa) ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
[`Graphcore/vqa` ๋ฐ์ดํ„ฐ์„ธํŠธ](https://huggingface.co/datasets/Graphcore/vqa) ์˜ ๋Œ€์•ˆ์œผ๋กœ ๊ณต์‹ [VQA ๋ฐ์ดํ„ฐ์„ธํŠธ ํŽ˜์ด์ง€](https://visualqa.org/download.html) ์—์„œ ๋™์ผํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ง์ ‘ ๊ณต์ˆ˜ํ•œ ๋ฐ์ดํ„ฐ๋กœ ํŠœํ† ๋ฆฌ์–ผ์„ ๋”ฐ๋ฅด๊ณ  ์‹ถ๋‹ค๋ฉด [์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์„ธํŠธ ๋งŒ๋“ค๊ธฐ](https://huggingface.co/docs/datasets/image_dataset#loading-script) ๋ผ๋Š”
๐Ÿค— Datasets ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.
๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์˜ ์ฒซ 200๊ฐœ ํ•ญ๋ชฉ์„ ๋ถˆ๋Ÿฌ์™€ ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ํŠน์„ฑ์„ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
```python
>>> from datasets import load_dataset
>>> dataset = load_dataset("Graphcore/vqa", split="validation[:200]")
>>> dataset
Dataset({
features: ['question', 'question_type', 'question_id', 'image_id', 'answer_type', 'label'],
num_rows: 200
})
```
์˜ˆ์ œ๋ฅผ ํ•˜๋‚˜ ๋ฝ‘์•„ ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ํŠน์„ฑ์„ ์ดํ•ดํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
```py
>>> dataset[0]
{'question': 'Where is he looking?',
'question_type': 'none of the above',
'question_id': 262148000,
'image_id': '/root/.cache/huggingface/datasets/downloads/extracted/ca733e0e000fb2d7a09fbcc94dbfe7b5a30750681d0e965f8e0a23b1c2f98c75/val2014/COCO_val2014_000000262148.jpg',
'answer_type': 'other',
'label': {'ids': ['at table', 'down', 'skateboard', 'table'],
'weights': [0.30000001192092896,
1.0,
0.30000001192092896,
0.30000001192092896]}}
```
๋ฐ์ดํ„ฐ์„ธํŠธ์—๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํŠน์„ฑ์ด ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค:
* `question`: ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์งˆ๋ฌธ
* `image_id`: ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ์ด๋ฏธ์ง€์˜ ๊ฒฝ๋กœ
* `label`: ๋ฐ์ดํ„ฐ์˜ ๋ ˆ์ด๋ธ” (annotations)
๋‚˜๋จธ์ง€ ํŠน์„ฑ๋“ค์€ ํ•„์š”ํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์‚ญ์ œํ•ด๋„ ๋ฉ๋‹ˆ๋‹ค:
```py
>>> dataset = dataset.remove_columns(['question_type', 'question_id', 'answer_type'])
```
๋ณด์‹œ๋‹ค์‹œํ”ผ `label` ํŠน์„ฑ์€ ๊ฐ™์€ ์งˆ๋ฌธ๋งˆ๋‹ค ๋‹ต๋ณ€์ด ์—ฌ๋Ÿฌ ๊ฐœ ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ชจ๋‘ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ๋ผ๋ฒจ๋Ÿฌ๋“ค๋กœ๋ถ€ํ„ฐ ์ˆ˜์ง‘๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์ธ๋ฐ์š”. ์งˆ๋ฌธ์˜ ๋‹ต๋ณ€์€ ์ฃผ๊ด€์ ์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ ์งˆ๋ฌธ์€ "๊ทธ๋Š” ์–ด๋””๋ฅผ ๋ณด๊ณ  ์žˆ๋‚˜์š”?" ์˜€์ง€๋งŒ, ์–ด๋–ค ์‚ฌ๋žŒ๋“ค์€ "์•„๋ž˜"๋กœ ๋ ˆ์ด๋ธ”์„ ๋‹ฌ์•˜๊ณ , ๋‹ค๋ฅธ ์‚ฌ๋žŒ๋“ค์€ "ํ…Œ์ด๋ธ”" ๋˜๋Š” "์Šค์ผ€์ดํŠธ๋ณด๋“œ" ๋“ฑ์œผ๋กœ ์ฃผ์„์„ ๋‹ฌ์•˜์Šต๋‹ˆ๋‹ค.
์•„๋ž˜์˜ ์ด๋ฏธ์ง€๋ฅผ ๋ณด๊ณ  ์–ด๋–ค ๋‹ต๋ณ€์„ ์„ ํƒํ•  ๊ฒƒ์ธ์ง€ ์ƒ๊ฐํ•ด ๋ณด์„ธ์š”:
```python
>>> from PIL import Image
>>> image = Image.open(dataset[0]['image_id'])
>>> image
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/vqa-example.png" alt="VQA Image Example"/>
</div>
์งˆ๋ฌธ๊ณผ ๋‹ต๋ณ€์˜ ๋ชจํ˜ธ์„ฑ์œผ๋กœ ์ธํ•ด ์ด๋Ÿฌํ•œ ๋ฐ์ดํ„ฐ์„ธํŠธ๋Š” ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋‹ต๋ณ€์ด ๊ฐ€๋Šฅํ•˜๋ฏ€๋กœ ๋‹ค์ค‘ ๋ ˆ์ด๋ธ” ๋ถ„๋ฅ˜ ๋ฌธ์ œ๋กœ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค. ๊ฒŒ๋‹ค๊ฐ€, ์›ํ•ซ(one-hot) ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ๋ณด๋‹ค๋Š” ๋ ˆ์ด๋ธ”์—์„œ ํŠน์ • ๋‹ต๋ณ€์ด ๋‚˜ํƒ€๋‚˜๋Š” ํšŸ์ˆ˜๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์†Œํ”„ํŠธ ์ธ์ฝ”๋”ฉ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
์œ„์˜ ์˜ˆ์‹œ์—์„œ "์•„๋ž˜"๋ผ๋Š” ๋‹ต๋ณ€์ด ๋‹ค๋ฅธ ๋‹ต๋ณ€๋ณด๋‹ค ํ›จ์”ฌ ๋” ์ž์ฃผ ์„ ํƒ๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ์ดํ„ฐ์„ธํŠธ์—์„œ `weight`๋ผ๊ณ  ๋ถˆ๋ฆฌ๋Š” ์ ์ˆ˜๋กœ 1.0์„ ๊ฐ€์ง€๋ฉฐ, ๋‚˜๋จธ์ง€ ๋‹ต๋ณ€๋“ค์€ 1.0 ๋ฏธ๋งŒ์˜ ์ ์ˆ˜๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค.
์ ์ ˆํ•œ ๋ถ„๋ฅ˜ ํ—ค๋”๋กœ ๋ชจ๋ธ์„ ๋‚˜์ค‘์— ์ธ์Šคํ„ด์Šคํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋ ˆ์ด๋ธ”์„ ์ •์ˆ˜๋กœ ๋งคํ•‘ํ•œ ๋”•์…”๋„ˆ๋ฆฌ ํ•˜๋‚˜, ๋ฐ˜๋Œ€๋กœ ์ •์ˆ˜๋ฅผ ๋ ˆ์ด๋ธ”๋กœ ๋งคํ•‘ํ•œ ๋”•์…”๋„ˆ๋ฆฌ ํ•˜๋‚˜ ์ด 2๊ฐœ์˜ ๋”•์…”๋„ˆ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•˜์„ธ์š”:
```py
>>> import itertools
>>> labels = [item['ids'] for item in dataset['label']]
>>> flattened_labels = list(itertools.chain(*labels))
>>> unique_labels = list(set(flattened_labels))
>>> label2id = {label: idx for idx, label in enumerate(unique_labels)}
>>> id2label = {idx: label for label, idx in label2id.items()}
```
์ด์ œ ๋งคํ•‘์ด ์™„๋ฃŒ๋˜์—ˆ์œผ๋ฏ€๋กœ ๋ฌธ์ž์—ด ๋‹ต๋ณ€์„ ํ•ด๋‹น id๋กœ ๊ต์ฒดํ•˜๊ณ , ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ๋” ํŽธ๋ฆฌํ•œ ํ›„์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด ํŽธํ‰ํ™” ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
```python
>>> def replace_ids(inputs):
... inputs["label"]["ids"] = [label2id[x] for x in inputs["label"]["ids"]]
... return inputs
>>> dataset = dataset.map(replace_ids)
>>> flat_dataset = dataset.flatten()
>>> flat_dataset.features
{'question': Value(dtype='string', id=None),
'image_id': Value(dtype='string', id=None),
'label.ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
'label.weights': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None)}
```
## ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ [[preprocessing-data]]
๋‹ค์Œ ๋‹จ๊ณ„๋Š” ๋ชจ๋ธ์„ ์œ„ํ•ด ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์ค€๋น„ํ•˜๊ธฐ ์œ„ํ•ด ViLT ํ”„๋กœ์„ธ์„œ๋ฅผ ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
[`ViltProcessor`]๋Š” BERT ํ† ํฌ๋‚˜์ด์ €์™€ ViLT ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋ฅผ ํŽธ๋ฆฌํ•˜๊ฒŒ ํ•˜๋‚˜์˜ ํ”„๋กœ์„ธ์„œ๋กœ ๋ฌถ์Šต๋‹ˆ๋‹ค:
```py
>>> from transformers import ViltProcessor
>>> processor = ViltProcessor.from_pretrained(model_checkpoint)
```
๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์„ [`ViltProcessor`]๋กœ ์ธ์ฝ”๋”ฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ํ”„๋กœ์„ธ์„œ๋Š” [`BertTokenizerFast`]๋กœ ํ…์ŠคํŠธ๋ฅผ ํ† ํฌ๋‚˜์ด์ฆˆํ•˜๊ณ  ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์œ„ํ•ด `input_ids`, `attention_mask` ๋ฐ `token_type_ids`๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
์ด๋ฏธ์ง€๋Š” [`ViltImageProcessor`]๋กœ ์ด๋ฏธ์ง€๋ฅผ ํฌ๊ธฐ ์กฐ์ •ํ•˜๊ณ  ์ •๊ทœํ™”ํ•˜๋ฉฐ, `pixel_values`์™€ `pixel_mask`๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
์ด๋Ÿฐ ์ „์ฒ˜๋ฆฌ ๋‹จ๊ณ„๋Š” ๋ชจ๋‘ ๋‚ด๋ถ€์—์„œ ์ด๋ฃจ์–ด์ง€๋ฏ€๋กœ, `processor`๋ฅผ ํ˜ธ์ถœํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์•„์ง ํƒ€๊ฒŸ ๋ ˆ์ด๋ธ”์ด ์™„์„ฑ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ํƒ€๊ฒŸ์˜ ํ‘œํ˜„์—์„œ ๊ฐ ์š”์†Œ๋Š” ๊ฐ€๋Šฅํ•œ ๋‹ต๋ณ€(๋ ˆ์ด๋ธ”)์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค. ์ •ํ™•ํ•œ ๋‹ต๋ณ€์˜ ์š”์†Œ๋Š” ํ•ด๋‹น ์ ์ˆ˜(weight)๋ฅผ ์œ ์ง€์‹œํ‚ค๊ณ  ๋‚˜๋จธ์ง€ ์š”์†Œ๋Š” 0์œผ๋กœ ์„ค์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
์•„๋ž˜ ํ•จ์ˆ˜๊ฐ€ ์œ„์—์„œ ์„ค๋ช…ํ•œ๋Œ€๋กœ ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์— `processor`๋ฅผ ์ ์šฉํ•˜๊ณ  ๋ ˆ์ด๋ธ”์„ ํ˜•์‹์— ๋งž์ถฅ๋‹ˆ๋‹ค:
```py
>>> import torch
>>> def preprocess_data(examples):
... image_paths = examples['image_id']
... images = [Image.open(image_path) for image_path in image_paths]
... texts = examples['question']
... encoding = processor(images, texts, padding="max_length", truncation=True, return_tensors="pt")
... for k, v in encoding.items():
... encoding[k] = v.squeeze()
... targets = []
... for labels, scores in zip(examples['label.ids'], examples['label.weights']):
... target = torch.zeros(len(id2label))
... for label, score in zip(labels, scores):
... target[label] = score
... targets.append(target)
... encoding["labels"] = targets
... return encoding
```
์ „์ฒด ๋ฐ์ดํ„ฐ์„ธํŠธ์— ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜๋ ค๋ฉด ๐Ÿค— Datasets์˜ [`~datasets.map`] ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์‹ญ์‹œ์˜ค. `batched=True`๋ฅผ ์„ค์ •ํ•˜์—ฌ ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ์—ฌ๋Ÿฌ ์š”์†Œ๋ฅผ ํ•œ ๋ฒˆ์— ์ฒ˜๋ฆฌํ•จ์œผ๋กœ์จ `map`์„ ๋” ๋น ๋ฅด๊ฒŒ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์‹œ์ ์—์„œ ํ•„์š”ํ•˜์ง€ ์•Š์€ ์—ด์€ ์ œ๊ฑฐํ•˜์„ธ์š”.
```py
>>> processed_dataset = flat_dataset.map(preprocess_data, batched=True, remove_columns=['question','question_type', 'question_id', 'image_id', 'answer_type', 'label.ids', 'label.weights'])
>>> processed_dataset
Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'],
num_rows: 200
})
```
๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„๋กœ, [`DefaultDataCollator`]๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ œ๋กœ ์“ธ ๋ฐฐ์น˜๋ฅผ ์ƒ์„ฑํ•˜์„ธ์š”:
```py
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator()
```
## ๋ชจ๋ธ ํ›ˆ๋ จ [[train-the-model]]
์ด์ œ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•˜๊ธฐ ์œ„ํ•ด ์ค€๋น„๋˜์—ˆ์Šต๋‹ˆ๋‹ค! [`ViltForQuestionAnswering`]์œผ๋กœ ViLT๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ฐจ๋ก€์ž…๋‹ˆ๋‹ค. ๋ ˆ์ด๋ธ”์˜ ์ˆ˜์™€ ๋ ˆ์ด๋ธ” ๋งคํ•‘์„ ์ง€์ •ํ•˜์„ธ์š”:
```py
>>> from transformers import ViltForQuestionAnswering
>>> model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)
```
์ด ์‹œ์ ์—์„œ๋Š” ๋‹ค์Œ ์„ธ ๋‹จ๊ณ„๋งŒ ๋‚จ์•˜์Šต๋‹ˆ๋‹ค:
1. [`TrainingArguments`]์—์„œ ํ›ˆ๋ จ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ •์˜ํ•˜์„ธ์š”:
```py
>>> from transformers import TrainingArguments
>>> repo_id = "MariaK/vilt_finetuned_200"
>>> training_args = TrainingArguments(
... output_dir=repo_id,
... per_device_train_batch_size=4,
... num_train_epochs=20,
... save_steps=200,
... logging_steps=50,
... learning_rate=5e-5,
... save_total_limit=2,
... remove_unused_columns=False,
... push_to_hub=True,
... )
```
2. ๋ชจ๋ธ, ๋ฐ์ดํ„ฐ์„ธํŠธ, ํ”„๋กœ์„ธ์„œ, ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ์™€ ํ•จ๊ป˜ ํ›ˆ๋ จ ์ธ์ˆ˜๋ฅผ [`Trainer`]์— ์ „๋‹ฌํ•˜์„ธ์š”:
```py
>>> from transformers import Trainer
>>> trainer = Trainer(
... model=model,
... args=training_args,
... data_collator=data_collator,
... train_dataset=processed_dataset,
... processing_class=processor,
... )
```
3. [`~Trainer.train`]์„ ํ˜ธ์ถœํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜์„ธ์š”:
```py
>>> trainer.train()
```
ํ›ˆ๋ จ์ด ์™„๋ฃŒ๋˜๋ฉด, [`~Trainer.push_to_hub`] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๐Ÿค— Hub์— ๋ชจ๋ธ์„ ๊ณต์œ ํ•˜์„ธ์š”:
```py
>>> trainer.push_to_hub()
```
## ์ถ”๋ก  [[inference]]
ViLT ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ณ  ๐Ÿค— Hub์— ์—…๋กœ๋“œํ–ˆ๋‹ค๋ฉด ์ถ”๋ก ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•ด๋ณด๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ [`Pipeline`]์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
```py
>>> from transformers import pipeline
>>> pipe = pipeline("visual-question-answering", model="MariaK/vilt_finetuned_200")
```
์ด ๊ฐ€์ด๋“œ์˜ ๋ชจ๋ธ์€ 200๊ฐœ์˜ ์˜ˆ์ œ์—์„œ๋งŒ ํ›ˆ๋ จ๋˜์—ˆ์œผ๋ฏ€๋กœ ๊ทธ๋‹ค์ง€ ๋งŽ์€ ๊ฒƒ์„ ๊ธฐ๋Œ€ํ•  ์ˆ˜๋Š” ์—†์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ์ฒซ ๋ฒˆ์งธ ์˜ˆ์ œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก  ๊ฒฐ๊ณผ๋ฅผ ์„ค๋ช…ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
```py
>>> example = dataset[0]
>>> image = Image.open(example['image_id'])
>>> question = example['question']
>>> print(question)
>>> pipe(image, question, top_k=1)
"Where is he looking?"
[{'score': 0.5498199462890625, 'answer': 'down'}]
```
๋น„๋ก ํ™•์‹ ์€ ๋ณ„๋กœ ์—†์ง€๋งŒ, ๋ชจ๋ธ์€ ์‹ค์ œ๋กœ ๋ฌด์–ธ๊ฐ€๋ฅผ ๋ฐฐ์› ์Šต๋‹ˆ๋‹ค. ๋” ๋งŽ์€ ์˜ˆ์ œ์™€ ๋” ๊ธด ํ›ˆ๋ จ ๊ธฐ๊ฐ„์ด ์ฃผ์–ด์ง„๋‹ค๋ฉด ๋ถ„๋ช… ๋” ๋‚˜์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค!
์›ํ•œ๋‹ค๋ฉด ํŒŒ์ดํ”„๋ผ์ธ์˜ ๊ฒฐ๊ณผ๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋ณต์ œํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:
1. ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์„ ๊ฐ€์ ธ์™€์„œ ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์— ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค.
2. ์ „์ฒ˜๋ฆฌ๋œ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ๋ธ์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
3. ๋กœ์ง“์—์„œ ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ ์žˆ๋Š” ๋‹ต๋ณ€์˜ id๋ฅผ ๊ฐ€์ ธ์™€์„œ `id2label`์—์„œ ์‹ค์ œ ๋‹ต๋ณ€์„ ์ฐพ์Šต๋‹ˆ๋‹ค.
```py
>>> processor = ViltProcessor.from_pretrained("MariaK/vilt_finetuned_200")
>>> image = Image.open(example['image_id'])
>>> question = example['question']
>>> # prepare inputs
>>> inputs = processor(image, question, return_tensors="pt")
>>> model = ViltForQuestionAnswering.from_pretrained("MariaK/vilt_finetuned_200")
>>> # forward pass
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits = outputs.logits
>>> idx = logits.argmax(-1).item()
>>> print("Predicted answer:", model.config.id2label[idx])
Predicted answer: down
```
## ์ œ๋กœ์ƒท VQA [[zeroshot-vqa]]
์ด์ „ ๋ชจ๋ธ์€ VQA๋ฅผ ๋ถ„๋ฅ˜ ๋ฌธ์ œ๋กœ ์ฒ˜๋ฆฌํ–ˆ์Šต๋‹ˆ๋‹ค. BLIP, BLIP-2 ๋ฐ InstructBLIP์™€ ๊ฐ™์€ ์ตœ๊ทผ์˜ ๋ชจ๋ธ์€ VQA๋ฅผ ์ƒ์„ฑ ์ž‘์—…์œผ๋กœ ์ ‘๊ทผํ•ฉ๋‹ˆ๋‹ค. [BLIP-2](../../en/model_doc/blip-2)๋ฅผ ์˜ˆ๋กœ ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ์‚ฌ์ „ํ›ˆ๋ จ๋œ ๋น„์ „ ์ธ์ฝ”๋”์™€ LLM์˜ ๋ชจ๋“  ์กฐํ•ฉ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ์ƒˆ๋กœ์šด ๋น„์ „-์ž์—ฐ์–ด ์‚ฌ์ „ ํ•™์Šต ํŒจ๋Ÿฌ๋‹ค์ž„์„ ๋„์ž…ํ–ˆ์Šต๋‹ˆ๋‹ค. ([BLIP-2 ๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ](https://huggingface.co/blog/blip-2)๋ฅผ ํ†ตํ•ด ๋” ์ž์„ธํžˆ ์•Œ์•„๋ณผ ์ˆ˜ ์žˆ์–ด์š”)
์ด๋ฅผ ํ†ตํ•ด ์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต์„ ํฌํ•จํ•œ ์—ฌ๋Ÿฌ ๋น„์ „-์ž์—ฐ์–ด ์ž‘์—…์—์„œ SOTA๋ฅผ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.
์ด ๋ชจ๋ธ์„ ์–ด๋–ป๊ฒŒ VQA์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ์„ค๋ช…ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋จผ์ € ๋ชจ๋ธ์„ ๊ฐ€์ ธ์™€ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ GPU๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ ๋ชจ๋ธ์„ ๋ช…์‹œ์ ์œผ๋กœ GPU๋กœ ์ „์†กํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด์ „์—๋Š” ํ›ˆ๋ จํ•  ๋•Œ ์“ฐ์ง€ ์•Š์€ ์ด์œ ๋Š” [`Trainer`]๊ฐ€ ์ด ๋ถ€๋ถ„์„ ์ž๋™์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค:
```py
>>> from transformers import AutoProcessor, Blip2ForConditionalGeneration
>>> import torch
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", dtype=torch.float16)
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model.to(device)
```
๋ชจ๋ธ์€ ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์œผ๋ฏ€๋กœ, VQA ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ์ฒซ ๋ฒˆ์งธ ์˜ˆ์ œ์—์„œ์™€ ๋™์ผํ•œ ์ด๋ฏธ์ง€/์งˆ๋ฌธ ์Œ์„ ์‚ฌ์šฉํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
```py
>>> example = dataset[0]
>>> image = Image.open(example['image_id'])
>>> question = example['question']
```
BLIP-2๋ฅผ ์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต ์ž‘์—…์— ์‚ฌ์šฉํ•˜๋ ค๋ฉด ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๊ฐ€ `Question: {} Answer:` ํ˜•์‹์„ ๋”ฐ๋ผ์•ผ ํ•ฉ๋‹ˆ๋‹ค.
```py
>>> prompt = f"Question: {question} Answer:"
```
์ด์ œ ๋ชจ๋ธ์˜ ํ”„๋กœ์„ธ์„œ๋กœ ์ด๋ฏธ์ง€/ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๊ณ , ์ฒ˜๋ฆฌ๋œ ์ž…๋ ฅ์„ ๋ชจ๋ธ์„ ํ†ตํ•ด ์ „๋‹ฌํ•˜๊ณ , ์ถœ๋ ฅ์„ ๋””์ฝ”๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:
```py
>>> inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
>>> generated_ids = model.generate(**inputs, max_new_tokens=10)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
>>> print(generated_text)
"He is looking at the crowd"
```
๋ณด์‹œ๋‹ค์‹œํ”ผ ๋ชจ๋ธ์€ ๊ตฐ์ค‘์„ ์ธ์‹ํ•˜๊ณ , ์–ผ๊ตด์˜ ๋ฐฉํ–ฅ(์•„๋ž˜์ชฝ์„ ๋ณด๊ณ  ์žˆ์Œ)์„ ์ธ์‹ํ–ˆ์ง€๋งŒ, ๊ตฐ์ค‘์ด ์Šค์ผ€์ดํ„ฐ ๋’ค์— ์žˆ๋‹ค๋Š” ์‚ฌ์‹ค์„ ๋†“์ณค์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์‚ฌ๋žŒ์ด ์ง์ ‘ ๋ผ๋ฒจ๋งํ•œ ๋ฐ์ดํ„ฐ์…‹์„ ์–ป์„ ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ์—, ์ด ์ ‘๊ทผ๋ฒ•์€ ๋น ๋ฅด๊ฒŒ ์œ ์šฉํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.