Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- README.md +188 -6
- app.py +406 -0
- requirements.txt +4 -0
- transformer.py +183 -0
README.md
CHANGED
|
@@ -1,13 +1,195 @@
|
|
| 1 |
---
|
| 2 |
-
title: Transformer
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
python_version:
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Transformer Demo
|
| 3 |
+
emoji: π€
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.29.0
|
| 8 |
+
python_version: "3.10"
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
license: mit
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Transformer β λ
Όλ¬Έ μ¬ν λ°λͺ¨
|
| 15 |
+
|
| 16 |
+
**λ
Όλ¬Έ**: [Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani et al., NIPS 2017)
|
| 17 |
+
|
| 18 |
+
> RNNκ³Ό CNNμ λͺ¨λ λ²λ¦¬κ³ **μ€μ§ attentionλ§μΌλ‘** μΈμ½λ-λμ½λλ₯Ό ꡬμ±ν
|
| 19 |
+
> Transformer λ
Όλ¬Έμ μ²μλΆν° μ¬ννκ³ , νμ΅λ λͺ¨λΈμ μ§μ 체νν μ μλ Spaceμ
λλ€.
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## 무μμ ν μ μλμ?
|
| 24 |
+
|
| 25 |
+
μ«μ μνμ€λ₯Ό μ
λ ₯νλ©΄ **Transformerκ° λ€μ§μ΄** μ€λλ€.
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
μ
λ ₯ : 1 2 3 4 5
|
| 29 |
+
μΆλ ₯ : 5 4 3 2 1
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
κ·Έλ¦¬κ³ λ ν₯λ―Έλ‘μ΄ κ±΄ β λμ½λμ **cross-attention κ°μ€μΉ**λ₯Ό μκ°νν΄μ
|
| 33 |
+
λͺ¨λΈμ΄ "μΆλ ₯ iλ²μ§Έ μμΉλ₯Ό λ§λ€ λ μ
λ ₯ μ΄λλ₯Ό λ΄€λμ§"λ₯Ό μ§μ λ³Ό μ μλ€λ κ±°μμ.
|
| 34 |
+
λ€μ§κΈ° νμ€ν¬μμλ **λ°λκ°μ (anti-diagonal) ν¨ν΄**μ΄ λλ ·μ΄ λνλ©λλ€.
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
## μ λ²μμ΄ μλλΌ μ«μ λ€μ§κΈ°μΈκ°μ?
|
| 39 |
+
|
| 40 |
+
λ
Όλ¬Έμ μμ΄βλ
μΌμ΄ λ²μμΌλ‘ κ²μ¦νμ§λ§, 그건 8Γ P100 GPUλ‘ 12μκ° νμ΅μ΄ νμν΄μ.
|
| 41 |
+
λ¬΄λ£ Spaceμμ κ·Έκ² μ λλκΉ, **λΆν
μ 30μ΄ μμ νμ΅ λλλ toy task**λ₯Ό 골λμ΅λλ€.
|
| 42 |
+
|
| 43 |
+
μ«μ λ€μ§κΈ°μ μ₯μ :
|
| 44 |
+
|
| 45 |
+
- μ΄νκ° μμ (0~9 + νΉμ ν ν° = 13κ°)
|
| 46 |
+
- μ
μΆλ ₯ κΈΈμ΄κ° κ°κ³ μ λ΅μ΄ λͺ
ν
|
| 47 |
+
- **μ₯거리 μμ‘΄μ±**μ κ°μ β μΆλ ₯ 1λ²μ§Έλ μ
λ ₯ λ§μ§λ§μ λ΄μΌ ν¨
|
| 48 |
+
- μκ°νκ° κ·Ήμ (λ°λκ°μ ν¨ν΄)
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## νλ‘μ νΈ κ΅¬μ‘°
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
βββ app.py # Gradio λ°λͺ¨ (νμ΅ + μΆλ‘ + μκ°ν)
|
| 56 |
+
βββ transformer.py # λ
Όλ¬Έμ κ·Έλλ‘ μ¬νν Transformer 본체
|
| 57 |
+
βββ requirements.txt # ν¨ν€μ§ λͺ©λ‘
|
| 58 |
+
βββ README.md # μ΄ νμΌ
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## λͺ¨λΈ ꡬμ±
|
| 64 |
+
|
| 65 |
+
μ΄ λ°λͺ¨λ λ
Όλ¬Έ base λͺ¨λΈμ **1/8 ν¬κΈ°**μ
λλ€. ꡬ쑰λ μμ ν λμΌνκ³ ν¬κΈ°λ§ μ€μμ΄μ.
|
| 66 |
+
|
| 67 |
+
| νλͺ© | λ
Όλ¬Έ base | μ΄ λ°λͺ¨ |
|
| 68 |
+
|------|-----------|---------|
|
| 69 |
+
| d_model | 512 | **64** |
|
| 70 |
+
| μΈ΅ μ N | 6 | **2** |
|
| 71 |
+
| ν€λ μ h | 8 | **4** |
|
| 72 |
+
| d_ff | 2048 | **128** |
|
| 73 |
+
| μ΄ν ν¬κΈ° | 37K (BPE) | **13** |
|
| 74 |
+
| νλΌλ―Έν° | 65M | **~80K** |
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
## νμ΅ μ€μ
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
optimizer = Adam(lr=5e-4, betas=(0.9, 0.98), eps=1e-9) # λ
Όλ¬Έ Β§5.3
|
| 82 |
+
loss = CrossEntropy(ignore_index=PAD, label_smoothing=0.1)
|
| 83 |
+
steps = 2000
|
| 84 |
+
batch = 128
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
- λ§€ stepλ§λ€ κΈΈμ΄ 3~10μ 무μμ μ«μμ΄μ μλ‘ μμ± (λ©λͺ¨λ¦¬ μ μ½)
|
| 88 |
+
- Gradient clipping = 1.0
|
| 89 |
+
- Greedy decodingμΌλ‘ μΆλ‘
|
| 90 |
+
|
| 91 |
+
νμ΅μ λΆν
ν λ μλμΌλ‘ μ§νλλ©°, λλ λͺ¨λΈμ `model.pt`λ‘ μΊμ±λ©λλ€.
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## λ
Όλ¬Έ ν΅μ¬ λΆλΆ μ½λ λ§€ν
|
| 96 |
+
|
| 97 |
+
| λ
Όλ¬Έ μμΉ | μ½λ μμΉ |
|
| 98 |
+
|-----------|-----------|
|
| 99 |
+
| μ (1) `softmax(QKα΅/βd_k)V` | `transformer.py :: scaled_dot_product_attention` |
|
| 100 |
+
| Β§3.2.2 Multi-Head | `MultiHeadAttention` |
|
| 101 |
+
| Β§3.5 Positional Encoding | `PositionalEncoding` |
|
| 102 |
+
| μ (2) FFN | `FeedForward` |
|
| 103 |
+
| Β§3.1 μΈμ½λ 1μΈ΅ | `EncoderLayer` (Post-LN) |
|
| 104 |
+
| Β§3.1 λμ½λ 1μΈ΅ | `DecoderLayer` (Post-LN) |
|
| 105 |
+
| Β§3.4 μλ² λ© Γ βd_model | `Transformer.encode` λ΄λΆ |
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
## μ΄λ»κ² λ΄μΌ νλμ? (μκ°ν ν΄μ)
|
| 110 |
+
|
| 111 |
+
**Cross-Attention ννΈλ§΅**:
|
| 112 |
+
|
| 113 |
+
- κ°λ‘μΆ: μΈμ½λ μμΉ (μ
λ ₯ ν ν°λ€, μΌμͺ½μ΄ μνμ€ μμͺ½)
|
| 114 |
+
- μΈλ‘μΆ: λμ½λ μμΉ (μΆλ ₯ ν ν°λ€, μμͺ½μ΄ λ¨Όμ μμ±)
|
| 115 |
+
- μμ΄ λ°μμλ‘ κ°ν attention
|
| 116 |
+
|
| 117 |
+
λ€μ§κΈ° νμ€ν¬μμ μ νμ΅λ λͺ¨λΈμ:
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
μΆλ ₯ μμΉ 0 (BOS λ€μ, 첫 μΆλ ₯ ν ν°) β μ
λ ₯ λ§μ§λ§ ν ν°μ λ΄
|
| 121 |
+
μΆλ ₯ μμΉ 1 β μ
λ ₯ λμμ λ λ²μ§Έλ₯Ό λ΄
|
| 122 |
+
...
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
λ°λΌμ **μΌμͺ½ μ β μ€λ₯Έμͺ½ μλ λκ°μ **μ λ°λ λ°©ν₯, μ¦
|
| 126 |
+
**μ€λ₯Έμͺ½ μ β μΌμͺ½ μλλ‘ νλ₯΄λ anti-diagonal**μ΄ λ³΄μ΄λ©΄ μ±κ³΅μ
λλ€.
|
| 127 |
+
|
| 128 |
+
---
|
| 129 |
+
|
| 130 |
+
## Hugging Face Spaces λ°°ν¬ μ μ£Όμμ¬ν
|
| 131 |
+
|
| 132 |
+
ResNet λ°λͺ¨λ₯Ό λ°°ν¬ν λ κ²ͺμλ λ¬Έμ λ€μ΄ μ¬κΈ°μλ λμΌνκ² λ°μν μ μμ΄μ:
|
| 133 |
+
|
| 134 |
+
### 1. YAML νλ‘ νΈλ§€ν° νμ
|
| 135 |
+
|
| 136 |
+
μ΄ README.md μ΅μλ¨μ `--- ... ---` λΈλ‘μ΄ μμΌλ©΄ Spaceκ° λΉλλμ§ μμ΅λλ€.
|
| 137 |
+
|
| 138 |
+
### 2. `colorFrom`/`colorTo`λ μ ν΄μ§ 8μλ§
|
| 139 |
+
|
| 140 |
+
νμ©λλ μ: `red, yellow, green, blue, indigo, purple, pink, gray`
|
| 141 |
+
|
| 142 |
+
### 3. Python 3.13 ννΌ
|
| 143 |
+
|
| 144 |
+
`audioop` νμ€ λΌμ΄λΈλ¬λ¦¬κ° 3.13μμ μ κ±°λμ΄ μΌλΆ ν¨ν€μ§ λΉλ μ€ν¨. **3.10** κΆμ₯.
|
| 145 |
+
|
| 146 |
+
### 4. PyTorch CPU λΉλ
|
| 147 |
+
|
| 148 |
+
κΈ°λ³Έμ μΌλ‘ λ¬΄λ£ Spaceλ CPUλ§ μ 곡λ©λλ€. `torch` μ€μΉ μ CUDA λ²μ μ΄ λ€μ΄κ°λ©΄
|
| 149 |
+
λμ€ν¬ μ©λμ μ΄κ³Όν μ μμΌλ νμμ `torch --index-url https://download.pytorch.org/whl/cpu`λ‘
|
| 150 |
+
λͺ
μνμΈμ.
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
## λ‘컬 μ€ν
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
# 1) μμ‘΄μ± μ€μΉ
|
| 158 |
+
pip install -r requirements.txt
|
| 159 |
+
|
| 160 |
+
# 2) λ°λͺ¨ μ€ν (첫 μ€ν μ μλ νμ΅)
|
| 161 |
+
python app.py
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
κΈ°λ³Έμ μΌλ‘ `http://127.0.0.1:7860` μμ μ΄λ¦½λλ€.
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## νμ΅μ΄ μ μ λλ©΄
|
| 169 |
+
|
| 170 |
+
체ν¬λ¦¬μ€νΈ:
|
| 171 |
+
|
| 172 |
+
- [ ] PyTorch λ²μ μ΄ 2.0 μ΄μμΈκ°
|
| 173 |
+
- [ ] νμ΅ stepμ΄ 2000λ² μ΄μ λλκ° (μ½μμ step 200, 400, ... λ‘κ·Έ νμΈ)
|
| 174 |
+
- [ ] step 1000μ―€ λλ©΄ `token_acc`κ° 0.95 μ΄μμΈκ°
|
| 175 |
+
- [ ] μΆλ ₯μ΄ νμ κ°μ ν ν°λ§ λ°λ³΅νλ€λ©΄ β νμ΅μ΄ κ±°μ μ λ κ². step λ리거λ lr μ‘°μ
|
| 176 |
+
- [ ] cross-attentionμ΄ κ· μΌ(uniform)νλ€λ©΄ β λ νμ΅ νμ
|
| 177 |
+
|
| 178 |
+
---
|
| 179 |
+
|
| 180 |
+
## μ°Έκ³
|
| 181 |
+
|
| 182 |
+
```bibtex
|
| 183 |
+
@inproceedings{vaswani2017attention,
|
| 184 |
+
title = {Attention Is All You Need},
|
| 185 |
+
author = {Vaswani, Ashish and Shazeer, Noam and Parmar, Niki
|
| 186 |
+
and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N
|
| 187 |
+
and Kaiser, {\L}ukasz and Polosukhin, Illia},
|
| 188 |
+
booktitle = {Advances in Neural Information Processing Systems},
|
| 189 |
+
year = {2017}
|
| 190 |
+
}
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
- π λ
Όλ¬Έ: [arXiv:1706.03762](https://arxiv.org/abs/1706.03762)
|
| 194 |
+
- π The Annotated Transformer: <http://nlp.seas.harvard.edu/annotated-transformer/>
|
| 195 |
+
- π₯ The Illustrated Transformer (Jay Alammar): <https://jalammar.github.io/illustrated-transformer/>
|
app.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transformer λ°λͺ¨ β μ«μ μνμ€ λ€μ§κΈ° (digit reversal)
|
| 3 |
+
|
| 4 |
+
ꡬμ±:
|
| 5 |
+
1) μμ Transformerλ₯Ό λΆν
μ μ¦μμμ νμ΅ (~30μ΄)
|
| 6 |
+
2) Gradio UIμμ μ¬μ©μκ° μ
λ ₯ν μ«μμ΄μ λ€μ§μ΄ μΆλ ₯
|
| 7 |
+
3) λμ½λμ cross-attentionμ μκ°ν β λ©μ§ anti-diagonal ν¨ν΄
|
| 8 |
+
|
| 9 |
+
μ΄ νμ€ν¬λ λ¨μνμ§λ§ Transformerκ° μμΉ κ° μνΈμμ©μ μ΄λ»κ² νμ΅νλμ§
|
| 10 |
+
κ°μ₯ μ§κ΄μ μΌλ‘ 보μ¬μ€λλ€.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import math
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
import matplotlib
|
| 20 |
+
matplotlib.use("Agg")
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
import gradio as gr
|
| 23 |
+
|
| 24 |
+
from transformer import Transformer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
# ν ν° μ μ
|
| 29 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
PAD, BOS, EOS = 0, 1, 2
|
| 31 |
+
DIGIT_OFFSET = 3 # μ«μ d β ν ν° id (d + 3)
|
| 32 |
+
VOCAB = DIGIT_OFFSET + 10 # 0~9 κΉμ§ β μ΄ 13κ° ν ν°
|
| 33 |
+
ID2STR = {PAD: "<P>", BOS: "<S>", EOS: "<E>"}
|
| 34 |
+
for d in range(10):
|
| 35 |
+
ID2STR[d + DIGIT_OFFSET] = str(d)
|
| 36 |
+
|
| 37 |
+
MAX_INPUT_LEN = 10 # μ¬μ©μ μ
λ ₯ μλ¦Ώμ μν
|
| 38 |
+
MAX_DECODE_LEN = MAX_INPUT_LEN + 2
|
| 39 |
+
|
| 40 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
# λͺ¨λΈ νμ΄νΌνλΌλ―Έν° (μκ²)
|
| 42 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
D_MODEL = 64
|
| 44 |
+
N_LAYERS = 2
|
| 45 |
+
N_HEADS = 4
|
| 46 |
+
D_FF = 128
|
| 47 |
+
DROPOUT = 0.1
|
| 48 |
+
|
| 49 |
+
CKPT_PATH = "model.pt"
|
| 50 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
# ν ν°ν μ νΈ
|
| 55 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
def digits_to_ids(digits, add_bos_eos=True):
|
| 57 |
+
ids = [d + DIGIT_OFFSET for d in digits]
|
| 58 |
+
if add_bos_eos:
|
| 59 |
+
ids = [BOS] + ids + [EOS]
|
| 60 |
+
return ids
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def ids_to_digits(ids, stop_at_eos=True):
|
| 64 |
+
out = []
|
| 65 |
+
for i in ids:
|
| 66 |
+
if stop_at_eos and i == EOS:
|
| 67 |
+
break
|
| 68 |
+
if i >= DIGIT_OFFSET:
|
| 69 |
+
out.append(i - DIGIT_OFFSET)
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def parse_user_input(text):
|
| 74 |
+
"""μ¬μ©μ μ
λ ₯ λ¬Έμμ΄μμ μ«μ μΆμΆ. κ³΅λ°±Β·μ½€λ§ λ± λͺ¨λ νμ©."""
|
| 75 |
+
digits = []
|
| 76 |
+
for ch in text:
|
| 77 |
+
if ch.isdigit():
|
| 78 |
+
digits.append(int(ch))
|
| 79 |
+
return digits[:MAX_INPUT_LEN]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
+
# λ§μ€ν¬ λ§λ€κΈ°
|
| 84 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 85 |
+
def make_src_mask(src):
|
| 86 |
+
# (B, S) β (B, 1, 1, S)
|
| 87 |
+
return (src != PAD).unsqueeze(1).unsqueeze(2)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def make_tgt_mask(tgt):
|
| 91 |
+
"""ν¨λ© + causal λ§μ€ν¬ κ²°ν©"""
|
| 92 |
+
B, T = tgt.shape
|
| 93 |
+
pad_mask = (tgt != PAD).unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
|
| 94 |
+
causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool()
|
| 95 |
+
causal = causal.unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
|
| 96 |
+
return pad_mask & causal
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
# νμ΅ λ°μ΄ν° μμ±κΈ°
|
| 101 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 102 |
+
def make_batch(batch_size=128, min_len=3, max_len=8):
|
| 103 |
+
src_list, tgt_list = [], []
|
| 104 |
+
for _ in range(batch_size):
|
| 105 |
+
L = np.random.randint(min_len, max_len + 1)
|
| 106 |
+
digits = np.random.randint(0, 10, size=L).tolist()
|
| 107 |
+
src_list.append(digits_to_ids(digits))
|
| 108 |
+
tgt_list.append(digits_to_ids(digits[::-1]))
|
| 109 |
+
|
| 110 |
+
s_max = max(len(s) for s in src_list)
|
| 111 |
+
t_max = max(len(t) for t in tgt_list)
|
| 112 |
+
src = torch.tensor([s + [PAD] * (s_max - len(s)) for s in src_list])
|
| 113 |
+
tgt = torch.tensor([t + [PAD] * (t_max - len(t)) for t in tgt_list])
|
| 114 |
+
return src, tgt
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββ
|
| 118 |
+
# νμ΅ λ£¨ν
|
| 119 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 120 |
+
def train(model, steps=2000, batch_size=128, lr=5e-4, log_every=200):
|
| 121 |
+
model.train()
|
| 122 |
+
opt = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)
|
| 123 |
+
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD, label_smoothing=0.1)
|
| 124 |
+
|
| 125 |
+
print(f"[train] device={DEVICE}, steps={steps}, batch={batch_size}")
|
| 126 |
+
for step in range(1, steps + 1):
|
| 127 |
+
src, tgt = make_batch(batch_size, min_len=3, max_len=MAX_INPUT_LEN)
|
| 128 |
+
src, tgt = src.to(DEVICE), tgt.to(DEVICE)
|
| 129 |
+
tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:]
|
| 130 |
+
|
| 131 |
+
src_mask = make_src_mask(src)
|
| 132 |
+
tgt_mask = make_tgt_mask(tgt_in)
|
| 133 |
+
|
| 134 |
+
logits = model(src, tgt_in, src_mask, tgt_mask)
|
| 135 |
+
loss = loss_fn(logits.reshape(-1, VOCAB), tgt_out.reshape(-1))
|
| 136 |
+
|
| 137 |
+
opt.zero_grad()
|
| 138 |
+
loss.backward()
|
| 139 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 140 |
+
opt.step()
|
| 141 |
+
|
| 142 |
+
if step % log_every == 0 or step == 1:
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
pred = logits.argmax(-1)
|
| 145 |
+
mask = (tgt_out != PAD)
|
| 146 |
+
acc = ((pred == tgt_out) & mask).sum().item() / mask.sum().item()
|
| 147 |
+
print(f" step {step:4d} loss={loss.item():.4f} token_acc={acc:.3f}")
|
| 148 |
+
return model
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 152 |
+
# μΆλ‘ (Greedy decoding)
|
| 153 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
+
@torch.no_grad()
|
| 155 |
+
def greedy_decode(model, src_ids):
|
| 156 |
+
"""src_ids: list[int] (BOSΒ·EOS ν¬ν¨)"""
|
| 157 |
+
model.eval()
|
| 158 |
+
src = torch.tensor([src_ids], device=DEVICE)
|
| 159 |
+
src_mask = make_src_mask(src)
|
| 160 |
+
enc_out = model.encode(src, src_mask)
|
| 161 |
+
|
| 162 |
+
ys = torch.tensor([[BOS]], device=DEVICE)
|
| 163 |
+
for _ in range(MAX_DECODE_LEN):
|
| 164 |
+
tgt_mask = make_tgt_mask(ys)
|
| 165 |
+
dec_out = model.decode(ys, enc_out, src_mask, tgt_mask)
|
| 166 |
+
logits = model.out(dec_out)
|
| 167 |
+
next_tok = logits[:, -1].argmax(-1, keepdim=True)
|
| 168 |
+
ys = torch.cat([ys, next_tok], dim=1)
|
| 169 |
+
if next_tok.item() == EOS:
|
| 170 |
+
break
|
| 171 |
+
return ys[0].tolist()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 175 |
+
# μ΄ν
μ
μκ°ν
|
| 176 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 177 |
+
def plot_cross_attention(model, src_ids, tgt_ids, layer_idx=-1):
|
| 178 |
+
"""λμ½λ cross-attentionμ ν€λ νκ· μΌλ‘ κ·Έλ¦Ό."""
|
| 179 |
+
attn = model.get_decoder_cross_attn(layer_idx) # (1, h, T, S)
|
| 180 |
+
if attn is None:
|
| 181 |
+
return None
|
| 182 |
+
attn_avg = attn.mean(dim=1)[0].cpu().numpy() # (T, S)
|
| 183 |
+
|
| 184 |
+
src_labels = [ID2STR[i] for i in src_ids]
|
| 185 |
+
tgt_labels = [ID2STR[i] for i in tgt_ids]
|
| 186 |
+
|
| 187 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
| 188 |
+
im = ax.imshow(attn_avg, cmap="viridis", aspect="auto", vmin=0, vmax=attn_avg.max())
|
| 189 |
+
ax.set_xticks(range(len(src_labels)))
|
| 190 |
+
ax.set_xticklabels(src_labels)
|
| 191 |
+
ax.set_yticks(range(len(tgt_labels)))
|
| 192 |
+
ax.set_yticklabels(tgt_labels)
|
| 193 |
+
ax.set_xlabel("Encoder positions (μ
λ ₯)")
|
| 194 |
+
ax.set_ylabel("Decoder positions (μΆλ ₯)")
|
| 195 |
+
ax.set_title(f"Decoder β Encoder Cross-Attention\n(layer {layer_idx}, heads averaged)")
|
| 196 |
+
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 197 |
+
plt.tight_layout()
|
| 198 |
+
return fig
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def plot_positional_encoding(model, length=20):
|
| 202 |
+
"""positional encodingμ μ§μ μκ°ν."""
|
| 203 |
+
pe = model.pe.pe[0, :length].cpu().numpy() # (L, d_model)
|
| 204 |
+
fig, ax = plt.subplots(figsize=(7, 4))
|
| 205 |
+
im = ax.imshow(pe, cmap="RdBu", aspect="auto", vmin=-1, vmax=1)
|
| 206 |
+
ax.set_xlabel("Embedding dimension")
|
| 207 |
+
ax.set_ylabel("Position")
|
| 208 |
+
ax.set_title("Positional Encoding (sin/cos pattern)")
|
| 209 |
+
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 210 |
+
plt.tight_layout()
|
| 211 |
+
return fig
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 215 |
+
# λͺ¨λΈ μ€λΉ (νμ΅ λλ λ‘λ)
|
| 216 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 217 |
+
def build_model():
|
| 218 |
+
model = Transformer(
|
| 219 |
+
src_vocab=VOCAB,
|
| 220 |
+
tgt_vocab=VOCAB,
|
| 221 |
+
d_model=D_MODEL,
|
| 222 |
+
N=N_LAYERS,
|
| 223 |
+
h=N_HEADS,
|
| 224 |
+
d_ff=D_FF,
|
| 225 |
+
dropout=DROPOUT,
|
| 226 |
+
max_len=64,
|
| 227 |
+
).to(DEVICE)
|
| 228 |
+
return model
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def load_or_train():
|
| 232 |
+
model = build_model()
|
| 233 |
+
if os.path.exists(CKPT_PATH):
|
| 234 |
+
print(f"[init] loading checkpoint: {CKPT_PATH}")
|
| 235 |
+
state = torch.load(CKPT_PATH, map_location=DEVICE)
|
| 236 |
+
model.load_state_dict(state)
|
| 237 |
+
else:
|
| 238 |
+
print("[init] no checkpoint found β training from scratch")
|
| 239 |
+
train(model, steps=2000, batch_size=128, lr=5e-4)
|
| 240 |
+
torch.save(model.state_dict(), CKPT_PATH)
|
| 241 |
+
print(f"[init] saved checkpoint: {CKPT_PATH}")
|
| 242 |
+
model.eval()
|
| 243 |
+
return model
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
print("=" * 60)
|
| 247 |
+
print("Transformer Demo β initializing")
|
| 248 |
+
print("=" * 60)
|
| 249 |
+
MODEL = load_or_train()
|
| 250 |
+
PE_FIG = plot_positional_encoding(MODEL, length=20)
|
| 251 |
+
print("[init] ready β")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 255 |
+
# Gradio μ½λ°±
|
| 256 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 257 |
+
def run_inference(user_text):
|
| 258 |
+
digits = parse_user_input(user_text)
|
| 259 |
+
if len(digits) == 0:
|
| 260 |
+
return "β μ«μλ₯Ό μ
λ ₯ν΄ μ£ΌμΈμ.", None, "(μ΄ν
μ
μμ)"
|
| 261 |
+
if len(digits) > MAX_INPUT_LEN:
|
| 262 |
+
digits = digits[:MAX_INPUT_LEN]
|
| 263 |
+
|
| 264 |
+
src_ids = digits_to_ids(digits)
|
| 265 |
+
out_ids = greedy_decode(MODEL, src_ids)
|
| 266 |
+
pred_digits = ids_to_digits(out_ids[1:]) # BOS μ μΈ, EOSκΉμ§
|
| 267 |
+
|
| 268 |
+
expected = digits[::-1]
|
| 269 |
+
correct = pred_digits == expected
|
| 270 |
+
|
| 271 |
+
pred_str = " ".join(str(d) for d in pred_digits) if pred_digits else "(λΉ μΆλ ₯)"
|
| 272 |
+
expected_str = " ".join(str(d) for d in expected)
|
| 273 |
+
input_str = " ".join(str(d) for d in digits)
|
| 274 |
+
|
| 275 |
+
msg = (
|
| 276 |
+
f"**μ
λ ₯** : {input_str}\n\n"
|
| 277 |
+
f"**μμΈ‘ μΆλ ₯** : {pred_str}\n\n"
|
| 278 |
+
f"**μ λ΅** : {expected_str}\n\n"
|
| 279 |
+
f"**μΌμΉ μ¬λΆ** : {'β
μ λ΅!' if correct else 'β μ€λ΅'}"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# μκ°νλ₯Ό μν΄ λ€μ forward (cross-attention κ°±μ μ©)
|
| 283 |
+
src = torch.tensor([src_ids], device=DEVICE)
|
| 284 |
+
tgt = torch.tensor([out_ids], device=DEVICE)
|
| 285 |
+
src_mask = make_src_mask(src)
|
| 286 |
+
tgt_mask = make_tgt_mask(tgt)
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
MODEL(src, tgt, src_mask, tgt_mask)
|
| 289 |
+
|
| 290 |
+
fig = plot_cross_attention(MODEL, src_ids, out_ids, layer_idx=-1)
|
| 291 |
+
info = (
|
| 292 |
+
"μ΄ ννΈλ§΅μ λμ½λμ λ§μ§λ§ μΈ΅ cross-attention κ°μ€μΉ(ν€λ νκ· )μ
λλ€.\n"
|
| 293 |
+
"κ° ν(μΆλ ₯ μμΉ)μ΄ μ΄λ€ μ
λ ₯ μμΉλ₯Ό κ°μ₯ λ§μ΄ λ³΄κ³ μλμ§ λνλ
λλ€.\n"
|
| 294 |
+
"λ€μ§κΈ° νμ€ν¬μμλ **λ°λκ°μ (anti-diagonal)** ν¨ν΄μ΄ 보μ΄λ©΄ νμ΅ μ±κ³΅!"
|
| 295 |
+
)
|
| 296 |
+
return msg, fig, info
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 300 |
+
# UI ꡬμ±
|
| 301 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 302 |
+
with gr.Blocks(title="Transformer Demo β Digit Reversal") as demo:
|
| 303 |
+
gr.Markdown(
|
| 304 |
+
"""
|
| 305 |
+
# π€ Transformer λ°λͺ¨ β μ«μ μνμ€ λ€μ§κΈ°
|
| 306 |
+
|
| 307 |
+
Vaswani et al. (2017) **"Attention Is All You Need"** λ
Όλ¬Έμ μ²μλΆν° μ¬νν
|
| 308 |
+
Transformerλ‘ μ
λ ₯ μ«μμ΄μ λ€μ§μ΄ λ΄
λλ€.
|
| 309 |
+
|
| 310 |
+
- λͺ¨λΈ: d_model=64, N=2μΈ΅, h=4ν€λ (μ΄ ~80K νλΌλ―Έν°)
|
| 311 |
+
- νμ΅ λ°μ΄ν°: κΈΈμ΄ 3~10μ 무μμ μ«μμ΄, λ§€ step μλ‘ μμ±
|
| 312 |
+
- νμ΅ μκ°: λΆν
μ ~30μ΄ (CPU κΈ°μ€)
|
| 313 |
+
|
| 314 |
+
νλ¨ ν€λλ§΅μμ **λ°λκ°μ ν¨ν΄**μ΄ λ³΄μΈλ€λ©΄, λͺ¨λΈμ΄ "μΆλ ₯ iλ²μ§Έ = μ
λ ₯μ
|
| 315 |
+
λ°λνΈ μμΉ"λ₯Ό νμ΅νλ€λ μ¦κ±°μμ.
|
| 316 |
+
"""
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
with gr.Tab("λ€μ§κΈ°"):
|
| 320 |
+
with gr.Row():
|
| 321 |
+
with gr.Column(scale=1):
|
| 322 |
+
inp = gr.Textbox(
|
| 323 |
+
label="μ«μμ΄ μ
λ ₯ (μ΅λ 10μ리)",
|
| 324 |
+
placeholder="μ: 1 2 3 4 5 λλ 12345",
|
| 325 |
+
value="1 2 3 4 5 6 7",
|
| 326 |
+
)
|
| 327 |
+
btn = gr.Button("λ€μ§κΈ° μ€ν", variant="primary")
|
| 328 |
+
out_text = gr.Markdown()
|
| 329 |
+
|
| 330 |
+
gr.Examples(
|
| 331 |
+
examples=[
|
| 332 |
+
["1 2 3"],
|
| 333 |
+
["1 2 3 4 5"],
|
| 334 |
+
["9 8 7 6 5 4 3"],
|
| 335 |
+
["1 1 2 2 3 3"],
|
| 336 |
+
["0 1 2 3 4 5 6 7 8 9"],
|
| 337 |
+
],
|
| 338 |
+
inputs=inp,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
with gr.Column(scale=2):
|
| 342 |
+
attn_plot = gr.Plot(label="Cross-Attention Heatmap")
|
| 343 |
+
attn_info = gr.Markdown()
|
| 344 |
+
|
| 345 |
+
btn.click(run_inference, inputs=inp, outputs=[out_text, attn_plot, attn_info])
|
| 346 |
+
|
| 347 |
+
with gr.Tab("Positional Encoding"):
|
| 348 |
+
gr.Markdown(
|
| 349 |
+
"""
|
| 350 |
+
### Positional Encoding οΏ½οΏ½κ°ν
|
| 351 |
+
|
| 352 |
+
λ
Όλ¬Έ Β§3.5μ sin/cos μμΉ μΈμ½λ©μ μ§μ κ·Έλ¦° κ²μ
λλ€.
|
| 353 |
+
κ°λ‘μΆμ΄ μλ² λ© μ°¨μ, μΈλ‘μΆμ΄ μμΉ(μκ° μμ)μμ.
|
| 354 |
+
|
| 355 |
+
μ§μ μ°¨μμ sin, νμ μ°¨μμ cosλ‘ μ±μμ§λ©°, μ°¨μμ΄ ν΄μλ‘ μ£ΌκΈ°κ° κΈΈμ΄μ§λλ€.
|
| 356 |
+
λλΆμ λͺ¨λΈμ΄ **μλ μμΉ**λ₯Ό μ ν λ³νμΌλ‘ ννν μ μκ² λ©λλ€.
|
| 357 |
+
"""
|
| 358 |
+
)
|
| 359 |
+
gr.Plot(value=PE_FIG, label="PE matrix")
|
| 360 |
+
|
| 361 |
+
with gr.Tab("μ΄ λ°λͺ¨μ λν΄"):
|
| 362 |
+
gr.Markdown(
|
| 363 |
+
"""
|
| 364 |
+
### μ "μ«μ λ€μ§κΈ°"μΈκ°μ?
|
| 365 |
+
|
| 366 |
+
λ²μ κ°μ μ§μ§ νμ€ν¬λ κ±°λν λ°μ΄ν°Β·μ°μ°μ μꡬν΄μ λ¬΄λ£ Spaceμμ λΆμ ν©ν©λλ€.
|
| 367 |
+
λμ **μ«μ λ€μ§κΈ°**λ:
|
| 368 |
+
|
| 369 |
+
1. μμ λͺ¨λΈ(8λ§ νλΌλ―Έν°)μ΄ 1~2λΆ λ΄ νμ΅ κ°λ₯
|
| 370 |
+
2. μ
μΆλ ₯μ΄ λͺ
νν΄μ μ λ΅ μ¬λΆλ₯Ό μ¦μ νλ¨ κ°λ₯
|
| 371 |
+
3. **cross-attentionμ΄ anti-diagonal ν¨ν΄**μ κ·Έλ € μκ°ν ν¨κ³Όκ° νΌ
|
| 372 |
+
4. μΈλΆ λ°μ΄ν° λΆνμ (λ°νμ μμ±)
|
| 373 |
+
|
| 374 |
+
### λͺ¨λΈ ꡬ쑰
|
| 375 |
+
|
| 376 |
+
```
|
| 377 |
+
VOCAB(13) β Embedding(64) + PE
|
| 378 |
+
β 2Γ EncoderLayer (h=4, d_ff=128)
|
| 379 |
+
β 2Γ DecoderLayer (h=4, d_ff=128)
|
| 380 |
+
β Linear β 13κ° ν ν° logits
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
λ
Όλ¬Έμ d_model=512, N=6, h=8, d_ff=2048μ΄μ§λ§, μ΄ λ°λͺ¨λ κ·Έ ν¬κΈ°μ 1/8 μμ€μ
λλ€.
|
| 384 |
+
ꡬ쑰λ μμ ν λμΌνκ³ , ν¬κΈ°λ§ μ€μμ΄μ.
|
| 385 |
+
|
| 386 |
+
### μΆλ‘ λ°©μ
|
| 387 |
+
|
| 388 |
+
**Greedy decoding**: λ§€ μμ κ°μ₯ νλ₯ λμ ν ν°μ μ ν. Beam search κ°μ
|
| 389 |
+
κ³ κΈ λμ½λ©μ μλ΅νμ΅λλ€.
|
| 390 |
+
|
| 391 |
+
### νκ³
|
| 392 |
+
|
| 393 |
+
- μλ¦Ώμκ° κΈΈμ΄μ§μλ‘(>10) μ νλ νλ½
|
| 394 |
+
- νμ΅ μ λ³΄μ§ λͺ»ν ν¨ν΄(λ°λ³΅, λ§€μ° κΈ΄ μνμ€)μ μ·¨μ½
|
| 395 |
+
- μ§μ§ NMTκ° μλλ―λ‘ μΌλ° μμ°μ΄λ μ²λ¦¬ λΆκ°
|
| 396 |
+
|
| 397 |
+
### μ°Έκ³
|
| 398 |
+
|
| 399 |
+
- λ
Όλ¬Έ: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
|
| 400 |
+
- The Annotated Transformer: [http://nlp.seas.harvard.edu/annotated-transformer/](http://nlp.seas.harvard.edu/annotated-transformer/)
|
| 401 |
+
"""
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
if __name__ == "__main__":
|
| 406 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
gradio==5.29.0
|
| 3 |
+
matplotlib>=3.7
|
| 4 |
+
numpy>=1.24
|
transformer.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transformer λͺ¨λΈ ꡬν β Vaswani et al. (2017) "Attention Is All You Need"
|
| 3 |
+
|
| 4 |
+
ResNet νλ‘μ νΈμ λμΌν μ² νμΌλ‘, λ
Όλ¬Έμ μ²μλΆν° λκΉμ§ μ¬νν©λλ€.
|
| 5 |
+
μκ°νλ₯Ό μν΄ κ° attention λͺ¨λμ΄ λ§μ§λ§ attention κ°μ€μΉλ₯Ό 보κ΄νλλ‘ νμ΅λλ€.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
# 1) Scaled Dot-Product Attention (λ
Όλ¬Έ Β§3.2.1, μ 1)
|
| 16 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
def scaled_dot_product_attention(Q, K, V, mask=None, return_attn=False):
|
| 18 |
+
"""
|
| 19 |
+
Attention(Q, K, V) = softmax(QKα΅ / βd_k) V
|
| 20 |
+
"""
|
| 21 |
+
d_k = Q.size(-1)
|
| 22 |
+
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
|
| 23 |
+
if mask is not None:
|
| 24 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
| 25 |
+
attn = F.softmax(scores, dim=-1)
|
| 26 |
+
out = attn @ V
|
| 27 |
+
if return_attn:
|
| 28 |
+
return out, attn
|
| 29 |
+
return out
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
# 2) Multi-Head Attention (λ
Όλ¬Έ Β§3.2.2, Figure 2 μ€λ₯Έμͺ½)
|
| 34 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
class MultiHeadAttention(nn.Module):
|
| 36 |
+
def __init__(self, d_model=512, h=8):
|
| 37 |
+
super().__init__()
|
| 38 |
+
assert d_model % h == 0
|
| 39 |
+
self.h = h
|
| 40 |
+
self.d_k = d_model // h
|
| 41 |
+
self.W_q = nn.Linear(d_model, d_model)
|
| 42 |
+
self.W_k = nn.Linear(d_model, d_model)
|
| 43 |
+
self.W_v = nn.Linear(d_model, d_model)
|
| 44 |
+
self.W_o = nn.Linear(d_model, d_model)
|
| 45 |
+
# μκ°νμ©: λ§μ§λ§ forwardμ attention κ°μ€μΉ (B, h, seq_q, seq_k)
|
| 46 |
+
self.last_attn = None
|
| 47 |
+
|
| 48 |
+
def forward(self, Q, K, V, mask=None):
|
| 49 |
+
B = Q.size(0)
|
| 50 |
+
Q = self.W_q(Q).view(B, -1, self.h, self.d_k).transpose(1, 2)
|
| 51 |
+
K = self.W_k(K).view(B, -1, self.h, self.d_k).transpose(1, 2)
|
| 52 |
+
V = self.W_v(V).view(B, -1, self.h, self.d_k).transpose(1, 2)
|
| 53 |
+
|
| 54 |
+
out, attn = scaled_dot_product_attention(Q, K, V, mask, return_attn=True)
|
| 55 |
+
self.last_attn = attn.detach()
|
| 56 |
+
|
| 57 |
+
out = out.transpose(1, 2).contiguous().view(B, -1, self.h * self.d_k)
|
| 58 |
+
return self.W_o(out)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
# 3) Positional Encoding (λ
Όλ¬Έ Β§3.5)
|
| 63 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
class PositionalEncoding(nn.Module):
|
| 65 |
+
def __init__(self, d_model=512, max_len=5000):
|
| 66 |
+
super().__init__()
|
| 67 |
+
pe = torch.zeros(max_len, d_model)
|
| 68 |
+
pos = torch.arange(0, max_len).unsqueeze(1).float()
|
| 69 |
+
div = torch.exp(torch.arange(0, d_model, 2).float() *
|
| 70 |
+
-(math.log(10000.0) / d_model))
|
| 71 |
+
pe[:, 0::2] = torch.sin(pos * div)
|
| 72 |
+
pe[:, 1::2] = torch.cos(pos * div)
|
| 73 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
return x + self.pe[:, :x.size(1)]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
+
# 4) Position-wise Feed-Forward (λ
Όλ¬Έ Β§3.3, μ 2)
|
| 81 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
class FeedForward(nn.Module):
|
| 83 |
+
def __init__(self, d_model=512, d_ff=2048):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.net = nn.Sequential(
|
| 86 |
+
nn.Linear(d_model, d_ff),
|
| 87 |
+
nn.ReLU(),
|
| 88 |
+
nn.Linear(d_ff, d_model),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
return self.net(x)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
# 5) Encoder Layer (λ
Όλ¬Έ Β§3.1)
|
| 97 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
+
class EncoderLayer(nn.Module):
|
| 99 |
+
def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.attn = MultiHeadAttention(d_model, h)
|
| 102 |
+
self.ffn = FeedForward(d_model, d_ff)
|
| 103 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 104 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 105 |
+
self.dropout = nn.Dropout(dropout)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, mask=None):
|
| 108 |
+
x = self.norm1(x + self.dropout(self.attn(x, x, x, mask)))
|
| 109 |
+
x = self.norm2(x + self.dropout(self.ffn(x)))
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
+
# 6) Decoder Layer (λ
Όλ¬Έ Β§3.1)
|
| 115 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
+
class DecoderLayer(nn.Module):
|
| 117 |
+
def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.self_attn = MultiHeadAttention(d_model, h) # masked
|
| 120 |
+
self.cross_attn = MultiHeadAttention(d_model, h) # enc-dec
|
| 121 |
+
self.ffn = FeedForward(d_model, d_ff)
|
| 122 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 123 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 124 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 125 |
+
self.dropout = nn.Dropout(dropout)
|
| 126 |
+
|
| 127 |
+
def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
|
| 128 |
+
x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask)))
|
| 129 |
+
x = self.norm2(x + self.dropout(self.cross_attn(x, enc_out, enc_out, src_mask)))
|
| 130 |
+
x = self.norm3(x + self.dropout(self.ffn(x)))
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 135 |
+
# 7) μ 체 Transformer
|
| 136 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 137 |
+
class Transformer(nn.Module):
|
| 138 |
+
def __init__(self, src_vocab, tgt_vocab,
|
| 139 |
+
d_model=512, N=6, h=8, d_ff=2048, dropout=0.1, max_len=5000):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.d_model = d_model
|
| 142 |
+
self.src_embed = nn.Embedding(src_vocab, d_model)
|
| 143 |
+
self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
|
| 144 |
+
self.pe = PositionalEncoding(d_model, max_len)
|
| 145 |
+
self.encoder = nn.ModuleList([
|
| 146 |
+
EncoderLayer(d_model, h, d_ff, dropout) for _ in range(N)
|
| 147 |
+
])
|
| 148 |
+
self.decoder = nn.ModuleList([
|
| 149 |
+
DecoderLayer(d_model, h, d_ff, dropout) for _ in range(N)
|
| 150 |
+
])
|
| 151 |
+
self.out = nn.Linear(d_model, tgt_vocab)
|
| 152 |
+
|
| 153 |
+
def encode(self, src, src_mask=None):
|
| 154 |
+
e = self.pe(self.src_embed(src) * math.sqrt(self.d_model))
|
| 155 |
+
for layer in self.encoder:
|
| 156 |
+
e = layer(e, src_mask)
|
| 157 |
+
return e
|
| 158 |
+
|
| 159 |
+
def decode(self, tgt, enc_out, src_mask=None, tgt_mask=None):
|
| 160 |
+
d = self.pe(self.tgt_embed(tgt) * math.sqrt(self.d_model))
|
| 161 |
+
for layer in self.decoder:
|
| 162 |
+
d = layer(d, enc_out, src_mask, tgt_mask)
|
| 163 |
+
return d
|
| 164 |
+
|
| 165 |
+
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
|
| 166 |
+
enc_out = self.encode(src, src_mask)
|
| 167 |
+
dec_out = self.decode(tgt, enc_out, src_mask, tgt_mask)
|
| 168 |
+
return self.out(dec_out)
|
| 169 |
+
|
| 170 |
+
# ββ μκ°νμ© ν¬νΌ βββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
+
def get_decoder_cross_attn(self, layer_idx=-1):
|
| 172 |
+
"""λ§μ§λ§ forwardμ λμ½λ cross-attention κ°μ€μΉλ₯Ό λ°ν.
|
| 173 |
+
|
| 174 |
+
Returns: (B, h, tgt_len, src_len)
|
| 175 |
+
"""
|
| 176 |
+
return self.decoder[layer_idx].cross_attn.last_attn
|
| 177 |
+
|
| 178 |
+
def get_encoder_self_attn(self, layer_idx=-1):
|
| 179 |
+
"""λ§μ§λ§ forwardμ μΈμ½λ self-attention κ°μ€μΉλ₯Ό λ°ν.
|
| 180 |
+
|
| 181 |
+
Returns: (B, h, src_len, src_len)
|
| 182 |
+
"""
|
| 183 |
+
return self.encoder[layer_idx].attn.last_attn
|