Transformer / transformer.py
JangTaeng's picture
Upload 4 files
0465ac4 verified
"""
Transformer λͺ¨λΈ κ΅¬ν˜„ β€” Vaswani et al. (2017) "Attention Is All You Need"
ResNet ν”„λ‘œμ νŠΈμ™€ λ™μΌν•œ μ² ν•™μœΌλ‘œ, 논문을 μ²˜μŒλΆ€ν„° λκΉŒμ§€ μž¬ν˜„ν•©λ‹ˆλ‹€.
μ‹œκ°ν™”λ₯Ό μœ„ν•΄ 각 attention λͺ¨λ“ˆμ΄ λ§ˆμ§€λ§‰ attention κ°€μ€‘μΉ˜λ₯Ό λ³΄κ΄€ν•˜λ„λ‘ ν–ˆμŠ΅λ‹ˆλ‹€.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ─────────────────────────────────────────────────────────────
# 1) Scaled Dot-Product Attention (λ…Όλ¬Έ Β§3.2.1, 식 1)
# ─────────────────────────────────────────────────────────────
def scaled_dot_product_attention(Q, K, V, mask=None, return_attn=False):
"""
Attention(Q, K, V) = softmax(QKα΅€ / √d_k) V
"""
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=-1)
out = attn @ V
if return_attn:
return out, attn
return out
# ─────────────────────────────────────────────────────────────
# 2) Multi-Head Attention (λ…Όλ¬Έ Β§3.2.2, Figure 2 였λ₯Έμͺ½)
# ─────────────────────────────────────────────────────────────
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, h=8):
super().__init__()
assert d_model % h == 0
self.h = h
self.d_k = d_model // h
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# μ‹œκ°ν™”μš©: λ§ˆμ§€λ§‰ forward의 attention κ°€μ€‘μΉ˜ (B, h, seq_q, seq_k)
self.last_attn = None
def forward(self, Q, K, V, mask=None):
B = Q.size(0)
Q = self.W_q(Q).view(B, -1, self.h, self.d_k).transpose(1, 2)
K = self.W_k(K).view(B, -1, self.h, self.d_k).transpose(1, 2)
V = self.W_v(V).view(B, -1, self.h, self.d_k).transpose(1, 2)
out, attn = scaled_dot_product_attention(Q, K, V, mask, return_attn=True)
self.last_attn = attn.detach()
out = out.transpose(1, 2).contiguous().view(B, -1, self.h * self.d_k)
return self.W_o(out)
# ─────────────────────────────────────────────────────────────
# 3) Positional Encoding (λ…Όλ¬Έ Β§3.5)
# ─────────────────────────────────────────────────────────────
class PositionalEncoding(nn.Module):
def __init__(self, d_model=512, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1).float()
div = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
# ─────────────────────────────────────────────────────────────
# 4) Position-wise Feed-Forward (λ…Όλ¬Έ Β§3.3, 식 2)
# ─────────────────────────────────────────────────────────────
class FeedForward(nn.Module):
def __init__(self, d_model=512, d_ff=2048):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model),
)
def forward(self, x):
return self.net(x)
# ─────────────────────────────────────────────────────────────
# 5) Encoder Layer (λ…Όλ¬Έ Β§3.1)
# ─────────────────────────────────────────────────────────────
class EncoderLayer(nn.Module):
def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, h)
self.ffn = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
x = self.norm1(x + self.dropout(self.attn(x, x, x, mask)))
x = self.norm2(x + self.dropout(self.ffn(x)))
return x
# ─────────────────────────────────────────────────────────────
# 6) Decoder Layer (λ…Όλ¬Έ Β§3.1)
# ─────────────────────────────────────────────────────────────
class DecoderLayer(nn.Module):
def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, h) # masked
self.cross_attn = MultiHeadAttention(d_model, h) # enc-dec
self.ffn = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask)))
x = self.norm2(x + self.dropout(self.cross_attn(x, enc_out, enc_out, src_mask)))
x = self.norm3(x + self.dropout(self.ffn(x)))
return x
# ─────────────────────────────────────────────────────────────
# 7) 전체 Transformer
# ─────────────────────────────────────────────────────────────
class Transformer(nn.Module):
def __init__(self, src_vocab, tgt_vocab,
d_model=512, N=6, h=8, d_ff=2048, dropout=0.1, max_len=5000):
super().__init__()
self.d_model = d_model
self.src_embed = nn.Embedding(src_vocab, d_model)
self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
self.pe = PositionalEncoding(d_model, max_len)
self.encoder = nn.ModuleList([
EncoderLayer(d_model, h, d_ff, dropout) for _ in range(N)
])
self.decoder = nn.ModuleList([
DecoderLayer(d_model, h, d_ff, dropout) for _ in range(N)
])
self.out = nn.Linear(d_model, tgt_vocab)
def encode(self, src, src_mask=None):
e = self.pe(self.src_embed(src) * math.sqrt(self.d_model))
for layer in self.encoder:
e = layer(e, src_mask)
return e
def decode(self, tgt, enc_out, src_mask=None, tgt_mask=None):
d = self.pe(self.tgt_embed(tgt) * math.sqrt(self.d_model))
for layer in self.decoder:
d = layer(d, enc_out, src_mask, tgt_mask)
return d
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
enc_out = self.encode(src, src_mask)
dec_out = self.decode(tgt, enc_out, src_mask, tgt_mask)
return self.out(dec_out)
# ── μ‹œκ°ν™”μš© 헬퍼 ─────────────────────────────────────────
def get_decoder_cross_attn(self, layer_idx=-1):
"""λ§ˆμ§€λ§‰ forward의 디코더 cross-attention κ°€μ€‘μΉ˜λ₯Ό λ°˜ν™˜.
Returns: (B, h, tgt_len, src_len)
"""
return self.decoder[layer_idx].cross_attn.last_attn
def get_encoder_self_attn(self, layer_idx=-1):
"""λ§ˆμ§€λ§‰ forward의 인코더 self-attention κ°€μ€‘μΉ˜λ₯Ό λ°˜ν™˜.
Returns: (B, h, src_len, src_len)
"""
return self.encoder[layer_idx].attn.last_attn