File size: 10,581 Bytes
17c6d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
<!--Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# 인코더-디코더 λͺ¨λΈ[[Encoder Decoder Models]]

## κ°œμš”[[Overview]]

[`EncoderDecoderModel`]은 사전 ν•™μŠ΅λœ μžλ™ 인코딩(autoencoding) λͺ¨λΈμ„ μΈμ½”λ”λ‘œ, 사전 ν•™μŠ΅λœ μžκ°€ νšŒκ·€(autoregressive) λͺ¨λΈμ„ λ””μ½”λ”λ‘œ ν™œμš©ν•˜μ—¬ μ‹œν€€μŠ€-투-μ‹œν€€μŠ€(sequence-to-sequence) λͺ¨λΈμ„ μ΄ˆκΈ°ν™”ν•˜λŠ” 데 μ΄μš©λ©λ‹ˆλ‹€.

사전 ν•™μŠ΅λœ 체크포인트λ₯Ό ν™œμš©ν•΄ μ‹œν€€μŠ€-투-μ‹œν€€μŠ€ λͺ¨λΈμ„ μ΄ˆκΈ°ν™”ν•˜λŠ” 것이 μ‹œν€€μŠ€ 생성(sequence generation) μž‘μ—…μ— νš¨κ³Όμ μ΄λΌλŠ” 점이 Sascha Rothe, Shashi Narayan, Aliaksei Severyn의 λ…Όλ¬Έ [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461)μ—μ„œ μž…μ¦λ˜μ—ˆμŠ΅λ‹ˆλ‹€.

[`EncoderDecoderModel`]이 ν•™μŠ΅/λ―Έμ„Έ μ‘°μ •λœ ν›„μ—λŠ” λ‹€λ₯Έ λͺ¨λΈκ³Ό λ§ˆμ°¬κ°€μ§€λ‘œ μ €μž₯/λΆˆλŸ¬μ˜€κΈ°κ°€ κ°€λŠ₯ν•©λ‹ˆλ‹€. μžμ„Έν•œ μ‚¬μš©λ²•μ€ 예제λ₯Ό μ°Έκ³ ν•˜μ„Έμš”.

