Vision Transformer (ViT) [[vision-transformer-vit]]
๊ฐ์ [[overview]]
Vision Transformer (ViT) ๋ชจ๋ธ์ Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby๊ฐ ์ ์ํ ๋ ผ๋ฌธ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale์์ ์๊ฐ๋์์ต๋๋ค. ์ด๋ Transformer ์ธ์ฝ๋๋ฅผ ImageNet์์ ์ฑ๊ณต์ ์ผ๋ก ํ๋ จ์ํจ ์ฒซ ๋ฒ์งธ ๋ ผ๋ฌธ์ผ๋ก, ๊ธฐ์กด์ ์ ์๋ ค์ง ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง(CNN) ๊ตฌ์กฐ์ ๋น๊ตํด ๋งค์ฐ ์ฐ์ํ ๊ฒฐ๊ณผ๋ฅผ ๋ฌ์ฑํ์ต๋๋ค.
๋ ผ๋ฌธ์ ์ด๋ก์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
Transformer ์ํคํ ์ฒ๋ ์์ฐ์ด ์ฒ๋ฆฌ ์์ ์์ ์ฌ์ค์ ํ์ค์ผ๋ก ์๋ฆฌ ์ก์์ผ๋, ์ปดํจํฐ ๋น์ ๋ถ์ผ์์์ ์ ์ฉ์ ์ฌ์ ํ ์ ํ์ ์ ๋๋ค. ๋น์ ์์ ์ดํ ์ ๋ฉ์ปค๋์ฆ์ ์ข ์ข ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง(CNN)๊ณผ ๊ฒฐํฉํ์ฌ ์ฌ์ฉ๋๊ฑฐ๋, ์ ์ฒด ๊ตฌ์กฐ๋ฅผ ์ ์งํ๋ฉด์ ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง์ ํน์ ๊ตฌ์ฑ ์์๋ฅผ ๋์ฒดํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. ์ฐ๋ฆฌ๋ ์ด๋ฌํ CNN ์์กด์ฑ์ด ํ์ํ์ง ์์ผ๋ฉฐ, ์ด๋ฏธ์ง ํจ์น๋ฅผ ์์ฐจ์ ์ผ๋ก ์ ๋ ฅ๋ฐ๋ ์์ํ Transformer๊ฐ ์ด๋ฏธ์ง ๋ถ๋ฅ ์์ ์์ ๋งค์ฐ ์ฐ์ํ ์ฑ๋ฅ์ ๋ฐํํ ์ ์์์ ๋ณด์ฌ์ค๋๋ค. ๋๊ท๋ชจ ๋ฐ์ดํฐ๋ก ์ฌ์ ํ์ต๋ ํ, ImageNet, CIFAR-100, VTAB ๋ฑ ๋ค์ํ ์ค์ํ ์ด๋ฏธ์ง ์ธ์ ๋ฒค์น๋งํฌ์ ์ ์ฉํ๋ฉด Vision Transformer(ViT)๋ ์ต์ ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง๊ณผ ๋น๊ตํด ๋งค์ฐ ์ฐ์ํ ์ฑ๋ฅ์ ๋ฐํํ๋ฉด์๋ ํ๋ จ์ ํ์ํ ๊ณ์ฐ ์์์ ์๋นํ ์ค์ผ ์ ์์ต๋๋ค.

