| <!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # ์๊ฐ์ ์ง์์๋ต (Visual Question Answering) | |
| [[open-in-colab]] | |
| ์๊ฐ์ ์ง์์๋ต(VQA)์ ์ด๋ฏธ์ง๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ๋ฐฉํ ์ง๋ฌธ์ ๋์ํ๋ ์์ ์ ๋๋ค. ์ด ์์ ์ ์ง์ํ๋ ๋ชจ๋ธ์ ์ ๋ ฅ์ ๋๋ถ๋ถ ์ด๋ฏธ์ง์ ์ง๋ฌธ์ ์กฐํฉ์ด๋ฉฐ, ์ถ๋ ฅ์ ์์ฐ์ด๋ก ๋ ๋ต๋ณ์ ๋๋ค. | |
| VQA์ ์ฃผ์ ์ฌ์ฉ ์ฌ๋ก๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
| * ์๊ฐ ์ฅ์ ์ธ์ ์ํ ์ ๊ทผ์ฑ ์ ํ๋ฆฌ์ผ์ด์ ์ ๊ตฌ์ถํ ์ ์์ต๋๋ค. | |
| * ๊ต์ก: ๊ฐ์๋ ๊ต๊ณผ์์ ๋์จ ์๊ฐ ์๋ฃ์ ๋ํ ์ง๋ฌธ์ ๋ตํ ์ ์์ต๋๋ค. ๋ํ ์ฒดํํ ์ ์์ ์ ์ ๋ฑ์์๋ VQA๋ฅผ ํ์ฉํ ์ ์์ต๋๋ค. | |
| * ๊ณ ๊ฐ ์๋น์ค ๋ฐ ์ ์์๊ฑฐ๋: VQA๋ ์ฌ์ฉ์๊ฐ ์ ํ์ ๋ํด ์ง๋ฌธํ ์ ์๊ฒ ํจ์ผ๋ก์จ ์ฌ์ฉ์ ๊ฒฝํ์ ํฅ์์ํฌ ์ ์์ต๋๋ค. | |
| * ์ด๋ฏธ์ง ๊ฒ์: VQA ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ํ๋ ํน์ฑ์ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ ๊ฒ์ํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด ์ฌ์ฉ์๋ "๊ฐ์์ง๊ฐ ์์ด?"๋ผ๊ณ ๋ฌผ์ด๋ด์ ์ฃผ์ด์ง ์ด๋ฏธ์ง ๋ฌถ์์์ ๊ฐ์์ง๊ฐ ์๋ ๋ชจ๋ ์ด๋ฏธ์ง๋ฅผ ๋ฐ์๋ณผ ์ ์์ต๋๋ค. | |
| ์ด ๊ฐ์ด๋์์ ํ์ตํ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
| - VQA ๋ชจ๋ธ ์ค ํ๋์ธ [ViLT](../../en/model_doc/vilt)๋ฅผ [`Graphcore/vqa` ๋ฐ์ดํฐ์ ](https://huggingface.co/datasets/Graphcore/vqa) ์์ ๋ฏธ์ธ์กฐ์ ํ๋ ๋ฐฉ๋ฒ | |
| - ๋ฏธ์ธ์กฐ์ ๋ ViLT ๋ชจ๋ธ๋ก ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ | |
| - BLIP-2 ๊ฐ์ ์์ฑ ๋ชจ๋ธ๋ก ์ ๋ก์ท VQA ์ถ๋ก ์ ์คํํ๋ ๋ฐฉ๋ฒ | |
| ## ViLT ๋ฏธ์ธ ์กฐ์ [[finetuning-vilt]] | |
| ViLT๋ Vision Transformer (ViT) ๋ด์ ํ ์คํธ ์๋ฒ ๋ฉ์ ํฌํจํ์ฌ ๋น์ /์์ฐ์ด ์ฌ์ ํ๋ จ(VLP; Vision-and-Language Pretraining)์ ์ํ ๊ธฐ๋ณธ ๋์์ธ์ ์ ๊ณตํฉ๋๋ค. | |
| ViLT ๋ชจ๋ธ์ ๋น์ ํธ๋์คํฌ๋จธ(ViT)์ ํ ์คํธ ์๋ฒ ๋ฉ์ ๋ฃ์ด ๋น์ /์ธ์ด ์ฌ์ ํ๋ จ(VLP; Vision-and-Language Pre-training)์ ์ํ ๊ธฐ๋ณธ์ ์ธ ๋์์ธ์ ๊ฐ์ท์ต๋๋ค. ์ด ๋ชจ๋ธ์ ์ฌ๋ฌ ๋ค์ด์คํธ๋ฆผ ์์ ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. VQA ํ์คํฌ์์๋ (`[CLS]` ํ ํฐ์ ์ต์ข ์๋ ์ํ ์์ ์ ํ ๋ ์ด์ด์ธ) ๋ถ๋ฅ ํค๋๊ฐ ์์ผ๋ฉฐ ๋ฌด์์๋ก ์ด๊ธฐํ๋ฉ๋๋ค. | |
| ๋ฐ๋ผ์ ์ฌ๊ธฐ์์ ์๊ฐ์ ์ง์์๋ต์ **๋ถ๋ฅ ๋ฌธ์ **๋ก ์ทจ๊ธ๋ฉ๋๋ค. | |
| ์ต๊ทผ์ BLIP, BLIP-2, InstructBLIP์ ๊ฐ์ ๋ชจ๋ธ๋ค์ VQA๋ฅผ ์์ฑํ ์์ ์ผ๋ก ๊ฐ์ฃผํฉ๋๋ค. ๊ฐ์ด๋์ ํ๋ฐ๋ถ์์๋ ์ด๋ฐ ๋ชจ๋ธ๋ค์ ์ฌ์ฉํ์ฌ ์ ๋ก์ท VQA ์ถ๋ก ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ค๋ช ํ๊ฒ ์ต๋๋ค. | |
| ์์ํ๊ธฐ ์ ํ์ํ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํ๋์ง ํ์ธํ์ธ์. | |
| ```bash | |
| pip install -q transformers datasets | |
| ``` | |
| ์ปค๋ฎค๋ํฐ์ ๋ชจ๋ธ์ ๊ณต์ ํ๋ ๊ฒ์ ๊ถ์ฅ ๋๋ฆฝ๋๋ค. Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ์ฌ ๐ค Hub์ ์ ๋ก๋ํ ์ ์์ต๋๋ค. | |
| ๋ฉ์์ง๊ฐ ๋ํ๋๋ฉด ๋ก๊ทธ์ธํ ํ ํฐ์ ์ ๋ ฅํ์ธ์: | |
| ```py | |
| >>> from huggingface_hub import notebook_login | |
| >>> notebook_login() | |
| ``` | |
| ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ญ ๋ณ์๋ก ์ ์ธํ์ธ์. | |
| ```py | |
| >>> model_checkpoint = "dandelin/vilt-b32-mlm" | |
| ``` | |
| ## ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ [[load-the-data]] | |
| ์ด ๊ฐ์ด๋์์๋ `Graphcore/vqa` ๋ฐ์ดํฐ์ธํธ์ ์์ ์ํ์ ์ฌ์ฉํฉ๋๋ค. ์ ์ฒด ๋ฐ์ดํฐ์ธํธ๋ [๐ค Hub](https://huggingface.co/datasets/Graphcore/vqa) ์์ ํ์ธํ ์ ์์ต๋๋ค. | |
| [`Graphcore/vqa` ๋ฐ์ดํฐ์ธํธ](https://huggingface.co/datasets/Graphcore/vqa) ์ ๋์์ผ๋ก ๊ณต์ [VQA ๋ฐ์ดํฐ์ธํธ ํ์ด์ง](https://visualqa.org/download.html) ์์ ๋์ผํ ๋ฐ์ดํฐ๋ฅผ ์๋์ผ๋ก ๋ค์ด๋ก๋ํ ์ ์์ต๋๋ค. ์ง์ ๊ณต์ํ ๋ฐ์ดํฐ๋ก ํํ ๋ฆฌ์ผ์ ๋ฐ๋ฅด๊ณ ์ถ๋ค๋ฉด [์ด๋ฏธ์ง ๋ฐ์ดํฐ์ธํธ ๋ง๋ค๊ธฐ](https://huggingface.co/docs/datasets/image_dataset#loading-script) ๋ผ๋ | |
| ๐ค Datasets ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ธ์. | |
| ๊ฒ์ฆ ๋ฐ์ดํฐ์ ์ฒซ 200๊ฐ ํญ๋ชฉ์ ๋ถ๋ฌ์ ๋ฐ์ดํฐ์ธํธ์ ํน์ฑ์ ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค: | |
| ```python | |
| >>> from datasets import load_dataset | |
| >>> dataset = load_dataset("Graphcore/vqa", split="validation[:200]") | |
| >>> dataset | |
| Dataset({ | |
| features: ['question', 'question_type', 'question_id', 'image_id', 'answer_type', 'label'], | |
| num_rows: 200 | |
| }) | |
| ``` | |
| ์์ ๋ฅผ ํ๋ ๋ฝ์ ๋ฐ์ดํฐ์ธํธ์ ํน์ฑ์ ์ดํดํด ๋ณด๊ฒ ์ต๋๋ค. | |
| ```py | |
| >>> dataset[0] | |
| {'question': 'Where is he looking?', | |
| 'question_type': 'none of the above', | |
| 'question_id': 262148000, | |
| 'image_id': '/root/.cache/huggingface/datasets/downloads/extracted/ca733e0e000fb2d7a09fbcc94dbfe7b5a30750681d0e965f8e0a23b1c2f98c75/val2014/COCO_val2014_000000262148.jpg', | |
| 'answer_type': 'other', | |
| 'label': {'ids': ['at table', 'down', 'skateboard', 'table'], | |
| 'weights': [0.30000001192092896, | |
| 1.0, | |
| 0.30000001192092896, | |
| 0.30000001192092896]}} | |
| ``` | |
| ๋ฐ์ดํฐ์ธํธ์๋ ๋ค์๊ณผ ๊ฐ์ ํน์ฑ์ด ํฌํจ๋์ด ์์ต๋๋ค: | |
| * `question`: ์ด๋ฏธ์ง์ ๋ํ ์ง๋ฌธ | |
| * `image_id`: ์ง๋ฌธ๊ณผ ๊ด๋ จ๋ ์ด๋ฏธ์ง์ ๊ฒฝ๋ก | |
| * `label`: ๋ฐ์ดํฐ์ ๋ ์ด๋ธ (annotations) | |
| ๋๋จธ์ง ํน์ฑ๋ค์ ํ์ํ์ง ์๊ธฐ ๋๋ฌธ์ ์ญ์ ํด๋ ๋ฉ๋๋ค: | |
| ```py | |
| >>> dataset = dataset.remove_columns(['question_type', 'question_id', 'answer_type']) | |
| ``` | |
| ๋ณด์๋ค์ํผ `label` ํน์ฑ์ ๊ฐ์ ์ง๋ฌธ๋ง๋ค ๋ต๋ณ์ด ์ฌ๋ฌ ๊ฐ ์์ ์ ์์ต๋๋ค. ๋ชจ๋ ๋ค๋ฅธ ๋ฐ์ดํฐ ๋ผ๋ฒจ๋ฌ๋ค๋ก๋ถํฐ ์์ง๋์๊ธฐ ๋๋ฌธ์ธ๋ฐ์. ์ง๋ฌธ์ ๋ต๋ณ์ ์ฃผ๊ด์ ์ผ ์ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ ์ง๋ฌธ์ "๊ทธ๋ ์ด๋๋ฅผ ๋ณด๊ณ ์๋์?" ์์ง๋ง, ์ด๋ค ์ฌ๋๋ค์ "์๋"๋ก ๋ ์ด๋ธ์ ๋ฌ์๊ณ , ๋ค๋ฅธ ์ฌ๋๋ค์ "ํ ์ด๋ธ" ๋๋ "์ค์ผ์ดํธ๋ณด๋" ๋ฑ์ผ๋ก ์ฃผ์์ ๋ฌ์์ต๋๋ค. | |
| ์๋์ ์ด๋ฏธ์ง๋ฅผ ๋ณด๊ณ ์ด๋ค ๋ต๋ณ์ ์ ํํ ๊ฒ์ธ์ง ์๊ฐํด ๋ณด์ธ์: | |
| ```python | |
| >>> from PIL import Image | |
| >>> image = Image.open(dataset[0]['image_id']) | |
| >>> image | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/vqa-example.png" alt="VQA Image Example"/> | |
| </div> | |
| ์ง๋ฌธ๊ณผ ๋ต๋ณ์ ๋ชจํธ์ฑ์ผ๋ก ์ธํด ์ด๋ฌํ ๋ฐ์ดํฐ์ธํธ๋ ์ฌ๋ฌ ๊ฐ์ ๋ต๋ณ์ด ๊ฐ๋ฅํ๋ฏ๋ก ๋ค์ค ๋ ์ด๋ธ ๋ถ๋ฅ ๋ฌธ์ ๋ก ์ฒ๋ฆฌ๋ฉ๋๋ค. ๊ฒ๋ค๊ฐ, ์ํซ(one-hot) ์ธ์ฝ๋ฉ ๋ฒกํฐ๋ฅผ ์์ฑํ๊ธฐ๋ณด๋ค๋ ๋ ์ด๋ธ์์ ํน์ ๋ต๋ณ์ด ๋ํ๋๋ ํ์๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ํํธ ์ธ์ฝ๋ฉ์ ์์ฑํฉ๋๋ค. | |
| ์์ ์์์์ "์๋"๋ผ๋ ๋ต๋ณ์ด ๋ค๋ฅธ ๋ต๋ณ๋ณด๋ค ํจ์ฌ ๋ ์์ฃผ ์ ํ๋์๊ธฐ ๋๋ฌธ์ ๋ฐ์ดํฐ์ธํธ์์ `weight`๋ผ๊ณ ๋ถ๋ฆฌ๋ ์ ์๋ก 1.0์ ๊ฐ์ง๋ฉฐ, ๋๋จธ์ง ๋ต๋ณ๋ค์ 1.0 ๋ฏธ๋ง์ ์ ์๋ฅผ ๊ฐ์ง๋๋ค. | |
| ์ ์ ํ ๋ถ๋ฅ ํค๋๋ก ๋ชจ๋ธ์ ๋์ค์ ์ธ์คํด์คํํ๊ธฐ ์ํด ๋ ์ด๋ธ์ ์ ์๋ก ๋งคํํ ๋์ ๋๋ฆฌ ํ๋, ๋ฐ๋๋ก ์ ์๋ฅผ ๋ ์ด๋ธ๋ก ๋งคํํ ๋์ ๋๋ฆฌ ํ๋ ์ด 2๊ฐ์ ๋์ ๋๋ฆฌ๋ฅผ ์์ฑํ์ธ์: | |
| ```py | |
| >>> import itertools | |
| >>> labels = [item['ids'] for item in dataset['label']] | |
| >>> flattened_labels = list(itertools.chain(*labels)) | |
| >>> unique_labels = list(set(flattened_labels)) | |
| >>> label2id = {label: idx for idx, label in enumerate(unique_labels)} | |
| >>> id2label = {idx: label for label, idx in label2id.items()} | |
| ``` | |
| ์ด์ ๋งคํ์ด ์๋ฃ๋์์ผ๋ฏ๋ก ๋ฌธ์์ด ๋ต๋ณ์ ํด๋น id๋ก ๊ต์ฒดํ๊ณ , ๋ฐ์ดํฐ์ธํธ์ ๋ ํธ๋ฆฌํ ํ์ฒ๋ฆฌ๋ฅผ ์ํด ํธํํ ํ ์ ์์ต๋๋ค. | |
| ```python | |
| >>> def replace_ids(inputs): | |
| ... inputs["label"]["ids"] = [label2id[x] for x in inputs["label"]["ids"]] | |
| ... return inputs | |
| >>> dataset = dataset.map(replace_ids) | |
| >>> flat_dataset = dataset.flatten() | |
| >>> flat_dataset.features | |
| {'question': Value(dtype='string', id=None), | |
| 'image_id': Value(dtype='string', id=None), | |
| 'label.ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), | |
| 'label.weights': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None)} | |
| ``` | |
| ## ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ [[preprocessing-data]] | |
| ๋ค์ ๋จ๊ณ๋ ๋ชจ๋ธ์ ์ํด ์ด๋ฏธ์ง์ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ์ค๋นํ๊ธฐ ์ํด ViLT ํ๋ก์ธ์๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ์ ๋๋ค. | |
| [`ViltProcessor`]๋ BERT ํ ํฌ๋์ด์ ์ ViLT ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ํธ๋ฆฌํ๊ฒ ํ๋์ ํ๋ก์ธ์๋ก ๋ฌถ์ต๋๋ค: | |
| ```py | |
| >>> from transformers import ViltProcessor | |
| >>> processor = ViltProcessor.from_pretrained(model_checkpoint) | |
| ``` | |
| ๋ฐ์ดํฐ๋ฅผ ์ ์ฒ๋ฆฌํ๋ ค๋ฉด ์ด๋ฏธ์ง์ ์ง๋ฌธ์ [`ViltProcessor`]๋ก ์ธ์ฝ๋ฉํด์ผ ํฉ๋๋ค. ํ๋ก์ธ์๋ [`BertTokenizerFast`]๋ก ํ ์คํธ๋ฅผ ํ ํฌ๋์ด์ฆํ๊ณ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ์ํด `input_ids`, `attention_mask` ๋ฐ `token_type_ids`๋ฅผ ์์ฑํฉ๋๋ค. | |
| ์ด๋ฏธ์ง๋ [`ViltImageProcessor`]๋ก ์ด๋ฏธ์ง๋ฅผ ํฌ๊ธฐ ์กฐ์ ํ๊ณ ์ ๊ทํํ๋ฉฐ, `pixel_values`์ `pixel_mask`๋ฅผ ์์ฑํฉ๋๋ค. | |
| ์ด๋ฐ ์ ์ฒ๋ฆฌ ๋จ๊ณ๋ ๋ชจ๋ ๋ด๋ถ์์ ์ด๋ฃจ์ด์ง๋ฏ๋ก, `processor`๋ฅผ ํธ์ถํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค. ํ์ง๋ง ์์ง ํ๊ฒ ๋ ์ด๋ธ์ด ์์ฑ๋์ง ์์์ต๋๋ค. ํ๊ฒ์ ํํ์์ ๊ฐ ์์๋ ๊ฐ๋ฅํ ๋ต๋ณ(๋ ์ด๋ธ)์ ํด๋นํฉ๋๋ค. ์ ํํ ๋ต๋ณ์ ์์๋ ํด๋น ์ ์(weight)๋ฅผ ์ ์ง์ํค๊ณ ๋๋จธ์ง ์์๋ 0์ผ๋ก ์ค์ ํด์ผ ํฉ๋๋ค. | |
| ์๋ ํจ์๊ฐ ์์์ ์ค๋ช ํ๋๋ก ์ด๋ฏธ์ง์ ์ง๋ฌธ์ `processor`๋ฅผ ์ ์ฉํ๊ณ ๋ ์ด๋ธ์ ํ์์ ๋ง์ถฅ๋๋ค: | |
| ```py | |
| >>> import torch | |
| >>> def preprocess_data(examples): | |
| ... image_paths = examples['image_id'] | |
| ... images = [Image.open(image_path) for image_path in image_paths] | |
| ... texts = examples['question'] | |
| ... encoding = processor(images, texts, padding="max_length", truncation=True, return_tensors="pt") | |
| ... for k, v in encoding.items(): | |
| ... encoding[k] = v.squeeze() | |
| ... targets = [] | |
| ... for labels, scores in zip(examples['label.ids'], examples['label.weights']): | |
| ... target = torch.zeros(len(id2label)) | |
| ... for label, score in zip(labels, scores): | |
| ... target[label] = score | |
| ... targets.append(target) | |
| ... encoding["labels"] = targets | |
| ... return encoding | |
| ``` | |
| ์ ์ฒด ๋ฐ์ดํฐ์ธํธ์ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets์ [`~datasets.map`] ํจ์๋ฅผ ์ฌ์ฉํ์ญ์์ค. `batched=True`๋ฅผ ์ค์ ํ์ฌ ๋ฐ์ดํฐ์ธํธ์ ์ฌ๋ฌ ์์๋ฅผ ํ ๋ฒ์ ์ฒ๋ฆฌํจ์ผ๋ก์จ `map`์ ๋ ๋น ๋ฅด๊ฒ ํ ์ ์์ต๋๋ค. ์ด ์์ ์์ ํ์ํ์ง ์์ ์ด์ ์ ๊ฑฐํ์ธ์. | |
| ```py | |
| >>> processed_dataset = flat_dataset.map(preprocess_data, batched=True, remove_columns=['question','question_type', 'question_id', 'image_id', 'answer_type', 'label.ids', 'label.weights']) | |
| >>> processed_dataset | |
| Dataset({ | |
| features: ['input_ids', 'token_type_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'], | |
| num_rows: 200 | |
| }) | |
| ``` | |
| ๋ง์ง๋ง ๋จ๊ณ๋ก, [`DefaultDataCollator`]๋ฅผ ์ฌ์ฉํ์ฌ ์์ ๋ก ์ธ ๋ฐฐ์น๋ฅผ ์์ฑํ์ธ์: | |
| ```py | |
| >>> from transformers import DefaultDataCollator | |
| >>> data_collator = DefaultDataCollator() | |
| ``` | |
| ## ๋ชจ๋ธ ํ๋ จ [[train-the-model]] | |
| ์ด์ ๋ชจ๋ธ์ ํ๋ จํ๊ธฐ ์ํด ์ค๋น๋์์ต๋๋ค! [`ViltForQuestionAnswering`]์ผ๋ก ViLT๋ฅผ ๊ฐ์ ธ์ฌ ์ฐจ๋ก์ ๋๋ค. ๋ ์ด๋ธ์ ์์ ๋ ์ด๋ธ ๋งคํ์ ์ง์ ํ์ธ์: | |
| ```py | |
| >>> from transformers import ViltForQuestionAnswering | |
| >>> model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id) | |
| ``` | |
| ์ด ์์ ์์๋ ๋ค์ ์ธ ๋จ๊ณ๋ง ๋จ์์ต๋๋ค: | |
| 1. [`TrainingArguments`]์์ ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ํ์ธ์: | |
| ```py | |
| >>> from transformers import TrainingArguments | |
| >>> repo_id = "MariaK/vilt_finetuned_200" | |
| >>> training_args = TrainingArguments( | |
| ... output_dir=repo_id, | |
| ... per_device_train_batch_size=4, | |
| ... num_train_epochs=20, | |
| ... save_steps=200, | |
| ... logging_steps=50, | |
| ... learning_rate=5e-5, | |
| ... save_total_limit=2, | |
| ... remove_unused_columns=False, | |
| ... push_to_hub=True, | |
| ... ) | |
| ``` | |
| 2. ๋ชจ๋ธ, ๋ฐ์ดํฐ์ธํธ, ํ๋ก์ธ์, ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ์ ํจ๊ป ํ๋ จ ์ธ์๋ฅผ [`Trainer`]์ ์ ๋ฌํ์ธ์: | |
| ```py | |
| >>> from transformers import Trainer | |
| >>> trainer = Trainer( | |
| ... model=model, | |
| ... args=training_args, | |
| ... data_collator=data_collator, | |
| ... train_dataset=processed_dataset, | |
| ... processing_class=processor, | |
| ... ) | |
| ``` | |
| 3. [`~Trainer.train`]์ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ธ์: | |
| ```py | |
| >>> trainer.train() | |
| ``` | |
| ํ๋ จ์ด ์๋ฃ๋๋ฉด, [`~Trainer.push_to_hub`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๐ค Hub์ ๋ชจ๋ธ์ ๊ณต์ ํ์ธ์: | |
| ```py | |
| >>> trainer.push_to_hub() | |
| ``` | |
| ## ์ถ๋ก [[inference]] | |
| ViLT ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ณ ๐ค Hub์ ์ ๋ก๋ํ๋ค๋ฉด ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ถ๋ก ์ ์ฌ์ฉํด๋ณด๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ [`Pipeline`]์์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. | |
| ```py | |
| >>> from transformers import pipeline | |
| >>> pipe = pipeline("visual-question-answering", model="MariaK/vilt_finetuned_200") | |
| ``` | |
| ์ด ๊ฐ์ด๋์ ๋ชจ๋ธ์ 200๊ฐ์ ์์ ์์๋ง ํ๋ จ๋์์ผ๋ฏ๋ก ๊ทธ๋ค์ง ๋ง์ ๊ฒ์ ๊ธฐ๋ํ ์๋ ์์ต๋๋ค. ๋ฐ์ดํฐ์ธํธ์ ์ฒซ ๋ฒ์งธ ์์ ๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ค๋ช ํด๋ณด๊ฒ ์ต๋๋ค: | |
| ```py | |
| >>> example = dataset[0] | |
| >>> image = Image.open(example['image_id']) | |
| >>> question = example['question'] | |
| >>> print(question) | |
| >>> pipe(image, question, top_k=1) | |
| "Where is he looking?" | |
| [{'score': 0.5498199462890625, 'answer': 'down'}] | |
| ``` | |
| ๋น๋ก ํ์ ์ ๋ณ๋ก ์์ง๋ง, ๋ชจ๋ธ์ ์ค์ ๋ก ๋ฌด์ธ๊ฐ๋ฅผ ๋ฐฐ์ ์ต๋๋ค. ๋ ๋ง์ ์์ ์ ๋ ๊ธด ํ๋ จ ๊ธฐ๊ฐ์ด ์ฃผ์ด์ง๋ค๋ฉด ๋ถ๋ช ๋ ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์์ ๊ฒ์ ๋๋ค! | |
| ์ํ๋ค๋ฉด ํ์ดํ๋ผ์ธ์ ๊ฒฐ๊ณผ๋ฅผ ์๋์ผ๋ก ๋ณต์ ํ ์๋ ์์ต๋๋ค: | |
| 1. ์ด๋ฏธ์ง์ ์ง๋ฌธ์ ๊ฐ์ ธ์์ ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ค๋นํฉ๋๋ค. | |
| 2. ์ ์ฒ๋ฆฌ๋ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ๋ธ์ ์ ๋ฌํฉ๋๋ค. | |
| 3. ๋ก์ง์์ ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ ์๋ ๋ต๋ณ์ id๋ฅผ ๊ฐ์ ธ์์ `id2label`์์ ์ค์ ๋ต๋ณ์ ์ฐพ์ต๋๋ค. | |
| ```py | |
| >>> processor = ViltProcessor.from_pretrained("MariaK/vilt_finetuned_200") | |
| >>> image = Image.open(example['image_id']) | |
| >>> question = example['question'] | |
| >>> # prepare inputs | |
| >>> inputs = processor(image, question, return_tensors="pt") | |
| >>> model = ViltForQuestionAnswering.from_pretrained("MariaK/vilt_finetuned_200") | |
| >>> # forward pass | |
| >>> with torch.no_grad(): | |
| ... outputs = model(**inputs) | |
| >>> logits = outputs.logits | |
| >>> idx = logits.argmax(-1).item() | |
| >>> print("Predicted answer:", model.config.id2label[idx]) | |
| Predicted answer: down | |
| ``` | |
| ## ์ ๋ก์ท VQA [[zeroshot-vqa]] | |
| ์ด์ ๋ชจ๋ธ์ VQA๋ฅผ ๋ถ๋ฅ ๋ฌธ์ ๋ก ์ฒ๋ฆฌํ์ต๋๋ค. BLIP, BLIP-2 ๋ฐ InstructBLIP์ ๊ฐ์ ์ต๊ทผ์ ๋ชจ๋ธ์ VQA๋ฅผ ์์ฑ ์์ ์ผ๋ก ์ ๊ทผํฉ๋๋ค. [BLIP-2](../../en/model_doc/blip-2)๋ฅผ ์๋ก ๋ค์ด ๋ณด๊ฒ ์ต๋๋ค. ์ด ๋ชจ๋ธ์ ์ฌ์ ํ๋ จ๋ ๋น์ ์ธ์ฝ๋์ LLM์ ๋ชจ๋ ์กฐํฉ์ ์ฌ์ฉํ ์ ์๋ ์๋ก์ด ๋น์ -์์ฐ์ด ์ฌ์ ํ์ต ํจ๋ฌ๋ค์์ ๋์ ํ์ต๋๋ค. ([BLIP-2 ๋ธ๋ก๊ทธ ํฌ์คํธ](https://huggingface.co/blog/blip-2)๋ฅผ ํตํด ๋ ์์ธํ ์์๋ณผ ์ ์์ด์) | |
| ์ด๋ฅผ ํตํด ์๊ฐ์ ์ง์์๋ต์ ํฌํจํ ์ฌ๋ฌ ๋น์ -์์ฐ์ด ์์ ์์ SOTA๋ฅผ ๋ฌ์ฑํ ์ ์์์ต๋๋ค. | |
| ์ด ๋ชจ๋ธ์ ์ด๋ป๊ฒ VQA์ ์ฌ์ฉํ ์ ์๋์ง ์ค๋ช ํด ๋ณด๊ฒ ์ต๋๋ค. ๋จผ์ ๋ชจ๋ธ์ ๊ฐ์ ธ์ ๋ณด๊ฒ ์ต๋๋ค. ์ฌ๊ธฐ์ GPU๊ฐ ์ฌ์ฉ ๊ฐ๋ฅํ ๊ฒฝ์ฐ ๋ชจ๋ธ์ ๋ช ์์ ์ผ๋ก GPU๋ก ์ ์กํ ๊ฒ์ ๋๋ค. ์ด์ ์๋ ํ๋ จํ ๋ ์ฐ์ง ์์ ์ด์ ๋ [`Trainer`]๊ฐ ์ด ๋ถ๋ถ์ ์๋์ผ๋ก ์ฒ๋ฆฌํ๊ธฐ ๋๋ฌธ์ ๋๋ค: | |
| ```py | |
| >>> from transformers import AutoProcessor, Blip2ForConditionalGeneration | |
| >>> import torch | |
| >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
| >>> model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", dtype=torch.float16) | |
| >>> device = "cuda" if torch.cuda.is_available() else "cpu" | |
| >>> model.to(device) | |
| ``` | |
| ๋ชจ๋ธ์ ์ด๋ฏธ์ง์ ํ ์คํธ๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์ผ๋ฏ๋ก, VQA ๋ฐ์ดํฐ์ธํธ์ ์ฒซ ๋ฒ์งธ ์์ ์์์ ๋์ผํ ์ด๋ฏธ์ง/์ง๋ฌธ ์์ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค: | |
| ```py | |
| >>> example = dataset[0] | |
| >>> image = Image.open(example['image_id']) | |
| >>> question = example['question'] | |
| ``` | |
| BLIP-2๋ฅผ ์๊ฐ์ ์ง์์๋ต ์์ ์ ์ฌ์ฉํ๋ ค๋ฉด ํ ์คํธ ํ๋กฌํํธ๊ฐ `Question: {} Answer:` ํ์์ ๋ฐ๋ผ์ผ ํฉ๋๋ค. | |
| ```py | |
| >>> prompt = f"Question: {question} Answer:" | |
| ``` | |
| ์ด์ ๋ชจ๋ธ์ ํ๋ก์ธ์๋ก ์ด๋ฏธ์ง/ํ๋กฌํํธ๋ฅผ ์ ์ฒ๋ฆฌํ๊ณ , ์ฒ๋ฆฌ๋ ์ ๋ ฅ์ ๋ชจ๋ธ์ ํตํด ์ ๋ฌํ๊ณ , ์ถ๋ ฅ์ ๋์ฝ๋ํด์ผ ํฉ๋๋ค: | |
| ```py | |
| >>> inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16) | |
| >>> generated_ids = model.generate(**inputs, max_new_tokens=10) | |
| >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| >>> print(generated_text) | |
| "He is looking at the crowd" | |
| ``` | |
| ๋ณด์๋ค์ํผ ๋ชจ๋ธ์ ๊ตฐ์ค์ ์ธ์ํ๊ณ , ์ผ๊ตด์ ๋ฐฉํฅ(์๋์ชฝ์ ๋ณด๊ณ ์์)์ ์ธ์ํ์ง๋ง, ๊ตฐ์ค์ด ์ค์ผ์ดํฐ ๋ค์ ์๋ค๋ ์ฌ์ค์ ๋์ณค์ต๋๋ค. ๊ทธ๋ฌ๋ ์ฌ๋์ด ์ง์ ๋ผ๋ฒจ๋งํ ๋ฐ์ดํฐ์ ์ ์ป์ ์ ์๋ ๊ฒฝ์ฐ์, ์ด ์ ๊ทผ๋ฒ์ ๋น ๋ฅด๊ฒ ์ ์ฉํ ๊ฒฐ๊ณผ๋ฅผ ์์ฑํ ์ ์์ต๋๋ค. | |