| <!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # ๋ฌธ์ ์ง์ ์๋ต(Document Question Answering) [[document_question_answering]] | |
| [[open-in-colab]] | |
| ๋ฌธ์ ์๊ฐ์ ์ง์ ์๋ต(Document Visual Question Answering)์ด๋ผ๊ณ ๋ ํ๋ | |
| ๋ฌธ์ ์ง์ ์๋ต(Document Question Answering)์ ๋ฌธ์ ์ด๋ฏธ์ง์ ๋ํ ์ง๋ฌธ์ ๋ต๋ณ์ ์ฃผ๋ ํ์คํฌ์ ๋๋ค. | |
| ์ด ํ์คํฌ๋ฅผ ์ง์ํ๋ ๋ชจ๋ธ์ ์ ๋ ฅ์ ์ผ๋ฐ์ ์ผ๋ก ์ด๋ฏธ์ง์ ์ง๋ฌธ์ ์กฐํฉ์ด๊ณ , ์ถ๋ ฅ์ ์์ฐ์ด๋ก ๋ ๋ต๋ณ์ ๋๋ค. ์ด๋ฌํ ๋ชจ๋ธ์ ํ ์คํธ, ๋จ์ด์ ์์น(๋ฐ์ด๋ฉ ๋ฐ์ค), ์ด๋ฏธ์ง ๋ฑ ๋ค์ํ ๋ชจ๋ฌ๋ฆฌํฐ๋ฅผ ํ์ฉํฉ๋๋ค. | |
| ์ด ๊ฐ์ด๋๋ ๋ค์ ๋ด์ฉ์ ์ค๋ช ํฉ๋๋ค: | |
| - [DocVQA dataset](https://huggingface.co/datasets/nielsr/docvqa_1200_examples_donut)์ ์ฌ์ฉํด [LayoutLMv2](../model_doc/layoutlmv2) ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ | |
| - ์ถ๋ก ์ ์ํด ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ธฐ | |
| <Tip> | |
| ์ด ์์ ๊ณผ ํธํ๋๋ ๋ชจ๋ ์ํคํ ์ฒ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ณด๋ ค๋ฉด [์์ ํ์ด์ง](https://huggingface.co/tasks/image-to-text)๋ฅผ ํ์ธํ๋ ๊ฒ์ด ์ข์ต๋๋ค. | |
| </Tip> | |
| LayoutLMv2๋ ํ ํฐ์ ๋ง์ง๋ง ์๋์ธต ์์ ์ง์ ์๋ต ํค๋๋ฅผ ์ถ๊ฐํด ๋ต๋ณ์ ์์ ํ ํฐ๊ณผ ๋ ํ ํฐ์ ์์น๋ฅผ ์์ธกํจ์ผ๋ก์จ ๋ฌธ์ ์ง์ ์๋ต ํ์คํฌ๋ฅผ ํด๊ฒฐํฉ๋๋ค. ์ฆ, ๋ฌธ๋งฅ์ด ์ฃผ์ด์ก์ ๋ ์ง๋ฌธ์ ๋ตํ๋ ์ ๋ณด๋ฅผ ์ถ์ถํ๋ ์ถ์ถํ ์ง์ ์๋ต(Extractive question answering)์ผ๋ก ๋ฌธ์ ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค. | |
| ๋ฌธ๋งฅ์ OCR ์์ง์ ์ถ๋ ฅ์์ ๊ฐ์ ธ์ค๋ฉฐ, ์ฌ๊ธฐ์๋ Google์ Tesseract๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
| ์์ํ๊ธฐ ์ ์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ๋ชจ๋ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์. LayoutLMv2๋ detectron2, torchvision ๋ฐ ํ ์๋ํธ๋ฅผ ํ์๋ก ํฉ๋๋ค. | |
| ```bash | |
| pip install -q transformers datasets | |
| ``` | |
| ```bash | |
| pip install 'git+https://github.com/facebookresearch/detectron2.git' | |
| pip install torchvision | |
| ``` | |
| ```bash | |
| sudo apt install tesseract-ocr | |
| pip install -q pytesseract | |
| ``` | |
| ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ ๋ชจ๋ ์ค์นํ ํ ๋ฐํ์์ ๋ค์ ์์ํฉ๋๋ค. | |
| ์ปค๋ฎค๋ํฐ์ ๋น์ ์ ๋ชจ๋ธ์ ๊ณต์ ํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํด์ ๋ชจ๋ธ์ ๐ค Hub์ ์ ๋ก๋ํ์ธ์. | |
| ํ๋กฌํํธ๊ฐ ์คํ๋๋ฉด, ๋ก๊ทธ์ธ์ ์ํด ํ ํฐ์ ์ ๋ ฅํ์ธ์: | |
| ```py | |
| >>> from huggingface_hub import notebook_login | |
| >>> notebook_login() | |
| ``` | |
| ๋ช ๊ฐ์ง ์ ์ญ ๋ณ์๋ฅผ ์ ์ํด ๋ณด๊ฒ ์ต๋๋ค. | |
| ```py | |
| >>> model_checkpoint = "microsoft/layoutlmv2-base-uncased" | |
| >>> batch_size = 4 | |
| ``` | |
| ## ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ [[load-the-data]] | |
| ์ด ๊ฐ์ด๋์์๋ ๐ค Hub์์ ์ฐพ์ ์ ์๋ ์ ์ฒ๋ฆฌ๋ DocVQA์ ์์ ์ํ์ ์ฌ์ฉํฉ๋๋ค. | |
| DocVQA์ ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ฌ์ฉํ๊ณ ์ถ๋ค๋ฉด, [DocVQA homepage](https://rrc.cvc.uab.es/?ch=17)์ ๊ฐ์ ํ ๋ค์ด๋ก๋ ํ ์ ์์ต๋๋ค. ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ค์ด๋ก๋ ํ๋ค๋ฉด, ์ด ๊ฐ์ด๋๋ฅผ ๊ณ์ ์งํํ๊ธฐ ์ํด [๐ค dataset์ ํ์ผ์ ๊ฐ์ ธ์ค๋ ๋ฐฉ๋ฒ](https://huggingface.co/docs/datasets/loading#local-and-remote-files)์ ํ์ธํ์ธ์. | |
| ```py | |
| >>> from datasets import load_dataset | |
| >>> dataset = load_dataset("nielsr/docvqa_1200_examples") | |
| >>> dataset | |
| DatasetDict({ | |
| train: Dataset({ | |
| features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'], | |
| num_rows: 1000 | |
| }) | |
| test: Dataset({ | |
| features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'], | |
| num_rows: 200 | |
| }) | |
| }) | |
| ``` | |
| ๋ณด์๋ค์ํผ, ๋ฐ์ดํฐ ์ธํธ๋ ์ด๋ฏธ ํ๋ จ ์ธํธ์ ํ ์คํธ ์ธํธ๋ก ๋๋์ด์ ธ ์์ต๋๋ค. ๋ฌด์์๋ก ์์ ๋ฅผ ์ดํด๋ณด๋ฉด์ ํน์ฑ์ ํ์ธํด๋ณด์ธ์. | |
| ```py | |
| >>> dataset["train"].features | |
| ``` | |
| ๊ฐ ํ๋๊ฐ ๋ํ๋ด๋ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
| * `id`: ์์ ์ id | |
| * `image`: ๋ฌธ์ ์ด๋ฏธ์ง๋ฅผ ํฌํจํ๋ PIL.Image.Image ๊ฐ์ฒด | |
| * `query`: ์ง๋ฌธ ๋ฌธ์์ด - ์ฌ๋ฌ ์ธ์ด์ ์์ฐ์ด๋ก ๋ ์ง๋ฌธ | |
| * `answers`: ์ฌ๋์ด ์ฃผ์์ ๋จ ์ ๋ต ๋ฆฌ์คํธ | |
| * `words` and `bounding_boxes`: OCR์ ๊ฒฐ๊ณผ๊ฐ๋ค์ด๋ฉฐ ์ด ๊ฐ์ด๋์์๋ ์ฌ์ฉํ์ง ์์ ์์ | |
| * `answer`: ๋ค๋ฅธ ๋ชจ๋ธ๊ณผ ์ผ์นํ๋ ๋ต๋ณ์ด๋ฉฐ ์ด ๊ฐ์ด๋์์๋ ์ฌ์ฉํ์ง ์์ ์์ | |
| ์์ด๋ก ๋ ์ง๋ฌธ๋ง ๋จ๊ธฐ๊ณ ๋ค๋ฅธ ๋ชจ๋ธ์ ๋ํ ์์ธก์ ํฌํจํ๋ `answer` ํน์ฑ์ ์ญ์ ํ๊ฒ ์ต๋๋ค. | |
| ๊ทธ๋ฆฌ๊ณ ์ฃผ์ ์์ฑ์๊ฐ ์ ๊ณตํ ๋ฐ์ดํฐ ์ธํธ์์ ์ฒซ ๋ฒ์งธ ๋ต๋ณ์ ๊ฐ์ ธ์ต๋๋ค. ๋๋ ๋ฌด์์๋ก ์ํ์ ์ถ์ถํ ์๋ ์์ต๋๋ค. | |
| ```py | |
| >>> updated_dataset = dataset.map(lambda example: {"question": example["query"]["en"]}, remove_columns=["query"]) | |
| >>> updated_dataset = updated_dataset.map( | |
| ... lambda example: {"answer": example["answers"][0]}, remove_columns=["answer", "answers"] | |
| ... ) | |
| ``` | |
| ์ด ๊ฐ์ด๋์์ ์ฌ์ฉํ๋ LayoutLMv2 ์ฒดํฌํฌ์ธํธ๋ `max_position_embeddings = 512`๋ก ํ๋ จ๋์์ต๋๋ค(์ด ์ ๋ณด๋ [์ฒดํฌํฌ์ธํธ์ `config.json` ํ์ผ](https://huggingface.co/microsoft/layoutlmv2-base-uncased/blob/main/config.json#L18)์์ ํ์ธํ ์ ์์ต๋๋ค). | |
| ๋ฐ๋ก ์์ ๋ฅผ ์๋ผ๋ผ ์๋ ์์ง๋ง, ๊ธด ๋ฌธ์์ ๋์ ๋ต๋ณ์ด ์์ด ์๋ฆฌ๋ ์ํฉ์ ํผํ๊ธฐ ์ํด ์ฌ๊ธฐ์๋ ์๋ฒ ๋ฉ์ด 512๋ณด๋ค ๊ธธ์ด์ง ๊ฐ๋ฅ์ฑ์ด ์๋ ๋ช ๊ฐ์ง ์์ ๋ฅผ ์ ๊ฑฐํ๊ฒ ์ต๋๋ค. | |
| ๋ฐ์ดํฐ ์ธํธ์ ์๋ ๋๋ถ๋ถ์ ๋ฌธ์๊ฐ ๊ธด ๊ฒฝ์ฐ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ์ ์์ต๋๋ค - ์์ธํ ๋ด์ฉ์ ํ์ธํ๊ณ ์ถ์ผ๋ฉด ์ด [๋ ธํธ๋ถ](https://github.com/huggingface/notebooks/blob/main/examples/question_answering.ipynb)์ ํ์ธํ์ธ์. | |
| ```py | |
| >>> updated_dataset = updated_dataset.filter(lambda x: len(x["words"]) + len(x["question"].split()) < 512) | |
| ``` | |
| ์ด ์์ ์์ ์ด ๋ฐ์ดํฐ ์ธํธ์ OCR ํน์ฑ๋ ์ ๊ฑฐํด ๋ณด๊ฒ ์ต๋๋ค. OCR ํน์ฑ์ ๋ค๋ฅธ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ ์ํ ๊ฒ์ผ๋ก, ์ด ๊ฐ์ด๋์์ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ์ ๋ ฅ ์๊ตฌ ์ฌํญ๊ณผ ์ผ์นํ์ง ์๊ธฐ ๋๋ฌธ์ ์ด ํน์ฑ์ ์ฌ์ฉํ๊ธฐ ์ํด์๋ ์ผ๋ถ ์ฒ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. | |
| ๋์ , ์๋ณธ ๋ฐ์ดํฐ์ [`LayoutLMv2Processor`]๋ฅผ ์ฌ์ฉํ์ฌ OCR ๋ฐ ํ ํฐํ๋ฅผ ๋ชจ๋ ์ํํ ์ ์์ต๋๋ค. | |
| ์ด๋ ๊ฒ ํ๋ฉด ๋ชจ๋ธ์ด ์๊ตฌํ๋ ์ ๋ ฅ์ ์ป์ ์ ์์ต๋๋ค. | |
| ์ด๋ฏธ์ง๋ฅผ ์๋์ผ๋ก ์ฒ๋ฆฌํ๋ ค๋ฉด, [`LayoutLMv2` model documentation](../model_doc/layoutlmv2)์์ ๋ชจ๋ธ์ด ์๊ตฌํ๋ ์ ๋ ฅ ํฌ๋งท์ ํ์ธํด๋ณด์ธ์. | |
| ```py | |
| >>> updated_dataset = updated_dataset.remove_columns("words") | |
| >>> updated_dataset = updated_dataset.remove_columns("bounding_boxes") | |
| ``` | |
| ๋ง์ง๋ง์ผ๋ก, ๋ฐ์ดํฐ ํ์์ ์๋ฃํ๊ธฐ ์ํด ์ด๋ฏธ์ง ์์๋ฅผ ์ดํด๋ด ์๋ค. | |
| ```py | |
| >>> updated_dataset["train"][11]["image"] | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/docvqa_example.jpg" alt="DocVQA Image Example"/> | |
| </div> | |
| ## ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ [[preprocess-the-data]] | |
| ๋ฌธ์ ์ง์ ์๋ต ํ์คํฌ๋ ๋ฉํฐ๋ชจ๋ฌ ํ์คํฌ์ด๋ฉฐ, ๊ฐ ๋ชจ๋ฌ๋ฆฌํฐ์ ์ ๋ ฅ์ด ๋ชจ๋ธ์ ์๊ตฌ์ ๋ง๊ฒ ์ ์ฒ๋ฆฌ ๋์๋์ง ํ์ธํด์ผ ํฉ๋๋ค. | |
| ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์๋ ์ด๋ฏธ์ง ํ๋ก์ธ์์ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ์ธ์ฝ๋ฉํ ์ ์๋ ํ ํฌ๋์ด์ ๋ฅผ ๊ฒฐํฉํ [`LayoutLMv2Processor`]๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ๋ถํฐ ์์ํด ๋ณด๊ฒ ์ต๋๋ค. | |
| ```py | |
| >>> from transformers import AutoProcessor | |
| >>> processor = AutoProcessor.from_pretrained(model_checkpoint) | |
| ``` | |
| ### ๋ฌธ์ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ [[preprocessing-document-images]] | |
| ๋จผ์ , ํ๋ก์ธ์์ `image_processor`๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ๋ํ ๋ฌธ์ ์ด๋ฏธ์ง๋ฅผ ์ค๋นํด ๋ณด๊ฒ ์ต๋๋ค. | |
| ๊ธฐ๋ณธ๊ฐ์ผ๋ก, ์ด๋ฏธ์ง ํ๋ก์ธ์๋ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ 224x224๋ก ์กฐ์ ํ๊ณ ์์ ์ฑ๋์ ์์๊ฐ ์ฌ๋ฐ๋ฅธ์ง ํ์ธํ ํ ๋จ์ด์ ์ ๊ทํ๋ ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ์ป๊ธฐ ์ํด ํ ์๋ํธ๋ฅผ ์ฌ์ฉํด OCR๋ฅผ ์ ์ฉํฉ๋๋ค. | |
| ์ด ํํ ๋ฆฌ์ผ์์ ์ฐ๋ฆฌ๊ฐ ํ์ํ ๊ฒ๊ณผ ๊ธฐ๋ณธ๊ฐ์ ์์ ํ ๋์ผํฉ๋๋ค. ์ด๋ฏธ์ง ๋ฐฐ์น์ ๊ธฐ๋ณธ ์ด๋ฏธ์ง ์ฒ๋ฆฌ๋ฅผ ์ ์ฉํ๊ณ OCR์ ๊ฒฐ๊ณผ๋ฅผ ๋ณํํ๋ ํจ์๋ฅผ ์์ฑํฉ๋๋ค. | |
| ```py | |
| >>> image_processor = processor.image_processor | |
| >>> def get_ocr_words_and_boxes(examples): | |
| ... images = [image.convert("RGB") for image in examples["image"]] | |
| ... encoded_inputs = image_processor(images) | |
| ... examples["image"] = encoded_inputs.pixel_values | |
| ... examples["words"] = encoded_inputs.words | |
| ... examples["boxes"] = encoded_inputs.boxes | |
| ... return examples | |
| ``` | |
| ์ด ์ ์ฒ๋ฆฌ๋ฅผ ๋ฐ์ดํฐ ์ธํธ ์ ์ฒด์ ๋น ๋ฅด๊ฒ ์ ์ฉํ๋ ค๋ฉด [`~datasets.Dataset.map`]๋ฅผ ์ฌ์ฉํ์ธ์. | |
| ```py | |
| >>> dataset_with_ocr = updated_dataset.map(get_ocr_words_and_boxes, batched=True, batch_size=2) | |
| ``` | |
| ### ํ ์คํธ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ [[preprocessing-text-data]] | |
| ์ด๋ฏธ์ง์ OCR์ ์ ์ฉํ์ผ๋ฉด ๋ฐ์ดํฐ ์ธํธ์ ํ ์คํธ ๋ถ๋ถ์ ๋ชจ๋ธ์ ๋ง๊ฒ ์ธ์ฝ๋ฉํด์ผ ํฉ๋๋ค. | |
| ์ด ์ธ์ฝ๋ฉ์๋ ์ด์ ๋จ๊ณ์์ ๊ฐ์ ธ์จ ๋จ์ด์ ๋ฐ์ค๋ฅผ ํ ํฐ ์์ค์ `input_ids`, `attention_mask`, `token_type_ids` ๋ฐ `bbox`๋ก ๋ณํํ๋ ์์ ์ด ํฌํจ๋ฉ๋๋ค. | |
| ํ ์คํธ๋ฅผ ์ ์ฒ๋ฆฌํ๋ ค๋ฉด ํ๋ก์ธ์์ `tokenizer`๊ฐ ํ์ํฉ๋๋ค. | |
| ```py | |
| >>> tokenizer = processor.tokenizer | |
| ``` | |
| ์์์ ์ธ๊ธํ ์ ์ฒ๋ฆฌ ์ธ์๋ ๋ชจ๋ธ์ ์ํด ๋ ์ด๋ธ์ ์ถ๊ฐํด์ผ ํฉ๋๋ค. ๐ค Transformers์ `xxxForQuestionAnswering` ๋ชจ๋ธ์ ๊ฒฝ์ฐ, ๋ ์ด๋ธ์ `start_positions`์ `end_positions`๋ก ๊ตฌ์ฑ๋๋ฉฐ ์ด๋ค ํ ํฐ์ด ๋ต๋ณ์ ์์๊ณผ ๋์ ์๋์ง๋ฅผ ๋ํ๋ ๋๋ค. | |
| ๋ ์ด๋ธ ์ถ๊ฐ๋ฅผ ์ํด์, ๋จผ์ ๋ ํฐ ๋ฆฌ์คํธ(๋จ์ด ๋ฆฌ์คํธ)์์ ํ์ ๋ฆฌ์คํธ(๋จ์ด๋ก ๋ถํ ๋ ๋ต๋ณ)์ ์ฐพ์ ์ ์๋ ํฌํผ ํจ์๋ฅผ ์ ์ํฉ๋๋ค. | |
| ์ด ํจ์๋ `words_list`์ `answer_list`, ์ด๋ ๊ฒ ๋ ๋ฆฌ์คํธ๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์ต๋๋ค. | |
| ๊ทธ๋ฐ ๋ค์ `words_list`๋ฅผ ๋ฐ๋ณตํ์ฌ `words_list`์ ํ์ฌ ๋จ์ด(words_list[i])๊ฐ `answer_list`์ ์ฒซ ๋ฒ์งธ ๋จ์ด(answer_list[0])์ ๊ฐ์์ง, | |
| ํ์ฌ ๋จ์ด์์ ์์ํด `answer_list`์ ๊ฐ์ ๊ธธ์ด๋งํผ์ `words_list`์ ํ์ ๋ฆฌ์คํธ๊ฐ `answer_list`์ ์ผ์นํ๋์ง ํ์ธํฉ๋๋ค. | |
| ์ด ์กฐ๊ฑด์ด ์ฐธ์ด๋ผ๋ฉด ์ผ์นํ๋ ํญ๋ชฉ์ ๋ฐ๊ฒฌํ์์ ์๋ฏธํ๋ฉฐ, ํจ์๋ ์ผ์น ํญ๋ชฉ, ์์ ์ธ๋ฑ์ค(idx) ๋ฐ ์ข ๋ฃ ์ธ๋ฑ์ค(idx + len(answer_list) - 1)๋ฅผ ๊ธฐ๋กํฉ๋๋ค. ์ผ์นํ๋ ํญ๋ชฉ์ด ๋ ๊ฐ ์ด์ ๋ฐ๊ฒฌ๋๋ฉด ํจ์๋ ์ฒซ ๋ฒ์งธ ํญ๋ชฉ๋ง ๋ฐํํฉ๋๋ค. ์ผ์นํ๋ ํญ๋ชฉ์ด ์๋ค๋ฉด ํจ์๋ (`None`, 0, 0)์ ๋ฐํํฉ๋๋ค. | |
| ```py | |
| >>> def subfinder(words_list, answer_list): | |
| ... matches = [] | |
| ... start_indices = [] | |
| ... end_indices = [] | |
| ... for idx, i in enumerate(range(len(words_list))): | |
| ... if words_list[i] == answer_list[0] and words_list[i : i + len(answer_list)] == answer_list: | |
| ... matches.append(answer_list) | |
| ... start_indices.append(idx) | |
| ... end_indices.append(idx + len(answer_list) - 1) | |
| ... if matches: | |
| ... return matches[0], start_indices[0], end_indices[0] | |
| ... else: | |
| ... return None, 0, 0 | |
| ``` | |
| ์ด ํจ์๊ฐ ์ด๋ป๊ฒ ์ ๋ต์ ์์น๋ฅผ ์ฐพ๋์ง ์ค๋ช ํ๊ธฐ ์ํด ๋ค์ ์์ ์์ ํจ์๋ฅผ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค: | |
| ```py | |
| >>> example = dataset_with_ocr["train"][1] | |
| >>> words = [word.lower() for word in example["words"]] | |
| >>> match, word_idx_start, word_idx_end = subfinder(words, example["answer"].lower().split()) | |
| >>> print("Question: ", example["question"]) | |
| >>> print("Words:", words) | |
| >>> print("Answer: ", example["answer"]) | |
| >>> print("start_index", word_idx_start) | |
| >>> print("end_index", word_idx_end) | |
| Question: Who is in cc in this letter? | |
| Words: ['wie', 'baw', 'brown', '&', 'williamson', 'tobacco', 'corporation', 'research', '&', 'development', 'internal', 'correspondence', 'to:', 'r.', 'h.', 'honeycutt', 'ce:', 't.f.', 'riehl', 'from:', '.', 'c.j.', 'cook', 'date:', 'may', '8,', '1995', 'subject:', 'review', 'of', 'existing', 'brainstorming', 'ideas/483', 'the', 'major', 'function', 'of', 'the', 'product', 'innovation', 'graup', 'is', 'to', 'develop', 'marketable', 'nove!', 'products', 'that', 'would', 'be', 'profitable', 'to', 'manufacture', 'and', 'sell.', 'novel', 'is', 'defined', 'as:', 'of', 'a', 'new', 'kind,', 'or', 'different', 'from', 'anything', 'seen', 'or', 'known', 'before.', 'innovation', 'is', 'defined', 'as:', 'something', 'new', 'or', 'different', 'introduced;', 'act', 'of', 'innovating;', 'introduction', 'of', 'new', 'things', 'or', 'methods.', 'the', 'products', 'may', 'incorporate', 'the', 'latest', 'technologies,', 'materials', 'and', 'know-how', 'available', 'to', 'give', 'then', 'a', 'unique', 'taste', 'or', 'look.', 'the', 'first', 'task', 'of', 'the', 'product', 'innovation', 'group', 'was', 'to', 'assemble,', 'review', 'and', 'categorize', 'a', 'list', 'of', 'existing', 'brainstorming', 'ideas.', 'ideas', 'were', 'grouped', 'into', 'two', 'major', 'categories', 'labeled', 'appearance', 'and', 'taste/aroma.', 'these', 'categories', 'are', 'used', 'for', 'novel', 'products', 'that', 'may', 'differ', 'from', 'a', 'visual', 'and/or', 'taste/aroma', 'point', 'of', 'view', 'compared', 'to', 'canventional', 'cigarettes.', 'other', 'categories', 'include', 'a', 'combination', 'of', 'the', 'above,', 'filters,', 'packaging', 'and', 'brand', 'extensions.', 'appearance', 'this', 'category', 'is', 'used', 'for', 'novel', 'cigarette', 'constructions', 'that', 'yield', 'visually', 'different', 'products', 'with', 'minimal', 'changes', 'in', 'smoke', 'chemistry', 'two', 'cigarettes', 'in', 'cne.', 'emulti-plug', 'te', 'build', 'yaur', 'awn', 'cigarette.', 'eswitchable', 'menthol', 'or', 'non', 'menthol', 'cigarette.', '*cigarettes', 'with', 'interspaced', 'perforations', 'to', 'enable', 'smoker', 'to', 'separate', 'unburned', 'section', 'for', 'future', 'smoking.', 'ยซshort', 'cigarette,', 'tobacco', 'section', '30', 'mm.', 'ยซextremely', 'fast', 'buming', 'cigarette.', 'ยซnovel', 'cigarette', 'constructions', 'that', 'permit', 'a', 'significant', 'reduction', 'iretobacco', 'weight', 'while', 'maintaining', 'smoking', 'mechanics', 'and', 'visual', 'characteristics.', 'higher', 'basis', 'weight', 'paper:', 'potential', 'reduction', 'in', 'tobacco', 'weight.', 'ยซmore', 'rigid', 'tobacco', 'column;', 'stiffing', 'agent', 'for', 'tobacco;', 'e.g.', 'starch', '*colored', 'tow', 'and', 'cigarette', 'papers;', 'seasonal', 'promotions,', 'e.g.', 'pastel', 'colored', 'cigarettes', 'for', 'easter', 'or', 'in', 'an', 'ebony', 'and', 'ivory', 'brand', 'containing', 'a', 'mixture', 'of', 'all', 'black', '(black', 'paper', 'and', 'tow)', 'and', 'ail', 'white', 'cigarettes.', '499150498'] | |
| Answer: T.F. Riehl | |
| start_index 17 | |
| end_index 18 | |
| ``` | |
| ํํธ, ์ ์์ ๊ฐ ์ธ์ฝ๋ฉ๋๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํ์๋ฉ๋๋ค: | |
| ```py | |
| >>> encoding = tokenizer(example["question"], example["words"], example["boxes"]) | |
| >>> tokenizer.decode(encoding["input_ids"]) | |
| [CLS] who is in cc in this letter? [SEP] wie baw brown & williamson tobacco corporation research & development ... | |
| ``` | |
| ์ด์ ์ธ์ฝ๋ฉ๋ ์ ๋ ฅ์์ ์ ๋ต์ ์์น๋ฅผ ์ฐพ์์ผ ํฉ๋๋ค. | |
| * `token_type_ids`๋ ์ด๋ค ํ ํฐ์ด ์ง๋ฌธ์ ์ํ๋์ง, ๊ทธ๋ฆฌ๊ณ ์ด๋ค ํ ํฐ์ด ๋ฌธ์์ ๋จ์ด์ ํฌํจ๋๋์ง๋ฅผ ์๋ ค์ค๋๋ค. | |
| * `tokenizer.cls_token_id` ์ ๋ ฅ์ ์์ ๋ถ๋ถ์ ์๋ ํน์ ํ ํฐ์ ์ฐพ๋ ๋ฐ ๋์์ ์ค๋๋ค. | |
| * `word_ids`๋ ์๋ณธ `words`์์ ์ฐพ์ ๋ต๋ณ์ ์ ์ฒด ์ธ์ฝ๋ฉ๋ ์ ๋ ฅ์ ๋์ผํ ๋ต๊ณผ ์ผ์น์ํค๊ณ ์ธ์ฝ๋ฉ๋ ์ ๋ ฅ์์ ๋ต๋ณ์ ์์/๋ ์์น๋ฅผ ๊ฒฐ์ ํฉ๋๋ค. | |
| ์ ๋ด์ฉ๋ค์ ์ผ๋์ ๋๊ณ ๋ฐ์ดํฐ ์ธํธ ์์ ์ ๋ฐฐ์น๋ฅผ ์ธ์ฝ๋ฉํ๋ ํจ์๋ฅผ ๋ง๋ค์ด ๋ณด๊ฒ ์ต๋๋ค: | |
| ```py | |
| >>> def encode_dataset(examples, max_length=512): | |
| ... questions = examples["question"] | |
| ... words = examples["words"] | |
| ... boxes = examples["boxes"] | |
| ... answers = examples["answer"] | |
| ... # ์์ ๋ฐฐ์น๋ฅผ ์ธ์ฝ๋ฉํ๊ณ start_positions์ end_positions๋ฅผ ์ด๊ธฐํํฉ๋๋ค | |
| ... encoding = tokenizer(questions, words, boxes, max_length=max_length, padding="max_length", truncation=True) | |
| ... start_positions = [] | |
| ... end_positions = [] | |
| ... # ๋ฐฐ์น์ ์์ ๋ฅผ ๋ฐ๋ณตํฉ๋๋ค | |
| ... for i in range(len(questions)): | |
| ... cls_index = encoding["input_ids"][i].index(tokenizer.cls_token_id) | |
| ... # ์์ ์ words์์ ๋ต๋ณ์ ์์น๋ฅผ ์ฐพ์ต๋๋ค | |
| ... words_example = [word.lower() for word in words[i]] | |
| ... answer = answers[i] | |
| ... match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split()) | |
| ... if match: | |
| ... # ์ผ์นํ๋ ํญ๋ชฉ์ ๋ฐ๊ฒฌํ๋ฉด, `token_type_ids`๋ฅผ ์ฌ์ฉํด ์ธ์ฝ๋ฉ์์ ๋จ์ด๊ฐ ์์ํ๋ ์์น๋ฅผ ์ฐพ์ต๋๋ค | |
| ... token_type_ids = encoding["token_type_ids"][i] | |
| ... token_start_index = 0 | |
| ... while token_type_ids[token_start_index] != 1: | |
| ... token_start_index += 1 | |
| ... token_end_index = len(encoding["input_ids"][i]) - 1 | |
| ... while token_type_ids[token_end_index] != 1: | |
| ... token_end_index -= 1 | |
| ... word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1] | |
| ... start_position = cls_index | |
| ... end_position = cls_index | |
| ... # words์ ๋ต๋ณ ์์น์ ์ผ์นํ ๋๊น์ง word_ids๋ฅผ ๋ฐ๋ณตํ๊ณ `token_start_index`๋ฅผ ๋๋ฆฝ๋๋ค | |
| ... # ์ผ์นํ๋ฉด `token_start_index`๋ฅผ ์ธ์ฝ๋ฉ์์ ๋ต๋ณ์ `start_position`์ผ๋ก ์ ์ฅํฉ๋๋ค | |
| ... for id in word_ids: | |
| ... if id == word_idx_start: | |
| ... start_position = token_start_index | |
| ... else: | |
| ... token_start_index += 1 | |
| ... # ๋น์ทํ๊ฒ, ๋์์ ์์ํด `word_ids`๋ฅผ ๋ฐ๋ณตํ๋ฉฐ ๋ต๋ณ์ `end_position`์ ์ฐพ์ต๋๋ค | |
| ... for id in word_ids[::-1]: | |
| ... if id == word_idx_end: | |
| ... end_position = token_end_index | |
| ... else: | |
| ... token_end_index -= 1 | |
| ... start_positions.append(start_position) | |
| ... end_positions.append(end_position) | |
| ... else: | |
| ... start_positions.append(cls_index) | |
| ... end_positions.append(cls_index) | |
| ... encoding["image"] = examples["image"] | |
| ... encoding["start_positions"] = start_positions | |
| ... encoding["end_positions"] = end_positions | |
| ... return encoding | |
| ``` | |
| ์ด์ ์ด ์ ์ฒ๋ฆฌ ํจ์๊ฐ ์์ผ๋ ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ธ์ฝ๋ฉํ ์ ์์ต๋๋ค: | |
| ```py | |
| >>> encoded_train_dataset = dataset_with_ocr["train"].map( | |
| ... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["train"].column_names | |
| ... ) | |
| >>> encoded_test_dataset = dataset_with_ocr["test"].map( | |
| ... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["test"].column_names | |
| ... ) | |
| ``` | |
| ์ธ์ฝ๋ฉ๋ ๋ฐ์ดํฐ ์ธํธ์ ํน์ฑ์ด ์ด๋ป๊ฒ ์๊ฒผ๋์ง ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค: | |
| ```py | |
| >>> encoded_train_dataset.features | |
| {'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None), | |
| 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), | |
| 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), | |
| 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), | |
| 'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None), | |
| 'start_positions': Value(dtype='int64', id=None), | |
| 'end_positions': Value(dtype='int64', id=None)} | |
| ``` | |
| ## ํ๊ฐ [[evaluation]] | |
| ๋ฌธ์ ์ง์ ์๋ต์ ํ๊ฐํ๋ ค๋ฉด ์๋นํ ์์ ํ์ฒ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. ์๊ฐ์ด ๋๋ฌด ๋ง์ด ๊ฑธ๋ฆฌ์ง ์๋๋ก ์ด ๊ฐ์ด๋์์๋ ํ๊ฐ ๋จ๊ณ๋ฅผ ์๋ตํฉ๋๋ค. | |
| [`Trainer`]๊ฐ ํ๋ จ ๊ณผ์ ์์ ํ๊ฐ ์์ค(evaluation loss)์ ๊ณ์ ๊ณ์ฐํ๊ธฐ ๋๋ฌธ์ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋๋ต์ ์ผ๋ก ์ ์ ์์ต๋๋ค. | |
| ์ถ์ถ์ (Extractive) ์ง์ ์๋ต์ ๋ณดํต F1/exact match ๋ฐฉ๋ฒ์ ์ฌ์ฉํด ํ๊ฐ๋ฉ๋๋ค. | |
| ์ง์ ๊ตฌํํด๋ณด๊ณ ์ถ์ผ์๋ค๋ฉด, Hugging Face course์ [Question Answering chapter](https://huggingface.co/course/chapter7/7?fw=pt#postprocessing)์ ์ฐธ๊ณ ํ์ธ์. | |
| ## ํ๋ จ [[train]] | |
| ์ถํํฉ๋๋ค! ์ด ๊ฐ์ด๋์ ๊ฐ์ฅ ์ด๋ ค์ด ๋ถ๋ถ์ ์ฑ๊ณต์ ์ผ๋ก ์ฒ๋ฆฌํ์ผ๋ ์ด์ ๋๋ง์ ๋ชจ๋ธ์ ํ๋ จํ ์ค๋น๊ฐ ๋์์ต๋๋ค. | |
| ํ๋ จ์ ๋ค์๊ณผ ๊ฐ์ ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค: | |
| * ์ ์ฒ๋ฆฌ์์์ ๋์ผํ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด [`AutoModelForDocumentQuestionAnswering`]์ผ๋ก ๋ชจ๋ธ์ ๊ฐ์ ธ์ต๋๋ค. | |
| * [`TrainingArguments`]๋ก ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ํฉ๋๋ค. | |
| * ์์ ๋ฅผ ๋ฐฐ์น ์ฒ๋ฆฌํ๋ ํจ์๋ฅผ ์ ์ํฉ๋๋ค. ์ฌ๊ธฐ์๋ [`DefaultDataCollator`]๊ฐ ์ ๋นํฉ๋๋ค. | |
| * ๋ชจ๋ธ, ๋ฐ์ดํฐ ์ธํธ, ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ(Data collator)์ ํจ๊ป [`Trainer`]์ ํ๋ จ ์ธ์๋ค์ ์ ๋ฌํฉ๋๋ค. | |
| * [`~Trainer.train`]์ ํธ์ถํด์ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค. | |
| ```py | |
| >>> from transformers import AutoModelForDocumentQuestionAnswering | |
| >>> model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint) | |
| ``` | |
| [`TrainingArguments`]์์ `output_dir`์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ์ง์ ํ๊ณ , ์ ์ ํ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ค์ ํฉ๋๋ค. | |
| ๋ชจ๋ธ์ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ๋ ค๋ฉด `push_to_hub`๋ฅผ `True`๋ก ์ค์ ํ์ธ์ (๋ชจ๋ธ์ ์ ๋ก๋ํ๋ ค๋ฉด Hugging Face์ ๋ก๊ทธ์ธํด์ผ ํฉ๋๋ค). | |
| ์ด ๊ฒฝ์ฐ `output_dir`์ ๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ํธ์ํ ๋ ํฌ์งํ ๋ฆฌ์ ์ด๋ฆ์ด ๋ฉ๋๋ค. | |
| ```py | |
| >>> from transformers import TrainingArguments | |
| >>> # ๋ณธ์ธ์ ๋ ํฌ์งํ ๋ฆฌ ID๋ก ๋ฐ๊พธ์ธ์ | |
| >>> repo_id = "MariaK/layoutlmv2-base-uncased_finetuned_docvqa" | |
| >>> training_args = TrainingArguments( | |
| ... output_dir=repo_id, | |
| ... per_device_train_batch_size=4, | |
| ... num_train_epochs=20, | |
| ... save_steps=200, | |
| ... logging_steps=50, | |
| ... eval_strategy="steps", | |
| ... learning_rate=5e-5, | |
| ... save_total_limit=2, | |
| ... remove_unused_columns=False, | |
| ... push_to_hub=True, | |
| ... ) | |
| ``` | |
| ๊ฐ๋จํ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๋ฅผ ์ ์ํ์ฌ ์์ ๋ฅผ ํจ๊ป ๋ฐฐ์นํฉ๋๋ค. | |
| ```py | |
| >>> from transformers import DefaultDataCollator | |
| >>> data_collator = DefaultDataCollator() | |
| ``` | |
| ๋ง์ง๋ง์ผ๋ก, ๋ชจ๋ ๊ฒ์ ํ ๊ณณ์ ๋ชจ์ [`~Trainer.train`]์ ํธ์ถํฉ๋๋ค: | |
| ```py | |
| >>> from transformers import Trainer | |
| >>> trainer = Trainer( | |
| ... model=model, | |
| ... args=training_args, | |
| ... data_collator=data_collator, | |
| ... train_dataset=encoded_train_dataset, | |
| ... eval_dataset=encoded_test_dataset, | |
| ... processing_class=processor, | |
| ... ) | |
| >>> trainer.train() | |
| ``` | |
| ์ต์ข ๋ชจ๋ธ์ ๐ค Hub์ ์ถ๊ฐํ๋ ค๋ฉด, ๋ชจ๋ธ ์นด๋๋ฅผ ์์ฑํ๊ณ `push_to_hub`๋ฅผ ํธ์ถํฉ๋๋ค: | |
| ```py | |
| >>> trainer.create_model_card() | |
| >>> trainer.push_to_hub() | |
| ``` | |
| ## ์ถ๋ก [[inference]] | |
| ์ด์ LayoutLMv2 ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ณ ๐ค Hub์ ์ ๋ก๋ํ์ผ๋ ์ถ๋ก ์๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. | |
| ์ถ๋ก ์ ์ํด ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ณด๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ [`Pipeline`]์ ์ฌ์ฉํ๋ ๊ฒ ์ ๋๋ค. | |
| ์๋ฅผ ๋ค์ด ๋ณด๊ฒ ์ต๋๋ค: | |
| ```py | |
| >>> example = dataset["test"][2] | |
| >>> question = example["query"]["en"] | |
| >>> image = example["image"] | |
| >>> print(question) | |
| >>> print(example["answers"]) | |
| 'Who is โpresidingโ TRRF GENERAL SESSION (PART 1)?' | |
| ['TRRF Vice President', 'lee a. waller'] | |
| ``` | |
| ๊ทธ ๋ค์, ๋ชจ๋ธ๋ก ๋ฌธ์ ์ง์ ์๋ต์ ํ๊ธฐ ์ํด ํ์ดํ๋ผ์ธ์ ์ธ์คํด์คํํ๊ณ ์ด๋ฏธ์ง + ์ง๋ฌธ ์กฐํฉ์ ์ ๋ฌํฉ๋๋ค. | |
| ```py | |
| >>> from transformers import pipeline | |
| >>> qa_pipeline = pipeline("document-question-answering", model="MariaK/layoutlmv2-base-uncased_finetuned_docvqa") | |
| >>> qa_pipeline(image, question) | |
| [{'score': 0.9949808120727539, | |
| 'answer': 'Lee A. Waller', | |
| 'start': 55, | |
| 'end': 57}] | |
| ``` | |
| ์ํ๋ค๋ฉด ํ์ดํ๋ผ์ธ์ ๊ฒฐ๊ณผ๋ฅผ ์๋์ผ๋ก ๋ณต์ ํ ์๋ ์์ต๋๋ค: | |
| 1. ์ด๋ฏธ์ง์ ์ง๋ฌธ์ ๊ฐ์ ธ์ ๋ชจ๋ธ์ ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ๋ง๊ฒ ์ค๋นํฉ๋๋ค. | |
| 2. ๋ชจ๋ธ์ ํตํด ๊ฒฐ๊ณผ ๋๋ ์ ์ฒ๋ฆฌ๋ฅผ ์ ๋ฌํฉ๋๋ค. | |
| 3. ๋ชจ๋ธ์ ์ด๋ค ํ ํฐ์ด ๋ต๋ณ์ ์์์ ์๋์ง, ์ด๋ค ํ ํฐ์ด ๋ต๋ณ์ด ๋์ ์๋์ง๋ฅผ ๋ํ๋ด๋ `start_logits`์ `end_logits`๋ฅผ ๋ฐํํฉ๋๋ค. ๋ ๋ค (batch_size, sequence_length) ํํ๋ฅผ ๊ฐ์ต๋๋ค. | |
| 4. `start_logits`์ `end_logits`์ ๋ง์ง๋ง ์ฐจ์์ ์ต๋๋ก ๋ง๋๋ ๊ฐ์ ์ฐพ์ ์์ `start_idx`์ `end_idx`๋ฅผ ์ป์ต๋๋ค. | |
| 5. ํ ํฌ๋์ด์ ๋ก ๋ต๋ณ์ ๋์ฝ๋ฉํฉ๋๋ค. | |
| ```py | |
| >>> import torch | |
| >>> from transformers import AutoProcessor | |
| >>> from transformers import AutoModelForDocumentQuestionAnswering | |
| >>> processor = AutoProcessor.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa") | |
| >>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa") | |
| >>> with torch.no_grad(): | |
| ... encoding = processor(image.convert("RGB"), question, return_tensors="pt") | |
| ... outputs = model(**encoding) | |
| ... start_logits = outputs.start_logits | |
| ... end_logits = outputs.end_logits | |
| ... predicted_start_idx = start_logits.argmax(-1).item() | |
| ... predicted_end_idx = end_logits.argmax(-1).item() | |
| >>> processor.tokenizer.decode(encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]) | |
| 'lee a. waller' | |
| ``` | |