""" Transformer 데모 — 숫자 시퀀스 뒤집기 (digit reversal) 구성: 1) 작은 Transformer를 부팅 시 즉석에서 학습 (~30초) 2) Gradio UI에서 사용자가 입력한 숫자열을 뒤집어 출력 3) 디코더의 cross-attention을 시각화 → 멋진 anti-diagonal 패턴 이 태스크는 단순하지만 Transformer가 위치 간 상호작용을 어떻게 학습하는지 가장 직관적으로 보여줍니다. """ import os import math import numpy as np import torch import torch.nn as nn import torch.optim as optim import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import gradio as gr from transformer import Transformer # ───────────────────────────────────────────────────────────── # 토큰 정의 # ───────────────────────────────────────────────────────────── PAD, BOS, EOS = 0, 1, 2 DIGIT_OFFSET = 3 # 숫자 d → 토큰 id (d + 3) VOCAB = DIGIT_OFFSET + 10 # 0~9 까지 → 총 13개 토큰 ID2STR = {PAD: "

", BOS: "", EOS: ""} for d in range(10): ID2STR[d + DIGIT_OFFSET] = str(d) MAX_INPUT_LEN = 10 # 사용자 입력 자릿수 상한 MAX_DECODE_LEN = MAX_INPUT_LEN + 2 # ───────────────────────────────────────────────────────────── # 모델 하이퍼파라미터 (작게) # ───────────────────────────────────────────────────────────── D_MODEL = 64 N_LAYERS = 2 N_HEADS = 4 D_FF = 128 DROPOUT = 0.1 CKPT_PATH = "model.pt" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ───────────────────────────────────────────────────────────── # 토큰화 유틸 # ───────────────────────────────────────────────────────────── def digits_to_ids(digits, add_bos_eos=True): ids = [d + DIGIT_OFFSET for d in digits] if add_bos_eos: ids = [BOS] + ids + [EOS] return ids def ids_to_digits(ids, stop_at_eos=True): out = [] for i in ids: if stop_at_eos and i == EOS: break if i >= DIGIT_OFFSET: out.append(i - DIGIT_OFFSET) return out def parse_user_input(text): """사용자 입력 문자열에서 숫자 추출. 공백·콤마 등 모두 허용.""" digits = [] for ch in text: if ch.isdigit(): digits.append(int(ch)) return digits[:MAX_INPUT_LEN] # ───────────────────────────────────────────────────────────── # 마스크 만들기 # ───────────────────────────────────────────────────────────── def make_src_mask(src): # (B, S) → (B, 1, 1, S) return (src != PAD).unsqueeze(1).unsqueeze(2) def make_tgt_mask(tgt): """패딩 + causal 마스크 결합""" B, T = tgt.shape pad_mask = (tgt != PAD).unsqueeze(1).unsqueeze(2) # (B, 1, 1, T) causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool() causal = causal.unsqueeze(0).unsqueeze(0) # (1, 1, T, T) return pad_mask & causal # ───────────────────────────────────────────────────────────── # 학습 데이터 생성기 # ───────────────────────────────────────────────────────────── def make_batch(batch_size=128, min_len=3, max_len=8): src_list, tgt_list = [], [] for _ in range(batch_size): L = np.random.randint(min_len, max_len + 1) digits = np.random.randint(0, 10, size=L).tolist() src_list.append(digits_to_ids(digits)) tgt_list.append(digits_to_ids(digits[::-1])) s_max = max(len(s) for s in src_list) t_max = max(len(t) for t in tgt_list) src = torch.tensor([s + [PAD] * (s_max - len(s)) for s in src_list]) tgt = torch.tensor([t + [PAD] * (t_max - len(t)) for t in tgt_list]) return src, tgt # ───────────────────────────────────────────────────────────── # 학습 루프 # ───────────────────────────────────────────────────────────── def train(model, steps=2000, batch_size=128, lr=5e-4, log_every=200): model.train() opt = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9) loss_fn = nn.CrossEntropyLoss(ignore_index=PAD, label_smoothing=0.1) print(f"[train] device={DEVICE}, steps={steps}, batch={batch_size}") for step in range(1, steps + 1): src, tgt = make_batch(batch_size, min_len=3, max_len=MAX_INPUT_LEN) src, tgt = src.to(DEVICE), tgt.to(DEVICE) tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:] src_mask = make_src_mask(src) tgt_mask = make_tgt_mask(tgt_in) logits = model(src, tgt_in, src_mask, tgt_mask) loss = loss_fn(logits.reshape(-1, VOCAB), tgt_out.reshape(-1)) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if step % log_every == 0 or step == 1: with torch.no_grad(): pred = logits.argmax(-1) mask = (tgt_out != PAD) acc = ((pred == tgt_out) & mask).sum().item() / mask.sum().item() print(f" step {step:4d} loss={loss.item():.4f} token_acc={acc:.3f}") return model # ───────────────────────────────────────────────────────────── # 추론 (Greedy decoding) # ───────────────────────────────────────────────────────────── @torch.no_grad() def greedy_decode(model, src_ids): """src_ids: list[int] (BOS·EOS 포함)""" model.eval() src = torch.tensor([src_ids], device=DEVICE) src_mask = make_src_mask(src) enc_out = model.encode(src, src_mask) ys = torch.tensor([[BOS]], device=DEVICE) for _ in range(MAX_DECODE_LEN): tgt_mask = make_tgt_mask(ys) dec_out = model.decode(ys, enc_out, src_mask, tgt_mask) logits = model.out(dec_out) next_tok = logits[:, -1].argmax(-1, keepdim=True) ys = torch.cat([ys, next_tok], dim=1) if next_tok.item() == EOS: break return ys[0].tolist() # ───────────────────────────────────────────────────────────── # 어텐션 시각화 # ───────────────────────────────────────────────────────────── def plot_cross_attention(model, src_ids, tgt_ids, layer_idx=-1): """디코더 cross-attention을 헤드 평균으로 그림.""" attn = model.get_decoder_cross_attn(layer_idx) # (1, h, T, S) if attn is None: return None attn_avg = attn.mean(dim=1)[0].cpu().numpy() # (T, S) src_labels = [ID2STR[i] for i in src_ids] tgt_labels = [ID2STR[i] for i in tgt_ids] fig, ax = plt.subplots(figsize=(7, 6)) im = ax.imshow(attn_avg, cmap="viridis", aspect="auto", vmin=0, vmax=attn_avg.max()) ax.set_xticks(range(len(src_labels))) ax.set_xticklabels(src_labels) ax.set_yticks(range(len(tgt_labels))) ax.set_yticklabels(tgt_labels) ax.set_xlabel("Encoder positions (입력)") ax.set_ylabel("Decoder positions (출력)") ax.set_title(f"Decoder ↔ Encoder Cross-Attention\n(layer {layer_idx}, heads averaged)") plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) plt.tight_layout() return fig def plot_positional_encoding(model, length=20): """positional encoding을 직접 시각화.""" pe = model.pe.pe[0, :length].cpu().numpy() # (L, d_model) fig, ax = plt.subplots(figsize=(7, 4)) im = ax.imshow(pe, cmap="RdBu", aspect="auto", vmin=-1, vmax=1) ax.set_xlabel("Embedding dimension") ax.set_ylabel("Position") ax.set_title("Positional Encoding (sin/cos pattern)") plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) plt.tight_layout() return fig # ───────────────────────────────────────────────────────────── # 모델 준비 (학습 또는 로드) # ───────────────────────────────────────────────────────────── def build_model(): model = Transformer( src_vocab=VOCAB, tgt_vocab=VOCAB, d_model=D_MODEL, N=N_LAYERS, h=N_HEADS, d_ff=D_FF, dropout=DROPOUT, max_len=64, ).to(DEVICE) return model def load_or_train(): model = build_model() if os.path.exists(CKPT_PATH): print(f"[init] loading checkpoint: {CKPT_PATH}") state = torch.load(CKPT_PATH, map_location=DEVICE) model.load_state_dict(state) else: print("[init] no checkpoint found → training from scratch") train(model, steps=2000, batch_size=128, lr=5e-4) torch.save(model.state_dict(), CKPT_PATH) print(f"[init] saved checkpoint: {CKPT_PATH}") model.eval() return model print("=" * 60) print("Transformer Demo — initializing") print("=" * 60) MODEL = load_or_train() PE_FIG = plot_positional_encoding(MODEL, length=20) print("[init] ready ✔") # ───────────────────────────────────────────────────────────── # Gradio 콜백 # ───────────────────────────────────────────────────────────── def run_inference(user_text): digits = parse_user_input(user_text) if len(digits) == 0: return "❗ 숫자를 입력해 주세요.", None, "(어텐션 없음)" if len(digits) > MAX_INPUT_LEN: digits = digits[:MAX_INPUT_LEN] src_ids = digits_to_ids(digits) out_ids = greedy_decode(MODEL, src_ids) pred_digits = ids_to_digits(out_ids[1:]) # BOS 제외, EOS까지 expected = digits[::-1] correct = pred_digits == expected pred_str = " ".join(str(d) for d in pred_digits) if pred_digits else "(빈 출력)" expected_str = " ".join(str(d) for d in expected) input_str = " ".join(str(d) for d in digits) msg = ( f"**입력** : {input_str}\n\n" f"**예측 출력** : {pred_str}\n\n" f"**정답** : {expected_str}\n\n" f"**일치 여부** : {'✅ 정답!' if correct else '❌ 오답'}" ) # 시각화를 위해 다시 forward (cross-attention 갱신용) src = torch.tensor([src_ids], device=DEVICE) tgt = torch.tensor([out_ids], device=DEVICE) src_mask = make_src_mask(src) tgt_mask = make_tgt_mask(tgt) with torch.no_grad(): MODEL(src, tgt, src_mask, tgt_mask) fig = plot_cross_attention(MODEL, src_ids, out_ids, layer_idx=-1) info = ( "이 히트맵은 디코더의 마지막 층 cross-attention 가중치(헤드 평균)입니다.\n" "각 행(출력 위치)이 어떤 입력 위치를 가장 많이 보고 있는지 나타냅니다.\n" "뒤집기 태스크에서는 **반대각선(anti-diagonal)** 패턴이 보이면 학습 성공!" ) return msg, fig, info # ───────────────────────────────────────────────────────────── # UI 구성 # ───────────────────────────────────────────────────────────── with gr.Blocks(title="Transformer Demo — Digit Reversal") as demo: gr.Markdown( """ # 🤖 Transformer 데모 — 숫자 시퀀스 뒤집기 Vaswani et al. (2017) **"Attention Is All You Need"** 논문을 처음부터 재현한 Transformer로 입력 숫자열을 뒤집어 봅니다. - 모델: d_model=64, N=2층, h=4헤드 (총 ~80K 파라미터) - 학습 데이터: 길이 3~10의 무작위 숫자열, 매 step 새로 생성 - 학습 시간: 부팅 시 ~30초 (CPU 기준) 하단 헤도맵에서 **반대각선 패턴**이 보인다면, 모델이 "출력 i번째 = 입력의 반대편 위치"를 학습했다는 증거예요. """ ) with gr.Tab("뒤집기"): with gr.Row(): with gr.Column(scale=1): inp = gr.Textbox( label="숫자열 입력 (최대 10자리)", placeholder="예: 1 2 3 4 5 또는 12345", value="1 2 3 4 5 6 7", ) btn = gr.Button("뒤집기 실행", variant="primary") out_text = gr.Markdown() gr.Examples( examples=[ ["1 2 3"], ["1 2 3 4 5"], ["9 8 7 6 5 4 3"], ["1 1 2 2 3 3"], ["0 1 2 3 4 5 6 7 8 9"], ], inputs=inp, ) with gr.Column(scale=2): attn_plot = gr.Plot(label="Cross-Attention Heatmap") attn_info = gr.Markdown() btn.click(run_inference, inputs=inp, outputs=[out_text, attn_plot, attn_info]) with gr.Tab("Positional Encoding"): gr.Markdown( """ ### Positional Encoding 시각화 논문 §3.5의 sin/cos 위치 인코딩을 직접 그린 것입니다. 가로축이 임베딩 차원, 세로축이 위치(시간 순서)예요. 짝수 차원은 sin, 홀수 차원은 cos로 채워지며, 차원이 클수록 주기가 길어집니다. 덕분에 모델이 **상대 위치**를 선형 변환으로 표현할 수 있게 됩니다. """ ) gr.Plot(value=PE_FIG, label="PE matrix") with gr.Tab("이 데모에 대해"): gr.Markdown( """ ### 왜 "숫자 뒤집기"인가요? 번역 같은 진짜 태스크는 거대한 데이터·연산을 요구해서 무료 Space에선 부적합합니다. 대신 **숫자 뒤집기**는: 1. 작은 모델(8만 파라미터)이 1~2분 내 학습 가능 2. 입출력이 명확해서 정답 여부를 즉시 판단 가능 3. **cross-attention이 anti-diagonal 패턴**을 그려 시각화 효과가 큼 4. 외부 데이터 불필요 (런타임 생성) ### 모델 구조 ``` VOCAB(13) → Embedding(64) + PE → 2× EncoderLayer (h=4, d_ff=128) → 2× DecoderLayer (h=4, d_ff=128) → Linear → 13개 토큰 logits ``` 논문은 d_model=512, N=6, h=8, d_ff=2048이지만, 이 데모는 그 크기의 1/8 수준입니다. 구조는 완전히 동일하고, 크기만 줄였어요. ### 추론 방식 **Greedy decoding**: 매 시점 가장 확률 높은 토큰을 선택. Beam search 같은 고급 디코딩은 생략했습니다. ### 한계 - 자릿수가 길어질수록(>10) 정확도 하락 - 학습 시 보지 못한 패턴(반복, 매우 긴 시퀀스)에 취약 - 진짜 NMT가 아니므로 일반 자연어는 처리 불가 ### 참고 - 논문: [Attention Is All You Need](https://arxiv.org/abs/1706.03762) - The Annotated Transformer: [http://nlp.seas.harvard.edu/annotated-transformer/](http://nlp.seas.harvard.edu/annotated-transformer/) """ ) if __name__ == "__main__": demo.launch()