| <!--Copyright 2022 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # ํ ํฐ ๋ถ๋ฅ[[token-classification]] | |
| [[open-in-colab]] | |
| <Youtube id="wVHdVlPScxA"/> | |
| ํ ํฐ ๋ถ๋ฅ๋ ๋ฌธ์ฅ์ ๊ฐ๋ณ ํ ํฐ์ ๋ ์ด๋ธ์ ํ ๋นํฉ๋๋ค. ๊ฐ์ฅ ์ผ๋ฐ์ ์ธ ํ ํฐ ๋ถ๋ฅ ์์ ์ค ํ๋๋ ๊ฐ์ฒด๋ช ์ธ์(Named Entity Recognition, NER)์ ๋๋ค. ๊ฐ์ฒด๋ช ์ธ์์ ๋ฌธ์ฅ์์ ์ฌ๋, ์์น ๋๋ ์กฐ์ง๊ณผ ๊ฐ์ ๊ฐ ๊ฐ์ฒด์ ๋ ์ด๋ธ์ ์ฐพ์ผ๋ ค๊ณ ์๋ํฉ๋๋ค. | |
| ์ด ๊ฐ์ด๋์์ ํ์ตํ ๋ด์ฉ์: | |
| 1. [WNUT 17](https://huggingface.co/datasets/wnut_17) ๋ฐ์ดํฐ ์ธํธ์์ [DistilBERT](https://huggingface.co/distilbert/distilbert-base-uncased)๋ฅผ ํ์ธ ํ๋ํ์ฌ ์๋ก์ด ๊ฐ์ฒด๋ฅผ ํ์งํฉ๋๋ค. | |
| 2. ์ถ๋ก ์ ์ํด ํ์ธ ํ๋ ๋ชจ๋ธ์ ์ฌ์ฉํฉ๋๋ค. | |
| <Tip> | |
| ์ด ํํ ๋ฆฌ์ผ์์ ์ค๋ช ํ๋ ์์ ์ ๋ค์ ๋ชจ๋ธ ์ํคํ ์ฒ์ ์ํด ์ง์๋ฉ๋๋ค: | |
| <!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> | |
| [ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nystrรถmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) | |
| <!--End of the generated tip--> | |
| </Tip> | |
| ์์ํ๊ธฐ ์ ์, ํ์ํ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์: | |
| ```bash | |
| pip install transformers datasets evaluate seqeval | |
| ``` | |
| Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ์ฌ ๋ชจ๋ธ์ ์ ๋ก๋ํ๊ณ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. ๋ฉ์์ง๊ฐ ํ์๋๋ฉด, ํ ํฐ์ ์ ๋ ฅํ์ฌ ๋ก๊ทธ์ธํ์ธ์: | |
| ```py | |
| >>> from huggingface_hub import notebook_login | |
| >>> notebook_login() | |
| ``` | |
| ## WNUT 17 ๋ฐ์ดํฐ ์ธํธ ๊ฐ์ ธ์ค๊ธฐ[[load-wnut-17-dataset]] | |
| ๋จผ์ ๐ค Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ WNUT 17 ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๊ฐ์ ธ์ต๋๋ค: | |
| ```py | |
| >>> from datasets import load_dataset | |
| >>> wnut = load_dataset("wnut_17") | |
| ``` | |
| ๋ค์ ์์ ๋ฅผ ์ดํด๋ณด์ธ์: | |
| ```py | |
| >>> wnut["train"][0] | |
| {'id': '0', | |
| 'ner_tags': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0], | |
| 'tokens': ['@paulwalk', 'It', "'s", 'the', 'view', 'from', 'where', 'I', "'m", 'living', 'for', 'two', 'weeks', '.', 'Empire', 'State', 'Building', '=', 'ESB', '.', 'Pretty', 'bad', 'storm', 'here', 'last', 'evening', '.'] | |
| } | |
| ``` | |
| `ner_tags`์ ๊ฐ ์ซ์๋ ๊ฐ์ฒด๋ฅผ ๋ํ๋ ๋๋ค. ์ซ์๋ฅผ ๋ ์ด๋ธ ์ด๋ฆ์ผ๋ก ๋ณํํ์ฌ ๊ฐ์ฒด๊ฐ ๋ฌด์์ธ์ง ํ์ธํฉ๋๋ค: | |
| ```py | |
| >>> label_list = wnut["train"].features[f"ner_tags"].feature.names | |
| >>> label_list | |
| [ | |
| "O", | |
| "B-corporation", | |
| "I-corporation", | |
| "B-creative-work", | |
| "I-creative-work", | |
| "B-group", | |
| "I-group", | |
| "B-location", | |
| "I-location", | |
| "B-person", | |
| "I-person", | |
| "B-product", | |
| "I-product", | |
| ] | |
| ``` | |
| ๊ฐ `ner_tag`์ ์์ ๋ถ์ ๋ฌธ์๋ ๊ฐ์ฒด์ ํ ํฐ ์์น๋ฅผ ๋ํ๋ ๋๋ค: | |
| - `B-`๋ ๊ฐ์ฒด์ ์์์ ๋ํ๋ ๋๋ค. | |
| - `I-`๋ ํ ํฐ์ด ๋์ผํ ๊ฐ์ฒด ๋ด๋ถ์ ํฌํจ๋์ด ์์์ ๋ํ๋ ๋๋ค(์๋ฅผ ๋ค์ด `State` ํ ํฐ์ `Empire State Building`์ ๊ฐ์ ๊ฐ์ฒด์ ์ผ๋ถ์ ๋๋ค). | |
| - `0`๋ ํ ํฐ์ด ์ด๋ค ๊ฐ์ฒด์๋ ํด๋นํ์ง ์์์ ๋ํ๋ ๋๋ค. | |
| ## ์ ์ฒ๋ฆฌ[[preprocess]] | |
| <Youtube id="iY2AZYdZAr0"/> | |
| ๋ค์์ผ๋ก `tokens` ํ๋๋ฅผ ์ ์ฒ๋ฆฌํ๊ธฐ ์ํด DistilBERT ํ ํฌ๋์ด์ ๋ฅผ ๊ฐ์ ธ์ต๋๋ค: | |
| ```py | |
| >>> from transformers import AutoTokenizer | |
| >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased") | |
| ``` | |
| ์์ ์์ `tokens` ํ๋๋ฅผ ๋ณด๋ฉด ์ ๋ ฅ์ด ์ด๋ฏธ ํ ํฐํ๋ ๊ฒ์ฒ๋ผ ๋ณด์ ๋๋ค. ๊ทธ๋ฌ๋ ์ค์ ๋ก ์ ๋ ฅ์ ์์ง ํ ํฐํ๋์ง ์์์ผ๋ฏ๋ก ๋จ์ด๋ฅผ ํ์ ๋จ์ด๋ก ํ ํฐํํ๊ธฐ ์ํด `is_split_into_words=True`๋ฅผ ์ค์ ํด์ผ ํฉ๋๋ค. ์์ ๋ก ํ์ธํฉ๋๋ค: | |
| ```py | |
| >>> example = wnut["train"][0] | |
| >>> tokenized_input = tokenizer(example["tokens"], is_split_into_words=True) | |
| >>> tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"]) | |
| >>> tokens | |
| ['[CLS]', '@', 'paul', '##walk', 'it', "'", 's', 'the', 'view', 'from', 'where', 'i', "'", 'm', 'living', 'for', 'two', 'weeks', '.', 'empire', 'state', 'building', '=', 'es', '##b', '.', 'pretty', 'bad', 'storm', 'here', 'last', 'evening', '.', '[SEP]'] | |
| ``` | |
| ๊ทธ๋ฌ๋ ์ด๋ก ์ธํด `[CLS]`๊ณผ `[SEP]`๋ผ๋ ํน์ ํ ํฐ์ด ์ถ๊ฐ๋๊ณ , ํ์ ๋จ์ด ํ ํฐํ๋ก ์ธํด ์ ๋ ฅ๊ณผ ๋ ์ด๋ธ ๊ฐ์ ๋ถ์ผ์น๊ฐ ๋ฐ์ํฉ๋๋ค. ํ๋์ ๋ ์ด๋ธ์ ํด๋นํ๋ ๋จ์ผ ๋จ์ด๋ ์ด์ ๋ ๊ฐ์ ํ์ ๋จ์ด๋ก ๋ถํ ๋ ์ ์์ต๋๋ค. ํ ํฐ๊ณผ ๋ ์ด๋ธ์ ๋ค์๊ณผ ๊ฐ์ด ์ฌ์ ๋ ฌํด์ผ ํฉ๋๋ค: | |
| 1. [`word_ids`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.BatchEncoding.word_ids) ๋ฉ์๋๋ก ๋ชจ๋ ํ ํฐ์ ํด๋น ๋จ์ด์ ๋งคํํฉ๋๋ค. | |
| 2. ํน์ ํ ํฐ `[CLS]`์ `[SEP]`์ `-100` ๋ ์ด๋ธ์ ํ ๋นํ์ฌ, PyTorch ์์ค ํจ์๊ฐ ํด๋น ํ ํฐ์ ๋ฌด์ํ๋๋ก ํฉ๋๋ค. | |
| 3. ์ฃผ์ด์ง ๋จ์ด์ ์ฒซ ๋ฒ์งธ ํ ํฐ์๋ง ๋ ์ด๋ธ์ ์ง์ ํฉ๋๋ค. ๊ฐ์ ๋จ์ด์ ๋ค๋ฅธ ํ์ ํ ํฐ์ `-100`์ ํ ๋นํฉ๋๋ค. | |
| ๋ค์์ ํ ํฐ๊ณผ ๋ ์ด๋ธ์ ์ฌ์ ๋ ฌํ๊ณ DistilBERT์ ์ต๋ ์ ๋ ฅ ๊ธธ์ด๋ณด๋ค ๊ธธ์ง ์๋๋ก ์ํ์ค๋ฅผ ์๋ผ๋ด๋ ํจ์๋ฅผ ๋ง๋๋ ๋ฐฉ๋ฒ์ ๋๋ค: | |
| ```py | |
| >>> def tokenize_and_align_labels(examples): | |
| ... tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True) | |
| ... labels = [] | |
| ... for i, label in enumerate(examples[f"ner_tags"]): | |
| ... word_ids = tokenized_inputs.word_ids(batch_index=i) # Map tokens to their respective word. | |
| ... previous_word_idx = None | |
| ... label_ids = [] | |
| ... for word_idx in word_ids: # Set the special tokens to -100. | |
| ... if word_idx is None: | |
| ... label_ids.append(-100) | |
| ... elif word_idx != previous_word_idx: # Only label the first token of a given word. | |
| ... label_ids.append(label[word_idx]) | |
| ... else: | |
| ... label_ids.append(-100) | |
| ... previous_word_idx = word_idx | |
| ... labels.append(label_ids) | |
| ... tokenized_inputs["labels"] = labels | |
| ... return tokenized_inputs | |
| ``` | |
| ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ์ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด, ๐ค Datasets [`~datasets.Dataset.map`] ํจ์๋ฅผ ์ฌ์ฉํ์ธ์. `batched=True`๋ก ์ค์ ํ์ฌ ๋ฐ์ดํฐ ์ธํธ์ ์ฌ๋ฌ ์์๋ฅผ ํ ๋ฒ์ ์ฒ๋ฆฌํ๋ฉด `map` ํจ์์ ์๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค: | |
| ```py | |
| >>> tokenized_wnut = wnut.map(tokenize_and_align_labels, batched=True) | |
| ``` | |
| ์ด์ [`DataCollatorWithPadding`]๋ฅผ ์ฌ์ฉํ์ฌ ์์ ๋ฐฐ์น๋ฅผ ๋ง๋ค์ด๋ด ์๋ค. ๋ฐ์ดํฐ ์ธํธ ์ ์ฒด๋ฅผ ์ต๋ ๊ธธ์ด๋ก ํจ๋ฉํ๋ ๋์ , *๋์ ํจ๋ฉ*์ ์ฌ์ฉํ์ฌ ๋ฐฐ์น์์ ๊ฐ์ฅ ๊ธด ๊ธธ์ด์ ๋ง๊ฒ ๋ฌธ์ฅ์ ํจ๋ฉํ๋ ๊ฒ์ด ํจ์จ์ ์ ๋๋ค. | |
| <frameworkcontent> | |
| <pt> | |
| ```py | |
| >>> from transformers import DataCollatorForTokenClassification | |
| >>> data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) | |
| ``` | |
| </pt> | |
| <tf> | |
| ```py | |
| >>> from transformers import DataCollatorForTokenClassification | |
| >>> data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors="tf") | |
| ``` | |
| </tf> | |
| </frameworkcontent> | |
| ## ํ๊ฐ[[evaluation]] | |
| ํ๋ จ ์ค ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํ๊ฐํ๊ธฐ ์ํด ํ๊ฐ ์งํ๋ฅผ ํฌํจํ๋ ๊ฒ์ด ์ ์ฉํฉ๋๋ค. ๐ค [Evaluate](https://huggingface.co/docs/evaluate/index) ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ๋น ๋ฅด๊ฒ ํ๊ฐ ๋ฐฉ๋ฒ์ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค. ์ด ์์ ์์๋ [seqeval](https://huggingface.co/spaces/evaluate-metric/seqeval) ํ๊ฐ ์งํ๋ฅผ ๊ฐ์ ธ์ต๋๋ค. (ํ๊ฐ ์งํ๋ฅผ ๊ฐ์ ธ์ค๊ณ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ๋ํด์๋ ๐ค Evaluate [๋น ๋ฅธ ๋๋ฌ๋ณด๊ธฐ](https://huggingface.co/docs/evaluate/a_quick_tour)๋ฅผ ์ฐธ์กฐํ์ธ์). Seqeval์ ์ค์ ๋ก ์ ๋ฐ๋, ์ฌํ๋ฅ , F1 ๋ฐ ์ ํ๋์ ๊ฐ์ ์ฌ๋ฌ ์ ์๋ฅผ ์ฐ์ถํฉ๋๋ค. | |
| ```py | |
| >>> import evaluate | |
| >>> seqeval = evaluate.load("seqeval") | |
| ``` | |
| ๋จผ์ NER ๋ ์ด๋ธ์ ๊ฐ์ ธ์จ ๋ค์, [`~evaluate.EvaluationModule.compute`]์ ์ค์ ์์ธก๊ณผ ์ค์ ๋ ์ด๋ธ์ ์ ๋ฌํ์ฌ ์ ์๋ฅผ ๊ณ์ฐํ๋ ํจ์๋ฅผ ๋ง๋ญ๋๋ค: | |
| ```py | |
| >>> import numpy as np | |
| >>> labels = [label_list[i] for i in example[f"ner_tags"]] | |
| >>> def compute_metrics(p): | |
| ... predictions, labels = p | |
| ... predictions = np.argmax(predictions, axis=2) | |
| ... true_predictions = [ | |
| ... [label_list[p] for (p, l) in zip(prediction, label) if l != -100] | |
| ... for prediction, label in zip(predictions, labels) | |
| ... ] | |
| ... true_labels = [ | |
| ... [label_list[l] for (p, l) in zip(prediction, label) if l != -100] | |
| ... for prediction, label in zip(predictions, labels) | |
| ... ] | |
| ... results = seqeval.compute(predictions=true_predictions, references=true_labels) | |
| ... return { | |
| ... "precision": results["overall_precision"], | |
| ... "recall": results["overall_recall"], | |
| ... "f1": results["overall_f1"], | |
| ... "accuracy": results["overall_accuracy"], | |
| ... } | |
| ``` | |
| ์ด์ `compute_metrics` ํจ์๋ฅผ ์ฌ์ฉํ ์ค๋น๊ฐ ๋์์ผ๋ฉฐ, ํ๋ จ์ ์ค์ ํ๋ฉด ์ด ํจ์๋ก ๋๋์์ฌ ๊ฒ์ ๋๋ค. | |
| ## ํ๋ จ[[train]] | |
| ๋ชจ๋ธ์ ํ๋ จํ๊ธฐ ์ ์, `id2label`์ `label2id`๋ฅผ ์ฌ์ฉํ์ฌ ์์๋๋ id์ ๋ ์ด๋ธ์ ๋งต์ ์์ฑํ์ธ์: | |
| ```py | |
| >>> id2label = { | |
| ... 0: "O", | |
| ... 1: "B-corporation", | |
| ... 2: "I-corporation", | |
| ... 3: "B-creative-work", | |
| ... 4: "I-creative-work", | |
| ... 5: "B-group", | |
| ... 6: "I-group", | |
| ... 7: "B-location", | |
| ... 8: "I-location", | |
| ... 9: "B-person", | |
| ... 10: "I-person", | |
| ... 11: "B-product", | |
| ... 12: "I-product", | |
| ... } | |
| >>> label2id = { | |
| ... "O": 0, | |
| ... "B-corporation": 1, | |
| ... "I-corporation": 2, | |
| ... "B-creative-work": 3, | |
| ... "I-creative-work": 4, | |
| ... "B-group": 5, | |
| ... "I-group": 6, | |
| ... "B-location": 7, | |
| ... "I-location": 8, | |
| ... "B-person": 9, | |
| ... "I-person": 10, | |
| ... "B-product": 11, | |
| ... "I-product": 12, | |
| ... } | |
| ``` | |
| <frameworkcontent> | |
| <pt> | |
| <Tip> | |
| [`Trainer`]๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ์ธ ํ๋ํ๋ ๋ฐฉ๋ฒ์ ์ต์ํ์ง ์์ ๊ฒฝ์ฐ, [์ฌ๊ธฐ](../training#train-with-pytorch-trainer)์์ ๊ธฐ๋ณธ ํํ ๋ฆฌ์ผ์ ํ์ธํ์ธ์! | |
| </Tip> | |
| ์ด์ ๋ชจ๋ธ์ ํ๋ จ์ํฌ ์ค๋น๊ฐ ๋์์ต๋๋ค! [`AutoModelForSequenceClassification`]๋ก DistilBERT๋ฅผ ๊ฐ์ ธ์ค๊ณ ์์๋๋ ๋ ์ด๋ธ ์์ ๋ ์ด๋ธ ๋งคํ์ ์ง์ ํ์ธ์: | |
| ```py | |
| >>> from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer | |
| >>> model = AutoModelForTokenClassification.from_pretrained( | |
| ... "distilbert/distilbert-base-uncased", num_labels=13, id2label=id2label, label2id=label2id | |
| ... ) | |
| ``` | |
| ์ด์ ์ธ ๋จ๊ณ๋ง ๊ฑฐ์น๋ฉด ๋์ ๋๋ค: | |
| 1. [`TrainingArguments`]์์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ํ์ธ์. `output_dir`๋ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ์ง์ ํ๋ ์ ์ผํ ๋งค๊ฐ๋ณ์์ ๋๋ค. ์ด ๋ชจ๋ธ์ ํ๋ธ์ ์ ๋ก๋ํ๊ธฐ ์ํด `push_to_hub=True`๋ฅผ ์ค์ ํฉ๋๋ค(๋ชจ๋ธ์ ์ ๋ก๋ํ๊ธฐ ์ํด Hugging Face์ ๋ก๊ทธ์ธํด์ผํฉ๋๋ค.) ๊ฐ ์ํญ์ด ๋๋ ๋๋ง๋ค, [`Trainer`]๋ seqeval ์ ์๋ฅผ ํ๊ฐํ๊ณ ํ๋ จ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค. | |
| 2. [`Trainer`]์ ํ๋ จ ์ธ์์ ๋ชจ๋ธ, ๋ฐ์ดํฐ ์ธํธ, ํ ํฌ๋์ด์ , ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ ๋ฐ `compute_metrics` ํจ์๋ฅผ ์ ๋ฌํ์ธ์. | |
| 3. [`~Trainer.train`]๋ฅผ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ํ์ธ ํ๋ํ์ธ์. | |
| ```py | |
| >>> training_args = TrainingArguments( | |
| ... output_dir="my_awesome_wnut_model", | |
| ... learning_rate=2e-5, | |
| ... per_device_train_batch_size=16, | |
| ... per_device_eval_batch_size=16, | |
| ... num_train_epochs=2, | |
| ... weight_decay=0.01, | |
| ... evaluation_strategy="epoch", | |
| ... save_strategy="epoch", | |
| ... load_best_model_at_end=True, | |
| ... push_to_hub=True, | |
| ... ) | |
| >>> trainer = Trainer( | |
| ... model=model, | |
| ... args=training_args, | |
| ... train_dataset=tokenized_wnut["train"], | |
| ... eval_dataset=tokenized_wnut["test"], | |
| ... tokenizer=tokenizer, | |
| ... data_collator=data_collator, | |
| ... compute_metrics=compute_metrics, | |
| ... ) | |
| >>> trainer.train() | |
| ``` | |
| ํ๋ จ์ด ์๋ฃ๋๋ฉด, [`~transformers.Trainer.push_to_hub`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ๋ธ์ ๊ณต์ ํ ์ ์์ต๋๋ค. | |
| ```py | |
| >>> trainer.push_to_hub() | |
| ``` | |
| </pt> | |
| <tf> | |
| <Tip> | |
| Keras๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ์ธ ํ๋ํ๋ ๋ฐฉ๋ฒ์ ์ต์ํ์ง ์์ ๊ฒฝ์ฐ, [์ฌ๊ธฐ](../training#train-a-tensorflow-model-with-keras)์ ๊ธฐ๋ณธ ํํ ๋ฆฌ์ผ์ ํ์ธํ์ธ์! | |
| </Tip> | |
| TensorFlow์์ ๋ชจ๋ธ์ ํ์ธ ํ๋ํ๋ ค๋ฉด, ๋จผ์ ์ตํฐ๋ง์ด์ ํจ์์ ํ์ต๋ฅ ์ค์ผ์ฅด, ๊ทธ๋ฆฌ๊ณ ์ผ๋ถ ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ค์ ํด์ผ ํฉ๋๋ค: | |
| ```py | |
| >>> from transformers import create_optimizer | |
| >>> batch_size = 16 | |
| >>> num_train_epochs = 3 | |
| >>> num_train_steps = (len(tokenized_wnut["train"]) // batch_size) * num_train_epochs | |
| >>> optimizer, lr_schedule = create_optimizer( | |
| ... init_lr=2e-5, | |
| ... num_train_steps=num_train_steps, | |
| ... weight_decay_rate=0.01, | |
| ... num_warmup_steps=0, | |
| ... ) | |
| ``` | |
| ๊ทธ๋ฐ ๋ค์ [`TFAutoModelForSequenceClassification`]์ ์ฌ์ฉํ์ฌ DistilBERT๋ฅผ ๊ฐ์ ธ์ค๊ณ , ์์๋๋ ๋ ์ด๋ธ ์์ ๋ ์ด๋ธ ๋งคํ์ ์ง์ ํฉ๋๋ค: | |
| ```py | |
| >>> from transformers import TFAutoModelForTokenClassification | |
| >>> model = TFAutoModelForTokenClassification.from_pretrained( | |
| ... "distilbert/distilbert-base-uncased", num_labels=13, id2label=id2label, label2id=label2id | |
| ... ) | |
| ``` | |
| [`~transformers.TFPreTrainedModel.prepare_tf_dataset`]์ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ๋ฅผ `tf.data.Dataset` ํ์์ผ๋ก ๋ณํํฉ๋๋ค: | |
| ```py | |
| >>> tf_train_set = model.prepare_tf_dataset( | |
| ... tokenized_wnut["train"], | |
| ... shuffle=True, | |
| ... batch_size=16, | |
| ... collate_fn=data_collator, | |
| ... ) | |
| >>> tf_validation_set = model.prepare_tf_dataset( | |
| ... tokenized_wnut["validation"], | |
| ... shuffle=False, | |
| ... batch_size=16, | |
| ... collate_fn=data_collator, | |
| ... ) | |
| ``` | |
| [`compile`](https://keras.io/api/models/model_training_apis/#compile-method)๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จํ ๋ชจ๋ธ์ ๊ตฌ์ฑํฉ๋๋ค: | |
| ```py | |
| >>> import tensorflow as tf | |
| >>> model.compile(optimizer=optimizer) | |
| ``` | |
| ํ๋ จ์ ์์ํ๊ธฐ ์ ์ ์ค์ ํด์ผํ ๋ง์ง๋ง ๋ ๊ฐ์ง๋ ์์ธก์์ seqeval ์ ์๋ฅผ ๊ณ์ฐํ๊ณ , ๋ชจ๋ธ์ ํ๋ธ์ ์ ๋ก๋ํ ๋ฐฉ๋ฒ์ ์ ๊ณตํ๋ ๊ฒ์ ๋๋ค. ๋ชจ๋ [Keras callbacks](../main_classes/keras_callbacks)๋ฅผ ์ฌ์ฉํ์ฌ ์ํ๋ฉ๋๋ค. | |
| [`~transformers.KerasMetricCallback`]์ `compute_metrics` ํจ์๋ฅผ ์ ๋ฌํ์ธ์: | |
| ```py | |
| >>> from transformers.keras_callbacks import KerasMetricCallback | |
| >>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_validation_set) | |
| ``` | |
| [`~transformers.PushToHubCallback`]์์ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ์ ๋ก๋ํ ์์น๋ฅผ ์ง์ ํฉ๋๋ค: | |
| ```py | |
| >>> from transformers.keras_callbacks import PushToHubCallback | |
| >>> push_to_hub_callback = PushToHubCallback( | |
| ... output_dir="my_awesome_wnut_model", | |
| ... tokenizer=tokenizer, | |
| ... ) | |
| ``` | |
| ๊ทธ๋ฐ ๋ค์ ์ฝ๋ฐฑ์ ํจ๊ป ๋ฌถ์ต๋๋ค: | |
| ```py | |
| >>> callbacks = [metric_callback, push_to_hub_callback] | |
| ``` | |
| ๋๋์ด, ๋ชจ๋ธ ํ๋ จ์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค! [`fit`](https://keras.io/api/models/model_training_apis/#fit-method)์ ํ๋ จ ๋ฐ์ดํฐ ์ธํธ, ๊ฒ์ฆ ๋ฐ์ดํฐ ์ธํธ, ์ํญ์ ์ ๋ฐ ์ฝ๋ฐฑ์ ์ ๋ฌํ์ฌ ํ์ธ ํ๋ํฉ๋๋ค: | |
| ```py | |
| >>> model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=3, callbacks=callbacks) | |
| ``` | |
| ํ๋ จ์ด ์๋ฃ๋๋ฉด, ๋ชจ๋ธ์ด ์๋์ผ๋ก ํ๋ธ์ ์ ๋ก๋๋์ด ๋๊ตฌ๋ ์ฌ์ฉํ ์ ์์ต๋๋ค! | |
| </tf> | |
| </frameworkcontent> | |
| <Tip> | |
| ํ ํฐ ๋ถ๋ฅ๋ฅผ ์ํ ๋ชจ๋ธ์ ํ์ธ ํ๋ํ๋ ์์ธํ ์์ ๋ ๋ค์ | |
| [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification.ipynb) | |
| ๋๋ [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification-tf.ipynb)๋ฅผ ์ฐธ์กฐํ์ธ์. | |
| </Tip> | |
| ## ์ถ๋ก [[inference]] | |
| ์ข์์, ์ด์ ๋ชจ๋ธ์ ํ์ธ ํ๋ํ์ผ๋ ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค! | |
| ์ถ๋ก ์ ์ํํ๊ณ ์ ํ๋ ํ ์คํธ๋ฅผ ๊ฐ์ ธ์๋ด ์๋ค: | |
| ```py | |
| >>> text = "The Golden State Warriors are an American professional basketball team based in San Francisco." | |
| ``` | |
| ํ์ธ ํ๋๋ ๋ชจ๋ธ๋ก ์ถ๋ก ์ ์๋ํ๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ [`pipeline`]๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. ๋ชจ๋ธ๋ก NER์ `pipeline`์ ์ธ์คํด์คํํ๊ณ , ํ ์คํธ๋ฅผ ์ ๋ฌํด๋ณด์ธ์: | |
| ```py | |
| >>> from transformers import pipeline | |
| >>> classifier = pipeline("ner", model="stevhliu/my_awesome_wnut_model") | |
| >>> classifier(text) | |
| [{'entity': 'B-location', | |
| 'score': 0.42658573, | |
| 'index': 2, | |
| 'word': 'golden', | |
| 'start': 4, | |
| 'end': 10}, | |
| {'entity': 'I-location', | |
| 'score': 0.35856336, | |
| 'index': 3, | |
| 'word': 'state', | |
| 'start': 11, | |
| 'end': 16}, | |
| {'entity': 'B-group', | |
| 'score': 0.3064001, | |
| 'index': 4, | |
| 'word': 'warriors', | |
| 'start': 17, | |
| 'end': 25}, | |
| {'entity': 'B-location', | |
| 'score': 0.65523505, | |
| 'index': 13, | |
| 'word': 'san', | |
| 'start': 80, | |
| 'end': 83}, | |
| {'entity': 'B-location', | |
| 'score': 0.4668663, | |
| 'index': 14, | |
| 'word': 'francisco', | |
| 'start': 84, | |
| 'end': 93}] | |
| ``` | |
| ์ํ๋ค๋ฉด, `pipeline`์ ๊ฒฐ๊ณผ๋ฅผ ์๋์ผ๋ก ๋ณต์ ํ ์๋ ์์ต๋๋ค: | |
| <frameworkcontent> | |
| <pt> | |
| ํ ์คํธ๋ฅผ ํ ํฐํํ๊ณ PyTorch ํ ์๋ฅผ ๋ฐํํฉ๋๋ค: | |
| ```py | |
| >>> from transformers import AutoTokenizer | |
| >>> tokenizer = AutoTokenizer.from_pretrained("stevhliu/my_awesome_wnut_model") | |
| >>> inputs = tokenizer(text, return_tensors="pt") | |
| ``` | |
| ์ ๋ ฅ์ ๋ชจ๋ธ์ ์ ๋ฌํ๊ณ `logits`์ ๋ฐํํฉ๋๋ค: | |
| ```py | |
| >>> from transformers import AutoModelForTokenClassification | |
| >>> model = AutoModelForTokenClassification.from_pretrained("stevhliu/my_awesome_wnut_model") | |
| >>> with torch.no_grad(): | |
| ... logits = model(**inputs).logits | |
| ``` | |
| ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง ํด๋์ค๋ฅผ ๋ชจ๋ธ์ `id2label` ๋งคํ์ ์ฌ์ฉํ์ฌ ํ ์คํธ ๋ ์ด๋ธ๋ก ๋ณํํฉ๋๋ค: | |
| ```py | |
| >>> predictions = torch.argmax(logits, dim=2) | |
| >>> predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]] | |
| >>> predicted_token_class | |
| ['O', | |
| 'O', | |
| 'B-location', | |
| 'I-location', | |
| 'B-group', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'B-location', | |
| 'B-location', | |
| 'O', | |
| 'O'] | |
| ``` | |
| </pt> | |
| <tf> | |
| ํ ์คํธ๋ฅผ ํ ํฐํํ๊ณ TensorFlow ํ ์๋ฅผ ๋ฐํํฉ๋๋ค: | |
| ```py | |
| >>> from transformers import AutoTokenizer | |
| >>> tokenizer = AutoTokenizer.from_pretrained("stevhliu/my_awesome_wnut_model") | |
| >>> inputs = tokenizer(text, return_tensors="tf") | |
| ``` | |
| ์ ๋ ฅ๊ฐ์ ๋ชจ๋ธ์ ์ ๋ฌํ๊ณ `logits`์ ๋ฐํํฉ๋๋ค: | |
| ```py | |
| >>> from transformers import TFAutoModelForTokenClassification | |
| >>> model = TFAutoModelForTokenClassification.from_pretrained("stevhliu/my_awesome_wnut_model") | |
| >>> logits = model(**inputs).logits | |
| ``` | |
| ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง ํด๋์ค๋ฅผ ๋ชจ๋ธ์ `id2label` ๋งคํ์ ์ฌ์ฉํ์ฌ ํ ์คํธ ๋ ์ด๋ธ๋ก ๋ณํํฉ๋๋ค: | |
| ```py | |
| >>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1) | |
| >>> predicted_token_class = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()] | |
| >>> predicted_token_class | |
| ['O', | |
| 'O', | |
| 'B-location', | |
| 'I-location', | |
| 'B-group', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'O', | |
| 'B-location', | |
| 'B-location', | |
| 'O', | |
| 'O'] | |
| ``` | |
| </tf> | |
| </frameworkcontent> | |