Spaces:
Sleeping
Sleeping
| """ | |
| Transformer λͺ¨λΈ ꡬν β Vaswani et al. (2017) "Attention Is All You Need" | |
| ResNet νλ‘μ νΈμ λμΌν μ² νμΌλ‘, λ Όλ¬Έμ μ²μλΆν° λκΉμ§ μ¬νν©λλ€. | |
| μκ°νλ₯Ό μν΄ κ° attention λͺ¨λμ΄ λ§μ§λ§ attention κ°μ€μΉλ₯Ό 보κ΄νλλ‘ νμ΅λλ€. | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1) Scaled Dot-Product Attention (λ Όλ¬Έ Β§3.2.1, μ 1) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def scaled_dot_product_attention(Q, K, V, mask=None, return_attn=False): | |
| """ | |
| Attention(Q, K, V) = softmax(QKα΅ / βd_k) V | |
| """ | |
| d_k = Q.size(-1) | |
| scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, -1e9) | |
| attn = F.softmax(scores, dim=-1) | |
| out = attn @ V | |
| if return_attn: | |
| return out, attn | |
| return out | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2) Multi-Head Attention (λ Όλ¬Έ Β§3.2.2, Figure 2 μ€λ₯Έμͺ½) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model=512, h=8): | |
| super().__init__() | |
| assert d_model % h == 0 | |
| self.h = h | |
| self.d_k = d_model // h | |
| self.W_q = nn.Linear(d_model, d_model) | |
| self.W_k = nn.Linear(d_model, d_model) | |
| self.W_v = nn.Linear(d_model, d_model) | |
| self.W_o = nn.Linear(d_model, d_model) | |
| # μκ°νμ©: λ§μ§λ§ forwardμ attention κ°μ€μΉ (B, h, seq_q, seq_k) | |
| self.last_attn = None | |
| def forward(self, Q, K, V, mask=None): | |
| B = Q.size(0) | |
| Q = self.W_q(Q).view(B, -1, self.h, self.d_k).transpose(1, 2) | |
| K = self.W_k(K).view(B, -1, self.h, self.d_k).transpose(1, 2) | |
| V = self.W_v(V).view(B, -1, self.h, self.d_k).transpose(1, 2) | |
| out, attn = scaled_dot_product_attention(Q, K, V, mask, return_attn=True) | |
| self.last_attn = attn.detach() | |
| out = out.transpose(1, 2).contiguous().view(B, -1, self.h * self.d_k) | |
| return self.W_o(out) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3) Positional Encoding (λ Όλ¬Έ Β§3.5) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model=512, max_len=5000): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| pos = torch.arange(0, max_len).unsqueeze(1).float() | |
| div = torch.exp(torch.arange(0, d_model, 2).float() * | |
| -(math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(pos * div) | |
| pe[:, 1::2] = torch.cos(pos * div) | |
| self.register_buffer("pe", pe.unsqueeze(0)) | |
| def forward(self, x): | |
| return x + self.pe[:, :x.size(1)] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4) Position-wise Feed-Forward (λ Όλ¬Έ Β§3.3, μ 2) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FeedForward(nn.Module): | |
| def __init__(self, d_model=512, d_ff=2048): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(d_model, d_ff), | |
| nn.ReLU(), | |
| nn.Linear(d_ff, d_model), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5) Encoder Layer (λ Όλ¬Έ Β§3.1) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EncoderLayer(nn.Module): | |
| def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1): | |
| super().__init__() | |
| self.attn = MultiHeadAttention(d_model, h) | |
| self.ffn = FeedForward(d_model, d_ff) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| x = self.norm1(x + self.dropout(self.attn(x, x, x, mask))) | |
| x = self.norm2(x + self.dropout(self.ffn(x))) | |
| return x | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 6) Decoder Layer (λ Όλ¬Έ Β§3.1) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1): | |
| super().__init__() | |
| self.self_attn = MultiHeadAttention(d_model, h) # masked | |
| self.cross_attn = MultiHeadAttention(d_model, h) # enc-dec | |
| self.ffn = FeedForward(d_model, d_ff) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, enc_out, src_mask=None, tgt_mask=None): | |
| x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask))) | |
| x = self.norm2(x + self.dropout(self.cross_attn(x, enc_out, enc_out, src_mask))) | |
| x = self.norm3(x + self.dropout(self.ffn(x))) | |
| return x | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 7) μ 체 Transformer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Transformer(nn.Module): | |
| def __init__(self, src_vocab, tgt_vocab, | |
| d_model=512, N=6, h=8, d_ff=2048, dropout=0.1, max_len=5000): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.src_embed = nn.Embedding(src_vocab, d_model) | |
| self.tgt_embed = nn.Embedding(tgt_vocab, d_model) | |
| self.pe = PositionalEncoding(d_model, max_len) | |
| self.encoder = nn.ModuleList([ | |
| EncoderLayer(d_model, h, d_ff, dropout) for _ in range(N) | |
| ]) | |
| self.decoder = nn.ModuleList([ | |
| DecoderLayer(d_model, h, d_ff, dropout) for _ in range(N) | |
| ]) | |
| self.out = nn.Linear(d_model, tgt_vocab) | |
| def encode(self, src, src_mask=None): | |
| e = self.pe(self.src_embed(src) * math.sqrt(self.d_model)) | |
| for layer in self.encoder: | |
| e = layer(e, src_mask) | |
| return e | |
| def decode(self, tgt, enc_out, src_mask=None, tgt_mask=None): | |
| d = self.pe(self.tgt_embed(tgt) * math.sqrt(self.d_model)) | |
| for layer in self.decoder: | |
| d = layer(d, enc_out, src_mask, tgt_mask) | |
| return d | |
| def forward(self, src, tgt, src_mask=None, tgt_mask=None): | |
| enc_out = self.encode(src, src_mask) | |
| dec_out = self.decode(tgt, enc_out, src_mask, tgt_mask) | |
| return self.out(dec_out) | |
| # ββ μκ°νμ© ν¬νΌ βββββββββββββββββββββββββββββββββββββββββ | |
| def get_decoder_cross_attn(self, layer_idx=-1): | |
| """λ§μ§λ§ forwardμ λμ½λ cross-attention κ°μ€μΉλ₯Ό λ°ν. | |
| Returns: (B, h, tgt_len, src_len) | |
| """ | |
| return self.decoder[layer_idx].cross_attn.last_attn | |
| def get_encoder_self_attn(self, layer_idx=-1): | |
| """λ§μ§λ§ forwardμ μΈμ½λ self-attention κ°μ€μΉλ₯Ό λ°ν. | |
| Returns: (B, h, src_len, src_len) | |
| """ | |
| return self.encoder[layer_idx].attn.last_attn | |