""" Transformer 모델 구현 — Vaswani et al. (2017) "Attention Is All You Need" ResNet 프로젝트와 동일한 철학으로, 논문을 처음부터 끝까지 재현합니다. 시각화를 위해 각 attention 모듈이 마지막 attention 가중치를 보관하도록 했습니다. """ import math import torch import torch.nn as nn import torch.nn.functional as F # ───────────────────────────────────────────────────────────── # 1) Scaled Dot-Product Attention (논문 §3.2.1, 식 1) # ───────────────────────────────────────────────────────────── def scaled_dot_product_attention(Q, K, V, mask=None, return_attn=False): """ Attention(Q, K, V) = softmax(QKᵀ / √d_k) V """ d_k = Q.size(-1) scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn = F.softmax(scores, dim=-1) out = attn @ V if return_attn: return out, attn return out # ───────────────────────────────────────────────────────────── # 2) Multi-Head Attention (논문 §3.2.2, Figure 2 오른쪽) # ───────────────────────────────────────────────────────────── class MultiHeadAttention(nn.Module): def __init__(self, d_model=512, h=8): super().__init__() assert d_model % h == 0 self.h = h self.d_k = d_model // h self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) # 시각화용: 마지막 forward의 attention 가중치 (B, h, seq_q, seq_k) self.last_attn = None def forward(self, Q, K, V, mask=None): B = Q.size(0) Q = self.W_q(Q).view(B, -1, self.h, self.d_k).transpose(1, 2) K = self.W_k(K).view(B, -1, self.h, self.d_k).transpose(1, 2) V = self.W_v(V).view(B, -1, self.h, self.d_k).transpose(1, 2) out, attn = scaled_dot_product_attention(Q, K, V, mask, return_attn=True) self.last_attn = attn.detach() out = out.transpose(1, 2).contiguous().view(B, -1, self.h * self.d_k) return self.W_o(out) # ───────────────────────────────────────────────────────────── # 3) Positional Encoding (논문 §3.5) # ───────────────────────────────────────────────────────────── class PositionalEncoding(nn.Module): def __init__(self, d_model=512, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) pos = torch.arange(0, max_len).unsqueeze(1).float() div = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer("pe", pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)] # ───────────────────────────────────────────────────────────── # 4) Position-wise Feed-Forward (논문 §3.3, 식 2) # ───────────────────────────────────────────────────────────── class FeedForward(nn.Module): def __init__(self, d_model=512, d_ff=2048): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model), ) def forward(self, x): return self.net(x) # ───────────────────────────────────────────────────────────── # 5) Encoder Layer (논문 §3.1) # ───────────────────────────────────────────────────────────── class EncoderLayer(nn.Module): def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1): super().__init__() self.attn = MultiHeadAttention(d_model, h) self.ffn = FeedForward(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): x = self.norm1(x + self.dropout(self.attn(x, x, x, mask))) x = self.norm2(x + self.dropout(self.ffn(x))) return x # ───────────────────────────────────────────────────────────── # 6) Decoder Layer (논문 §3.1) # ───────────────────────────────────────────────────────────── class DecoderLayer(nn.Module): def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, h) # masked self.cross_attn = MultiHeadAttention(d_model, h) # enc-dec self.ffn = FeedForward(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, enc_out, src_mask=None, tgt_mask=None): x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask))) x = self.norm2(x + self.dropout(self.cross_attn(x, enc_out, enc_out, src_mask))) x = self.norm3(x + self.dropout(self.ffn(x))) return x # ───────────────────────────────────────────────────────────── # 7) 전체 Transformer # ───────────────────────────────────────────────────────────── class Transformer(nn.Module): def __init__(self, src_vocab, tgt_vocab, d_model=512, N=6, h=8, d_ff=2048, dropout=0.1, max_len=5000): super().__init__() self.d_model = d_model self.src_embed = nn.Embedding(src_vocab, d_model) self.tgt_embed = nn.Embedding(tgt_vocab, d_model) self.pe = PositionalEncoding(d_model, max_len) self.encoder = nn.ModuleList([ EncoderLayer(d_model, h, d_ff, dropout) for _ in range(N) ]) self.decoder = nn.ModuleList([ DecoderLayer(d_model, h, d_ff, dropout) for _ in range(N) ]) self.out = nn.Linear(d_model, tgt_vocab) def encode(self, src, src_mask=None): e = self.pe(self.src_embed(src) * math.sqrt(self.d_model)) for layer in self.encoder: e = layer(e, src_mask) return e def decode(self, tgt, enc_out, src_mask=None, tgt_mask=None): d = self.pe(self.tgt_embed(tgt) * math.sqrt(self.d_model)) for layer in self.decoder: d = layer(d, enc_out, src_mask, tgt_mask) return d def forward(self, src, tgt, src_mask=None, tgt_mask=None): enc_out = self.encode(src, src_mask) dec_out = self.decode(tgt, enc_out, src_mask, tgt_mask) return self.out(dec_out) # ── 시각화용 헬퍼 ───────────────────────────────────────── def get_decoder_cross_attn(self, layer_idx=-1): """마지막 forward의 디코더 cross-attention 가중치를 반환. Returns: (B, h, tgt_len, src_len) """ return self.decoder[layer_idx].cross_attn.last_attn def get_encoder_self_attn(self, layer_idx=-1): """마지막 forward의 인코더 self-attention 가중치를 반환. Returns: (B, h, src_len, src_len) """ return self.encoder[layer_idx].attn.last_attn