BART[[bart]]
๊ฐ์ [[overview]]
Bart ๋ชจ๋ธ์ 2019๋ 10์ 29์ผ Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, Luke Zettlemoyer๊ฐ ๋ฐํํ BART: ์์ฐ์ด ์์ฑ, ๋ฒ์ญ, ์ดํด๋ฅผ ์ํ ์ก์ ์ ๊ฑฐ seq2seq ์ฌ์ ํ๋ จ์ด๋ผ๋ ๋ ผ๋ฌธ์์ ์๊ฐ๋์์ต๋๋ค.
๋ ผ๋ฌธ์ ์ด๋ก์ ๋ฐ๋ฅด๋ฉด,
- Bart๋ ์๋ฐฉํฅ ์ธ์ฝ๋(BERT์ ์ ์ฌ)์ ์ผ์ชฝ์์ ์ค๋ฅธ์ชฝ์ผ๋ก ๋์ฝ๋ฉํ๋ ๋์ฝ๋(GPT์ ์ ์ฌ)๋ฅผ ์ฌ์ฉํ๋ ํ์ค seq2seq/๊ธฐ๊ณ ๋ฒ์ญ ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํฉ๋๋ค.
- ์ฌ์ ํ๋ จ ์์ ์ ์๋ ๋ฌธ์ฅ์ ์์๋ฅผ ๋ฌด์์๋ก ์๊ณ , ํ ์คํธ์ ์ผ๋ถ ๊ตฌ๊ฐ์ ๋จ์ผ ๋ง์คํฌ ํ ํฐ์ผ๋ก ๋์ฒดํ๋ ์๋ก์ด ์ธํ๋ง(in-filling) ๋ฐฉ์์ ํฌํจํฉ๋๋ค.
- BART๋ ํนํ ํ ์คํธ ์์ฑ์ ์ํ ๋ฏธ์ธ ์กฐ์ ์ ํจ๊ณผ์ ์ด์ง๋ง ์ดํด ์์ ์๋ ์ ์๋ํฉ๋๋ค. GLUE์ SQuAD์์ ๋น์ทํ ํ๋ จ ๋ฆฌ์์ค๋ก RoBERTa์ ์ฑ๋ฅ๊ณผ ์ผ์นํ๋ฉฐ, ์ถ์์ ๋ํ, ์ง์์๋ต, ์์ฝ ์์ ๋ฑ์์ ์ต๋ 6 ROUGE ์ ์์ ํฅ์์ ๋ณด์ด๋ฉฐ ์๋ก์ด ์ต๊ณ ์ฑ๋ฅ์ ๋ฌ์ฑํ์ต๋๋ค.
์ด ๋ชจ๋ธ์ sshleifer์ ์ํด ๊ธฐ์ฌ ๋์์ต๋๋ค. ์ ์์ ์ฝ๋๋ ์ด๊ณณ์์ ํ์ธํ ์ ์์ต๋๋ค.
์ฌ์ฉ ํ:[[usage-tips]]
- BART๋ ์ ๋ ์์น ์๋ฒ ๋ฉ์ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ด๋ฏ๋ก ์ผ๋ฐ์ ์ผ๋ก ์ ๋ ฅ์ ์ผ์ชฝ๋ณด๋ค๋ ์ค๋ฅธ์ชฝ์ ํจ๋ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
- ์ธ์ฝ๋์ ๋์ฝ๋๊ฐ ์๋ seq2seq ๋ชจ๋ธ์ ๋๋ค. ์ธ์ฝ๋์๋ ์์๋ ํ ํฐ์ด(corrupted tokens) ์ ๋ ฅ๋๊ณ , ๋์ฝ๋์๋ ์๋ ํ ํฐ์ด ์ ๋ ฅ๋ฉ๋๋ค(๋จ, ์ผ๋ฐ์ ์ธ ํธ๋์คํฌ๋จธ ๋์ฝ๋์ฒ๋ผ ๋ฏธ๋ ๋จ์ด๋ฅผ ์จ๊ธฐ๋ ๋ง์คํฌ๊ฐ ์์ต๋๋ค). ์ฌ์ ํ๋ จ ์์ ์์ ์ธ์ฝ๋์ ์ ์ฉ๋๋ ๋ณํ๋ค์ ๊ตฌ์ฑ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
์ฌ์ ํ๋ จ ์์ ์์ ์ธ์ฝ๋์ ์ ์ฉ๋๋ ๋ณํ๋ค์ ๊ตฌ์ฑ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ๋ฌด์์ ํ ํฐ ๋ง์คํน (BERT ์ฒ๋ผ)
- ๋ฌด์์ ํ ํฐ ์ญ์
- k๊ฐ ํ ํฐ์ ๋ฒ์๋ฅผ ๋จ์ผ ๋ง์คํฌ ํ ํฐ์ผ๋ก ๋ง์คํน (0๊ฐ ํ ํฐ์ ๋ฒ์๋ ๋ง์คํฌ ํ ํฐ์ ์ฝ์ ์ ์๋ฏธ)
- ๋ฌธ์ฅ ์์ ๋ค์๊ธฐ
- ํน์ ํ ํฐ์์ ์์ํ๋๋ก ๋ฌธ์ ํ์
๊ตฌํ ๋ ธํธ[[implementation-notes]]
- Bart๋ ์ํ์ค ๋ถ๋ฅ์
token_type_ids๋ฅผ ์ฌ์ฉํ์ง ์์ต๋๋ค. ์ ์ ํ๊ฒ ๋๋๊ธฐ ์ํด์ [BartTokenizer]๋ [~BartTokenizer.encode]๋ฅผ ์ฌ์ฉํฉ๋๋ค. - [
BartModel]์ ์ ๋ฐฉํฅ ์ ๋ฌ์decoder_input_ids๊ฐ ์ ๋ฌ๋์ง ์์ผ๋ฉดdecoder_input_ids๋ฅผ ์๋์ผ๋ก ์์ฑํ ๊ฒ์ ๋๋ค. ์ด๋ ๋ค๋ฅธ ์ผ๋ถ ๋ชจ๋ธ๋ง API์ ๋ค๋ฅธ ์ ์ ๋๋ค. ์ด ๊ธฐ๋ฅ์ ์ผ๋ฐ์ ์ธ ์ฌ์ฉ ์ฌ๋ก๋ ๋ง์คํฌ ์ฑ์ฐ๊ธฐ(mask filling)์ ๋๋ค. - ๋ชจ๋ธ ์์ธก์
forced_bos_token_id=0์ผ ๋ ๊ธฐ์กด ๊ตฌํ๊ณผ ๋์ผํ๊ฒ ์๋ํ๋๋ก ์๋๋์์ต๋๋ค. ํ์ง๋ง, [fairseq.encode]์ ์ ๋ฌํ๋ ๋ฌธ์์ด์ด ๊ณต๋ฐฑ์ผ๋ก ์์ํ ๋๋ง ์ด ๊ธฐ๋ฅ์ด ์๋ํฉ๋๋ค. - [
~generation.GenerationMixin.generate]๋ ์์ฝ๊ณผ ๊ฐ์ ์กฐ๊ฑด๋ถ ์์ฑ ์์ ์ ์ฌ์ฉ๋์ด์ผ ํฉ๋๋ค. ์์ธํ ๋ด์ฉ์ ํด๋น ๋ฌธ์์ ์์ ๋ฅผ ์ฐธ์กฐํ์ธ์. - facebook/bart-large-cnn ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๋ ๋ชจ๋ธ์
mask_token_id๊ฐ ์๊ฑฐ๋, ๋ง์คํฌ ์ฑ์ฐ๊ธฐ ์์ ์ ์ํํ ์ ์์ต๋๋ค.
๋ง์คํฌ ์ฑ์ฐ๊ธฐ[[mask-filling]]
facebook/bart-base์ facebook/bart-large ์ฒดํฌํฌ์ธํธ๋ ๋ฉํฐ ํ ํฐ ๋ง์คํฌ๋ฅผ ์ฑ์ฐ๋๋ฐ ์ฌ์ฉ๋ ์ ์์ต๋๋ค.
from transformers import BartForConditionalGeneration, BartTokenizer
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
tok = BartTokenizer.from_pretrained("facebook/bart-large")
example_english_phrase = "UN Chief Says There Is No <mask> in Syria"
batch = tok(example_english_phrase, return_tensors="pt")
generated_ids = model.generate(batch["input_ids"])
assert tok.batch_decode(generated_ids, skip_special_tokens=True) == [
"UN Chief Says There Is No Plan to Stop Chemical Weapons in Syria"
]
์๋ฃ[[resources]]
BART๋ฅผ ์์ํ๋ ๋ฐ ๋์์ด ๋๋ Hugging Face์ community ์๋ฃ ๋ชฉ๋ก(๐๋ก ํ์๋จ) ์ ๋๋ค. ์ฌ๊ธฐ์ ํฌํจ๋ ์๋ฃ๋ฅผ ์ ์ถํ๊ณ ์ถ์ผ์๋ค๋ฉด PR(Pull Request)๋ฅผ ์ด์ด์ฃผ์ธ์. ๋ฆฌ๋ทฐ ํด๋๋ฆฌ๊ฒ ์ต๋๋ค! ์๋ฃ๋ ๊ธฐ์กด ์๋ฃ๋ฅผ ๋ณต์ ํ๋ ๋์ ์๋ก์ด ๋ด์ฉ์ ๋ด๊ณ ์์ด์ผ ํฉ๋๋ค.
- ๋ถ์ฐํ ํ์ต: ๐ค Transformers์ Amazon SageMaker๋ฅผ ์ด์ฉํ์ฌ ์์ฝํ๊ธฐ ์ํ BART/T5 ํ์ต์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
- blurr๋ฅผ ์ด์ฉํ์ฌ fastai๋ก ์์ฝํ๊ธฐ ์ํ BART๋ฅผ ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ ธํธ๋ถ. ๐
- Trainer ํด๋์ค๋ฅผ ์ฌ์ฉํ์ฌ ๋ ๊ฐ์ง ์ธ์ด๋ก ์์ฝํ๊ธฐ ์ํ BART ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ ธํธ๋ถ. ๐
- ์ด ์์ ์คํฌ๋ฆฝํธ์ ๋
ธํธ๋ถ์์๋ [
BartForConditionalGeneration]์ด ์ง์๋ฉ๋๋ค. - [
TFBartForConditionalGeneration]๋ ์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - ์ด ์์ ์คํฌ๋ฆฝํธ์์๋[
FlaxBartForConditionalGeneration]์ด ์ง์๋ฉ๋๋ค. - Hugging Face
datasets๊ฐ์ฒด๋ฅผ ํ์ฉํ์ฌ [BartForConditionalGeneration]์ ํ์ต์ํค๋ ๋ฐฉ๋ฒ์ ์๋ ์ด ํฌ๋ผ ํ ๋ก ์์ ์ฐพ์ ์ ์์ต๋๋ค. - ๐ค Hugging Face ์ฝ์ค์ ์์ฝ์ฅ.
- ์์ฝ ์์ ๊ฐ์ด๋
- [
BartForConditionalGeneration]๋ ์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์ ์ฐธ๊ณ ํ์ธ์. - [
TFBartForConditionalGeneration]๋ ์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์ ์ฐธ๊ณ ํ์ธ์. - [
FlaxBartForConditionalGeneration]๋ ์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์ ์ฐธ๊ณ ํ์ธ์. - ๐ค Hugging Face ์ฝ์ค์ ๋ง์คํฌ ์ธ์ด ๋ชจ๋ธ๋ง ์ฑํฐ.
- ๋ง์คํฌ ์ธ์ด ๋ชจ๋ธ๋ง ์์ ๊ฐ์ด๋
- Seq2SeqTrainer๋ฅผ ์ด์ฉํ์ฌ ์ธ๋์ด๋ฅผ ์์ด๋ก ๋ฒ์ญํ๋ mBART๋ฅผ ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๊ฐ์ด๋. ๐
- [
BartForConditionalGeneration]๋ ์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์ ์ฐธ๊ณ ํ์ธ์. - [
TFBartForConditionalGeneration]๋ ์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์ ์ฐธ๊ณ ํ์ธ์. - ๋ฒ์ญ ์์ ๊ฐ์ด๋
์ถ๊ฐ์ ์ผ๋ก ๋ณผ ๊ฒ๋ค:
- ํ ์คํธ ๋ถ๋ฅ ์์ ๊ฐ์ด๋
- ์ง๋ฌธ ๋ต๋ณ ์์ ๊ฐ์ด๋
- ์ธ๊ณผ์ ์ธ์ด ๋ชจ๋ธ๋ง ์์ ๊ฐ์ด๋
- ์ด ๋ ผ๋ฌธ์ ์ฆ๋ฅ๋ ์ฒดํฌํฌ์ธํธ์ ๋ํด ์ค๋ช ํฉ๋๋ค.
BartConfig[[transformers.BartConfig]]
[[autodoc]] BartConfig - all
BartTokenizer[[transformers.BartTokenizer]]
[[autodoc]] BartTokenizer - all
BartTokenizerFast[[transformers.BartTokenizerFast]]
[[autodoc]] BartTokenizerFast - all
BartModel[[transformers.BartModel]]
[[autodoc]] BartModel - forward
BartForConditionalGeneration[[transformers.BartForConditionalGeneration]]
[[autodoc]] BartForConditionalGeneration - forward
BartForSequenceClassification[[transformers.BartForSequenceClassification]]
[[autodoc]] BartForSequenceClassification - forward
BartForQuestionAnswering[[transformers.BartForQuestionAnswering]]
[[autodoc]] BartForQuestionAnswering - forward
BartForCausalLM[[transformers.BartForCausalLM]]
[[autodoc]] BartForCausalLM - forward
TFBartModel[[transformers.TFBartModel]]
[[autodoc]] TFBartModel - call
TFBartForConditionalGeneration[[transformers.TFBartForConditionalGeneration]]
[[autodoc]] TFBartForConditionalGeneration - call
TFBartForSequenceClassification[[transformers.TFBartForSequenceClassification]]
[[autodoc]] TFBartForSequenceClassification - call
FlaxBartModel[[transformers.FlaxBartModel]]
[[autodoc]] FlaxBartModel - call - encode - decode
FlaxBartForConditionalGeneration[[transformers.FlaxBartForConditionalGeneration]]
[[autodoc]] FlaxBartForConditionalGeneration - call - encode - decode
FlaxBartForSequenceClassification[[transformers.FlaxBartForSequenceClassification]]
[[autodoc]] FlaxBartForSequenceClassification - call - encode - decode
FlaxBartForQuestionAnswering[[transformers.FlaxBartForQuestionAnswering]]
[[autodoc]] FlaxBartForQuestionAnswering - call - encode - decode
FlaxBartForCausalLM[[transformers.FlaxBartForCausalLM]]
[[autodoc]] FlaxBartForCausalLM - call