File size: 4,457 Bytes
601cad6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
"""Positional encoding module."""
def __init__(self, d_model, max_len=5000):
super().__init__()
# Create positional encodings
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: Tensor of shape (batch_size, seq_len, d_model)
"""
return x + self.pe[:, :x.size(1), :]
class DecoderBlock(nn.Module):
def __init__(self, d_model, num_heads, dim_ff, dropout=0.2):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.cross_attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, memory, tgt_mask,tgt_key_padding_mask):
# x: (B, L, D)
# memory: (B, N, D)
# 1) Self-attention
attn_out, _ = self.self_attn(
x, x, x, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask
)
x = self.norm1(x + self.dropout(attn_out))
# 2) Cross-attention
attn_out, _ = self.cross_attn(
x, memory, memory
)
x = self.norm2(x + self.dropout(attn_out))
# 3) FFN
ffn_out = self.ffn(x)
x = self.norm3(x + self.dropout(ffn_out))
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
vocab_size,
pad_id,
d_model=512,
num_layers=6,
num_heads=8,
dim_ff=2048,
max_len=25,
dropout=0.1
):
super().__init__()
self.pad_id = pad_id
self.d_model = d_model
self.max_len = max_len
# 2. Text Embedding & Positional Encoding
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len=self.max_len) # For text
self.layers = nn.ModuleList([
DecoderBlock(d_model, num_heads, dim_ff, dropout)
for _ in range(num_layers)
])
self.fc_out = nn.Linear(d_model, vocab_size)
self.dropout = nn.Dropout(dropout)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize weights."""
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc_out.bias.data.zero_()
self.fc_out.weight.data.uniform_(-initrange, initrange)
def generate_square_subsequent_mask(self, sz):
"""Generate causal mask for decoder."""
return torch.triu(torch.ones(sz, sz), diagonal=1).bool()
def forward(self, captions, img_features, tgt_mask=None, tgt_padding_mask=None):
"""
captions: (B, L)
memory: (B, N, D)
"""
B, L = captions.shape
device = captions.device
src = img_features
# 2. Prepare Caption Embedding (Target)
tgt = self.dropout(self.pos_encoder(self.embedding(captions) * math.sqrt(self.d_model)))
# Generate target mask if not provided (Mask future tokens)
if tgt_mask is None:
tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
tgt_key_padding_mask = (captions == self.pad_id)
for layer in self.layers:
tgt = layer(tgt, src, tgt_mask, tgt_key_padding_mask)
logits = self.fc_out(tgt)
return logits
|