BERT[[BERT]]
๊ฐ์[[Overview]]
BERT ๋ชจ๋ธ์ Jacob Devlin. Ming-Wei Chang, Kenton Lee, Kristina Touranova๊ฐ ์ ์ํ ๋ ผ๋ฌธ BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding์์ ์๊ฐ๋์์ต๋๋ค. BERT๋ ์ฌ์ ํ์ต๋ ์๋ฐฉํฅ ํธ๋์คํฌ๋จธ๋ก, Toronto Book Corpus์ Wikipedia๋ก ๊ตฌ์ฑ๋ ๋๊ท๋ชจ ์ฝํผ์ค์์ ๋ง์คํน๋ ์ธ์ด ๋ชจ๋ธ๋ง๊ณผ ๋ค์ ๋ฌธ์ฅ ์์ธก(Next Sentence Prediction) ๋ชฉํ๋ฅผ ๊ฒฐํฉํด ํ์ต๋์์ต๋๋ค.
ํด๋น ๋ ผ๋ฌธ์ ์ด๋ก์ ๋๋ค:
์ฐ๋ฆฌ๋ BERT(Bidirectional Encoder Representations from Transformers)๋ผ๋ ์๋ก์ด ์ธ์ด ํํ ๋ชจ๋ธ์ ์๊ฐํฉ๋๋ค. ์ต๊ทผ์ ๋ค๋ฅธ ์ธ์ด ํํ ๋ชจ๋ธ๋ค๊ณผ ๋ฌ๋ฆฌ, BERT๋ ๋ชจ๋ ๊ณ์ธต์์ ์๋ฐฉํฅ์ผ๋ก ์์ชฝ ๋ฌธ๋งฅ์ ์กฐ๊ฑด์ผ๋ก ์ฌ์ฉํ์ฌ ๋น์ง๋ ํ์ต๋ ํ ์คํธ์์ ๊น์ด ์๋ ์๋ฐฉํฅ ํํ์ ์ฌ์ ํ์ตํ๋๋ก ์ค๊ณ๋์์ต๋๋ค. ๊ทธ ๊ฒฐ๊ณผ, ์ฌ์ ํ์ต๋ BERT ๋ชจ๋ธ์ ์ถ๊ฐ์ ์ธ ์ถ๋ ฅ ๊ณ์ธต ํ๋๋ง์ผ๋ก ์ง๋ฌธ ์๋ต, ์ธ์ด ์ถ๋ก ๊ณผ ๊ฐ์ ๋ค์ํ ์์ ์์ ๋ฏธ์ธ ์กฐ์ ๋ ์ ์์ผ๋ฏ๋ก, ํน์ ์์ ์ ์ํด ์ํคํ ์ฒ๋ฅผ ์์ ํ ํ์๊ฐ ์์ต๋๋ค.
BERT๋ ๊ฐ๋ ์ ์ผ๋ก ๋จ์ํ๋ฉด์๋ ์ค์ฆ์ ์ผ๋ก ๊ฐ๋ ฅํ ๋ชจ๋ธ์ ๋๋ค. BERT๋ 11๊ฐ์ ์์ฐ์ด ์ฒ๋ฆฌ ๊ณผ์ ์์ ์๋ก์ด ์ต๊ณ ์ฑ๋ฅ์ ๋ฌ์ฑํ์ผ๋ฉฐ, GLUE ์ ์๋ฅผ 80.5% (7.7% ํฌ์ธํธ ์ ๋ ๊ฐ์ )๋ก, MultiNLI ์ ํ๋๋ฅผ 86.7% (4.6% ํฌ์ธํธ ์ ๋ ๊ฐ์ ), SQuAD v1.1 ์ง๋ฌธ ์๋ต ํ ์คํธ์์ F1 ์ ์๋ฅผ 93.2 (1.5% ํฌ์ธํธ ์ ๋ ๊ฐ์ )๋ก, SQuAD v2.0์์ F1 ์ ์๋ฅผ 83.1 (5.1% ํฌ์ธํธ ์ ๋ ๊ฐ์ )๋ก ํฅ์์์ผฐ์ต๋๋ค.
์ด ๋ชจ๋ธ์ thomwolf๊ฐ ๊ธฐ์ฌํ์์ต๋๋ค. ์๋ณธ ์ฝ๋๋ ์ฌ๊ธฐ์์ ํ์ธํ ์ ์์ต๋๋ค.
์ฌ์ฉ ํ[[Usage tips]]
BERT๋ ์ ๋ ์์น ์๋ฒ ๋ฉ์ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ด๋ฏ๋ก ์ ๋ ฅ์ ์ผ์ชฝ์ด ์๋๋ผ ์ค๋ฅธ์ชฝ์์ ํจ๋ฉํ๋ ๊ฒ์ด ์ผ๋ฐ์ ์ผ๋ก ๊ถ์ฅ๋ฉ๋๋ค.
BERT๋ ๋ง์คํน๋ ์ธ์ด ๋ชจ๋ธ(MLM)๊ณผ Next Sentence Prediction(NSP) ๋ชฉํ๋ก ํ์ต๋์์ต๋๋ค. ์ด๋ ๋ง์คํน๋ ํ ํฐ ์์ธก๊ณผ ์ ๋ฐ์ ์ธ ์์ฐ์ด ์ดํด(NLU)์ ๋ฐ์ด๋์ง๋ง, ํ ์คํธ ์์ฑ์๋ ์ต์ ํ๋์ด์์ง ์์ต๋๋ค.
BERT์ ์ฌ์ ํ์ต ๊ณผ์ ์์๋ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋ฌด์์๋ก ๋ง์คํนํ์ฌ ์ผ๋ถ ํ ํฐ์ ๋ง์คํนํฉ๋๋ค. ์ ์ฒด ํ ํฐ ์ค ์ฝ 15%๊ฐ ๋ค์๊ณผ ๊ฐ์ ๋ฐฉ์์ผ๋ก ๋ง์คํน๋ฉ๋๋ค:
- 80% ํ๋ฅ ๋ก ๋ง์คํฌ ํ ํฐ์ผ๋ก ๋์ฒด
- 10% ํ๋ฅ ๋ก ์์์ ๋ค๋ฅธ ํ ํฐ์ผ๋ก ๋์ฒด
- 10% ํ๋ฅ ๋ก ์๋ ํ ํฐ ๊ทธ๋๋ก ์ ์ง
๋ชจ๋ธ์ ์ฃผ์ ๋ชฉํ๋ ์๋ณธ ๋ฌธ์ฅ์ ์์ธกํ๋ ๊ฒ์ด์ง๋ง, ๋ ๋ฒ์งธ ๋ชฉํ๊ฐ ์์ต๋๋ค: ์ ๋ ฅ์ผ๋ก ๋ฌธ์ฅ A์ B (์ฌ์ด์๋ ๊ตฌ๋ถ ํ ํฐ์ด ์์)๊ฐ ์ฃผ์ด์ง๋๋ค. ์ด ๋ฌธ์ฅ ์์ด ์ฐ์๋ ํ๋ฅ ์ 50%์ด๋ฉฐ, ๋๋จธ์ง 50%๋ ์๋ก ๋ฌด๊ดํ ๋ฌธ์ฅ๋ค์ ๋๋ค. ๋ชจ๋ธ์ ์ด ๋ ๋ฌธ์ฅ์ด ์๋์ง๋ฅผ ์์ธกํด์ผ ํฉ๋๋ค.
Scaled Dot Product Attention(SDPA) ์ฌ์ฉํ๊ธฐ [[Using Scaled Dot Product Attention (SDPA)]]
Pytorch๋ torch.nn.functional์ ์ผ๋ถ๋ก Scaled Dot Product Attention(SDPA) ์ฐ์ฐ์๋ฅผ ๊ธฐ๋ณธ์ ์ผ๋ก ์ ๊ณตํฉ๋๋ค. ์ด ํจ์๋ ์
๋ ฅ๊ณผ ํ๋์จ์ด์ ๋ฐ๋ผ ์ฌ๋ฌ ๊ตฌํ ๋ฐฉ์์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ ๊ณต์ ๋ฌธ์๋ GPU Inference์์ ํ์ธํ ์ ์์ต๋๋ค.
torch>=2.1.1์์๋ ๊ตฌํ์ด ๊ฐ๋ฅํ ๊ฒฝ์ฐ SDPA๊ฐ ๊ธฐ๋ณธ์ ์ผ๋ก ์ฌ์ฉ๋์ง๋ง, from_pretrained()ํจ์์์ attn_implementation="sdpa"๋ฅผ ์ค์ ํ์ฌ SDPA๋ฅผ ๋ช
์์ ์ผ๋ก ์ฌ์ฉํ๋๋ก ์ง์ ํ ์๋ ์์ต๋๋ค.
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
...
์ต์ ์ฑ๋ฅ ํฅ์์ ์ํด ๋ชจ๋ธ์ ๋ฐ์ ๋ฐ๋(์: torch.float16 ๋๋ torch.bfloat16)๋ก ๋ถ๋ฌ์ค๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
๋ก์ปฌ ๋ฒค์น๋งํฌ (A100-80GB, CPUx12, RAM 96.6GB, PyTorch 2.2.0, OS Ubuntu 22.04)์์ float16์ ์ฌ์ฉํด ํ์ต ๋ฐ ์ถ๋ก ์ ์ํํ ๊ฒฐ๊ณผ, ๋ค์๊ณผ ๊ฐ์ ์๋ ํฅ์์ด ๊ด์ฐฐ๋์์ต๋๋ค.
ํ์ต [[Training]]
| batch_size | seq_len | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|---|---|---|---|---|---|---|---|
| 4 | 256 | 0.023 | 0.017 | 35.472 | 939.213 | 764.834 | 22.800 |
| 4 | 512 | 0.023 | 0.018 | 23.687 | 1970.447 | 1227.162 | 60.569 |
| 8 | 256 | 0.023 | 0.018 | 23.491 | 1594.295 | 1226.114 | 30.028 |
| 8 | 512 | 0.035 | 0.025 | 43.058 | 3629.401 | 2134.262 | 70.054 |
| 16 | 256 | 0.030 | 0.024 | 25.583 | 2874.426 | 2134.262 | 34.680 |
| 16 | 512 | 0.064 | 0.044 | 46.223 | 6964.659 | 3961.013 | 75.830 |
์ถ๋ก [[Inference]]
| batch_size | seq_len | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|---|---|---|---|---|---|---|---|
| 1 | 128 | 5.736 | 4.987 | 15.022 | 282.661 | 282.924 | -0.093 |
| 1 | 256 | 5.689 | 4.945 | 15.055 | 298.686 | 298.948 | -0.088 |
| 2 | 128 | 6.154 | 4.982 | 23.521 | 314.523 | 314.785 | -0.083 |
| 2 | 256 | 6.201 | 4.949 | 25.303 | 347.546 | 347.033 | 0.148 |
| 4 | 128 | 6.049 | 4.987 | 21.305 | 378.895 | 379.301 | -0.107 |
| 4 | 256 | 6.285 | 5.364 | 17.166 | 443.209 | 444.382 | -0.264 |
์๋ฃ[[Resources]]
BERT๋ฅผ ์์ํ๋ ๋ฐ ๋์์ด ๋๋ Hugging Face์ community ์๋ฃ ๋ชฉ๋ก(๐๋ก ํ์๋จ) ์ ๋๋ค. ์ฌ๊ธฐ์ ํฌํจ๋ ์๋ฃ๋ฅผ ์ ์ถํ๊ณ ์ถ๋ค๋ฉด PR(Pull Request)๋ฅผ ์ด์ด์ฃผ์ธ์. ๋ฆฌ๋ทฐ ํด๋๋ฆฌ๊ฒ ์ต๋๋ค! ์๋ฃ๋ ๊ธฐ์กด ์๋ฃ๋ฅผ ๋ณต์ ํ๋ ๋์ ์๋ก์ด ๋ด์ฉ์ ๋ด๊ณ ์์ด์ผ ํฉ๋๋ค.
- BERT ํ ์คํธ ๋ถ๋ฅ (๋ค๋ฅธ ์ธ์ด๋ก)์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
- ๋ค์ค ๋ ์ด๋ธ ํ ์คํธ ๋ถ๋ฅ๋ฅผ ์ํ BERT (๋ฐ ๊ด๋ จ ๋ชจ๋ธ) ๋ฏธ์ธ ์กฐ์ ์ ๋ํ ๋ ธํธ๋ถ.
- PyTorch๋ฅผ ์ด์ฉํด BERT๋ฅผ ๋ค์ค ๋ ์ด๋ธ ๋ถ๋ฅ๋ฅผ ์ํด ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ ธํธ๋ถ. ๐
- BERT๋ก EncoderDecoder ๋ชจ๋ธ์ warm-startํ์ฌ ์์ฝํ๊ธฐ์ ๋ํ ๋ ธํธ๋ถ.
- [
BertForSequenceClassification]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - [
TFBertForSequenceClassification]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - [
FlaxBertForSequenceClassification]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - ํ ์คํธ ๋ถ๋ฅ ์์ ๊ฐ์ด๋
- Keras์ ํจ๊ป Hugging Face Transformers๋ฅผ ์ฌ์ฉํ์ฌ ๋น์๋ฆฌ BERT๋ฅผ ๊ฐ์ฒด๋ช ์ธ์(NER)์ฉ์ผ๋ก ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
- BERT๋ฅผ ๊ฐ์ฒด๋ช ์ธ์์ ์ํด ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ์ ๋ํ ๋ ธํธ๋ถ. ๊ฐ ๋จ์ด์ ์ฒซ ๋ฒ์งธ wordpiece์๋ง ๋ ์ด๋ธ์ ์ง์ ํ์ฌ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค. ๋ชจ๋ wordpiece์ ๋ ์ด๋ธ์ ์ ํํ๋ ๋ฐฉ๋ฒ์ ์ด ๋ฒ์ ์์ ํ์ธํ ์ ์์ต๋๋ค.
- [
BertForTokenClassification]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - [
TFBertForTokenClassification]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - [
FlaxBertForTokenClassification]์ด ์์ ์คํฌ๋ฆฝํธ์์ ์ง์๋ฉ๋๋ค. - ๐ค Hugging Face ์ฝ์ค์ ํ ํฐ ๋ถ๋ฅ ์ฑํฐ.
- ํ ํฐ ๋ถ๋ฅ ์์ ๊ฐ์ด๋
- [
BertForMaskedLM]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - [
TFBertForMaskedLM]์ด ์์ ์คํฌ๋ฆฝํธ ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - [
FlaxBertForMaskedLM]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - ๐ค Hugging Face ์ฝ์ค์ ๋ง์คํน๋ ์ธ์ด ๋ชจ๋ธ๋ง ์ฑํฐ.
- ๋ง์คํน๋ ์ธ์ด ๋ชจ๋ธ๋ง ์์ ๊ฐ์ด๋
- [
BertForQuestionAnswering]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - [
TFBertForQuestionAnswering]์ด ์์ ์คํฌ๋ฆฝํธ ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - [
FlaxBertForQuestionAnswering]์ด ์์ ์คํฌ๋ฆฝํธ์์ ์ง์๋ฉ๋๋ค. - ๐ค Hugging Face ์ฝ์ค์ ์ง๋ฌธ ๋ต๋ณ ์ฑํฐ.
- ์ง๋ฌธ ๋ต๋ณ ์์ ๊ฐ์ด๋
๋ค์ค ์ ํ
- [
BertForMultipleChoice]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - [
TFBertForMultipleChoice]์ด ์์ ์คํฌ๋ฆฝํธ์ ๋ ธํธ๋ถ์์ ์ง์๋ฉ๋๋ค. - ๋ค์ค ์ ํ ์์ ๊ฐ์ด๋
โก๏ธ ์ถ๋ก
- Hugging Face Transformers์ AWS Inferentia๋ฅผ ์ฌ์ฉํ์ฌ BERT ์ถ๋ก ์ ๊ฐ์ํํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
- GPU์์ DeepSpeed-Inference๋ก BERT ์ถ๋ก ์ ๊ฐ์ํํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
โ๏ธ ์ฌ์ ํ์ต
- Hugging Face Optimum์ผ๋ก Transformers๋ฅผ ONMX๋ก ๋ณํํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
๐ ๋ฐฐํฌ
- Hugging Face Optimum์ผ๋ก Transformers๋ฅผ ONMX๋ก ๋ณํํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
- AWS์์ Hugging Face Transformers๋ฅผ ์ํ Habana Gaudi ๋ฅ๋ฌ๋ ํ๊ฒฝ ์ค์ ๋ฐฉ๋ฒ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
- Hugging Face Transformers, Amazon SageMaker ๋ฐ Terraform ๋ชจ๋์ ์ด์ฉํ BERT ์๋ ํ์ฅ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
- Hugging Face, AWS Lambda, Docker๋ฅผ ํ์ฉํ์ฌ ์๋ฒ๋ฆฌ์ค BERT ์ค์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
- Amazon SageMaker์ Training Compiler๋ฅผ ์ฌ์ฉํ์ฌ Hugging Face Transformers์์ BERT ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ธ๋ก๊ทธ.
- Amazon SageMaker๋ฅผ ์ฌ์ฉํ Transformers์ BERT์ ์์ ๋ณ ์ง์ ์ฆ๋ฅ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ.
BertConfig
[[autodoc]] BertConfig - all
BertTokenizer
[[autodoc]] BertTokenizer - build_inputs_with_special_tokens - get_special_tokens_mask - create_token_type_ids_from_sequences - save_vocabulary
BertTokenizerFast
[[autodoc]] BertTokenizerFast
TFBertTokenizer
[[autodoc]] TFBertTokenizer
Bert specific outputs
[[autodoc]] models.bert.modeling_bert.BertForPreTrainingOutput
[[autodoc]] models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
[[autodoc]] models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput
BertModel
[[autodoc]] BertModel - forward
BertForPreTraining
[[autodoc]] BertForPreTraining - forward
BertLMHeadModel
[[autodoc]] BertLMHeadModel - forward
BertForMaskedLM
[[autodoc]] BertForMaskedLM - forward
BertForNextSentencePrediction
[[autodoc]] BertForNextSentencePrediction - forward
BertForSequenceClassification
[[autodoc]] BertForSequenceClassification - forward
BertForMultipleChoice
[[autodoc]] BertForMultipleChoice - forward
BertForTokenClassification
[[autodoc]] BertForTokenClassification - forward
BertForQuestionAnswering
[[autodoc]] BertForQuestionAnswering - forward
TFBertModel
[[autodoc]] TFBertModel - call
TFBertForPreTraining
[[autodoc]] TFBertForPreTraining - call
TFBertModelLMHeadModel
[[autodoc]] TFBertLMHeadModel - call
TFBertForMaskedLM
[[autodoc]] TFBertForMaskedLM - call
TFBertForNextSentencePrediction
[[autodoc]] TFBertForNextSentencePrediction - call
TFBertForSequenceClassification
[[autodoc]] TFBertForSequenceClassification - call
TFBertForMultipleChoice
[[autodoc]] TFBertForMultipleChoice - call
TFBertForTokenClassification
[[autodoc]] TFBertForTokenClassification - call
TFBertForQuestionAnswering
[[autodoc]] TFBertForQuestionAnswering - call
FlaxBertModel
[[autodoc]] FlaxBertModel - call
FlaxBertForPreTraining
[[autodoc]] FlaxBertForPreTraining - call
FlaxBertForCausalLM
[[autodoc]] FlaxBertForCausalLM - call
FlaxBertForMaskedLM
[[autodoc]] FlaxBertForMaskedLM - call
FlaxBertForNextSentencePrediction
[[autodoc]] FlaxBertForNextSentencePrediction - call
FlaxBertForSequenceClassification
[[autodoc]] FlaxBertForSequenceClassification - call
FlaxBertForMultipleChoice
[[autodoc]] FlaxBertForMultipleChoice - call
FlaxBertForTokenClassification
[[autodoc]] FlaxBertForTokenClassification - call
FlaxBertForQuestionAnswering
[[autodoc]] FlaxBertForQuestionAnswering - call