Spaces:
Sleeping
Sleeping
| """ | |
| Transformer λ°λͺ¨ β μ«μ μνμ€ λ€μ§κΈ° (digit reversal) | |
| ꡬμ±: | |
| 1) μμ Transformerλ₯Ό λΆν μ μ¦μμμ νμ΅ (~30μ΄) | |
| 2) Gradio UIμμ μ¬μ©μκ° μ λ ₯ν μ«μμ΄μ λ€μ§μ΄ μΆλ ₯ | |
| 3) λμ½λμ cross-attentionμ μκ°ν β λ©μ§ anti-diagonal ν¨ν΄ | |
| μ΄ νμ€ν¬λ λ¨μνμ§λ§ Transformerκ° μμΉ κ° μνΈμμ©μ μ΄λ»κ² νμ΅νλμ§ | |
| κ°μ₯ μ§κ΄μ μΌλ‘ 보μ¬μ€λλ€. | |
| """ | |
| import os | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from transformer import Transformer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ν ν° μ μ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PAD, BOS, EOS = 0, 1, 2 | |
| DIGIT_OFFSET = 3 # μ«μ d β ν ν° id (d + 3) | |
| VOCAB = DIGIT_OFFSET + 10 # 0~9 κΉμ§ β μ΄ 13κ° ν ν° | |
| ID2STR = {PAD: "<P>", BOS: "<S>", EOS: "<E>"} | |
| for d in range(10): | |
| ID2STR[d + DIGIT_OFFSET] = str(d) | |
| MAX_INPUT_LEN = 10 # μ¬μ©μ μ λ ₯ μλ¦Ώμ μν | |
| MAX_DECODE_LEN = MAX_INPUT_LEN + 2 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # λͺ¨λΈ νμ΄νΌνλΌλ―Έν° (μκ²) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| D_MODEL = 64 | |
| N_LAYERS = 2 | |
| N_HEADS = 4 | |
| D_FF = 128 | |
| DROPOUT = 0.1 | |
| CKPT_PATH = "model.pt" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ν ν°ν μ νΈ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def digits_to_ids(digits, add_bos_eos=True): | |
| ids = [d + DIGIT_OFFSET for d in digits] | |
| if add_bos_eos: | |
| ids = [BOS] + ids + [EOS] | |
| return ids | |
| def ids_to_digits(ids, stop_at_eos=True): | |
| out = [] | |
| for i in ids: | |
| if stop_at_eos and i == EOS: | |
| break | |
| if i >= DIGIT_OFFSET: | |
| out.append(i - DIGIT_OFFSET) | |
| return out | |
| def parse_user_input(text): | |
| """μ¬μ©μ μ λ ₯ λ¬Έμμ΄μμ μ«μ μΆμΆ. κ³΅λ°±Β·μ½€λ§ λ± λͺ¨λ νμ©.""" | |
| digits = [] | |
| for ch in text: | |
| if ch.isdigit(): | |
| digits.append(int(ch)) | |
| return digits[:MAX_INPUT_LEN] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # λ§μ€ν¬ λ§λ€κΈ° | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_src_mask(src): | |
| # (B, S) β (B, 1, 1, S) | |
| return (src != PAD).unsqueeze(1).unsqueeze(2) | |
| def make_tgt_mask(tgt): | |
| """ν¨λ© + causal λ§μ€ν¬ κ²°ν©""" | |
| B, T = tgt.shape | |
| pad_mask = (tgt != PAD).unsqueeze(1).unsqueeze(2) # (B, 1, 1, T) | |
| causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool() | |
| causal = causal.unsqueeze(0).unsqueeze(0) # (1, 1, T, T) | |
| return pad_mask & causal | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # νμ΅ λ°μ΄ν° μμ±κΈ° | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_batch(batch_size=128, min_len=3, max_len=8): | |
| src_list, tgt_list = [], [] | |
| for _ in range(batch_size): | |
| L = np.random.randint(min_len, max_len + 1) | |
| digits = np.random.randint(0, 10, size=L).tolist() | |
| src_list.append(digits_to_ids(digits)) | |
| tgt_list.append(digits_to_ids(digits[::-1])) | |
| s_max = max(len(s) for s in src_list) | |
| t_max = max(len(t) for t in tgt_list) | |
| src = torch.tensor([s + [PAD] * (s_max - len(s)) for s in src_list]) | |
| tgt = torch.tensor([t + [PAD] * (t_max - len(t)) for t in tgt_list]) | |
| return src, tgt | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # νμ΅ λ£¨ν | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train(model, steps=2000, batch_size=128, lr=5e-4, log_every=200): | |
| model.train() | |
| opt = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9) | |
| loss_fn = nn.CrossEntropyLoss(ignore_index=PAD, label_smoothing=0.1) | |
| print(f"[train] device={DEVICE}, steps={steps}, batch={batch_size}") | |
| for step in range(1, steps + 1): | |
| src, tgt = make_batch(batch_size, min_len=3, max_len=MAX_INPUT_LEN) | |
| src, tgt = src.to(DEVICE), tgt.to(DEVICE) | |
| tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:] | |
| src_mask = make_src_mask(src) | |
| tgt_mask = make_tgt_mask(tgt_in) | |
| logits = model(src, tgt_in, src_mask, tgt_mask) | |
| loss = loss_fn(logits.reshape(-1, VOCAB), tgt_out.reshape(-1)) | |
| opt.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| opt.step() | |
| if step % log_every == 0 or step == 1: | |
| with torch.no_grad(): | |
| pred = logits.argmax(-1) | |
| mask = (tgt_out != PAD) | |
| acc = ((pred == tgt_out) & mask).sum().item() / mask.sum().item() | |
| print(f" step {step:4d} loss={loss.item():.4f} token_acc={acc:.3f}") | |
| return model | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # μΆλ‘ (Greedy decoding) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def greedy_decode(model, src_ids): | |
| """src_ids: list[int] (BOSΒ·EOS ν¬ν¨)""" | |
| model.eval() | |
| src = torch.tensor([src_ids], device=DEVICE) | |
| src_mask = make_src_mask(src) | |
| enc_out = model.encode(src, src_mask) | |
| ys = torch.tensor([[BOS]], device=DEVICE) | |
| for _ in range(MAX_DECODE_LEN): | |
| tgt_mask = make_tgt_mask(ys) | |
| dec_out = model.decode(ys, enc_out, src_mask, tgt_mask) | |
| logits = model.out(dec_out) | |
| next_tok = logits[:, -1].argmax(-1, keepdim=True) | |
| ys = torch.cat([ys, next_tok], dim=1) | |
| if next_tok.item() == EOS: | |
| break | |
| return ys[0].tolist() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # μ΄ν μ μκ°ν | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_cross_attention(model, src_ids, tgt_ids, layer_idx=-1): | |
| """λμ½λ cross-attentionμ ν€λ νκ· μΌλ‘ κ·Έλ¦Ό.""" | |
| attn = model.get_decoder_cross_attn(layer_idx) # (1, h, T, S) | |
| if attn is None: | |
| return None | |
| attn_avg = attn.mean(dim=1)[0].cpu().numpy() # (T, S) | |
| src_labels = [ID2STR[i] for i in src_ids] | |
| tgt_labels = [ID2STR[i] for i in tgt_ids] | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| im = ax.imshow(attn_avg, cmap="viridis", aspect="auto", vmin=0, vmax=attn_avg.max()) | |
| ax.set_xticks(range(len(src_labels))) | |
| ax.set_xticklabels(src_labels) | |
| ax.set_yticks(range(len(tgt_labels))) | |
| ax.set_yticklabels(tgt_labels) | |
| ax.set_xlabel("Encoder positions (μ λ ₯)") | |
| ax.set_ylabel("Decoder positions (μΆλ ₯)") | |
| ax.set_title(f"Decoder β Encoder Cross-Attention\n(layer {layer_idx}, heads averaged)") | |
| plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| plt.tight_layout() | |
| return fig | |
| def plot_positional_encoding(model, length=20): | |
| """positional encodingμ μ§μ μκ°ν.""" | |
| pe = model.pe.pe[0, :length].cpu().numpy() # (L, d_model) | |
| fig, ax = plt.subplots(figsize=(7, 4)) | |
| im = ax.imshow(pe, cmap="RdBu", aspect="auto", vmin=-1, vmax=1) | |
| ax.set_xlabel("Embedding dimension") | |
| ax.set_ylabel("Position") | |
| ax.set_title("Positional Encoding (sin/cos pattern)") | |
| plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| plt.tight_layout() | |
| return fig | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # λͺ¨λΈ μ€λΉ (νμ΅ λλ λ‘λ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_model(): | |
| model = Transformer( | |
| src_vocab=VOCAB, | |
| tgt_vocab=VOCAB, | |
| d_model=D_MODEL, | |
| N=N_LAYERS, | |
| h=N_HEADS, | |
| d_ff=D_FF, | |
| dropout=DROPOUT, | |
| max_len=64, | |
| ).to(DEVICE) | |
| return model | |
| def load_or_train(): | |
| model = build_model() | |
| if os.path.exists(CKPT_PATH): | |
| print(f"[init] loading checkpoint: {CKPT_PATH}") | |
| state = torch.load(CKPT_PATH, map_location=DEVICE) | |
| model.load_state_dict(state) | |
| else: | |
| print("[init] no checkpoint found β training from scratch") | |
| train(model, steps=2000, batch_size=128, lr=5e-4) | |
| torch.save(model.state_dict(), CKPT_PATH) | |
| print(f"[init] saved checkpoint: {CKPT_PATH}") | |
| model.eval() | |
| return model | |
| print("=" * 60) | |
| print("Transformer Demo β initializing") | |
| print("=" * 60) | |
| MODEL = load_or_train() | |
| PE_FIG = plot_positional_encoding(MODEL, length=20) | |
| print("[init] ready β") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Gradio μ½λ°± | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_inference(user_text): | |
| digits = parse_user_input(user_text) | |
| if len(digits) == 0: | |
| return "β μ«μλ₯Ό μ λ ₯ν΄ μ£ΌμΈμ.", None, "(μ΄ν μ μμ)" | |
| if len(digits) > MAX_INPUT_LEN: | |
| digits = digits[:MAX_INPUT_LEN] | |
| src_ids = digits_to_ids(digits) | |
| out_ids = greedy_decode(MODEL, src_ids) | |
| pred_digits = ids_to_digits(out_ids[1:]) # BOS μ μΈ, EOSκΉμ§ | |
| expected = digits[::-1] | |
| correct = pred_digits == expected | |
| pred_str = " ".join(str(d) for d in pred_digits) if pred_digits else "(λΉ μΆλ ₯)" | |
| expected_str = " ".join(str(d) for d in expected) | |
| input_str = " ".join(str(d) for d in digits) | |
| msg = ( | |
| f"**μ λ ₯** : {input_str}\n\n" | |
| f"**μμΈ‘ μΆλ ₯** : {pred_str}\n\n" | |
| f"**μ λ΅** : {expected_str}\n\n" | |
| f"**μΌμΉ μ¬λΆ** : {'β μ λ΅!' if correct else 'β μ€λ΅'}" | |
| ) | |
| # μκ°νλ₯Ό μν΄ λ€μ forward (cross-attention κ°±μ μ©) | |
| src = torch.tensor([src_ids], device=DEVICE) | |
| tgt = torch.tensor([out_ids], device=DEVICE) | |
| src_mask = make_src_mask(src) | |
| tgt_mask = make_tgt_mask(tgt) | |
| with torch.no_grad(): | |
| MODEL(src, tgt, src_mask, tgt_mask) | |
| fig = plot_cross_attention(MODEL, src_ids, out_ids, layer_idx=-1) | |
| info = ( | |
| "μ΄ ννΈλ§΅μ λμ½λμ λ§μ§λ§ μΈ΅ cross-attention κ°μ€μΉ(ν€λ νκ· )μ λλ€.\n" | |
| "κ° ν(μΆλ ₯ μμΉ)μ΄ μ΄λ€ μ λ ₯ μμΉλ₯Ό κ°μ₯ λ§μ΄ λ³΄κ³ μλμ§ λνλ λλ€.\n" | |
| "λ€μ§κΈ° νμ€ν¬μμλ **λ°λκ°μ (anti-diagonal)** ν¨ν΄μ΄ 보μ΄λ©΄ νμ΅ μ±κ³΅!" | |
| ) | |
| return msg, fig, info | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UI κ΅¬μ± | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Transformer Demo β Digit Reversal") as demo: | |
| gr.Markdown( | |
| """ | |
| # π€ Transformer λ°λͺ¨ β μ«μ μνμ€ λ€μ§κΈ° | |
| Vaswani et al. (2017) **"Attention Is All You Need"** λ Όλ¬Έμ μ²μλΆν° μ¬νν | |
| Transformerλ‘ μ λ ₯ μ«μμ΄μ λ€μ§μ΄ λ΄ λλ€. | |
| - λͺ¨λΈ: d_model=64, N=2μΈ΅, h=4ν€λ (μ΄ ~80K νλΌλ―Έν°) | |
| - νμ΅ λ°μ΄ν°: κΈΈμ΄ 3~10μ 무μμ μ«μμ΄, λ§€ step μλ‘ μμ± | |
| - νμ΅ μκ°: λΆν μ ~30μ΄ (CPU κΈ°μ€) | |
| νλ¨ ν€λλ§΅μμ **λ°λκ°μ ν¨ν΄**μ΄ λ³΄μΈλ€λ©΄, λͺ¨λΈμ΄ "μΆλ ₯ iλ²μ§Έ = μ λ ₯μ | |
| λ°λνΈ μμΉ"λ₯Ό νμ΅νλ€λ μ¦κ±°μμ. | |
| """ | |
| ) | |
| with gr.Tab("λ€μ§κΈ°"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp = gr.Textbox( | |
| label="μ«μμ΄ μ λ ₯ (μ΅λ 10μ리)", | |
| placeholder="μ: 1 2 3 4 5 λλ 12345", | |
| value="1 2 3 4 5 6 7", | |
| ) | |
| btn = gr.Button("λ€μ§κΈ° μ€ν", variant="primary") | |
| out_text = gr.Markdown() | |
| gr.Examples( | |
| examples=[ | |
| ["1 2 3"], | |
| ["1 2 3 4 5"], | |
| ["9 8 7 6 5 4 3"], | |
| ["1 1 2 2 3 3"], | |
| ["0 1 2 3 4 5 6 7 8 9"], | |
| ], | |
| inputs=inp, | |
| ) | |
| with gr.Column(scale=2): | |
| attn_plot = gr.Plot(label="Cross-Attention Heatmap") | |
| attn_info = gr.Markdown() | |
| btn.click(run_inference, inputs=inp, outputs=[out_text, attn_plot, attn_info]) | |
| with gr.Tab("Positional Encoding"): | |
| gr.Markdown( | |
| """ | |
| ### Positional Encoding μκ°ν | |
| λ Όλ¬Έ Β§3.5μ sin/cos μμΉ μΈμ½λ©μ μ§μ κ·Έλ¦° κ²μ λλ€. | |
| κ°λ‘μΆμ΄ μλ² λ© μ°¨μ, μΈλ‘μΆμ΄ μμΉ(μκ° μμ)μμ. | |
| μ§μ μ°¨μμ sin, νμ μ°¨μμ cosλ‘ μ±μμ§λ©°, μ°¨μμ΄ ν΄μλ‘ μ£ΌκΈ°κ° κΈΈμ΄μ§λλ€. | |
| λλΆμ λͺ¨λΈμ΄ **μλ μμΉ**λ₯Ό μ ν λ³νμΌλ‘ ννν μ μκ² λ©λλ€. | |
| """ | |
| ) | |
| gr.Plot(value=PE_FIG, label="PE matrix") | |
| with gr.Tab("μ΄ λ°λͺ¨μ λν΄"): | |
| gr.Markdown( | |
| """ | |
| ### μ "μ«μ λ€μ§κΈ°"μΈκ°μ? | |
| λ²μ κ°μ μ§μ§ νμ€ν¬λ κ±°λν λ°μ΄ν°Β·μ°μ°μ μꡬν΄μ λ¬΄λ£ Spaceμμ λΆμ ν©ν©λλ€. | |
| λμ **μ«μ λ€μ§κΈ°**λ: | |
| 1. μμ λͺ¨λΈ(8λ§ νλΌλ―Έν°)μ΄ 1~2λΆ λ΄ νμ΅ κ°λ₯ | |
| 2. μ μΆλ ₯μ΄ λͺ νν΄μ μ λ΅ μ¬λΆλ₯Ό μ¦μ νλ¨ κ°λ₯ | |
| 3. **cross-attentionμ΄ anti-diagonal ν¨ν΄**μ κ·Έλ € μκ°ν ν¨κ³Όκ° νΌ | |
| 4. μΈλΆ λ°μ΄ν° λΆνμ (λ°νμ μμ±) | |
| ### λͺ¨λΈ ꡬ쑰 | |
| ``` | |
| VOCAB(13) β Embedding(64) + PE | |
| β 2Γ EncoderLayer (h=4, d_ff=128) | |
| β 2Γ DecoderLayer (h=4, d_ff=128) | |
| β Linear β 13κ° ν ν° logits | |
| ``` | |
| λ Όλ¬Έμ d_model=512, N=6, h=8, d_ff=2048μ΄μ§λ§, μ΄ λ°λͺ¨λ κ·Έ ν¬κΈ°μ 1/8 μμ€μ λλ€. | |
| ꡬ쑰λ μμ ν λμΌνκ³ , ν¬κΈ°λ§ μ€μμ΄μ. | |
| ### μΆλ‘ λ°©μ | |
| **Greedy decoding**: λ§€ μμ κ°μ₯ νλ₯ λμ ν ν°μ μ ν. Beam search κ°μ | |
| κ³ κΈ λμ½λ©μ μλ΅νμ΅λλ€. | |
| ### νκ³ | |
| - μλ¦Ώμκ° κΈΈμ΄μ§μλ‘(>10) μ νλ νλ½ | |
| - νμ΅ μ λ³΄μ§ λͺ»ν ν¨ν΄(λ°λ³΅, λ§€μ° κΈ΄ μνμ€)μ μ·¨μ½ | |
| - μ§μ§ NMTκ° μλλ―λ‘ μΌλ° μμ°μ΄λ μ²λ¦¬ λΆκ° | |
| ### μ°Έκ³ | |
| - λ Όλ¬Έ: [Attention Is All You Need](https://arxiv.org/abs/1706.03762) | |
| - The Annotated Transformer: [http://nlp.seas.harvard.edu/annotated-transformer/](http://nlp.seas.harvard.edu/annotated-transformer/) | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |