์ธ์ฝ๋-๋์ฝ๋ ๋ชจ๋ธ[[Encoder Decoder Models]]
๊ฐ์[[Overview]]
[EncoderDecoderModel]์ ์ฌ์ ํ์ต๋ ์๋ ์ธ์ฝ๋ฉ(autoencoding) ๋ชจ๋ธ์ ์ธ์ฝ๋๋ก, ์ฌ์ ํ์ต๋ ์๊ฐ ํ๊ท(autoregressive) ๋ชจ๋ธ์ ๋์ฝ๋๋ก ํ์ฉํ์ฌ ์ํ์ค-ํฌ-์ํ์ค(sequence-to-sequence) ๋ชจ๋ธ์ ์ด๊ธฐํํ๋ ๋ฐ ์ด์ฉ๋ฉ๋๋ค.
์ฌ์ ํ์ต๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ํ์ฉํด ์ํ์ค-ํฌ-์ํ์ค ๋ชจ๋ธ์ ์ด๊ธฐํํ๋ ๊ฒ์ด ์ํ์ค ์์ฑ(sequence generation) ์์ ์ ํจ๊ณผ์ ์ด๋ผ๋ ์ ์ด Sascha Rothe, Shashi Narayan, Aliaksei Severyn์ ๋ ผ๋ฌธ Leveraging Pre-trained Checkpoints for Sequence Generation Tasks์์ ์ ์ฆ๋์์ต๋๋ค.
[EncoderDecoderModel]์ด ํ์ต/๋ฏธ์ธ ์กฐ์ ๋ ํ์๋ ๋ค๋ฅธ ๋ชจ๋ธ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ์ ์ฅ/๋ถ๋ฌ์ค๊ธฐ๊ฐ ๊ฐ๋ฅํฉ๋๋ค. ์์ธํ ์ฌ์ฉ๋ฒ์ ์์ ๋ฅผ ์ฐธ๊ณ ํ์ธ์.
์ด ์ํคํ
์ฒ์ ํ ๊ฐ์ง ์์ฉ ์ฌ๋ก๋ ๋ ๊ฐ์ ์ฌ์ ํ์ต๋ [BertModel]์ ๊ฐ๊ฐ ์ธ์ฝ๋์ ๋์ฝ๋๋ก ํ์ฉํ์ฌ ์์ฝ ๋ชจ๋ธ(summarization model)์ ๊ตฌ์ถํ๋ ๊ฒ์
๋๋ค. ์ด๋ Yang Liu์ Mirella Lapata์ ๋
ผ๋ฌธ Text Summarization with Pretrained Encoders์์ ์ ์๋ ๋ฐ ์์ต๋๋ค.
๋ชจ๋ธ ์ค์ ์์ EncoderDecoderModel์ ๋ฌด์์ ์ด๊ธฐํํ๊ธฐ[[Randomly initializing EncoderDecoderModel from model configurations.]]
[EncoderDecoderModel]์ ์ธ์ฝ๋์ ๋์ฝ๋ ์ค์ (config)์ ๊ธฐ๋ฐ์ผ๋ก ๋ฌด์์ ์ด๊ธฐํ๋ฅผ ํ ์ ์์ต๋๋ค. ์๋ ์์๋ [BertModel] ์ค์ ์ ์ธ์ฝ๋๋ก, ๊ธฐ๋ณธ [BertForCausalLM] ์ค์ ์ ๋์ฝ๋๋ก ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
>>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
>>> config_encoder = BertConfig()
>>> config_decoder = BertConfig()
>>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
>>> model = EncoderDecoderModel(config=config)
์ฌ์ ํ์ต๋ ์ธ์ฝ๋์ ๋์ฝ๋๋ก EncoderDecoderModel ์ด๊ธฐํํ๊ธฐ[[Initialising EncoderDecoderModel from a pretrained encoder and a pretrained decoder.]]
[EncoderDecoderModel]์ ์ฌ์ ํ์ต๋ ์ธ์ฝ๋ ์ฒดํฌํฌ์ธํธ์ ์ฌ์ ํ์ต๋ ๋์ฝ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํด ์ด๊ธฐํํ ์ ์์ต๋๋ค. BERT์ ๊ฐ์ ๋ชจ๋ ์ฌ์ ํ์ต๋ ์๋ ์ธ์ฝ๋ฉ(auto-encoding) ๋ชจ๋ธ์ ์ธ์ฝ๋๋ก ํ์ฉํ ์ ์์ผ๋ฉฐ, GPT2์ ๊ฐ์ ์๊ฐ ํ๊ท(autoregressive) ๋ชจ๋ธ์ด๋ BART์ ๋์ฝ๋์ ๊ฐ์ด ์ฌ์ ํ์ต๋ ์ํ์ค-ํฌ-์ํ์ค ๋์ฝ๋ ๋ชจ๋ธ์ ๋์ฝ๋๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค. ๋์ฝ๋๋ก ์ ํํ ์ํคํ
์ฒ์ ๋ฐ๋ผ ๊ต์ฐจ ์ดํ
์
(cross-attention) ๋ ์ด์ด๊ฐ ๋ฌด์์๋ก ์ด๊ธฐํ๋ ์ ์์ต๋๋ค. ์ฌ์ ํ์ต๋ ์ธ์ฝ๋์ ๋์ฝ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ด์ฉํด [EncoderDecoderModel]์ ์ด๊ธฐํํ๋ ค๋ฉด, ๋ชจ๋ธ์ ๋ค์ด์คํธ๋ฆผ ์์
์ ๋ํด ๋ฏธ์ธ ์กฐ์ (fine-tuning)ํด์ผ ํฉ๋๋ค. ์ด์ ๋ํ ์์ธํ ๋ด์ฉ์ the Warm-starting-encoder-decoder blog post์ ์ค๋ช
๋์ด ์์ต๋๋ค. ์ด ์์
์ ์ํด EncoderDecoderModel ํด๋์ค๋ [EncoderDecoderModel.from_encoder_decoder_pretrained] ๋ฉ์๋๋ฅผ ์ ๊ณตํฉ๋๋ค.
>>> from transformers import EncoderDecoderModel, BertTokenizer
>>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
๊ธฐ์กด EncoderDecoderModel ์ฒดํฌํฌ์ธํธ ๋ถ๋ฌ์ค๊ธฐ ๋ฐ ์ถ๋ก ํ๊ธฐ[[Loading an existing EncoderDecoderModel checkpoint and perform inference.]]
EncoderDecoderModel ํด๋์ค์ ๋ฏธ์ธ ์กฐ์ (fine-tuned)๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ค๋ ค๋ฉด, Transformers์ ๋ค๋ฅธ ๋ชจ๋ธ ์ํคํ
์ฒ์ ๋ง์ฐฌ๊ฐ์ง๋ก [EncoderDecoderModel]์์ ์ ๊ณตํ๋ from_pretrained(...)๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ถ๋ก ์ ์ํํ๋ ค๋ฉด [generate] ๋ฉ์๋๋ฅผ ํ์ฉํ์ฌ ํ
์คํธ๋ฅผ ์๋ ํ๊ท(autoregressive) ๋ฐฉ์์ผ๋ก ์์ฑํ ์ ์์ต๋๋ค. ์ด ๋ฉ์๋๋ ํ์ ๋์ฝ๋ฉ(greedy decoding), ๋น ์์น(beam search), ๋คํญ ์ํ๋ง(multinomial sampling) ๋ฑ ๋ค์ํ ๋์ฝ๋ฉ ๋ฐฉ์์ ์ง์ํฉ๋๋ค.
>>> from transformers import AutoTokenizer, EncoderDecoderModel
>>> # ๋ฏธ์ธ ์กฐ์ ๋ seq2seq ๋ชจ๋ธ๊ณผ ๋์ํ๋ ํ ํฌ๋์ด์ ๊ฐ์ ธ์ค๊ธฐ
>>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
>>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
>>> # let's perform inference on a long piece of text
>>> ARTICLE_TO_SUMMARIZE = (
... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
... )
>>> input_ids = tokenizer(ARTICLE_TO_SUMMARIZE, return_tensors="pt").input_ids
>>> # ์๊ธฐํ๊ท์ ์ผ๋ก ์์ฝ ์์ฑ (๊ธฐ๋ณธ์ ์ผ๋ก ๊ทธ๋ฆฌ๋ ๋์ฝ๋ฉ ์ฌ์ฉ)
>>> generated_ids = model.generate(input_ids)
>>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
nearly 800 thousand customers were affected by the shutoffs. the aim is to reduce the risk of wildfires. nearly 800, 000 customers were expected to be affected by high winds amid dry conditions. pg & e said it scheduled the blackouts to last through at least midday tomorrow.
TFEncoderDecoderModel์ Pytorch ์ฒดํฌํฌ์ธํธ ๋ถ๋ฌ์ค๊ธฐ[[Loading a PyTorch checkpoint into TFEncoderDecoderModel.]]
[TFEncoderDecoderModel.from_pretrained] ๋ฉ์๋๋ ํ์ฌ Pytorch ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ ๋ชจ๋ธ ์ด๊ธฐํ๋ฅผ ์ง์ํ์ง ์์ต๋๋ค. ์ด ๋ฉ์๋์ from_pt=True๋ฅผ ์ ๋ฌํ๋ฉด ์์ธ(exception)๊ฐ ๋ฐ์ํฉ๋๋ค. ํน์ ์ธ์ฝ๋-๋์ฝ๋ ๋ชจ๋ธ์ ๋ํ Pytorch ์ฒดํฌํฌ์ธํธ๋ง ์กด์ฌํ๋ ๊ฒฝ์ฐ, ๋ค์๊ณผ ๊ฐ์ ํด๊ฒฐ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ์ ์์ต๋๋ค:
>>> # ํ์ดํ ์น ์ฒดํฌํฌ์ธํธ์์ ๋ก๋ํ๋ ํด๊ฒฐ ๋ฐฉ๋ฒ
>>> from transformers import EncoderDecoderModel, TFEncoderDecoderModel
>>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # ์ด ๋ถ๋ถ์ ํน์ ๋ชจ๋ธ์ ๊ตฌ์ฒด์ ์ธ ์ธ๋ถ์ฌํญ์ ๋ณต์ฌํ ๋์๋ง ์ฌ์ฉํฉ๋๋ค.
>>> model.config = _model.config
ํ์ต[[Training]]
๋ชจ๋ธ์ด ์์ฑ๋ ํ์๋ BART, T5 ๋๋ ๊ธฐํ ์ธ์ฝ๋-๋์ฝ๋ ๋ชจ๋ธ๊ณผ ์ ์ฌํ ๋ฐฉ์์ผ๋ก ๋ฏธ์ธ ์กฐ์ (fine-tuning)ํ ์ ์์ต๋๋ค.
๋ณด์๋ค์ํผ, ์์ค(loss)์ ๊ณ์ฐํ๋ ค๋ฉด ๋จ 2๊ฐ์ ์
๋ ฅ๋ง ํ์ํฉ๋๋ค: input_ids(์
๋ ฅ ์ํ์ค๋ฅผ ์ธ์ฝ๋ฉํ input_ids)์ labels(๋ชฉํ ์ํ์ค๋ฅผ ์ธ์ฝ๋ฉํ input_ids).
>>> from transformers import BertTokenizer, EncoderDecoderModel
>>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
>>> model.config.decoder_start_token_id = tokenizer.cls_token_id
>>> model.config.pad_token_id = tokenizer.pad_token_id
>>> input_ids = tokenizer(
... "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side.During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft).Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.",
... return_tensors="pt",
... ).input_ids
>>> labels = tokenizer(
... "the eiffel tower surpassed the washington monument to become the tallest structure in the world. it was the first structure to reach a height of 300 metres in paris in 1930. it is now taller than the chrysler building by 5. 2 metres ( 17 ft ) and is the second tallest free - standing structure in paris.",
... return_tensors="pt",
... ).input_ids
>>> # forward ํจ์๊ฐ ์๋์ผ๋ก ์ ํฉํ decoder_input_ids๋ฅผ ์์ฑํฉ๋๋ค.
>>> loss = model(input_ids=input_ids, labels=labels).loss
ํ๋ จ์ ๋ํ ์์ธํ ๋ด์ฉ์ colab ๋ ธํธ๋ถ์ ์ฐธ์กฐํ์ธ์.
์ด ๋ชจ๋ธ์ thomwolf๊ฐ ๊ธฐ์ฌํ์ผ๋ฉฐ, ์ด ๋ชจ๋ธ์ ๋ํ TensorFlow ๋ฐ Flax ๋ฒ์ ์ ydshieh๊ฐ ๊ธฐ์ฌํ์ต๋๋ค.
EncoderDecoderConfig
[[autodoc]] EncoderDecoderConfig
EncoderDecoderModel
[[autodoc]] EncoderDecoderModel - forward - from_encoder_decoder_pretrained
TFEncoderDecoderModel
[[autodoc]] TFEncoderDecoderModel - call - from_encoder_decoder_pretrained
FlaxEncoderDecoderModel
[[autodoc]] FlaxEncoderDecoderModel - call - from_encoder_decoder_pretrained