transformers / docs /source /ko /tasks /document_question_answering.md
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
โš ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# ๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต(Document Question Answering) [[document_question_answering]]
[[open-in-colab]]
๋ฌธ์„œ ์‹œ๊ฐ์  ์งˆ์˜ ์‘๋‹ต(Document Visual Question Answering)์ด๋ผ๊ณ ๋„ ํ•˜๋Š”
๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต(Document Question Answering)์€ ๋ฌธ์„œ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์งˆ๋ฌธ์— ๋‹ต๋ณ€์„ ์ฃผ๋Š” ํƒœ์Šคํฌ์ž…๋‹ˆ๋‹ค.
์ด ํƒœ์Šคํฌ๋ฅผ ์ง€์›ํ•˜๋Š” ๋ชจ๋ธ์˜ ์ž…๋ ฅ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์˜ ์กฐํ•ฉ์ด๊ณ , ์ถœ๋ ฅ์€ ์ž์—ฐ์–ด๋กœ ๋œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์€ ํ…์ŠคํŠธ, ๋‹จ์–ด์˜ ์œ„์น˜(๋ฐ”์šด๋”ฉ ๋ฐ•์Šค), ์ด๋ฏธ์ง€ ๋“ฑ ๋‹ค์–‘ํ•œ ๋ชจ๋‹ฌ๋ฆฌํ‹ฐ๋ฅผ ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค.
์ด ๊ฐ€์ด๋“œ๋Š” ๋‹ค์Œ ๋‚ด์šฉ์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค:
- [DocVQA dataset](https://huggingface.co/datasets/nielsr/docvqa_1200_examples_donut)์„ ์‚ฌ์šฉํ•ด [LayoutLMv2](../model_doc/layoutlmv2) ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ
- ์ถ”๋ก ์„ ์œ„ํ•ด ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๊ธฐ
<Tip>
์ด ์ž‘์—…๊ณผ ํ˜ธํ™˜๋˜๋Š” ๋ชจ๋“  ์•„ํ‚คํ…์ฒ˜์™€ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ณด๋ ค๋ฉด [์ž‘์—… ํŽ˜์ด์ง€](https://huggingface.co/tasks/image-to-text)๋ฅผ ํ™•์ธํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.
</Tip>
LayoutLMv2๋Š” ํ† ํฐ์˜ ๋งˆ์ง€๋ง‰ ์€๋‹‰์ธต ์œ„์— ์งˆ์˜ ์‘๋‹ต ํ—ค๋“œ๋ฅผ ์ถ”๊ฐ€ํ•ด ๋‹ต๋ณ€์˜ ์‹œ์ž‘ ํ† ํฐ๊ณผ ๋ ํ† ํฐ์˜ ์œ„์น˜๋ฅผ ์˜ˆ์ธกํ•จ์œผ๋กœ์จ ๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต ํƒœ์Šคํฌ๋ฅผ ํ•ด๊ฒฐํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ๋ฌธ๋งฅ์ด ์ฃผ์–ด์กŒ์„ ๋•Œ ์งˆ๋ฌธ์— ๋‹ตํ•˜๋Š” ์ •๋ณด๋ฅผ ์ถ”์ถœํ•˜๋Š” ์ถ”์ถœํ˜• ์งˆ์˜ ์‘๋‹ต(Extractive question answering)์œผ๋กœ ๋ฌธ์ œ๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
๋ฌธ๋งฅ์€ OCR ์—”์ง„์˜ ์ถœ๋ ฅ์—์„œ ๊ฐ€์ ธ์˜ค๋ฉฐ, ์—ฌ๊ธฐ์„œ๋Š” Google์˜ Tesseract๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ๋ชจ๋‘ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”. LayoutLMv2๋Š” detectron2, torchvision ๋ฐ ํ…Œ์„œ๋ž™ํŠธ๋ฅผ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค.
```bash
pip install -q transformers datasets
```
```bash
pip install 'git+https://github.com/facebookresearch/detectron2.git'
pip install torchvision
```
```bash
sudo apt install tesseract-ocr
pip install -q pytesseract
```
ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์„ ๋ชจ๋‘ ์„ค์น˜ํ•œ ํ›„ ๋Ÿฐํƒ€์ž„์„ ๋‹ค์‹œ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.
์ปค๋ฎค๋‹ˆํ‹ฐ์— ๋‹น์‹ ์˜ ๋ชจ๋ธ์„ ๊ณต์œ ํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธํ•ด์„œ ๋ชจ๋ธ์„ ๐Ÿค— Hub์— ์—…๋กœ๋“œํ•˜์„ธ์š”.
ํ”„๋กฌํ”„ํŠธ๊ฐ€ ์‹คํ–‰๋˜๋ฉด, ๋กœ๊ทธ์ธ์„ ์œ„ํ•ด ํ† ํฐ์„ ์ž…๋ ฅํ•˜์„ธ์š”:
```py
>>> from huggingface_hub import notebook_login
>>> notebook_login()
```
๋ช‡ ๊ฐ€์ง€ ์ „์—ญ ๋ณ€์ˆ˜๋ฅผ ์ •์˜ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
```py
>>> model_checkpoint = "microsoft/layoutlmv2-base-uncased"
>>> batch_size = 4
```
## ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ [[load-the-data]]
์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๐Ÿค— Hub์—์„œ ์ฐพ์„ ์ˆ˜ ์žˆ๋Š” ์ „์ฒ˜๋ฆฌ๋œ DocVQA์˜ ์ž‘์€ ์ƒ˜ํ”Œ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
DocVQA์˜ ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด, [DocVQA homepage](https://rrc.cvc.uab.es/?ch=17)์— ๊ฐ€์ž… ํ›„ ๋‹ค์šด๋กœ๋“œ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋‹ค์šด๋กœ๋“œ ํ–ˆ๋‹ค๋ฉด, ์ด ๊ฐ€์ด๋“œ๋ฅผ ๊ณ„์† ์ง„ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด [๐Ÿค— dataset์— ํŒŒ์ผ์„ ๊ฐ€์ ธ์˜ค๋Š” ๋ฐฉ๋ฒ•](https://huggingface.co/docs/datasets/loading#local-and-remote-files)์„ ํ™•์ธํ•˜์„ธ์š”.
```py
>>> from datasets import load_dataset
>>> dataset = load_dataset("nielsr/docvqa_1200_examples")
>>> dataset
DatasetDict({
train: Dataset({
features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
num_rows: 1000
})
test: Dataset({
features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
num_rows: 200
})
})
```
๋ณด์‹œ๋‹ค์‹œํ”ผ, ๋ฐ์ดํ„ฐ ์„ธํŠธ๋Š” ์ด๋ฏธ ํ›ˆ๋ จ ์„ธํŠธ์™€ ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋‚˜๋ˆ„์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฌด์ž‘์œ„๋กœ ์˜ˆ์ œ๋ฅผ ์‚ดํŽด๋ณด๋ฉด์„œ ํŠน์„ฑ์„ ํ™•์ธํ•ด๋ณด์„ธ์š”.
```py
>>> dataset["train"].features
```
๊ฐ ํ•„๋“œ๊ฐ€ ๋‚˜ํƒ€๋‚ด๋Š” ๋‚ด์šฉ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:
* `id`: ์˜ˆ์ œ์˜ id
* `image`: ๋ฌธ์„œ ์ด๋ฏธ์ง€๋ฅผ ํฌํ•จํ•˜๋Š” PIL.Image.Image ๊ฐ์ฒด
* `query`: ์งˆ๋ฌธ ๋ฌธ์ž์—ด - ์—ฌ๋Ÿฌ ์–ธ์–ด์˜ ์ž์—ฐ์–ด๋กœ ๋œ ์งˆ๋ฌธ
* `answers`: ์‚ฌ๋žŒ์ด ์ฃผ์„์„ ๋‹จ ์ •๋‹ต ๋ฆฌ์ŠคํŠธ
* `words` and `bounding_boxes`: OCR์˜ ๊ฒฐ๊ณผ๊ฐ’๋“ค์ด๋ฉฐ ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ์‚ฌ์šฉํ•˜์ง€ ์•Š์„ ์˜ˆ์ •
* `answer`: ๋‹ค๋ฅธ ๋ชจ๋ธ๊ณผ ์ผ์น˜ํ•˜๋Š” ๋‹ต๋ณ€์ด๋ฉฐ ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ์‚ฌ์šฉํ•˜์ง€ ์•Š์„ ์˜ˆ์ •
์˜์–ด๋กœ ๋œ ์งˆ๋ฌธ๋งŒ ๋‚จ๊ธฐ๊ณ  ๋‹ค๋ฅธ ๋ชจ๋ธ์— ๋Œ€ํ•œ ์˜ˆ์ธก์„ ํฌํ•จํ•˜๋Š” `answer` ํŠน์„ฑ์„ ์‚ญ์ œํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
๊ทธ๋ฆฌ๊ณ  ์ฃผ์„ ์ž‘์„ฑ์ž๊ฐ€ ์ œ๊ณตํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ์ฒซ ๋ฒˆ์งธ ๋‹ต๋ณ€์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. ๋˜๋Š” ๋ฌด์ž‘์œ„๋กœ ์ƒ˜ํ”Œ์„ ์ถ”์ถœํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
```py
>>> updated_dataset = dataset.map(lambda example: {"question": example["query"]["en"]}, remove_columns=["query"])
>>> updated_dataset = updated_dataset.map(
... lambda example: {"answer": example["answers"][0]}, remove_columns=["answer", "answers"]
... )
```
์ด ๊ฐ€์ด๋“œ์—์„œ ์‚ฌ์šฉํ•˜๋Š” LayoutLMv2 ์ฒดํฌํฌ์ธํŠธ๋Š” `max_position_embeddings = 512`๋กœ ํ›ˆ๋ จ๋˜์—ˆ์Šต๋‹ˆ๋‹ค(์ด ์ •๋ณด๋Š” [์ฒดํฌํฌ์ธํŠธ์˜ `config.json` ํŒŒ์ผ](https://huggingface.co/microsoft/layoutlmv2-base-uncased/blob/main/config.json#L18)์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค).
๋ฐ”๋กœ ์˜ˆ์ œ๋ฅผ ์ž˜๋ผ๋‚ผ ์ˆ˜๋„ ์žˆ์ง€๋งŒ, ๊ธด ๋ฌธ์„œ์˜ ๋์— ๋‹ต๋ณ€์ด ์žˆ์–ด ์ž˜๋ฆฌ๋Š” ์ƒํ™ฉ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ์—ฌ๊ธฐ์„œ๋Š” ์ž„๋ฒ ๋”ฉ์ด 512๋ณด๋‹ค ๊ธธ์–ด์งˆ ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ๋Š” ๋ช‡ ๊ฐ€์ง€ ์˜ˆ์ œ๋ฅผ ์ œ๊ฑฐํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
๋ฐ์ดํ„ฐ ์„ธํŠธ์— ์žˆ๋Š” ๋Œ€๋ถ€๋ถ„์˜ ๋ฌธ์„œ๊ฐ€ ๊ธด ๊ฒฝ์šฐ ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค - ์ž์„ธํ•œ ๋‚ด์šฉ์„ ํ™•์ธํ•˜๊ณ  ์‹ถ์œผ๋ฉด ์ด [๋…ธํŠธ๋ถ](https://github.com/huggingface/notebooks/blob/main/examples/question_answering.ipynb)์„ ํ™•์ธํ•˜์„ธ์š”.
```py
>>> updated_dataset = updated_dataset.filter(lambda x: len(x["words"]) + len(x["question"].split()) < 512)
```
์ด ์‹œ์ ์—์„œ ์ด ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ OCR ํŠน์„ฑ๋„ ์ œ๊ฑฐํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. OCR ํŠน์„ฑ์€ ๋‹ค๋ฅธ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•œ ๊ฒƒ์œผ๋กœ, ์ด ๊ฐ€์ด๋“œ์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ์˜ ์ž…๋ ฅ ์š”๊ตฌ ์‚ฌํ•ญ๊ณผ ์ผ์น˜ํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์ด ํŠน์„ฑ์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์ผ๋ถ€ ์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
๋Œ€์‹ , ์›๋ณธ ๋ฐ์ดํ„ฐ์— [`LayoutLMv2Processor`]๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ OCR ๋ฐ ํ† ํฐํ™”๋ฅผ ๋ชจ๋‘ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ชจ๋ธ์ด ์š”๊ตฌํ•˜๋Š” ์ž…๋ ฅ์„ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ด๋ฏธ์ง€๋ฅผ ์ˆ˜๋™์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด, [`LayoutLMv2` model documentation](../model_doc/layoutlmv2)์—์„œ ๋ชจ๋ธ์ด ์š”๊ตฌํ•˜๋Š” ์ž…๋ ฅ ํฌ๋งท์„ ํ™•์ธํ•ด๋ณด์„ธ์š”.
```py
>>> updated_dataset = updated_dataset.remove_columns("words")
>>> updated_dataset = updated_dataset.remove_columns("bounding_boxes")
```
๋งˆ์ง€๋ง‰์œผ๋กœ, ๋ฐ์ดํ„ฐ ํƒ์ƒ‰์„ ์™„๋ฃŒํ•˜๊ธฐ ์œ„ํ•ด ์ด๋ฏธ์ง€ ์˜ˆ์‹œ๋ฅผ ์‚ดํŽด๋ด…์‹œ๋‹ค.
```py
>>> updated_dataset["train"][11]["image"]
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/docvqa_example.jpg" alt="DocVQA Image Example"/>
</div>
## ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ [[preprocess-the-data]]
๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต ํƒœ์Šคํฌ๋Š” ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ํƒœ์Šคํฌ์ด๋ฉฐ, ๊ฐ ๋ชจ๋‹ฌ๋ฆฌํ‹ฐ์˜ ์ž…๋ ฅ์ด ๋ชจ๋ธ์˜ ์š”๊ตฌ์— ๋งž๊ฒŒ ์ „์ฒ˜๋ฆฌ ๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ์™€ ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์ธ์ฝ”๋”ฉํ•  ์ˆ˜ ์žˆ๋Š” ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๊ฒฐํ•ฉํ•œ [`LayoutLMv2Processor`]๋ฅผ ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
```py
>>> from transformers import AutoProcessor
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
```
### ๋ฌธ์„œ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ [[preprocessing-document-images]]
๋จผ์ €, ํ”„๋กœ์„ธ์„œ์˜ `image_processor`๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์— ๋Œ€ํ•œ ๋ฌธ์„œ ์ด๋ฏธ์ง€๋ฅผ ์ค€๋น„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
๊ธฐ๋ณธ๊ฐ’์œผ๋กœ, ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋Š” ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ 224x224๋กœ ์กฐ์ •ํ•˜๊ณ  ์ƒ‰์ƒ ์ฑ„๋„์˜ ์ˆœ์„œ๊ฐ€ ์˜ฌ๋ฐ”๋ฅธ์ง€ ํ™•์ธํ•œ ํ›„ ๋‹จ์–ด์™€ ์ •๊ทœํ™”๋œ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด ํ…Œ์„œ๋ž™ํŠธ๋ฅผ ์‚ฌ์šฉํ•ด OCR๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์šฐ๋ฆฌ๊ฐ€ ํ•„์š”ํ•œ ๊ฒƒ๊ณผ ๊ธฐ๋ณธ๊ฐ’์€ ์™„์ „ํžˆ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ๋ฐฐ์น˜์— ๊ธฐ๋ณธ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ๋ฅผ ์ ์šฉํ•˜๊ณ  OCR์˜ ๊ฒฐ๊ณผ๋ฅผ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.
```py
>>> image_processor = processor.image_processor
>>> def get_ocr_words_and_boxes(examples):
... images = [image.convert("RGB") for image in examples["image"]]
... encoded_inputs = image_processor(images)
... examples["image"] = encoded_inputs.pixel_values
... examples["words"] = encoded_inputs.words
... examples["boxes"] = encoded_inputs.boxes
... return examples
```
์ด ์ „์ฒ˜๋ฆฌ๋ฅผ ๋ฐ์ดํ„ฐ ์„ธํŠธ ์ „์ฒด์— ๋น ๋ฅด๊ฒŒ ์ ์šฉํ•˜๋ ค๋ฉด [`~datasets.Dataset.map`]๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”.
```py
>>> dataset_with_ocr = updated_dataset.map(get_ocr_words_and_boxes, batched=True, batch_size=2)
```
### ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ [[preprocessing-text-data]]
์ด๋ฏธ์ง€์— OCR์„ ์ ์šฉํ–ˆ์œผ๋ฉด ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํ…์ŠคํŠธ ๋ถ€๋ถ„์„ ๋ชจ๋ธ์— ๋งž๊ฒŒ ์ธ์ฝ”๋”ฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
์ด ์ธ์ฝ”๋”ฉ์—๋Š” ์ด์ „ ๋‹จ๊ณ„์—์„œ ๊ฐ€์ ธ์˜จ ๋‹จ์–ด์™€ ๋ฐ•์Šค๋ฅผ ํ† ํฐ ์ˆ˜์ค€์˜ `input_ids`, `attention_mask`, `token_type_ids` ๋ฐ `bbox`๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์ž‘์—…์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.
ํ…์ŠคํŠธ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด ํ”„๋กœ์„ธ์„œ์˜ `tokenizer`๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
```py
>>> tokenizer = processor.tokenizer
```
์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ์ „์ฒ˜๋ฆฌ ์™ธ์—๋„ ๋ชจ๋ธ์„ ์œ„ํ•ด ๋ ˆ์ด๋ธ”์„ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๐Ÿค— Transformers์˜ `xxxForQuestionAnswering` ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ, ๋ ˆ์ด๋ธ”์€ `start_positions`์™€ `end_positions`๋กœ ๊ตฌ์„ฑ๋˜๋ฉฐ ์–ด๋–ค ํ† ํฐ์ด ๋‹ต๋ณ€์˜ ์‹œ์ž‘๊ณผ ๋์— ์žˆ๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
๋ ˆ์ด๋ธ” ์ถ”๊ฐ€๋ฅผ ์œ„ํ•ด์„œ, ๋จผ์ € ๋” ํฐ ๋ฆฌ์ŠคํŠธ(๋‹จ์–ด ๋ฆฌ์ŠคํŠธ)์—์„œ ํ•˜์œ„ ๋ฆฌ์ŠคํŠธ(๋‹จ์–ด๋กœ ๋ถ„ํ• ๋œ ๋‹ต๋ณ€)์„ ์ฐพ์„ ์ˆ˜ ์žˆ๋Š” ํ—ฌํผ ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
์ด ํ•จ์ˆ˜๋Š” `words_list`์™€ `answer_list`, ์ด๋ ‡๊ฒŒ ๋‘ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์Šต๋‹ˆ๋‹ค.
๊ทธ๋Ÿฐ ๋‹ค์Œ `words_list`๋ฅผ ๋ฐ˜๋ณตํ•˜์—ฌ `words_list`์˜ ํ˜„์žฌ ๋‹จ์–ด(words_list[i])๊ฐ€ `answer_list`์˜ ์ฒซ ๋ฒˆ์งธ ๋‹จ์–ด(answer_list[0])์™€ ๊ฐ™์€์ง€,
ํ˜„์žฌ ๋‹จ์–ด์—์„œ ์‹œ์ž‘ํ•ด `answer_list`์™€ ๊ฐ™์€ ๊ธธ์ด๋งŒํผ์˜ `words_list`์˜ ํ•˜์œ„ ๋ฆฌ์ŠคํŠธ๊ฐ€ `answer_list`์™€ ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
์ด ์กฐ๊ฑด์ด ์ฐธ์ด๋ผ๋ฉด ์ผ์น˜ํ•˜๋Š” ํ•ญ๋ชฉ์„ ๋ฐœ๊ฒฌํ–ˆ์Œ์„ ์˜๋ฏธํ•˜๋ฉฐ, ํ•จ์ˆ˜๋Š” ์ผ์น˜ ํ•ญ๋ชฉ, ์‹œ์ž‘ ์ธ๋ฑ์Šค(idx) ๋ฐ ์ข…๋ฃŒ ์ธ๋ฑ์Šค(idx + len(answer_list) - 1)๋ฅผ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค. ์ผ์น˜ํ•˜๋Š” ํ•ญ๋ชฉ์ด ๋‘ ๊ฐœ ์ด์ƒ ๋ฐœ๊ฒฌ๋˜๋ฉด ํ•จ์ˆ˜๋Š” ์ฒซ ๋ฒˆ์งธ ํ•ญ๋ชฉ๋งŒ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์ผ์น˜ํ•˜๋Š” ํ•ญ๋ชฉ์ด ์—†๋‹ค๋ฉด ํ•จ์ˆ˜๋Š” (`None`, 0, 0)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
```py
>>> def subfinder(words_list, answer_list):
... matches = []
... start_indices = []
... end_indices = []
... for idx, i in enumerate(range(len(words_list))):
... if words_list[i] == answer_list[0] and words_list[i : i + len(answer_list)] == answer_list:
... matches.append(answer_list)
... start_indices.append(idx)
... end_indices.append(idx + len(answer_list) - 1)
... if matches:
... return matches[0], start_indices[0], end_indices[0]
... else:
... return None, 0, 0
```
์ด ํ•จ์ˆ˜๊ฐ€ ์–ด๋–ป๊ฒŒ ์ •๋‹ต์˜ ์œ„์น˜๋ฅผ ์ฐพ๋Š”์ง€ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ ์˜ˆ์ œ์—์„œ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
```py
>>> example = dataset_with_ocr["train"][1]
>>> words = [word.lower() for word in example["words"]]
>>> match, word_idx_start, word_idx_end = subfinder(words, example["answer"].lower().split())
>>> print("Question: ", example["question"])
>>> print("Words:", words)
>>> print("Answer: ", example["answer"])
>>> print("start_index", word_idx_start)
>>> print("end_index", word_idx_end)
Question: Who is in cc in this letter?
Words: ['wie', 'baw', 'brown', '&', 'williamson', 'tobacco', 'corporation', 'research', '&', 'development', 'internal', 'correspondence', 'to:', 'r.', 'h.', 'honeycutt', 'ce:', 't.f.', 'riehl', 'from:', '.', 'c.j.', 'cook', 'date:', 'may', '8,', '1995', 'subject:', 'review', 'of', 'existing', 'brainstorming', 'ideas/483', 'the', 'major', 'function', 'of', 'the', 'product', 'innovation', 'graup', 'is', 'to', 'develop', 'marketable', 'nove!', 'products', 'that', 'would', 'be', 'profitable', 'to', 'manufacture', 'and', 'sell.', 'novel', 'is', 'defined', 'as:', 'of', 'a', 'new', 'kind,', 'or', 'different', 'from', 'anything', 'seen', 'or', 'known', 'before.', 'innovation', 'is', 'defined', 'as:', 'something', 'new', 'or', 'different', 'introduced;', 'act', 'of', 'innovating;', 'introduction', 'of', 'new', 'things', 'or', 'methods.', 'the', 'products', 'may', 'incorporate', 'the', 'latest', 'technologies,', 'materials', 'and', 'know-how', 'available', 'to', 'give', 'then', 'a', 'unique', 'taste', 'or', 'look.', 'the', 'first', 'task', 'of', 'the', 'product', 'innovation', 'group', 'was', 'to', 'assemble,', 'review', 'and', 'categorize', 'a', 'list', 'of', 'existing', 'brainstorming', 'ideas.', 'ideas', 'were', 'grouped', 'into', 'two', 'major', 'categories', 'labeled', 'appearance', 'and', 'taste/aroma.', 'these', 'categories', 'are', 'used', 'for', 'novel', 'products', 'that', 'may', 'differ', 'from', 'a', 'visual', 'and/or', 'taste/aroma', 'point', 'of', 'view', 'compared', 'to', 'canventional', 'cigarettes.', 'other', 'categories', 'include', 'a', 'combination', 'of', 'the', 'above,', 'filters,', 'packaging', 'and', 'brand', 'extensions.', 'appearance', 'this', 'category', 'is', 'used', 'for', 'novel', 'cigarette', 'constructions', 'that', 'yield', 'visually', 'different', 'products', 'with', 'minimal', 'changes', 'in', 'smoke', 'chemistry', 'two', 'cigarettes', 'in', 'cne.', 'emulti-plug', 'te', 'build', 'yaur', 'awn', 'cigarette.', 'eswitchable', 'menthol', 'or', 'non', 'menthol', 'cigarette.', '*cigarettes', 'with', 'interspaced', 'perforations', 'to', 'enable', 'smoker', 'to', 'separate', 'unburned', 'section', 'for', 'future', 'smoking.', 'ยซshort', 'cigarette,', 'tobacco', 'section', '30', 'mm.', 'ยซextremely', 'fast', 'buming', 'cigarette.', 'ยซnovel', 'cigarette', 'constructions', 'that', 'permit', 'a', 'significant', 'reduction', 'iretobacco', 'weight', 'while', 'maintaining', 'smoking', 'mechanics', 'and', 'visual', 'characteristics.', 'higher', 'basis', 'weight', 'paper:', 'potential', 'reduction', 'in', 'tobacco', 'weight.', 'ยซmore', 'rigid', 'tobacco', 'column;', 'stiffing', 'agent', 'for', 'tobacco;', 'e.g.', 'starch', '*colored', 'tow', 'and', 'cigarette', 'papers;', 'seasonal', 'promotions,', 'e.g.', 'pastel', 'colored', 'cigarettes', 'for', 'easter', 'or', 'in', 'an', 'ebony', 'and', 'ivory', 'brand', 'containing', 'a', 'mixture', 'of', 'all', 'black', '(black', 'paper', 'and', 'tow)', 'and', 'ail', 'white', 'cigarettes.', '499150498']
Answer: T.F. Riehl
start_index 17
end_index 18
```
ํ•œํŽธ, ์œ„ ์˜ˆ์ œ๊ฐ€ ์ธ์ฝ”๋”ฉ๋˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค:
```py
>>> encoding = tokenizer(example["question"], example["words"], example["boxes"])
>>> tokenizer.decode(encoding["input_ids"])
[CLS] who is in cc in this letter? [SEP] wie baw brown & williamson tobacco corporation research & development ...
```
์ด์ œ ์ธ์ฝ”๋”ฉ๋œ ์ž…๋ ฅ์—์„œ ์ •๋‹ต์˜ ์œ„์น˜๋ฅผ ์ฐพ์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.
* `token_type_ids`๋Š” ์–ด๋–ค ํ† ํฐ์ด ์งˆ๋ฌธ์— ์†ํ•˜๋Š”์ง€, ๊ทธ๋ฆฌ๊ณ  ์–ด๋–ค ํ† ํฐ์ด ๋ฌธ์„œ์˜ ๋‹จ์–ด์— ํฌํ•จ๋˜๋Š”์ง€๋ฅผ ์•Œ๋ ค์ค๋‹ˆ๋‹ค.
* `tokenizer.cls_token_id` ์ž…๋ ฅ์˜ ์‹œ์ž‘ ๋ถ€๋ถ„์— ์žˆ๋Š” ํŠน์ˆ˜ ํ† ํฐ์„ ์ฐพ๋Š” ๋ฐ ๋„์›€์„ ์ค๋‹ˆ๋‹ค.
* `word_ids`๋Š” ์›๋ณธ `words`์—์„œ ์ฐพ์€ ๋‹ต๋ณ€์„ ์ „์ฒด ์ธ์ฝ”๋”ฉ๋œ ์ž…๋ ฅ์˜ ๋™์ผํ•œ ๋‹ต๊ณผ ์ผ์น˜์‹œํ‚ค๊ณ  ์ธ์ฝ”๋”ฉ๋œ ์ž…๋ ฅ์—์„œ ๋‹ต๋ณ€์˜ ์‹œ์ž‘/๋ ์œ„์น˜๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
์œ„ ๋‚ด์šฉ๋“ค์„ ์—ผ๋‘์— ๋‘๊ณ  ๋ฐ์ดํ„ฐ ์„ธํŠธ ์˜ˆ์ œ์˜ ๋ฐฐ์น˜๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
```py
>>> def encode_dataset(examples, max_length=512):
... questions = examples["question"]
... words = examples["words"]
... boxes = examples["boxes"]
... answers = examples["answer"]
... # ์˜ˆ์ œ ๋ฐฐ์น˜๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๊ณ  start_positions์™€ end_positions๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค
... encoding = tokenizer(questions, words, boxes, max_length=max_length, padding="max_length", truncation=True)
... start_positions = []
... end_positions = []
... # ๋ฐฐ์น˜์˜ ์˜ˆ์ œ๋ฅผ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค
... for i in range(len(questions)):
... cls_index = encoding["input_ids"][i].index(tokenizer.cls_token_id)
... # ์˜ˆ์ œ์˜ words์—์„œ ๋‹ต๋ณ€์˜ ์œ„์น˜๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค
... words_example = [word.lower() for word in words[i]]
... answer = answers[i]
... match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split())
... if match:
... # ์ผ์น˜ํ•˜๋Š” ํ•ญ๋ชฉ์„ ๋ฐœ๊ฒฌํ•˜๋ฉด, `token_type_ids`๋ฅผ ์‚ฌ์šฉํ•ด ์ธ์ฝ”๋”ฉ์—์„œ ๋‹จ์–ด๊ฐ€ ์‹œ์ž‘ํ•˜๋Š” ์œ„์น˜๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค
... token_type_ids = encoding["token_type_ids"][i]
... token_start_index = 0
... while token_type_ids[token_start_index] != 1:
... token_start_index += 1
... token_end_index = len(encoding["input_ids"][i]) - 1
... while token_type_ids[token_end_index] != 1:
... token_end_index -= 1
... word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1]
... start_position = cls_index
... end_position = cls_index
... # words์˜ ๋‹ต๋ณ€ ์œ„์น˜์™€ ์ผ์น˜ํ•  ๋•Œ๊นŒ์ง€ word_ids๋ฅผ ๋ฐ˜๋ณตํ•˜๊ณ  `token_start_index`๋ฅผ ๋Š˜๋ฆฝ๋‹ˆ๋‹ค
... # ์ผ์น˜ํ•˜๋ฉด `token_start_index`๋ฅผ ์ธ์ฝ”๋”ฉ์—์„œ ๋‹ต๋ณ€์˜ `start_position`์œผ๋กœ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
... for id in word_ids:
... if id == word_idx_start:
... start_position = token_start_index
... else:
... token_start_index += 1
... # ๋น„์Šทํ•˜๊ฒŒ, ๋์—์„œ ์‹œ์ž‘ํ•ด `word_ids`๋ฅผ ๋ฐ˜๋ณตํ•˜๋ฉฐ ๋‹ต๋ณ€์˜ `end_position`์„ ์ฐพ์Šต๋‹ˆ๋‹ค
... for id in word_ids[::-1]:
... if id == word_idx_end:
... end_position = token_end_index
... else:
... token_end_index -= 1
... start_positions.append(start_position)
... end_positions.append(end_position)
... else:
... start_positions.append(cls_index)
... end_positions.append(cls_index)
... encoding["image"] = examples["image"]
... encoding["start_positions"] = start_positions
... encoding["end_positions"] = end_positions
... return encoding
```
์ด์ œ ์ด ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๊ฐ€ ์žˆ์œผ๋‹ˆ ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์ธ์ฝ”๋”ฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:
```py
>>> encoded_train_dataset = dataset_with_ocr["train"].map(
... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["train"].column_names
... )
>>> encoded_test_dataset = dataset_with_ocr["test"].map(
... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["test"].column_names
... )
```
์ธ์ฝ”๋”ฉ๋œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํŠน์„ฑ์ด ์–ด๋–ป๊ฒŒ ์ƒ๊ฒผ๋Š”์ง€ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
```py
>>> encoded_train_dataset.features
{'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
'start_positions': Value(dtype='int64', id=None),
'end_positions': Value(dtype='int64', id=None)}
```
## ํ‰๊ฐ€ [[evaluation]]
๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต์„ ํ‰๊ฐ€ํ•˜๋ ค๋ฉด ์ƒ๋‹นํ•œ ์–‘์˜ ํ›„์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์‹œ๊ฐ„์ด ๋„ˆ๋ฌด ๋งŽ์ด ๊ฑธ๋ฆฌ์ง€ ์•Š๋„๋ก ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ํ‰๊ฐ€ ๋‹จ๊ณ„๋ฅผ ์ƒ๋žตํ•ฉ๋‹ˆ๋‹ค.
[`Trainer`]๊ฐ€ ํ›ˆ๋ จ ๊ณผ์ •์—์„œ ํ‰๊ฐ€ ์†์‹ค(evaluation loss)์„ ๊ณ„์† ๊ณ„์‚ฐํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ๋Œ€๋žต์ ์œผ๋กœ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ถ”์ถœ์ (Extractive) ์งˆ์˜ ์‘๋‹ต์€ ๋ณดํ†ต F1/exact match ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•ด ํ‰๊ฐ€๋ฉ๋‹ˆ๋‹ค.
์ง์ ‘ ๊ตฌํ˜„ํ•ด๋ณด๊ณ  ์‹ถ์œผ์‹œ๋‹ค๋ฉด, Hugging Face course์˜ [Question Answering chapter](https://huggingface.co/course/chapter7/7?fw=pt#postprocessing)์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.
## ํ›ˆ๋ จ [[train]]
์ถ•ํ•˜ํ•ฉ๋‹ˆ๋‹ค! ์ด ๊ฐ€์ด๋“œ์˜ ๊ฐ€์žฅ ์–ด๋ ค์šด ๋ถ€๋ถ„์„ ์„ฑ๊ณต์ ์œผ๋กœ ์ฒ˜๋ฆฌํ–ˆ์œผ๋‹ˆ ์ด์ œ ๋‚˜๋งŒ์˜ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
ํ›ˆ๋ จ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋‹จ๊ณ„๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค:
* ์ „์ฒ˜๋ฆฌ์—์„œ์˜ ๋™์ผํ•œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด [`AutoModelForDocumentQuestionAnswering`]์œผ๋กœ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
* [`TrainingArguments`]๋กœ ํ›ˆ๋ จ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ •ํ•ฉ๋‹ˆ๋‹ค.
* ์˜ˆ์ œ๋ฅผ ๋ฐฐ์น˜ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” [`DefaultDataCollator`]๊ฐ€ ์ ๋‹นํ•ฉ๋‹ˆ๋‹ค.
* ๋ชจ๋ธ, ๋ฐ์ดํ„ฐ ์„ธํŠธ, ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ(Data collator)์™€ ํ•จ๊ป˜ [`Trainer`]์— ํ›ˆ๋ จ ์ธ์ˆ˜๋“ค์„ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
* [`~Trainer.train`]์„ ํ˜ธ์ถœํ•ด์„œ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
```py
>>> from transformers import AutoModelForDocumentQuestionAnswering
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint)
```
[`TrainingArguments`]์—์„œ `output_dir`์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ์ €์žฅํ•  ์œ„์น˜๋ฅผ ์ง€์ •ํ•˜๊ณ , ์ ์ ˆํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
๋ชจ๋ธ์„ ์ปค๋ฎค๋‹ˆํ‹ฐ์™€ ๊ณต์œ ํ•˜๋ ค๋ฉด `push_to_hub`๋ฅผ `True`๋กœ ์„ค์ •ํ•˜์„ธ์š” (๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๋ ค๋ฉด Hugging Face์— ๋กœ๊ทธ์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค).
์ด ๊ฒฝ์šฐ `output_dir`์€ ๋ชจ๋ธ์˜ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ํ‘ธ์‹œํ•  ๋ ˆํฌ์ง€ํ† ๋ฆฌ์˜ ์ด๋ฆ„์ด ๋ฉ๋‹ˆ๋‹ค.
```py
>>> from transformers import TrainingArguments
>>> # ๋ณธ์ธ์˜ ๋ ˆํฌ์ง€ํ† ๋ฆฌ ID๋กœ ๋ฐ”๊พธ์„ธ์š”
>>> repo_id = "MariaK/layoutlmv2-base-uncased_finetuned_docvqa"
>>> training_args = TrainingArguments(
... output_dir=repo_id,
... per_device_train_batch_size=4,
... num_train_epochs=20,
... save_steps=200,
... logging_steps=50,
... eval_strategy="steps",
... learning_rate=5e-5,
... save_total_limit=2,
... remove_unused_columns=False,
... push_to_hub=True,
... )
```
๊ฐ„๋‹จํ•œ ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ๋ฅผ ์ •์˜ํ•˜์—ฌ ์˜ˆ์ œ๋ฅผ ํ•จ๊ป˜ ๋ฐฐ์น˜ํ•ฉ๋‹ˆ๋‹ค.
```py
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator()
```
๋งˆ์ง€๋ง‰์œผ๋กœ, ๋ชจ๋“  ๊ฒƒ์„ ํ•œ ๊ณณ์— ๋ชจ์•„ [`~Trainer.train`]์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค:
```py
>>> from transformers import Trainer
>>> trainer = Trainer(
... model=model,
... args=training_args,
... data_collator=data_collator,
... train_dataset=encoded_train_dataset,
... eval_dataset=encoded_test_dataset,
... processing_class=processor,
... )
>>> trainer.train()
```
์ตœ์ข… ๋ชจ๋ธ์„ ๐Ÿค— Hub์— ์ถ”๊ฐ€ํ•˜๋ ค๋ฉด, ๋ชจ๋ธ ์นด๋“œ๋ฅผ ์ƒ์„ฑํ•˜๊ณ  `push_to_hub`๋ฅผ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค:
```py
>>> trainer.create_model_card()
>>> trainer.push_to_hub()
```
## ์ถ”๋ก  [[inference]]
์ด์ œ LayoutLMv2 ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ณ  ๐Ÿค— Hub์— ์—…๋กœ๋“œํ–ˆ์œผ๋‹ˆ ์ถ”๋ก ์—๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ถ”๋ก ์„ ์œ„ํ•ด ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด ๋ณด๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ [`Pipeline`]์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ ์ž…๋‹ˆ๋‹ค.
์˜ˆ๋ฅผ ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
```py
>>> example = dataset["test"][2]
>>> question = example["query"]["en"]
>>> image = example["image"]
>>> print(question)
>>> print(example["answers"])
'Who is โ€˜presidingโ€™ TRRF GENERAL SESSION (PART 1)?'
['TRRF Vice President', 'lee a. waller']
```
๊ทธ ๋‹ค์Œ, ๋ชจ๋ธ๋กœ ๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต์„ ํ•˜๊ธฐ ์œ„ํ•ด ํŒŒ์ดํ”„๋ผ์ธ์„ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๊ณ  ์ด๋ฏธ์ง€ + ์งˆ๋ฌธ ์กฐํ•ฉ์„ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
```py
>>> from transformers import pipeline
>>> qa_pipeline = pipeline("document-question-answering", model="MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> qa_pipeline(image, question)
[{'score': 0.9949808120727539,
'answer': 'Lee A. Waller',
'start': 55,
'end': 57}]
```
์›ํ•œ๋‹ค๋ฉด ํŒŒ์ดํ”„๋ผ์ธ์˜ ๊ฒฐ๊ณผ๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋ณต์ œํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:
1. ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์„ ๊ฐ€์ ธ์™€ ๋ชจ๋ธ์˜ ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์— ๋งž๊ฒŒ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค.
2. ๋ชจ๋ธ์„ ํ†ตํ•ด ๊ฒฐ๊ณผ ๋˜๋Š” ์ „์ฒ˜๋ฆฌ๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
3. ๋ชจ๋ธ์€ ์–ด๋–ค ํ† ํฐ์ด ๋‹ต๋ณ€์˜ ์‹œ์ž‘์— ์žˆ๋Š”์ง€, ์–ด๋–ค ํ† ํฐ์ด ๋‹ต๋ณ€์ด ๋์— ์žˆ๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” `start_logits`์™€ `end_logits`๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๋‘˜ ๋‹ค (batch_size, sequence_length) ํ˜•ํƒœ๋ฅผ ๊ฐ–์Šต๋‹ˆ๋‹ค.
4. `start_logits`์™€ `end_logits`์˜ ๋งˆ์ง€๋ง‰ ์ฐจ์›์„ ์ตœ๋Œ€๋กœ ๋งŒ๋“œ๋Š” ๊ฐ’์„ ์ฐพ์•„ ์˜ˆ์ƒ `start_idx`์™€ `end_idx`๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.
5. ํ† ํฌ๋‚˜์ด์ €๋กœ ๋‹ต๋ณ€์„ ๋””์ฝ”๋”ฉํ•ฉ๋‹ˆ๋‹ค.
```py
>>> import torch
>>> from transformers import AutoProcessor
>>> from transformers import AutoModelForDocumentQuestionAnswering
>>> processor = AutoProcessor.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> with torch.no_grad():
... encoding = processor(image.convert("RGB"), question, return_tensors="pt")
... outputs = model(**encoding)
... start_logits = outputs.start_logits
... end_logits = outputs.end_logits
... predicted_start_idx = start_logits.argmax(-1).item()
... predicted_end_idx = end_logits.argmax(-1).item()
>>> processor.tokenizer.decode(encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1])
'lee a. waller'
```