이 μ•„ν‚€ν…μ²˜μ˜ ν•œ κ°€μ§€ μ‘μš© μ‚¬λ‘€λŠ” 두 개의 사전 ν•™μŠ΅λœ [`BertModel`]을 각각 인코더와 λ””μ½”λ”λ‘œ ν™œμš©ν•˜μ—¬ μš”μ•½ λͺ¨λΈ(summarization model)을 κ΅¬μΆ•ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. μ΄λŠ” Yang Liu와 Mirella Lapata의 λ…Όλ¬Έ [Text Summarization with Pretrained Encoders](https://arxiv.org/abs/1908.08345)μ—μ„œ μ œμ‹œλœ λ°” μžˆμŠ΅λ‹ˆλ‹€.

## λͺ¨λΈ μ„€μ •μ—μ„œ `EncoderDecoderModel`을 λ¬΄μž‘μœ„ μ΄ˆκΈ°ν™”ν•˜κΈ°[[Randomly initializing `EncoderDecoderModel` from model configurations.]]

[`EncoderDecoderModel`]은 인코더와 디코더 μ„€μ •(config)을 기반으둜 λ¬΄μž‘μœ„ μ΄ˆκΈ°ν™”λ₯Ό ν•  수 μžˆμŠ΅λ‹ˆλ‹€. μ•„λž˜ μ˜ˆμ‹œλŠ” [`BertModel`] 섀정을 μΈμ½”λ”λ‘œ, κΈ°λ³Έ [`BertForCausalLM`] 섀정을 λ””μ½”λ”λ‘œ μ‚¬μš©ν•˜λŠ” 방법을 λ³΄μ—¬μ€λ‹ˆλ‹€.

```python
>>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel

>>> config_encoder = BertConfig()
>>> config_decoder = BertConfig()

>>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
>>> model = EncoderDecoderModel(config=config)
```

## 사전 ν•™μŠ΅λœ 인코더와 λ””μ½”λ”λ‘œ `EncoderDecoderModel` μ΄ˆκΈ°ν™”ν•˜κΈ°[[Initialising `EncoderDecoderModel` from a pretrained encoder and a pretrained decoder.]]

[`EncoderDecoderModel`]은 사전 ν•™μŠ΅λœ 인코더 μ²΄ν¬ν¬μΈνŠΈμ™€ 사전 ν•™μŠ΅λœ 디코더 체크포인트λ₯Ό μ‚¬μš©ν•΄ μ΄ˆκΈ°ν™”ν•  수 μžˆμŠ΅λ‹ˆλ‹€. BERT와 같은 λͺ¨λ“  사전 ν•™μŠ΅λœ μžλ™ 인코딩(auto-encoding) λͺ¨λΈμ€ μΈμ½”λ”λ‘œ ν™œμš©ν•  수 있으며, GPT2와 같은 μžκ°€ νšŒκ·€(autoregressive) λͺ¨λΈμ΄λ‚˜ BART의 디코더와 같이 사전 ν•™μŠ΅λœ μ‹œν€€μŠ€-투-μ‹œν€€μŠ€ 디코더 λͺ¨λΈμ„ λ””μ½”λ”λ‘œ μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€. λ””μ½”λ”λ‘œ μ„ νƒν•œ μ•„ν‚€ν…μ²˜μ— 따라 ꡐ차 μ–΄ν…μ…˜(cross-attention) λ ˆμ΄μ–΄κ°€ λ¬΄μž‘μœ„λ‘œ μ΄ˆκΈ°ν™”λ  수 μžˆμŠ΅λ‹ˆλ‹€. 사전 ν•™μŠ΅λœ 인코더와 디코더 체크포인트λ₯Ό μ΄μš©ν•΄ [`EncoderDecoderModel`]을 μ΄ˆκΈ°ν™”ν•˜λ €λ©΄, λͺ¨λΈμ„ λ‹€μš΄μŠ€νŠΈλ¦Ό μž‘μ—…μ— λŒ€ν•΄ λ―Έμ„Έ μ‘°μ •(fine-tuning)ν•΄μ•Ό ν•©λ‹ˆλ‹€. 이에 λŒ€ν•œ μžμ„Έν•œ λ‚΄μš©μ€ [the *Warm-starting-encoder-decoder blog post*](https://huggingface.co/blog/warm-starting-encoder-decoder)에 μ„€λͺ…λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€. 이 μž‘μ—…μ„ μœ„ν•΄ `EncoderDecoderModel` ν΄λž˜μŠ€λŠ” [`EncoderDecoderModel.from_encoder_decoder_pretrained`] λ©”μ„œλ“œλ₯Ό μ œκ³΅ν•©λ‹ˆλ‹€.


```python
>>> from transformers import EncoderDecoderModel, BertTokenizer

>>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
```

## κΈ°μ‘΄ `EncoderDecoderModel` 체크포인트 뢈러였기 및 μΆ”λ‘ ν•˜κΈ°[[Loading an existing `EncoderDecoderModel` checkpoint and perform inference.]]

`EncoderDecoderModel` 클래슀의 λ―Έμ„Έ μ‘°μ •(fine-tuned)된 체크포인트λ₯Ό 뢈러였렀면, Transformers의 λ‹€λ₯Έ λͺ¨λΈ μ•„ν‚€ν…μ²˜μ™€ λ§ˆμ°¬κ°€μ§€λ‘œ [`EncoderDecoderModel`]μ—μ„œ μ œκ³΅ν•˜λŠ” `from_pretrained(...)`λ₯Ό μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

좔둠을 μˆ˜ν–‰ν•˜λ €λ©΄ [`generate`] λ©”μ„œλ“œλ₯Ό ν™œμš©ν•˜μ—¬ ν…μŠ€νŠΈλ₯Ό μžλ™ νšŒκ·€(autoregressive) λ°©μ‹μœΌλ‘œ 생성할 수 μžˆμŠ΅λ‹ˆλ‹€. 이 λ©”μ„œλ“œλŠ” νƒμš• λ””μ½”λ”©(greedy decoding), λΉ” μ„œμΉ˜(beam search), λ‹€ν•­ μƒ˜ν”Œλ§(multinomial sampling) λ“± λ‹€μ–‘ν•œ λ””μ½”λ”© 방식을 μ§€μ›ν•©λ‹ˆλ‹€.

```python
>>> from transformers import AutoTokenizer, EncoderDecoderModel

>>> # λ―Έμ„Έ μ‘°μ •λœ seq2seq λͺ¨λΈκ³Ό λŒ€μ‘ν•˜λŠ” ν† ν¬λ‚˜μ΄μ € κ°€μ Έμ˜€κΈ°
>>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
>>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")

>>> # let's perform inference on a long piece of text
>>> ARTICLE_TO_SUMMARIZE = (
...     "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
...     "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
...     "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
... )
>>> input_ids = tokenizer(ARTICLE_TO_SUMMARIZE, return_tensors="pt").input_ids

>>> # μžκΈ°νšŒκ·€μ μœΌλ‘œ μš”μ•½ 생성 (기본적으둜 그리디 λ””μ½”λ”© μ‚¬μš©)
>>> generated_ids = model.generate(input_ids)
>>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
nearly 800 thousand customers were affected by the shutoffs. the aim is to reduce the risk of wildfires. nearly 800, 000 customers were expected to be affected by high winds amid dry conditions. pg & e said it scheduled the blackouts to last through at least midday tomorrow.
```

## `TFEncoderDecoderModel`에 Pytorch 체크포인트 뢈러였기[[Loading a PyTorch checkpoint into `TFEncoderDecoderModel`.]]

[`TFEncoderDecoderModel.from_pretrained`] λ©”μ„œλ“œλŠ” ν˜„μž¬ Pytorch 체크포인트λ₯Ό μ‚¬μš©ν•œ λͺ¨λΈ μ΄ˆκΈ°ν™”λ₯Ό μ§€μ›ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€. 이 λ©”μ„œλ“œμ— `from_pt=True`λ₯Ό μ „λ‹¬ν•˜λ©΄ μ˜ˆμ™Έ(exception)κ°€ λ°œμƒν•©λ‹ˆλ‹€. νŠΉμ • 인코더-디코더 λͺ¨λΈμ— λŒ€ν•œ Pytorch 체크포인트만 μ‘΄μž¬ν•˜λŠ” 경우, λ‹€μŒκ³Ό 같은 ν•΄κ²° 방법을 μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€:

```python
>>> # νŒŒμ΄ν† μΉ˜ μ²΄ν¬ν¬μΈνŠΈμ—μ„œ λ‘œλ“œν•˜λŠ” ν•΄κ²° 방법
>>> from transformers import EncoderDecoderModel, TFEncoderDecoderModel

>>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")

>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")

>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
...     "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # 이 뢀뢄은 νŠΉμ • λͺ¨λΈμ˜ ꡬ체적인 세뢀사항을 볡사할 λ•Œμ—λ§Œ μ‚¬μš©ν•©λ‹ˆλ‹€.
>>> model.config = _model.config
```

## ν•™μŠ΅[[Training]]

λͺ¨λΈμ΄ μƒμ„±λœ ν›„μ—λŠ” BART, T5 λ˜λŠ” 기타 인코더-디코더 λͺ¨λΈκ³Ό μœ μ‚¬ν•œ λ°©μ‹μœΌλ‘œ λ―Έμ„Έ μ‘°μ •(fine-tuning)ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
λ³΄μ‹œλ‹€μ‹œν”Ό, 손싀(loss)을 κ³„μ‚°ν•˜λ €λ©΄ 단 2개의 μž…λ ₯만 ν•„μš”ν•©λ‹ˆλ‹€: `input_ids`(μž…λ ₯ μ‹œν€€μŠ€λ₯Ό μΈμ½”λ”©ν•œ `input_ids`)와 `labels`(λͺ©ν‘œ μ‹œν€€μŠ€λ₯Ό μΈμ½”λ”©ν•œ `input_ids`).

```python
>>> from transformers import BertTokenizer, EncoderDecoderModel

>>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")

>>> model.config.decoder_start_token_id = tokenizer.cls_token_id
>>> model.config.pad_token_id = tokenizer.pad_token_id

>>> input_ids = tokenizer(
...     "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side.During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was  finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft).Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.",
...     return_tensors="pt",
... ).input_ids

>>> labels = tokenizer(
...     "the eiffel tower surpassed the washington monument to become the tallest structure in the world. it was the first structure to reach a height of 300 metres in paris in 1930. it is now taller than the chrysler building by 5. 2 metres ( 17 ft ) and is the second tallest free - standing structure in paris.",
...     return_tensors="pt",
... ).input_ids

>>> # forward ν•¨μˆ˜κ°€ μžλ™μœΌλ‘œ μ ν•©ν•œ decoder_input_idsλ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
>>> loss = model(input_ids=input_ids, labels=labels).loss
```
ν›ˆλ ¨μ— λŒ€ν•œ μžμ„Έν•œ λ‚΄μš©μ€ [colab](https://colab.research.google.com/drive/1WIk2bxglElfZewOHboPFNj8H44_VAyKE?usp=sharing#scrollTo=ZwQIEhKOrJpl) λ…ΈνŠΈλΆμ„ μ°Έμ‘°ν•˜μ„Έμš”. 

이 λͺ¨λΈμ€ [thomwolf](https://github.com/thomwolf)κ°€ κΈ°μ—¬ν–ˆμœΌλ©°, 이 λͺ¨λΈμ— λŒ€ν•œ TensorFlow 및 Flax 버전은 [ydshieh](https://github.com/ydshieh)κ°€ κΈ°μ—¬ν–ˆμŠ΅λ‹ˆλ‹€.


## EncoderDecoderConfig

[[autodoc]] EncoderDecoderConfig

<frameworkcontent>
<pt>

## EncoderDecoderModel

[[autodoc]] EncoderDecoderModel
    - forward
    - from_encoder_decoder_pretrained

</pt>
<tf>

## TFEncoderDecoderModel

[[autodoc]] TFEncoderDecoderModel
    - call
    - from_encoder_decoder_pretrained

</tf>
<jax>

## FlaxEncoderDecoderModel

[[autodoc]] FlaxEncoderDecoderModel
    - __call__
    - from_encoder_decoder_pretrained

</jax>
</frameworkcontent>