import torch import torch.nn as nn import math class PositionalEncoding(nn.Module): """Positional encoding module.""" def __init__(self, d_model, max_len=5000): super().__init__() # Create positional encodings pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): """ Args: x: Tensor of shape (batch_size, seq_len, d_model) """ return x + self.pe[:, :x.size(1), :] class DecoderBlock(nn.Module): def __init__(self, d_model, num_heads, dim_ff, dropout=0.2): super().__init__() self.self_attn = nn.MultiheadAttention( d_model, num_heads, dropout=dropout, batch_first=True ) self.cross_attn = nn.MultiheadAttention( d_model, num_heads, dropout=dropout, batch_first=True ) self.ffn = nn.Sequential( nn.Linear(d_model, dim_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_ff, d_model), ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, memory, tgt_mask,tgt_key_padding_mask): # x: (B, L, D) # memory: (B, N, D) # 1) Self-attention attn_out, _ = self.self_attn( x, x, x, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask ) x = self.norm1(x + self.dropout(attn_out)) # 2) Cross-attention attn_out, _ = self.cross_attn( x, memory, memory ) x = self.norm2(x + self.dropout(attn_out)) # 3) FFN ffn_out = self.ffn(x) x = self.norm3(x + self.dropout(ffn_out)) return x class TransformerDecoder(nn.Module): def __init__( self, vocab_size, pad_id, d_model=512, num_layers=6, num_heads=8, dim_ff=2048, max_len=25, dropout=0.1 ): super().__init__() self.pad_id = pad_id self.d_model = d_model self.max_len = max_len # 2. Text Embedding & Positional Encoding self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoder = PositionalEncoding(d_model, max_len=self.max_len) # For text self.layers = nn.ModuleList([ DecoderBlock(d_model, num_heads, dim_ff, dropout) for _ in range(num_layers) ]) self.fc_out = nn.Linear(d_model, vocab_size) self.dropout = nn.Dropout(dropout) # Initialize weights self._init_weights() def _init_weights(self): """Initialize weights.""" initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc_out.bias.data.zero_() self.fc_out.weight.data.uniform_(-initrange, initrange) def generate_square_subsequent_mask(self, sz): """Generate causal mask for decoder.""" return torch.triu(torch.ones(sz, sz), diagonal=1).bool() def forward(self, captions, img_features, tgt_mask=None, tgt_padding_mask=None): """ captions: (B, L) memory: (B, N, D) """ B, L = captions.shape device = captions.device src = img_features # 2. Prepare Caption Embedding (Target) tgt = self.dropout(self.pos_encoder(self.embedding(captions) * math.sqrt(self.d_model))) # Generate target mask if not provided (Mask future tokens) if tgt_mask is None: tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) tgt_key_padding_mask = (captions == self.pad_id) for layer in self.layers: tgt = layer(tgt, src, tgt_mask, tgt_key_padding_mask) logits = self.fc_out(tgt) return logits