Transformer / app.py
JangTaeng's picture
Upload 4 files
0465ac4 verified
"""
Transformer 데λͺ¨ β€” 숫자 μ‹œν€€μŠ€ λ’€μ§‘κΈ° (digit reversal)
ꡬ성:
1) μž‘μ€ Transformerλ₯Ό λΆ€νŒ… μ‹œ μ¦‰μ„μ—μ„œ ν•™μŠ΅ (~30초)
2) Gradio UIμ—μ„œ μ‚¬μš©μžκ°€ μž…λ ₯ν•œ μˆ«μžμ—΄μ„ λ’€μ§‘μ–΄ 좜λ ₯
3) λ””μ½”λ”μ˜ cross-attention을 μ‹œκ°ν™” β†’ λ©‹μ§„ anti-diagonal νŒ¨ν„΄
이 νƒœμŠ€ν¬λŠ” λ‹¨μˆœν•˜μ§€λ§Œ Transformerκ°€ μœ„μΉ˜ κ°„ μƒν˜Έμž‘μš©μ„ μ–΄λ–»κ²Œ ν•™μŠ΅ν•˜λŠ”μ§€
κ°€μž₯ μ§κ΄€μ μœΌλ‘œ λ³΄μ—¬μ€λ‹ˆλ‹€.
"""
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import gradio as gr
from transformer import Transformer
# ─────────────────────────────────────────────────────────────
# 토큰 μ •μ˜
# ─────────────────────────────────────────────────────────────
PAD, BOS, EOS = 0, 1, 2
DIGIT_OFFSET = 3 # 숫자 d β†’ 토큰 id (d + 3)
VOCAB = DIGIT_OFFSET + 10 # 0~9 κΉŒμ§€ β†’ 총 13개 토큰
ID2STR = {PAD: "<P>", BOS: "<S>", EOS: "<E>"}
for d in range(10):
ID2STR[d + DIGIT_OFFSET] = str(d)
MAX_INPUT_LEN = 10 # μ‚¬μš©μž μž…λ ₯ 자릿수 μƒν•œ
MAX_DECODE_LEN = MAX_INPUT_LEN + 2
# ─────────────────────────────────────────────────────────────
# λͺ¨λΈ ν•˜μ΄νΌνŒŒλΌλ―Έν„° (μž‘κ²Œ)
# ─────────────────────────────────────────────────────────────
D_MODEL = 64
N_LAYERS = 2
N_HEADS = 4
D_FF = 128
DROPOUT = 0.1
CKPT_PATH = "model.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ─────────────────────────────────────────────────────────────
# 토큰화 μœ ν‹Έ
# ─────────────────────────────────────────────────────────────
def digits_to_ids(digits, add_bos_eos=True):
ids = [d + DIGIT_OFFSET for d in digits]
if add_bos_eos:
ids = [BOS] + ids + [EOS]
return ids
def ids_to_digits(ids, stop_at_eos=True):
out = []
for i in ids:
if stop_at_eos and i == EOS:
break
if i >= DIGIT_OFFSET:
out.append(i - DIGIT_OFFSET)
return out
def parse_user_input(text):
"""μ‚¬μš©μž μž…λ ₯ λ¬Έμžμ—΄μ—μ„œ 숫자 μΆ”μΆœ. 곡백·콀마 λ“± λͺ¨λ‘ ν—ˆμš©."""
digits = []
for ch in text:
if ch.isdigit():
digits.append(int(ch))
return digits[:MAX_INPUT_LEN]
# ─────────────────────────────────────────────────────────────
# 마슀크 λ§Œλ“€κΈ°
# ─────────────────────────────────────────────────────────────
def make_src_mask(src):
# (B, S) β†’ (B, 1, 1, S)
return (src != PAD).unsqueeze(1).unsqueeze(2)
def make_tgt_mask(tgt):
"""νŒ¨λ”© + causal 마슀크 κ²°ν•©"""
B, T = tgt.shape
pad_mask = (tgt != PAD).unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool()
causal = causal.unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
return pad_mask & causal
# ─────────────────────────────────────────────────────────────
# ν•™μŠ΅ 데이터 생성기
# ─────────────────────────────────────────────────────────────
def make_batch(batch_size=128, min_len=3, max_len=8):
src_list, tgt_list = [], []
for _ in range(batch_size):
L = np.random.randint(min_len, max_len + 1)
digits = np.random.randint(0, 10, size=L).tolist()
src_list.append(digits_to_ids(digits))
tgt_list.append(digits_to_ids(digits[::-1]))
s_max = max(len(s) for s in src_list)
t_max = max(len(t) for t in tgt_list)
src = torch.tensor([s + [PAD] * (s_max - len(s)) for s in src_list])
tgt = torch.tensor([t + [PAD] * (t_max - len(t)) for t in tgt_list])
return src, tgt
# ─────────────────────────────────────────────────────────────
# ν•™μŠ΅ 루프
# ─────────────────────────────────────────────────────────────
def train(model, steps=2000, batch_size=128, lr=5e-4, log_every=200):
model.train()
opt = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD, label_smoothing=0.1)
print(f"[train] device={DEVICE}, steps={steps}, batch={batch_size}")
for step in range(1, steps + 1):
src, tgt = make_batch(batch_size, min_len=3, max_len=MAX_INPUT_LEN)
src, tgt = src.to(DEVICE), tgt.to(DEVICE)
tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:]
src_mask = make_src_mask(src)
tgt_mask = make_tgt_mask(tgt_in)
logits = model(src, tgt_in, src_mask, tgt_mask)
loss = loss_fn(logits.reshape(-1, VOCAB), tgt_out.reshape(-1))
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if step % log_every == 0 or step == 1:
with torch.no_grad():
pred = logits.argmax(-1)
mask = (tgt_out != PAD)
acc = ((pred == tgt_out) & mask).sum().item() / mask.sum().item()
print(f" step {step:4d} loss={loss.item():.4f} token_acc={acc:.3f}")
return model
# ─────────────────────────────────────────────────────────────
# μΆ”λ‘  (Greedy decoding)
# ─────────────────────────────────────────────────────────────
@torch.no_grad()
def greedy_decode(model, src_ids):
"""src_ids: list[int] (BOSΒ·EOS 포함)"""
model.eval()
src = torch.tensor([src_ids], device=DEVICE)
src_mask = make_src_mask(src)
enc_out = model.encode(src, src_mask)
ys = torch.tensor([[BOS]], device=DEVICE)
for _ in range(MAX_DECODE_LEN):
tgt_mask = make_tgt_mask(ys)
dec_out = model.decode(ys, enc_out, src_mask, tgt_mask)
logits = model.out(dec_out)
next_tok = logits[:, -1].argmax(-1, keepdim=True)
ys = torch.cat([ys, next_tok], dim=1)
if next_tok.item() == EOS:
break
return ys[0].tolist()
# ─────────────────────────────────────────────────────────────
# μ–΄ν…μ…˜ μ‹œκ°ν™”
# ─────────────────────────────────────────────────────────────
def plot_cross_attention(model, src_ids, tgt_ids, layer_idx=-1):
"""디코더 cross-attention을 ν—€λ“œ ν‰κ· μœΌλ‘œ κ·Έλ¦Ό."""
attn = model.get_decoder_cross_attn(layer_idx) # (1, h, T, S)
if attn is None:
return None
attn_avg = attn.mean(dim=1)[0].cpu().numpy() # (T, S)
src_labels = [ID2STR[i] for i in src_ids]
tgt_labels = [ID2STR[i] for i in tgt_ids]
fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(attn_avg, cmap="viridis", aspect="auto", vmin=0, vmax=attn_avg.max())
ax.set_xticks(range(len(src_labels)))
ax.set_xticklabels(src_labels)
ax.set_yticks(range(len(tgt_labels)))
ax.set_yticklabels(tgt_labels)
ax.set_xlabel("Encoder positions (μž…λ ₯)")
ax.set_ylabel("Decoder positions (좜λ ₯)")
ax.set_title(f"Decoder ↔ Encoder Cross-Attention\n(layer {layer_idx}, heads averaged)")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
return fig
def plot_positional_encoding(model, length=20):
"""positional encoding을 직접 μ‹œκ°ν™”."""
pe = model.pe.pe[0, :length].cpu().numpy() # (L, d_model)
fig, ax = plt.subplots(figsize=(7, 4))
im = ax.imshow(pe, cmap="RdBu", aspect="auto", vmin=-1, vmax=1)
ax.set_xlabel("Embedding dimension")
ax.set_ylabel("Position")
ax.set_title("Positional Encoding (sin/cos pattern)")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
return fig
# ─────────────────────────────────────────────────────────────
# λͺ¨λΈ μ€€λΉ„ (ν•™μŠ΅ λ˜λŠ” λ‘œλ“œ)
# ─────────────────────────────────────────────────────────────
def build_model():
model = Transformer(
src_vocab=VOCAB,
tgt_vocab=VOCAB,
d_model=D_MODEL,
N=N_LAYERS,
h=N_HEADS,
d_ff=D_FF,
dropout=DROPOUT,
max_len=64,
).to(DEVICE)
return model
def load_or_train():
model = build_model()
if os.path.exists(CKPT_PATH):
print(f"[init] loading checkpoint: {CKPT_PATH}")
state = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(state)
else:
print("[init] no checkpoint found β†’ training from scratch")
train(model, steps=2000, batch_size=128, lr=5e-4)
torch.save(model.state_dict(), CKPT_PATH)
print(f"[init] saved checkpoint: {CKPT_PATH}")
model.eval()
return model
print("=" * 60)
print("Transformer Demo β€” initializing")
print("=" * 60)
MODEL = load_or_train()
PE_FIG = plot_positional_encoding(MODEL, length=20)
print("[init] ready βœ”")
# ─────────────────────────────────────────────────────────────
# Gradio 콜백
# ─────────────────────────────────────────────────────────────
def run_inference(user_text):
digits = parse_user_input(user_text)
if len(digits) == 0:
return "❗ 숫자λ₯Ό μž…λ ₯ν•΄ μ£Όμ„Έμš”.", None, "(μ–΄ν…μ…˜ μ—†μŒ)"
if len(digits) > MAX_INPUT_LEN:
digits = digits[:MAX_INPUT_LEN]
src_ids = digits_to_ids(digits)
out_ids = greedy_decode(MODEL, src_ids)
pred_digits = ids_to_digits(out_ids[1:]) # BOS μ œμ™Έ, EOSκΉŒμ§€
expected = digits[::-1]
correct = pred_digits == expected
pred_str = " ".join(str(d) for d in pred_digits) if pred_digits else "(빈 좜λ ₯)"
expected_str = " ".join(str(d) for d in expected)
input_str = " ".join(str(d) for d in digits)
msg = (
f"**μž…λ ₯** : {input_str}\n\n"
f"**예츑 좜λ ₯** : {pred_str}\n\n"
f"**μ •λ‹΅** : {expected_str}\n\n"
f"**일치 μ—¬λΆ€** : {'βœ… μ •λ‹΅!' if correct else '❌ μ˜€λ‹΅'}"
)
# μ‹œκ°ν™”λ₯Ό μœ„ν•΄ λ‹€μ‹œ forward (cross-attention κ°±μ‹ μš©)
src = torch.tensor([src_ids], device=DEVICE)
tgt = torch.tensor([out_ids], device=DEVICE)
src_mask = make_src_mask(src)
tgt_mask = make_tgt_mask(tgt)
with torch.no_grad():
MODEL(src, tgt, src_mask, tgt_mask)
fig = plot_cross_attention(MODEL, src_ids, out_ids, layer_idx=-1)
info = (
"이 νžˆνŠΈλ§΅μ€ λ””μ½”λ”μ˜ λ§ˆμ§€λ§‰ μΈ΅ cross-attention κ°€μ€‘μΉ˜(ν—€λ“œ 평균)μž…λ‹ˆλ‹€.\n"
"각 ν–‰(좜λ ₯ μœ„μΉ˜)이 μ–΄λ–€ μž…λ ₯ μœ„μΉ˜λ₯Ό κ°€μž₯ 많이 보고 μžˆλŠ”μ§€ λ‚˜νƒ€λƒ…λ‹ˆλ‹€.\n"
"λ’€μ§‘κΈ° νƒœμŠ€ν¬μ—μ„œλŠ” **λ°˜λŒ€κ°μ„ (anti-diagonal)** νŒ¨ν„΄μ΄ 보이면 ν•™μŠ΅ 성곡!"
)
return msg, fig, info
# ─────────────────────────────────────────────────────────────
# UI ꡬ성
# ─────────────────────────────────────────────────────────────
with gr.Blocks(title="Transformer Demo β€” Digit Reversal") as demo:
gr.Markdown(
"""
# πŸ€– Transformer 데λͺ¨ β€” 숫자 μ‹œν€€μŠ€ λ’€μ§‘κΈ°
Vaswani et al. (2017) **"Attention Is All You Need"** 논문을 μ²˜μŒλΆ€ν„° μž¬ν˜„ν•œ
Transformer둜 μž…λ ₯ μˆ«μžμ—΄μ„ λ’€μ§‘μ–΄ λ΄…λ‹ˆλ‹€.
- λͺ¨λΈ: d_model=64, N=2μΈ΅, h=4ν—€λ“œ (총 ~80K νŒŒλΌλ―Έν„°)
- ν•™μŠ΅ 데이터: 길이 3~10의 λ¬΄μž‘μœ„ μˆ«μžμ—΄, λ§€ step μƒˆλ‘œ 생성
- ν•™μŠ΅ μ‹œκ°„: λΆ€νŒ… μ‹œ ~30초 (CPU κΈ°μ€€)
ν•˜λ‹¨ ν—€λ„λ§΅μ—μ„œ **λ°˜λŒ€κ°μ„  νŒ¨ν„΄**이 보인닀면, λͺ¨λΈμ΄ "좜λ ₯ i번째 = μž…λ ₯의
λ°˜λŒ€νŽΈ μœ„μΉ˜"λ₯Ό ν•™μŠ΅ν–ˆλ‹€λŠ” μ¦κ±°μ˜ˆμš”.
"""
)
with gr.Tab("λ’€μ§‘κΈ°"):
with gr.Row():
with gr.Column(scale=1):
inp = gr.Textbox(
label="μˆ«μžμ—΄ μž…λ ₯ (μ΅œλŒ€ 10자리)",
placeholder="예: 1 2 3 4 5 λ˜λŠ” 12345",
value="1 2 3 4 5 6 7",
)
btn = gr.Button("λ’€μ§‘κΈ° μ‹€ν–‰", variant="primary")
out_text = gr.Markdown()
gr.Examples(
examples=[
["1 2 3"],
["1 2 3 4 5"],
["9 8 7 6 5 4 3"],
["1 1 2 2 3 3"],
["0 1 2 3 4 5 6 7 8 9"],
],
inputs=inp,
)
with gr.Column(scale=2):
attn_plot = gr.Plot(label="Cross-Attention Heatmap")
attn_info = gr.Markdown()
btn.click(run_inference, inputs=inp, outputs=[out_text, attn_plot, attn_info])
with gr.Tab("Positional Encoding"):
gr.Markdown(
"""
### Positional Encoding μ‹œκ°ν™”
λ…Όλ¬Έ Β§3.5의 sin/cos μœ„μΉ˜ 인코딩을 직접 κ·Έλ¦° κ²ƒμž…λ‹ˆλ‹€.
κ°€λ‘œμΆ•μ΄ μž„λ² λ”© 차원, μ„Έλ‘œμΆ•μ΄ μœ„μΉ˜(μ‹œκ°„ μˆœμ„œ)μ˜ˆμš”.
짝수 차원은 sin, ν™€μˆ˜ 차원은 cos둜 μ±„μ›Œμ§€λ©°, 차원이 클수둝 μ£ΌκΈ°κ°€ κΈΈμ–΄μ§‘λ‹ˆλ‹€.
덕뢄에 λͺ¨λΈμ΄ **μƒλŒ€ μœ„μΉ˜**λ₯Ό μ„ ν˜• λ³€ν™˜μœΌλ‘œ ν‘œν˜„ν•  수 있게 λ©λ‹ˆλ‹€.
"""
)
gr.Plot(value=PE_FIG, label="PE matrix")
with gr.Tab("이 데λͺ¨μ— λŒ€ν•΄"):
gr.Markdown(
"""
### μ™œ "숫자 λ’€μ§‘κΈ°"μΈκ°€μš”?
λ²ˆμ—­ 같은 μ§„μ§œ νƒœμŠ€ν¬λŠ” κ±°λŒ€ν•œ 데이터·연산을 μš”κ΅¬ν•΄μ„œ 무료 Space에선 λΆ€μ ν•©ν•©λ‹ˆλ‹€.
λŒ€μ‹  **숫자 λ’€μ§‘κΈ°**λŠ”:
1. μž‘μ€ λͺ¨λΈ(8만 νŒŒλΌλ―Έν„°)이 1~2λΆ„ λ‚΄ ν•™μŠ΅ κ°€λŠ₯
2. μž…μΆœλ ₯이 λͺ…ν™•ν•΄μ„œ μ •λ‹΅ μ—¬λΆ€λ₯Ό μ¦‰μ‹œ νŒλ‹¨ κ°€λŠ₯
3. **cross-attention이 anti-diagonal νŒ¨ν„΄**을 κ·Έλ € μ‹œκ°ν™” νš¨κ³Όκ°€ 큼
4. μ™ΈλΆ€ 데이터 λΆˆν•„μš” (λŸ°νƒ€μž„ 생성)
### λͺ¨λΈ ꡬ쑰
```
VOCAB(13) β†’ Embedding(64) + PE
β†’ 2Γ— EncoderLayer (h=4, d_ff=128)
β†’ 2Γ— DecoderLayer (h=4, d_ff=128)
β†’ Linear β†’ 13개 토큰 logits
```
논문은 d_model=512, N=6, h=8, d_ff=2048μ΄μ§€λ§Œ, 이 데λͺ¨λŠ” κ·Έ 크기의 1/8 μˆ˜μ€€μž…λ‹ˆλ‹€.
κ΅¬μ‘°λŠ” μ™„μ „νžˆ λ™μΌν•˜κ³ , 크기만 μ€„μ˜€μ–΄μš”.
### μΆ”λ‘  방식
**Greedy decoding**: λ§€ μ‹œμ  κ°€μž₯ ν™•λ₯  높은 토큰을 선택. Beam search 같은
κ³ κΈ‰ 디코딩은 μƒλž΅ν–ˆμŠ΅λ‹ˆλ‹€.
### ν•œκ³„
- μžλ¦Ώμˆ˜κ°€ κΈΈμ–΄μ§ˆμˆ˜λ‘(>10) 정확도 ν•˜λ½
- ν•™μŠ΅ μ‹œ 보지 λͺ»ν•œ νŒ¨ν„΄(반볡, 맀우 κΈ΄ μ‹œν€€μŠ€)에 μ·¨μ•½
- μ§„μ§œ NMTκ°€ μ•„λ‹ˆλ―€λ‘œ 일반 μžμ—°μ–΄λŠ” 처리 λΆˆκ°€
### μ°Έκ³ 
- λ…Όλ¬Έ: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- The Annotated Transformer: [http://nlp.seas.harvard.edu/annotated-transformer/](http://nlp.seas.harvard.edu/annotated-transformer/)
"""
)
if __name__ == "__main__":
demo.launch()