JangTaeng commited on
Commit
0465ac4
Β·
verified Β·
1 Parent(s): c5f0ba9

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +188 -6
  2. app.py +406 -0
  3. requirements.txt +4 -0
  4. transformer.py +183 -0
README.md CHANGED
@@ -1,13 +1,195 @@
1
  ---
2
- title: Transformer
3
- emoji: 🏒
4
  colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Transformer Demo
3
+ emoji: πŸ€–
4
  colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.29.0
8
+ python_version: "3.10"
9
  app_file: app.py
10
  pinned: false
11
+ license: mit
12
  ---
13
 
14
+ # Transformer β€” λ…Όλ¬Έ μž¬ν˜„ 데λͺ¨
15
+
16
+ **λ…Όλ¬Έ**: [Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani et al., NIPS 2017)
17
+
18
+ > RNNκ³Ό CNN을 λͺ¨λ‘ 버리고 **였직 attention만으둜** 인코더-디코더λ₯Ό κ΅¬μ„±ν•œ
19
+ > Transformer 논문을 μ²˜μŒλΆ€ν„° μž¬ν˜„ν•˜κ³ , ν•™μŠ΅λœ λͺ¨λΈμ„ 직접 μ²΄ν—˜ν•  수 μžˆλŠ” Spaceμž…λ‹ˆλ‹€.
20
+
21
+ ---
22
+
23
+ ## 무엇을 ν•  수 μžˆλ‚˜μš”?
24
+
25
+ 숫자 μ‹œν€€μŠ€λ₯Ό μž…λ ₯ν•˜λ©΄ **Transformerκ°€ λ’€μ§‘μ–΄** μ€λ‹ˆλ‹€.
26
+
27
+ ```
28
+ μž…λ ₯ : 1 2 3 4 5
29
+ 좜λ ₯ : 5 4 3 2 1
30
+ ```
31
+
32
+ 그리고 더 ν₯미둜운 건 β€” λ””μ½”λ”μ˜ **cross-attention κ°€μ€‘μΉ˜**λ₯Ό μ‹œκ°ν™”ν•΄μ„œ
33
+ λͺ¨λΈμ΄ "좜λ ₯ i번째 μœ„μΉ˜λ₯Ό λ§Œλ“€ λ•Œ μž…λ ₯ μ–΄λ””λ₯Ό λ΄€λŠ”μ§€"λ₯Ό 직접 λ³Ό 수 μžˆλ‹€λŠ” κ±°μ˜ˆμš”.
34
+ λ’€μ§‘κΈ° νƒœμŠ€ν¬μ—μ„œλŠ” **λ°˜λŒ€κ°μ„ (anti-diagonal) νŒ¨ν„΄**이 또렷이 λ‚˜νƒ€λ‚©λ‹ˆλ‹€.
35
+
36
+ ---
37
+
38
+ ## μ™œ λ²ˆμ—­μ΄ μ•„λ‹ˆλΌ 숫자 λ’€μ§‘κΈ°μΈκ°€μš”?
39
+
40
+ 논문은 μ˜μ–΄β†’λ…μΌμ–΄ λ²ˆμ—­μœΌλ‘œ κ²€μ¦ν–ˆμ§€λ§Œ, 그건 8Γ— P100 GPU둜 12μ‹œκ°„ ν•™μŠ΅μ΄ ν•„μš”ν•΄μš”.
41
+ 무료 Spaceμ—μ„œ 그게 μ•ˆ λ˜λ‹ˆκΉŒ, **λΆ€νŒ… μ‹œ 30초 μ•ˆμ— ν•™μŠ΅ λλ‚˜λŠ” toy task**λ₯Ό κ³¨λžμŠ΅λ‹ˆλ‹€.
42
+
43
+ 숫자 λ’€μ§‘κΈ°μ˜ μž₯점:
44
+
45
+ - μ–΄νœ˜κ°€ μž‘μŒ (0~9 + 특수 토큰 = 13개)
46
+ - μž…μΆœλ ₯ 길이가 κ°™κ³  정닡이 λͺ…ν™•
47
+ - **μž₯거리 μ˜μ‘΄μ„±**을 κ°•μ œ β€” 좜λ ₯ 1λ²ˆμ§ΈλŠ” μž…λ ₯ λ§ˆμ§€λ§‰μ„ 봐야 함
48
+ - μ‹œκ°ν™”κ°€ 극적 (λ°˜λŒ€κ°μ„  νŒ¨ν„΄)
49
+
50
+ ---
51
+
52
+ ## ν”„λ‘œμ νŠΈ ꡬ쑰
53
+
54
+ ```
55
+ β”œβ”€β”€ app.py # Gradio 데λͺ¨ (ν•™μŠ΅ + μΆ”λ‘  + μ‹œκ°ν™”)
56
+ β”œβ”€β”€ transformer.py # 논문을 κ·ΈλŒ€λ‘œ μž¬ν˜„ν•œ Transformer 본체
57
+ β”œβ”€β”€ requirements.txt # νŒ¨ν‚€μ§€ λͺ©λ‘
58
+ └── README.md # 이 파일
59
+ ```
60
+
61
+ ---
62
+
63
+ ## λͺ¨λΈ ꡬ성
64
+
65
+ 이 데λͺ¨λŠ” λ…Όλ¬Έ base λͺ¨λΈμ˜ **1/8 크기**μž…λ‹ˆλ‹€. κ΅¬μ‘°λŠ” μ™„μ „νžˆ λ™μΌν•˜κ³  크기만 μ€„μ˜€μ–΄μš”.
66
+
67
+ | ν•­λͺ© | λ…Όλ¬Έ base | 이 데λͺ¨ |
68
+ |------|-----------|---------|
69
+ | d_model | 512 | **64** |
70
+ | 측 수 N | 6 | **2** |
71
+ | ν—€λ“œ 수 h | 8 | **4** |
72
+ | d_ff | 2048 | **128** |
73
+ | μ–΄νœ˜ 크기 | 37K (BPE) | **13** |
74
+ | νŒŒλΌλ―Έν„° | 65M | **~80K** |
75
+
76
+ ---
77
+
78
+ ## ν•™μŠ΅ μ„€μ •
79
+
80
+ ```python
81
+ optimizer = Adam(lr=5e-4, betas=(0.9, 0.98), eps=1e-9) # λ…Όλ¬Έ Β§5.3
82
+ loss = CrossEntropy(ignore_index=PAD, label_smoothing=0.1)
83
+ steps = 2000
84
+ batch = 128
85
+ ```
86
+
87
+ - λ§€ stepλ§ˆλ‹€ 길이 3~10의 λ¬΄μž‘μœ„ μˆ«μžμ—΄μ„ μƒˆλ‘œ 생성 (λ©”λͺ¨λ¦¬ μ ˆμ•½)
88
+ - Gradient clipping = 1.0
89
+ - Greedy decoding으둜 μΆ”λ‘ 
90
+
91
+ ν•™μŠ΅μ€ λΆ€νŒ…ν•  λ•Œ μžλ™μœΌλ‘œ μ§„ν–‰λ˜λ©°, λλ‚œ λͺ¨λΈμ€ `model.pt`둜 μΊμ‹±λ©λ‹ˆλ‹€.
92
+
93
+ ---
94
+
95
+ ## λ…Όλ¬Έ 핡심 λΆ€λΆ„ μ½”λ“œ λ§€ν•‘
96
+
97
+ | λ…Όλ¬Έ μœ„μΉ˜ | μ½”λ“œ μœ„μΉ˜ |
98
+ |-----------|-----------|
99
+ | 식 (1) `softmax(QKα΅€/√d_k)V` | `transformer.py :: scaled_dot_product_attention` |
100
+ | Β§3.2.2 Multi-Head | `MultiHeadAttention` |
101
+ | Β§3.5 Positional Encoding | `PositionalEncoding` |
102
+ | 식 (2) FFN | `FeedForward` |
103
+ | Β§3.1 인코더 1μΈ΅ | `EncoderLayer` (Post-LN) |
104
+ | Β§3.1 디코더 1μΈ΅ | `DecoderLayer` (Post-LN) |
105
+ | Β§3.4 μž„λ² λ”© Γ— √d_model | `Transformer.encode` λ‚΄λΆ€ |
106
+
107
+ ---
108
+
109
+ ## μ–΄λ–»κ²Œ 봐야 ν•˜λ‚˜μš”? (μ‹œκ°ν™” 해석)
110
+
111
+ **Cross-Attention 히트맡**:
112
+
113
+ - κ°€λ‘œμΆ•: 인코더 μœ„μΉ˜ (μž…λ ₯ 토큰듀, μ™Όμͺ½μ΄ μ‹œν€€μŠ€ μ•žμͺ½)
114
+ - μ„Έλ‘œμΆ•: 디코더 μœ„μΉ˜ (좜λ ₯ 토큰듀, μœ„μͺ½μ΄ λ¨Όμ € 생성)
115
+ - 색이 λ°μ„μˆ˜λ‘ κ°•ν•œ attention
116
+
117
+ λ’€μ§‘κΈ° νƒœμŠ€ν¬μ—μ„œ 잘 ν•™μŠ΅λœ λͺ¨λΈμ€:
118
+
119
+ ```
120
+ 좜λ ₯ μœ„μΉ˜ 0 (BOS λ‹€μŒ, 첫 좜λ ₯ 토큰) β†’ μž…λ ₯ λ§ˆμ§€λ§‰ 토큰을 λ΄„
121
+ 좜λ ₯ μœ„μΉ˜ 1 β†’ μž…λ ₯ λμ—μ„œ 두 번째λ₯Ό λ΄„
122
+ ...
123
+ ```
124
+
125
+ λ”°λΌμ„œ **μ™Όμͺ½ μœ„ β†’ 였λ₯Έμͺ½ μ•„λž˜ λŒ€κ°μ„ **의 λ°˜λŒ€ λ°©ν–₯, 즉
126
+ **였λ₯Έμͺ½ μœ„ β†’ μ™Όμͺ½ μ•„λž˜λ‘œ 흐λ₯΄λŠ” anti-diagonal**이 보이면 μ„±κ³΅μž…λ‹ˆλ‹€.
127
+
128
+ ---
129
+
130
+ ## Hugging Face Spaces 배포 μ‹œ μ£Όμ˜μ‚¬ν•­
131
+
132
+ ResNet 데λͺ¨λ₯Ό 배포할 λ•Œ κ²ͺμ—ˆλ˜ λ¬Έμ œλ“€μ΄ μ—¬κΈ°μ„œλ„ λ™μΌν•˜κ²Œ λ°œμƒν•  수 μžˆμ–΄μš”:
133
+
134
+ ### 1. YAML ν”„λ‘ νŠΈλ§€ν„° ν•„μˆ˜
135
+
136
+ 이 README.md μ΅œμƒλ‹¨μ˜ `--- ... ---` 블둝이 μ—†μœΌλ©΄ Spaceκ°€ λΉŒλ“œλ˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.
137
+
138
+ ### 2. `colorFrom`/`colorTo`λŠ” μ •ν•΄μ§„ 8μƒ‰λ§Œ
139
+
140
+ ν—ˆμš©λ˜λŠ” 색: `red, yellow, green, blue, indigo, purple, pink, gray`
141
+
142
+ ### 3. Python 3.13 νšŒν”Ό
143
+
144
+ `audioop` ν‘œμ€€ λΌμ΄λΈŒλŸ¬λ¦¬κ°€ 3.13μ—μ„œ μ œκ±°λ˜μ–΄ 일뢀 νŒ¨ν‚€μ§€ λΉŒλ“œ μ‹€νŒ¨. **3.10** ꢌμž₯.
145
+
146
+ ### 4. PyTorch CPU λΉŒλ“œ
147
+
148
+ 기본적으둜 무료 SpaceλŠ” CPU만 μ œκ³΅λ©λ‹ˆλ‹€. `torch` μ„€μΉ˜ μ‹œ CUDA 버전이 λ“€μ–΄κ°€λ©΄
149
+ λ””μŠ€ν¬ μš©λŸ‰μ„ μ΄ˆκ³Όν•  수 μžˆμœΌλ‹ˆ ν•„μš”μ‹œ `torch --index-url https://download.pytorch.org/whl/cpu`둜
150
+ λͺ…μ‹œν•˜μ„Έμš”.
151
+
152
+ ---
153
+
154
+ ## 둜컬 μ‹€ν–‰
155
+
156
+ ```bash
157
+ # 1) μ˜μ‘΄μ„± μ„€μΉ˜
158
+ pip install -r requirements.txt
159
+
160
+ # 2) 데λͺ¨ μ‹€ν–‰ (첫 μ‹€ν–‰ μ‹œ μžλ™ ν•™μŠ΅)
161
+ python app.py
162
+ ```
163
+
164
+ 기본적으둜 `http://127.0.0.1:7860` μ—μ„œ μ—΄λ¦½λ‹ˆλ‹€.
165
+
166
+ ---
167
+
168
+ ## ν•™μŠ΅μ΄ 잘 μ•ˆ 되면
169
+
170
+ 체크리슀트:
171
+
172
+ - [ ] PyTorch 버전이 2.0 이상인가
173
+ - [ ] ν•™μŠ΅ step이 2000번 이상 λ„λŠ”κ°€ (μ½˜μ†”μ— step 200, 400, ... 둜그 확인)
174
+ - [ ] step 1000μ―€ 되면 `token_acc`κ°€ 0.95 이상인가
175
+ - [ ] 좜λ ₯이 항상 같은 ν† ν°λ§Œ λ°˜λ³΅ν•œλ‹€λ©΄ β†’ ν•™μŠ΅μ΄ 거의 μ•ˆ 된 것. step λŠ˜λ¦¬κ±°λ‚˜ lr μ‘°μ •
176
+ - [ ] cross-attention이 균일(uniform)ν•˜λ‹€λ©΄ β†’ 더 ν•™μŠ΅ ν•„μš”
177
+
178
+ ---
179
+
180
+ ## μ°Έκ³ 
181
+
182
+ ```bibtex
183
+ @inproceedings{vaswani2017attention,
184
+ title = {Attention Is All You Need},
185
+ author = {Vaswani, Ashish and Shazeer, Noam and Parmar, Niki
186
+ and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N
187
+ and Kaiser, {\L}ukasz and Polosukhin, Illia},
188
+ booktitle = {Advances in Neural Information Processing Systems},
189
+ year = {2017}
190
+ }
191
+ ```
192
+
193
+ - πŸ“„ λ…Όλ¬Έ: [arXiv:1706.03762](https://arxiv.org/abs/1706.03762)
194
+ - πŸ“ The Annotated Transformer: <http://nlp.seas.harvard.edu/annotated-transformer/>
195
+ - πŸŽ₯ The Illustrated Transformer (Jay Alammar): <https://jalammar.github.io/illustrated-transformer/>
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer 데λͺ¨ β€” 숫자 μ‹œν€€μŠ€ λ’€μ§‘κΈ° (digit reversal)
3
+
4
+ ꡬ성:
5
+ 1) μž‘μ€ Transformerλ₯Ό λΆ€νŒ… μ‹œ μ¦‰μ„μ—μ„œ ν•™μŠ΅ (~30초)
6
+ 2) Gradio UIμ—μ„œ μ‚¬μš©μžκ°€ μž…λ ₯ν•œ μˆ«μžμ—΄μ„ λ’€μ§‘μ–΄ 좜λ ₯
7
+ 3) λ””μ½”λ”μ˜ cross-attention을 μ‹œκ°ν™” β†’ λ©‹μ§„ anti-diagonal νŒ¨ν„΄
8
+
9
+ 이 νƒœμŠ€ν¬λŠ” λ‹¨μˆœν•˜μ§€λ§Œ Transformerκ°€ μœ„μΉ˜ κ°„ μƒν˜Έμž‘μš©μ„ μ–΄λ–»κ²Œ ν•™μŠ΅ν•˜λŠ”μ§€
10
+ κ°€μž₯ μ§κ΄€μ μœΌλ‘œ λ³΄μ—¬μ€λ‹ˆλ‹€.
11
+ """
12
+
13
+ import os
14
+ import math
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.optim as optim
19
+ import matplotlib
20
+ matplotlib.use("Agg")
21
+ import matplotlib.pyplot as plt
22
+ import gradio as gr
23
+
24
+ from transformer import Transformer
25
+
26
+
27
+ # ─────────────────────────────────────────────────────────────
28
+ # 토큰 μ •μ˜
29
+ # ─────────────────────────────────────────────────────────────
30
+ PAD, BOS, EOS = 0, 1, 2
31
+ DIGIT_OFFSET = 3 # 숫자 d β†’ 토큰 id (d + 3)
32
+ VOCAB = DIGIT_OFFSET + 10 # 0~9 κΉŒμ§€ β†’ 총 13개 토큰
33
+ ID2STR = {PAD: "<P>", BOS: "<S>", EOS: "<E>"}
34
+ for d in range(10):
35
+ ID2STR[d + DIGIT_OFFSET] = str(d)
36
+
37
+ MAX_INPUT_LEN = 10 # μ‚¬μš©μž μž…λ ₯ 자릿수 μƒν•œ
38
+ MAX_DECODE_LEN = MAX_INPUT_LEN + 2
39
+
40
+ # ─────────────────────────────────────────────────────────────
41
+ # λͺ¨λΈ ν•˜μ΄νΌνŒŒλΌλ―Έν„° (μž‘κ²Œ)
42
+ # ─────────────────────────────────────────────────────────────
43
+ D_MODEL = 64
44
+ N_LAYERS = 2
45
+ N_HEADS = 4
46
+ D_FF = 128
47
+ DROPOUT = 0.1
48
+
49
+ CKPT_PATH = "model.pt"
50
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+
52
+
53
+ # ─────────────────────────────────────────────────────────────
54
+ # 토큰화 μœ ν‹Έ
55
+ # ─────────────────────────────────────────────────────────────
56
+ def digits_to_ids(digits, add_bos_eos=True):
57
+ ids = [d + DIGIT_OFFSET for d in digits]
58
+ if add_bos_eos:
59
+ ids = [BOS] + ids + [EOS]
60
+ return ids
61
+
62
+
63
+ def ids_to_digits(ids, stop_at_eos=True):
64
+ out = []
65
+ for i in ids:
66
+ if stop_at_eos and i == EOS:
67
+ break
68
+ if i >= DIGIT_OFFSET:
69
+ out.append(i - DIGIT_OFFSET)
70
+ return out
71
+
72
+
73
+ def parse_user_input(text):
74
+ """μ‚¬μš©μž μž…λ ₯ λ¬Έμžμ—΄μ—μ„œ 숫자 μΆ”μΆœ. 곡백·콀마 λ“± λͺ¨λ‘ ν—ˆμš©."""
75
+ digits = []
76
+ for ch in text:
77
+ if ch.isdigit():
78
+ digits.append(int(ch))
79
+ return digits[:MAX_INPUT_LEN]
80
+
81
+
82
+ # ─────────────────────────────────────────────────────────────
83
+ # 마슀크 λ§Œλ“€κΈ°
84
+ # ─────────────────────────────────────────────────────────────
85
+ def make_src_mask(src):
86
+ # (B, S) β†’ (B, 1, 1, S)
87
+ return (src != PAD).unsqueeze(1).unsqueeze(2)
88
+
89
+
90
+ def make_tgt_mask(tgt):
91
+ """νŒ¨λ”© + causal 마슀크 κ²°ν•©"""
92
+ B, T = tgt.shape
93
+ pad_mask = (tgt != PAD).unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
94
+ causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool()
95
+ causal = causal.unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
96
+ return pad_mask & causal
97
+
98
+
99
+ # ─────────────────────────────────────────────────────────────
100
+ # ν•™μŠ΅ 데이터 생성기
101
+ # ─────────────────────────────────────────────────────────────
102
+ def make_batch(batch_size=128, min_len=3, max_len=8):
103
+ src_list, tgt_list = [], []
104
+ for _ in range(batch_size):
105
+ L = np.random.randint(min_len, max_len + 1)
106
+ digits = np.random.randint(0, 10, size=L).tolist()
107
+ src_list.append(digits_to_ids(digits))
108
+ tgt_list.append(digits_to_ids(digits[::-1]))
109
+
110
+ s_max = max(len(s) for s in src_list)
111
+ t_max = max(len(t) for t in tgt_list)
112
+ src = torch.tensor([s + [PAD] * (s_max - len(s)) for s in src_list])
113
+ tgt = torch.tensor([t + [PAD] * (t_max - len(t)) for t in tgt_list])
114
+ return src, tgt
115
+
116
+
117
+ # ────────────────────────────────────────���────────────────────
118
+ # ν•™μŠ΅ 루프
119
+ # ─────────────────────────────────────────────────────────────
120
+ def train(model, steps=2000, batch_size=128, lr=5e-4, log_every=200):
121
+ model.train()
122
+ opt = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)
123
+ loss_fn = nn.CrossEntropyLoss(ignore_index=PAD, label_smoothing=0.1)
124
+
125
+ print(f"[train] device={DEVICE}, steps={steps}, batch={batch_size}")
126
+ for step in range(1, steps + 1):
127
+ src, tgt = make_batch(batch_size, min_len=3, max_len=MAX_INPUT_LEN)
128
+ src, tgt = src.to(DEVICE), tgt.to(DEVICE)
129
+ tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:]
130
+
131
+ src_mask = make_src_mask(src)
132
+ tgt_mask = make_tgt_mask(tgt_in)
133
+
134
+ logits = model(src, tgt_in, src_mask, tgt_mask)
135
+ loss = loss_fn(logits.reshape(-1, VOCAB), tgt_out.reshape(-1))
136
+
137
+ opt.zero_grad()
138
+ loss.backward()
139
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
140
+ opt.step()
141
+
142
+ if step % log_every == 0 or step == 1:
143
+ with torch.no_grad():
144
+ pred = logits.argmax(-1)
145
+ mask = (tgt_out != PAD)
146
+ acc = ((pred == tgt_out) & mask).sum().item() / mask.sum().item()
147
+ print(f" step {step:4d} loss={loss.item():.4f} token_acc={acc:.3f}")
148
+ return model
149
+
150
+
151
+ # ─────────────────────────────────────────────────────────────
152
+ # μΆ”λ‘  (Greedy decoding)
153
+ # ─────────────────────────────────────────────────────────────
154
+ @torch.no_grad()
155
+ def greedy_decode(model, src_ids):
156
+ """src_ids: list[int] (BOSΒ·EOS 포함)"""
157
+ model.eval()
158
+ src = torch.tensor([src_ids], device=DEVICE)
159
+ src_mask = make_src_mask(src)
160
+ enc_out = model.encode(src, src_mask)
161
+
162
+ ys = torch.tensor([[BOS]], device=DEVICE)
163
+ for _ in range(MAX_DECODE_LEN):
164
+ tgt_mask = make_tgt_mask(ys)
165
+ dec_out = model.decode(ys, enc_out, src_mask, tgt_mask)
166
+ logits = model.out(dec_out)
167
+ next_tok = logits[:, -1].argmax(-1, keepdim=True)
168
+ ys = torch.cat([ys, next_tok], dim=1)
169
+ if next_tok.item() == EOS:
170
+ break
171
+ return ys[0].tolist()
172
+
173
+
174
+ # ─────────────────────────────────────────────────────────────
175
+ # μ–΄ν…μ…˜ μ‹œκ°ν™”
176
+ # ─────────────────────────────────────────────────────────────
177
+ def plot_cross_attention(model, src_ids, tgt_ids, layer_idx=-1):
178
+ """디코더 cross-attention을 ν—€λ“œ ν‰κ· μœΌλ‘œ κ·Έλ¦Ό."""
179
+ attn = model.get_decoder_cross_attn(layer_idx) # (1, h, T, S)
180
+ if attn is None:
181
+ return None
182
+ attn_avg = attn.mean(dim=1)[0].cpu().numpy() # (T, S)
183
+
184
+ src_labels = [ID2STR[i] for i in src_ids]
185
+ tgt_labels = [ID2STR[i] for i in tgt_ids]
186
+
187
+ fig, ax = plt.subplots(figsize=(7, 6))
188
+ im = ax.imshow(attn_avg, cmap="viridis", aspect="auto", vmin=0, vmax=attn_avg.max())
189
+ ax.set_xticks(range(len(src_labels)))
190
+ ax.set_xticklabels(src_labels)
191
+ ax.set_yticks(range(len(tgt_labels)))
192
+ ax.set_yticklabels(tgt_labels)
193
+ ax.set_xlabel("Encoder positions (μž…λ ₯)")
194
+ ax.set_ylabel("Decoder positions (좜λ ₯)")
195
+ ax.set_title(f"Decoder ↔ Encoder Cross-Attention\n(layer {layer_idx}, heads averaged)")
196
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
197
+ plt.tight_layout()
198
+ return fig
199
+
200
+
201
+ def plot_positional_encoding(model, length=20):
202
+ """positional encoding을 직접 μ‹œκ°ν™”."""
203
+ pe = model.pe.pe[0, :length].cpu().numpy() # (L, d_model)
204
+ fig, ax = plt.subplots(figsize=(7, 4))
205
+ im = ax.imshow(pe, cmap="RdBu", aspect="auto", vmin=-1, vmax=1)
206
+ ax.set_xlabel("Embedding dimension")
207
+ ax.set_ylabel("Position")
208
+ ax.set_title("Positional Encoding (sin/cos pattern)")
209
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
210
+ plt.tight_layout()
211
+ return fig
212
+
213
+
214
+ # ─────────────────────────────────────────────────────────────
215
+ # λͺ¨λΈ μ€€λΉ„ (ν•™μŠ΅ λ˜λŠ” λ‘œλ“œ)
216
+ # ─────────────────────────────────────────────────────────────
217
+ def build_model():
218
+ model = Transformer(
219
+ src_vocab=VOCAB,
220
+ tgt_vocab=VOCAB,
221
+ d_model=D_MODEL,
222
+ N=N_LAYERS,
223
+ h=N_HEADS,
224
+ d_ff=D_FF,
225
+ dropout=DROPOUT,
226
+ max_len=64,
227
+ ).to(DEVICE)
228
+ return model
229
+
230
+
231
+ def load_or_train():
232
+ model = build_model()
233
+ if os.path.exists(CKPT_PATH):
234
+ print(f"[init] loading checkpoint: {CKPT_PATH}")
235
+ state = torch.load(CKPT_PATH, map_location=DEVICE)
236
+ model.load_state_dict(state)
237
+ else:
238
+ print("[init] no checkpoint found β†’ training from scratch")
239
+ train(model, steps=2000, batch_size=128, lr=5e-4)
240
+ torch.save(model.state_dict(), CKPT_PATH)
241
+ print(f"[init] saved checkpoint: {CKPT_PATH}")
242
+ model.eval()
243
+ return model
244
+
245
+
246
+ print("=" * 60)
247
+ print("Transformer Demo β€” initializing")
248
+ print("=" * 60)
249
+ MODEL = load_or_train()
250
+ PE_FIG = plot_positional_encoding(MODEL, length=20)
251
+ print("[init] ready βœ”")
252
+
253
+
254
+ # ─────────────────────────────────────────────────────────────
255
+ # Gradio 콜백
256
+ # ─────────────────────────────────────────────────────────────
257
+ def run_inference(user_text):
258
+ digits = parse_user_input(user_text)
259
+ if len(digits) == 0:
260
+ return "❗ 숫자λ₯Ό μž…λ ₯ν•΄ μ£Όμ„Έμš”.", None, "(μ–΄ν…μ…˜ μ—†μŒ)"
261
+ if len(digits) > MAX_INPUT_LEN:
262
+ digits = digits[:MAX_INPUT_LEN]
263
+
264
+ src_ids = digits_to_ids(digits)
265
+ out_ids = greedy_decode(MODEL, src_ids)
266
+ pred_digits = ids_to_digits(out_ids[1:]) # BOS μ œμ™Έ, EOSκΉŒμ§€
267
+
268
+ expected = digits[::-1]
269
+ correct = pred_digits == expected
270
+
271
+ pred_str = " ".join(str(d) for d in pred_digits) if pred_digits else "(빈 좜λ ₯)"
272
+ expected_str = " ".join(str(d) for d in expected)
273
+ input_str = " ".join(str(d) for d in digits)
274
+
275
+ msg = (
276
+ f"**μž…λ ₯** : {input_str}\n\n"
277
+ f"**예츑 좜λ ₯** : {pred_str}\n\n"
278
+ f"**μ •λ‹΅** : {expected_str}\n\n"
279
+ f"**일치 μ—¬λΆ€** : {'βœ… μ •λ‹΅!' if correct else '❌ μ˜€λ‹΅'}"
280
+ )
281
+
282
+ # μ‹œκ°ν™”λ₯Ό μœ„ν•΄ λ‹€μ‹œ forward (cross-attention κ°±μ‹ μš©)
283
+ src = torch.tensor([src_ids], device=DEVICE)
284
+ tgt = torch.tensor([out_ids], device=DEVICE)
285
+ src_mask = make_src_mask(src)
286
+ tgt_mask = make_tgt_mask(tgt)
287
+ with torch.no_grad():
288
+ MODEL(src, tgt, src_mask, tgt_mask)
289
+
290
+ fig = plot_cross_attention(MODEL, src_ids, out_ids, layer_idx=-1)
291
+ info = (
292
+ "이 νžˆνŠΈλ§΅μ€ λ””μ½”λ”μ˜ λ§ˆμ§€λ§‰ μΈ΅ cross-attention κ°€μ€‘μΉ˜(ν—€λ“œ 평균)μž…λ‹ˆλ‹€.\n"
293
+ "각 ν–‰(좜λ ₯ μœ„μΉ˜)이 μ–΄λ–€ μž…λ ₯ μœ„μΉ˜λ₯Ό κ°€μž₯ 많이 보고 μžˆλŠ”μ§€ λ‚˜νƒ€λƒ…λ‹ˆλ‹€.\n"
294
+ "λ’€μ§‘κΈ° νƒœμŠ€ν¬μ—μ„œλŠ” **λ°˜λŒ€κ°μ„ (anti-diagonal)** νŒ¨ν„΄μ΄ 보이면 ν•™μŠ΅ 성곡!"
295
+ )
296
+ return msg, fig, info
297
+
298
+
299
+ # ─────────────────────────────────────────────────────────────
300
+ # UI ꡬ성
301
+ # ─────────────────────────────────────────────────────────────
302
+ with gr.Blocks(title="Transformer Demo β€” Digit Reversal") as demo:
303
+ gr.Markdown(
304
+ """
305
+ # πŸ€– Transformer 데λͺ¨ β€” 숫자 μ‹œν€€μŠ€ λ’€μ§‘κΈ°
306
+
307
+ Vaswani et al. (2017) **"Attention Is All You Need"** 논문을 μ²˜μŒλΆ€ν„° μž¬ν˜„ν•œ
308
+ Transformer둜 μž…λ ₯ μˆ«μžμ—΄μ„ λ’€μ§‘μ–΄ λ΄…λ‹ˆλ‹€.
309
+
310
+ - λͺ¨λΈ: d_model=64, N=2μΈ΅, h=4ν—€λ“œ (총 ~80K νŒŒλΌλ―Έν„°)
311
+ - ν•™μŠ΅ 데이터: 길이 3~10의 λ¬΄μž‘μœ„ μˆ«μžμ—΄, λ§€ step μƒˆλ‘œ 생성
312
+ - ν•™μŠ΅ μ‹œκ°„: λΆ€νŒ… μ‹œ ~30초 (CPU κΈ°μ€€)
313
+
314
+ ν•˜λ‹¨ ν—€λ„λ§΅μ—μ„œ **λ°˜λŒ€κ°μ„  νŒ¨ν„΄**이 보인닀면, λͺ¨λΈμ΄ "좜λ ₯ i번째 = μž…λ ₯의
315
+ λ°˜λŒ€νŽΈ μœ„μΉ˜"λ₯Ό ν•™μŠ΅ν–ˆλ‹€λŠ” μ¦κ±°μ˜ˆμš”.
316
+ """
317
+ )
318
+
319
+ with gr.Tab("λ’€μ§‘κΈ°"):
320
+ with gr.Row():
321
+ with gr.Column(scale=1):
322
+ inp = gr.Textbox(
323
+ label="μˆ«μžμ—΄ μž…λ ₯ (μ΅œλŒ€ 10자리)",
324
+ placeholder="예: 1 2 3 4 5 λ˜λŠ” 12345",
325
+ value="1 2 3 4 5 6 7",
326
+ )
327
+ btn = gr.Button("λ’€μ§‘κΈ° μ‹€ν–‰", variant="primary")
328
+ out_text = gr.Markdown()
329
+
330
+ gr.Examples(
331
+ examples=[
332
+ ["1 2 3"],
333
+ ["1 2 3 4 5"],
334
+ ["9 8 7 6 5 4 3"],
335
+ ["1 1 2 2 3 3"],
336
+ ["0 1 2 3 4 5 6 7 8 9"],
337
+ ],
338
+ inputs=inp,
339
+ )
340
+
341
+ with gr.Column(scale=2):
342
+ attn_plot = gr.Plot(label="Cross-Attention Heatmap")
343
+ attn_info = gr.Markdown()
344
+
345
+ btn.click(run_inference, inputs=inp, outputs=[out_text, attn_plot, attn_info])
346
+
347
+ with gr.Tab("Positional Encoding"):
348
+ gr.Markdown(
349
+ """
350
+ ### Positional Encoding ��각화
351
+
352
+ λ…Όλ¬Έ Β§3.5의 sin/cos μœ„μΉ˜ 인코딩을 직접 κ·Έλ¦° κ²ƒμž…λ‹ˆλ‹€.
353
+ κ°€λ‘œμΆ•μ΄ μž„λ² λ”© 차원, μ„Έλ‘œμΆ•μ΄ μœ„μΉ˜(μ‹œκ°„ μˆœμ„œ)μ˜ˆμš”.
354
+
355
+ 짝수 차원은 sin, ν™€μˆ˜ 차원은 cos둜 μ±„μ›Œμ§€λ©°, 차원이 클수둝 μ£ΌκΈ°κ°€ κΈΈμ–΄μ§‘λ‹ˆλ‹€.
356
+ 덕뢄에 λͺ¨λΈμ΄ **μƒλŒ€ μœ„μΉ˜**λ₯Ό μ„ ν˜• λ³€ν™˜μœΌλ‘œ ν‘œν˜„ν•  수 있게 λ©λ‹ˆλ‹€.
357
+ """
358
+ )
359
+ gr.Plot(value=PE_FIG, label="PE matrix")
360
+
361
+ with gr.Tab("이 데λͺ¨μ— λŒ€ν•΄"):
362
+ gr.Markdown(
363
+ """
364
+ ### μ™œ "숫자 λ’€μ§‘κΈ°"μΈκ°€μš”?
365
+
366
+ λ²ˆμ—­ 같은 μ§„μ§œ νƒœμŠ€ν¬λŠ” κ±°λŒ€ν•œ 데이터·연산을 μš”κ΅¬ν•΄μ„œ 무료 Space에선 λΆ€μ ν•©ν•©λ‹ˆλ‹€.
367
+ λŒ€μ‹  **숫자 λ’€μ§‘κΈ°**λŠ”:
368
+
369
+ 1. μž‘μ€ λͺ¨λΈ(8만 νŒŒλΌλ―Έν„°)이 1~2λΆ„ λ‚΄ ν•™μŠ΅ κ°€λŠ₯
370
+ 2. μž…μΆœλ ₯이 λͺ…ν™•ν•΄μ„œ μ •λ‹΅ μ—¬λΆ€λ₯Ό μ¦‰μ‹œ νŒλ‹¨ κ°€λŠ₯
371
+ 3. **cross-attention이 anti-diagonal νŒ¨ν„΄**을 κ·Έλ € μ‹œκ°ν™” νš¨κ³Όκ°€ 큼
372
+ 4. μ™ΈλΆ€ 데이터 λΆˆν•„μš” (λŸ°νƒ€μž„ 생성)
373
+
374
+ ### λͺ¨λΈ ꡬ쑰
375
+
376
+ ```
377
+ VOCAB(13) β†’ Embedding(64) + PE
378
+ β†’ 2Γ— EncoderLayer (h=4, d_ff=128)
379
+ β†’ 2Γ— DecoderLayer (h=4, d_ff=128)
380
+ β†’ Linear β†’ 13개 토큰 logits
381
+ ```
382
+
383
+ 논문은 d_model=512, N=6, h=8, d_ff=2048μ΄μ§€λ§Œ, 이 데λͺ¨λŠ” κ·Έ 크기의 1/8 μˆ˜μ€€μž…λ‹ˆλ‹€.
384
+ κ΅¬μ‘°λŠ” μ™„μ „νžˆ λ™μΌν•˜κ³ , 크기만 μ€„μ˜€μ–΄μš”.
385
+
386
+ ### μΆ”λ‘  방식
387
+
388
+ **Greedy decoding**: λ§€ μ‹œμ  κ°€μž₯ ν™•λ₯  높은 토큰을 선택. Beam search 같은
389
+ κ³ κΈ‰ 디코딩은 μƒλž΅ν–ˆμŠ΅λ‹ˆλ‹€.
390
+
391
+ ### ν•œκ³„
392
+
393
+ - μžλ¦Ώμˆ˜κ°€ κΈΈμ–΄μ§ˆμˆ˜λ‘(>10) 정확도 ν•˜λ½
394
+ - ν•™μŠ΅ μ‹œ 보지 λͺ»ν•œ νŒ¨ν„΄(반볡, 맀우 κΈ΄ μ‹œν€€μŠ€)에 μ·¨μ•½
395
+ - μ§„μ§œ NMTκ°€ μ•„λ‹ˆλ―€λ‘œ 일반 μžμ—°μ–΄λŠ” 처리 λΆˆκ°€
396
+
397
+ ### μ°Έκ³ 
398
+
399
+ - λ…Όλ¬Έ: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
400
+ - The Annotated Transformer: [http://nlp.seas.harvard.edu/annotated-transformer/](http://nlp.seas.harvard.edu/annotated-transformer/)
401
+ """
402
+ )
403
+
404
+
405
+ if __name__ == "__main__":
406
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0
2
+ gradio==5.29.0
3
+ matplotlib>=3.7
4
+ numpy>=1.24
transformer.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer λͺ¨λΈ κ΅¬ν˜„ β€” Vaswani et al. (2017) "Attention Is All You Need"
3
+
4
+ ResNet ν”„λ‘œμ νŠΈμ™€ λ™μΌν•œ μ² ν•™μœΌλ‘œ, 논문을 μ²˜μŒλΆ€ν„° λκΉŒμ§€ μž¬ν˜„ν•©λ‹ˆλ‹€.
5
+ μ‹œκ°ν™”λ₯Ό μœ„ν•΄ 각 attention λͺ¨λ“ˆμ΄ λ§ˆμ§€λ§‰ attention κ°€μ€‘μΉ˜λ₯Ό λ³΄κ΄€ν•˜λ„λ‘ ν–ˆμŠ΅λ‹ˆλ‹€.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ # ─────────────────────────────────────────────────────────────
15
+ # 1) Scaled Dot-Product Attention (λ…Όλ¬Έ Β§3.2.1, 식 1)
16
+ # ─────────────────────────────────────────────────────────────
17
+ def scaled_dot_product_attention(Q, K, V, mask=None, return_attn=False):
18
+ """
19
+ Attention(Q, K, V) = softmax(QKα΅€ / √d_k) V
20
+ """
21
+ d_k = Q.size(-1)
22
+ scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
23
+ if mask is not None:
24
+ scores = scores.masked_fill(mask == 0, -1e9)
25
+ attn = F.softmax(scores, dim=-1)
26
+ out = attn @ V
27
+ if return_attn:
28
+ return out, attn
29
+ return out
30
+
31
+
32
+ # ─────────────────────────────────────────────────────────────
33
+ # 2) Multi-Head Attention (λ…Όλ¬Έ Β§3.2.2, Figure 2 였λ₯Έμͺ½)
34
+ # ─────────────────────────────────────────────────────────────
35
+ class MultiHeadAttention(nn.Module):
36
+ def __init__(self, d_model=512, h=8):
37
+ super().__init__()
38
+ assert d_model % h == 0
39
+ self.h = h
40
+ self.d_k = d_model // h
41
+ self.W_q = nn.Linear(d_model, d_model)
42
+ self.W_k = nn.Linear(d_model, d_model)
43
+ self.W_v = nn.Linear(d_model, d_model)
44
+ self.W_o = nn.Linear(d_model, d_model)
45
+ # μ‹œκ°ν™”μš©: λ§ˆμ§€λ§‰ forward의 attention κ°€μ€‘μΉ˜ (B, h, seq_q, seq_k)
46
+ self.last_attn = None
47
+
48
+ def forward(self, Q, K, V, mask=None):
49
+ B = Q.size(0)
50
+ Q = self.W_q(Q).view(B, -1, self.h, self.d_k).transpose(1, 2)
51
+ K = self.W_k(K).view(B, -1, self.h, self.d_k).transpose(1, 2)
52
+ V = self.W_v(V).view(B, -1, self.h, self.d_k).transpose(1, 2)
53
+
54
+ out, attn = scaled_dot_product_attention(Q, K, V, mask, return_attn=True)
55
+ self.last_attn = attn.detach()
56
+
57
+ out = out.transpose(1, 2).contiguous().view(B, -1, self.h * self.d_k)
58
+ return self.W_o(out)
59
+
60
+
61
+ # ─────────────────────────────────────────────────────────────
62
+ # 3) Positional Encoding (λ…Όλ¬Έ Β§3.5)
63
+ # ─────────────────────────────────────────────────────────────
64
+ class PositionalEncoding(nn.Module):
65
+ def __init__(self, d_model=512, max_len=5000):
66
+ super().__init__()
67
+ pe = torch.zeros(max_len, d_model)
68
+ pos = torch.arange(0, max_len).unsqueeze(1).float()
69
+ div = torch.exp(torch.arange(0, d_model, 2).float() *
70
+ -(math.log(10000.0) / d_model))
71
+ pe[:, 0::2] = torch.sin(pos * div)
72
+ pe[:, 1::2] = torch.cos(pos * div)
73
+ self.register_buffer("pe", pe.unsqueeze(0))
74
+
75
+ def forward(self, x):
76
+ return x + self.pe[:, :x.size(1)]
77
+
78
+
79
+ # ─────────────────────────────────────────────────────────────
80
+ # 4) Position-wise Feed-Forward (λ…Όλ¬Έ Β§3.3, 식 2)
81
+ # ─────────────────────────────────────────────────────────────
82
+ class FeedForward(nn.Module):
83
+ def __init__(self, d_model=512, d_ff=2048):
84
+ super().__init__()
85
+ self.net = nn.Sequential(
86
+ nn.Linear(d_model, d_ff),
87
+ nn.ReLU(),
88
+ nn.Linear(d_ff, d_model),
89
+ )
90
+
91
+ def forward(self, x):
92
+ return self.net(x)
93
+
94
+
95
+ # ─────────────────────────────────────────────────────────────
96
+ # 5) Encoder Layer (λ…Όλ¬Έ Β§3.1)
97
+ # ─────────────────────────────────────────────────────────────
98
+ class EncoderLayer(nn.Module):
99
+ def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1):
100
+ super().__init__()
101
+ self.attn = MultiHeadAttention(d_model, h)
102
+ self.ffn = FeedForward(d_model, d_ff)
103
+ self.norm1 = nn.LayerNorm(d_model)
104
+ self.norm2 = nn.LayerNorm(d_model)
105
+ self.dropout = nn.Dropout(dropout)
106
+
107
+ def forward(self, x, mask=None):
108
+ x = self.norm1(x + self.dropout(self.attn(x, x, x, mask)))
109
+ x = self.norm2(x + self.dropout(self.ffn(x)))
110
+ return x
111
+
112
+
113
+ # ─────────────────────────────────────────────────────────────
114
+ # 6) Decoder Layer (λ…Όλ¬Έ Β§3.1)
115
+ # ─────────────────────────────────────────────────────────────
116
+ class DecoderLayer(nn.Module):
117
+ def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1):
118
+ super().__init__()
119
+ self.self_attn = MultiHeadAttention(d_model, h) # masked
120
+ self.cross_attn = MultiHeadAttention(d_model, h) # enc-dec
121
+ self.ffn = FeedForward(d_model, d_ff)
122
+ self.norm1 = nn.LayerNorm(d_model)
123
+ self.norm2 = nn.LayerNorm(d_model)
124
+ self.norm3 = nn.LayerNorm(d_model)
125
+ self.dropout = nn.Dropout(dropout)
126
+
127
+ def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
128
+ x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask)))
129
+ x = self.norm2(x + self.dropout(self.cross_attn(x, enc_out, enc_out, src_mask)))
130
+ x = self.norm3(x + self.dropout(self.ffn(x)))
131
+ return x
132
+
133
+
134
+ # ─────────────────────────────────────────────────────────────
135
+ # 7) 전체 Transformer
136
+ # ─────────────────────────────────────────────────────────────
137
+ class Transformer(nn.Module):
138
+ def __init__(self, src_vocab, tgt_vocab,
139
+ d_model=512, N=6, h=8, d_ff=2048, dropout=0.1, max_len=5000):
140
+ super().__init__()
141
+ self.d_model = d_model
142
+ self.src_embed = nn.Embedding(src_vocab, d_model)
143
+ self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
144
+ self.pe = PositionalEncoding(d_model, max_len)
145
+ self.encoder = nn.ModuleList([
146
+ EncoderLayer(d_model, h, d_ff, dropout) for _ in range(N)
147
+ ])
148
+ self.decoder = nn.ModuleList([
149
+ DecoderLayer(d_model, h, d_ff, dropout) for _ in range(N)
150
+ ])
151
+ self.out = nn.Linear(d_model, tgt_vocab)
152
+
153
+ def encode(self, src, src_mask=None):
154
+ e = self.pe(self.src_embed(src) * math.sqrt(self.d_model))
155
+ for layer in self.encoder:
156
+ e = layer(e, src_mask)
157
+ return e
158
+
159
+ def decode(self, tgt, enc_out, src_mask=None, tgt_mask=None):
160
+ d = self.pe(self.tgt_embed(tgt) * math.sqrt(self.d_model))
161
+ for layer in self.decoder:
162
+ d = layer(d, enc_out, src_mask, tgt_mask)
163
+ return d
164
+
165
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None):
166
+ enc_out = self.encode(src, src_mask)
167
+ dec_out = self.decode(tgt, enc_out, src_mask, tgt_mask)
168
+ return self.out(dec_out)
169
+
170
+ # ── μ‹œκ°ν™”μš© 헬퍼 ─────────────────────────────────────────
171
+ def get_decoder_cross_attn(self, layer_idx=-1):
172
+ """λ§ˆμ§€λ§‰ forward의 디코더 cross-attention κ°€μ€‘μΉ˜λ₯Ό λ°˜ν™˜.
173
+
174
+ Returns: (B, h, tgt_len, src_len)
175
+ """
176
+ return self.decoder[layer_idx].cross_attn.last_attn
177
+
178
+ def get_encoder_self_attn(self, layer_idx=-1):
179
+ """λ§ˆμ§€λ§‰ forward의 인코더 self-attention κ°€μ€‘μΉ˜λ₯Ό λ°˜ν™˜.
180
+
181
+ Returns: (B, h, src_len, src_len)
182
+ """
183
+ return self.encoder[layer_idx].attn.last_attn