ViT ์ํคํ ์ฒ. ์๋ณธ ๋ ผ๋ฌธ์์ ๋ฐ์ท.
์๋์ Vision Transformer์ ์ด์ด, ์ฌ๋ฌ ํ์ ์ฐ๊ตฌ๋ค์ด ์งํ๋์์ต๋๋ค:
DeiT (Data-efficient Image Transformers) (Facebook AI ๊ฐ๋ฐ). DeiT ๋ชจ๋ธ์ distilled vision transformers์ ๋๋ค. DeiT์ ์ ์๋ค์ ๋ ํจ์จ์ ์ผ๋ก ํ๋ จ๋ ViT ๋ชจ๋ธ๋ ๊ณต๊ฐํ์ผ๋ฉฐ, ์ด๋ [
ViTModel] ๋๋ [ViTForImageClassification]์ ๋ฐ๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ฌ๊ธฐ์๋ 3๊ฐ์ง ํฌ๊ธฐ๋ก 4๊ฐ์ ๋ณํ์ด ์ ๊ณต๋ฉ๋๋ค: facebook/deit-tiny-patch16-224, facebook/deit-small-patch16-224, facebook/deit-base-patch16-224 and facebook/deit-base-patch16-384. ๊ทธ๋ฆฌ๊ณ ๋ชจ๋ธ์ ์ด๋ฏธ์ง๋ฅผ ์ค๋นํ๋ ค๋ฉด [DeiTImageProcessor]๋ฅผ ์ฌ์ฉํด์ผ ํ๋ค๋ ์ ์ ์ ์ํ์ญ์์ค.BEiT (BERT pre-training of Image Transformers) (Microsoft Research ๊ฐ๋ฐ). BEiT ๋ชจ๋ธ์ BERT (masked image modeling)์ ์๊ฐ์ ๋ฐ๊ณ VQ-VAE์ ๊ธฐ๋ฐํ self-supervised ๋ฐฉ๋ฒ์ ์ด์ฉํ์ฌ supervised pre-trained vision transformers๋ณด๋ค ๋ ์ฐ์ํ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค.
DINO (Vision Transformers์ self-supervised ํ๋ จ์ ์ํ ๋ฐฉ๋ฒ) (Facebook AI ๊ฐ๋ฐ). DINO ๋ฐฉ๋ฒ์ผ๋ก ํ๋ จ๋ Vision Transformer๋ ํ์ต๋์ง ์์ ์ํ์์๋ ๊ฐ์ฒด๋ฅผ ๋ถํ ํ ์ ์๋ ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง์์๋ ๋ณผ ์ ์๋ ๋งค์ฐ ํฅ๋ฏธ๋ก์ด ๋ฅ๋ ฅ์ ๋ณด์ฌ์ค๋๋ค. DINO ์ฒดํฌํฌ์ธํธ๋ hub์์ ์ฐพ์ ์ ์์ต๋๋ค.
MAE (Masked Autoencoders) (Facebook AI ๊ฐ๋ฐ). Vision Transformer๋ฅผ ๋น๋์นญ ์ธ์ฝ๋-๋์ฝ๋ ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํ์ฌ ๋ง์คํฌ๋ ํจ์น์ ๋์ ๋น์จ(75%)์์ ํฝ์ ๊ฐ์ ์ฌ๊ตฌ์ฑํ๋๋ก ์ฌ์ ํ์ตํจ์ผ๋ก์จ, ์ ์๋ค์ ์ด ๊ฐ๋จํ ๋ฐฉ๋ฒ์ด ๋ฏธ์ธ ์กฐ์ ํ supervised ๋ฐฉ์์ ์ฌ์ ํ์ต์ ๋ฅ๊ฐํ๋ค๋ ๊ฒ์ ๋ณด์ฌ์ฃผ์์ต๋๋ค.
์ด ๋ชจ๋ธ์ nielsr์ ์ํด ๊ธฐ์ฌ๋์์ต๋๋ค. ์๋ณธ ์ฝ๋(JAX๋ก ์์ฑ๋จ)์ ์ฌ๊ธฐ์์ ํ์ธํ ์ ์์ต๋๋ค.
์ฐธ๊ณ ๋ก, ์ฐ๋ฆฌ๋ Ross Wightman์ timm ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ JAX์์ PyTorch๋ก ๋ณํ๋ ๊ฐ์ค์น๋ฅผ ๋ค์ ๋ณํํ์ต๋๋ค. ๋ชจ๋ ๊ณต๋ก๋ ๊ทธ์๊ฒ ๋๋ฆฝ๋๋ค!
์ฌ์ฉ ํ [[usage-tips]]
- Transformer ์ธ์ฝ๋์ ์ด๋ฏธ์ง๋ฅผ ์ ๋ ฅํ๊ธฐ ์ํด, ๊ฐ ์ด๋ฏธ์ง๋ ๊ณ ์ ํฌ๊ธฐ์ ๊ฒน์น์ง ์๋ ํจ์น๋ค๋ก ๋ถํ ๋ ํ ์ ํ ์๋ฒ ๋ฉ๋ฉ๋๋ค. ์ ์ฒด ์ด๋ฏธ์ง๋ฅผ ๋ํํ๋ [CLS] ํ ํฐ์ด ์ถ๊ฐ๋์ด, ๋ถ๋ฅ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ ์๋ค์ ๋ํ ์ ๋ ์์น ์๋ฒ ๋ฉ์ ์ถ๊ฐํ์ฌ, ๊ฒฐ๊ณผ์ ์ผ๋ก ์์ฑ๋ ๋ฒกํฐ ์ํ์ค๋ฅผ ํ์ค Transformer ์ธ์ฝ๋์ ์ ๋ ฅํฉ๋๋ค.
- Vision Transformer๋ ๋ชจ๋ ์ด๋ฏธ์ง๊ฐ ๋์ผํ ํฌ๊ธฐ(ํด์๋)์ฌ์ผ ํ๋ฏ๋ก, [ViTImageProcessor]๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ ๋ชจ๋ธ์ ๋ง๊ฒ ๋ฆฌ์ฌ์ด์ฆ(๋๋ ๋ฆฌ์ค์ผ์ผ)ํ๊ณ ์ ๊ทํํ ์ ์์ต๋๋ค.
- ์ฌ์ ํ์ต์ด๋ ๋ฏธ์ธ ์กฐ์ ์ ์ฌ์ฉ๋ ํจ์น ํด์๋์ ์ด๋ฏธ์ง ํด์๋๋ ๊ฐ ์ฒดํฌํฌ์ธํธ์ ์ด๋ฆ์ ๋ฐ์๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด,
google/vit-base-patch16-224๋ ํจ์น ํด์๋๊ฐ 16x16์ด๊ณ ๋ฏธ์ธ ์กฐ์ ํด์๋๊ฐ 224x224์ธ ๊ธฐ๋ณธ ํฌ๊ธฐ ์ํคํ ์ฒ๋ฅผ ๋ํ๋ ๋๋ค. ๋ชจ๋ ์ฒดํฌํฌ์ธํธ๋ hub์์ ํ์ธํ ์ ์์ต๋๋ค. - ์ฌ์ฉํ ์ ์๋ ์ฒดํฌํฌ์ธํธ๋ (1) ImageNet-21k (1,400๋ง ๊ฐ์ ์ด๋ฏธ์ง์ 21,000๊ฐ์ ํด๋์ค)์์๋ง ์ฌ์ ํ์ต๋์๊ฑฐ๋, ๋๋ (2) ImageNet (ILSVRC 2012, 130๋ง ๊ฐ์ ์ด๋ฏธ์ง์ 1,000๊ฐ์ ํด๋์ค)์์ ์ถ๊ฐ๋ก ๋ฏธ์ธ ์กฐ์ ๋ ๊ฒฝ์ฐ์ ๋๋ค.
- Vision Transformer๋ 224x224 ํด์๋๋ก ์ฌ์ ํ์ต๋์์ต๋๋ค. ๋ฏธ์ธ ์กฐ์ ์, ์ฌ์ ํ์ต๋ณด๋ค ๋ ๋์ ํด์๋๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ ๋ฆฌํ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค ((Touvron et al., 2019), (Kolesnikovet al., 2020). ๋ ๋์ ํด์๋๋ก ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ ์ํด, ์ ์๋ค์ ์๋ณธ ์ด๋ฏธ์ง์์์ ์์น์ ๋ฐ๋ผ ์ฌ์ ํ์ต๋ ์์น ์๋ฒ ๋ฉ์ 2D ๋ณด๊ฐ(interpolation)์ ์ํํฉ๋๋ค.
- ์ต๊ณ ์ ๊ฒฐ๊ณผ๋ supervised ๋ฐฉ์์ ์ฌ์ ํ์ต์์ ์ป์ด์ก์ผ๋ฉฐ, ์ด๋ NLP์์๋ ํด๋น๋์ง ์๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ์ ์๋ค์ ๋ง์คํฌ๋ ํจ์น ์์ธก(๋ง์คํฌ๋ ์ธ์ด ๋ชจ๋ธ๋ง์์ ์๊ฐ์ ๋ฐ์ self-supervised ์ฌ์ ํ์ต ๋ชฉํ)์ ์ฌ์ฉํ ์คํ๋ ์ํํ์ต๋๋ค. ์ด ์ ๊ทผ ๋ฐฉ์์ผ๋ก ๋ ์์ ViT-B/16 ๋ชจ๋ธ์ ImageNet์์ 79.9%์ ์ ํ๋๋ฅผ ๋ฌ์ฑํ์์ผ๋ฉฐ, ์ด๋ ์ฒ์๋ถํฐ ํ์ตํ ๊ฒ๋ณด๋ค 2% ๊ฐ์ ๋ ๊ฒฐ๊ณผ์ด์ง๋ง, ์ฌ์ ํ supervised ์ฌ์ ํ์ต๋ณด๋ค 4% ๋ฎ์ต๋๋ค.
Scaled Dot Product Attention (SDPA) ์ฌ์ฉํ๊ธฐ [[using-scaled-dot-product-attention-sdpa]]
PyTorch๋ torch.nn.functional์ ์ผ๋ถ๋ก์ native scaled dot-product attention (SDPA) ์ฐ์ฐ์๋ฅผ ํฌํจํ๊ณ ์์ต๋๋ค. ์ด ํจ์๋ ์
๋ ฅ ๋ฐ ์ฌ์ฉ ์ค์ธ ํ๋์จ์ด์ ๋ฐ๋ผ ์ฌ๋ฌ ๊ตฌํ ๋ฐฉ์์ ์ ์ฉํ ์ ์์ต๋๋ค.์์ธํ ๋ด์ฉ์ ๊ณต์ ๋ฌธ์๋ GPU ์ถ๋ก ํ์ด์ง๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
SDPA๋ torch>=2.1.1์์ ๊ตฌํ์ด ๊ฐ๋ฅํ ๊ฒฝ์ฐ ๊ธฐ๋ณธ์ ์ผ๋ก ์ฌ์ฉ๋์ง๋ง, from_pretrained()์์ attn_implementation="sdpa"๋ก ์ค์ ํ์ฌ SDPA๋ฅผ ๋ช
์์ ์ผ๋ก ์์ฒญํ ์๋ ์์ต๋๋ค.
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
์ต์ ์ ์๋ ํฅ์์ ์ํด ๋ชจ๋ธ์ ๋ฐ์ ๋ฐ๋(์: torch.float16 ๋๋ torch.bfloat16)๋ก ๋ก๋ํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
๋ก์ปฌ ๋ฒค์น๋งํฌ(A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04)์์ float32์ google/vit-base-patch16-224 ๋ชจ๋ธ์ ์ฌ์ฉํ ์ถ๋ก ์, ๋ค์๊ณผ ๊ฐ์ ์๋ ํฅ์์ ํ์ธํ์ต๋๋ค.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|---|---|---|---|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
๋ฆฌ์์ค [[resources]]
ViT์ ์ถ๋ก ๋ฐ ์ปค์คํ ๋ฐ์ดํฐ์ ๋ํ ๋ฏธ์ธ ์กฐ์ ๊ณผ ๊ด๋ จ๋ ๋ฐ๋ชจ ๋ ธํธ๋ถ์ ์ฌ๊ธฐ์์ ํ์ธํ ์ ์์ต๋๋ค. Hugging Face์์ ๊ณต์์ ์ผ๋ก ์ ๊ณตํ๋ ์๋ฃ์ ์ปค๋ฎค๋ํฐ(๐๋ก ํ์๋) ์๋ฃ ๋ชฉ๋ก์ ViT๋ฅผ ์์ํ๋ ๋ฐ ๋์์ด ๋ ๊ฒ์ ๋๋ค. ์ด ๋ชฉ๋ก์ ํฌํจ๋ ์๋ฃ๋ฅผ ์ ์ถํ๊ณ ์ถ๋ค๋ฉด Pull Request๋ฅผ ์ด์ด ์ฃผ์๋ฉด ๊ฒํ ํ๊ฒ ์ต๋๋ค. ์๋ก์ด ๋ด์ฉ์ ์ค๋ช ํ๋ ์๋ฃ๊ฐ ๊ฐ์ฅ ์ด์์ ์ด๋ฉฐ, ๊ธฐ์กด ์๋ฃ๋ฅผ ์ค๋ณตํ์ง ์๋๋ก ํด์ฃผ์ญ์์ค.
ViTForImageClassification ์ ๋ค์์์ ์ง์๋ฉ๋๋ค:
- Hugging Face Transformers๋ก ViT๋ฅผ ์ด๋ฏธ์ง ๋ถ๋ฅ์ ๋ง๊ฒ ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ
- Hugging Face Transformers์
Keras๋ฅผ ์ฌ์ฉํ ์ด๋ฏธ์ง ๋ถ๋ฅ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ - Hugging Face Transformers๋ฅผ ์ฌ์ฉํ ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ฏธ์ธ ์กฐ์ ์ ๋ํ ๋ ธํธ๋ถ
- Hugging Face Trainer๋ก CIFAR-10์์ Vision Transformer ๋ฏธ์ธ ์กฐ์ ์ ๋ํ ๋ ธํธ๋ถ
- PyTorch Lightning์ผ๋ก CIFAR-10์์ Vision Transformer ๋ฏธ์ธ ์กฐ์ ์ ๋ํ ๋ ธํธ๋ถ
โ๏ธ ์ต์ ํ
- Optimum์ ์ฌ์ฉํ ์์ํ๋ฅผ ํตํด Vision Transformer(ViT) ๊ฐ์์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ
โก๏ธ ์ถ๋ก
- Google Brain์ Vision Transformer(ViT) ๋น ๋ฅธ ๋ฐ๋ชจ์ ๋ํ ๋ ธํธ๋ถ
๐ ๋ฐฐํฌ
- TF Serving์ผ๋ก Hugging Face์์ Tensorflow Vision ๋ชจ๋ธ ๋ฐฐํฌ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ
- Vertex AI์์ Hugging Face ViT ๋ฐฐํฌ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ
- TF Serving์ ์ฌ์ฉํ์ฌ Kubernetes์์ Hugging Face ViT ๋ฐฐํฌ์ ๋ํ ๋ธ๋ก๊ทธ ํฌ์คํธ
ViTConfig [[transformers.ViTConfig]]
[[autodoc]] ViTConfig
ViTFeatureExtractor [[transformers.ViTFeatureExtractor]]
[[autodoc]] ViTFeatureExtractor - call
ViTImageProcessor [[transformers.ViTImageProcessor]]
[[autodoc]] ViTImageProcessor - preprocess
ViTImageProcessorFast [[transformers.ViTImageProcessorFast]]
[[autodoc]] ViTImageProcessorFast - preprocess
ViTModel [[transformers.ViTModel]]
[[autodoc]] ViTModel - forward
ViTForMaskedImageModeling [[transformers.ViTForMaskedImageModeling]]
[[autodoc]] ViTForMaskedImageModeling - forward
ViTForImageClassification [[transformers.ViTForImageClassification]]
[[autodoc]] ViTForImageClassification - forward
TFViTModel [[transformers.TFViTModel]]
[[autodoc]] TFViTModel - call
TFViTForImageClassification [[transformers.TFViTForImageClassification]]
[[autodoc]] TFViTForImageClassification - call
FlaxVitModel [[transformers.FlaxViTModel]]
[[autodoc]] FlaxViTModel - call
FlaxViTForImageClassification [[transformers.FlaxViTForImageClassification]]
[[autodoc]] FlaxViTForImageClassification - call