vit-image-captioning / transformer_decoder.py
mostafahagali's picture
Upload 9 files
601cad6 verified
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
"""Positional encoding module."""
def __init__(self, d_model, max_len=5000):
super().__init__()
# Create positional encodings
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: Tensor of shape (batch_size, seq_len, d_model)
"""
return x + self.pe[:, :x.size(1), :]
class DecoderBlock(nn.Module):
def __init__(self, d_model, num_heads, dim_ff, dropout=0.2):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.cross_attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, memory, tgt_mask,tgt_key_padding_mask):
# x: (B, L, D)
# memory: (B, N, D)
# 1) Self-attention
attn_out, _ = self.self_attn(
x, x, x, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask
)
x = self.norm1(x + self.dropout(attn_out))
# 2) Cross-attention
attn_out, _ = self.cross_attn(
x, memory, memory
)
x = self.norm2(x + self.dropout(attn_out))
# 3) FFN
ffn_out = self.ffn(x)
x = self.norm3(x + self.dropout(ffn_out))
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
vocab_size,
pad_id,
d_model=512,
num_layers=6,
num_heads=8,
dim_ff=2048,
max_len=25,
dropout=0.1
):
super().__init__()
self.pad_id = pad_id
self.d_model = d_model
self.max_len = max_len
# 2. Text Embedding & Positional Encoding
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len=self.max_len) # For text
self.layers = nn.ModuleList([
DecoderBlock(d_model, num_heads, dim_ff, dropout)
for _ in range(num_layers)
])
self.fc_out = nn.Linear(d_model, vocab_size)
self.dropout = nn.Dropout(dropout)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize weights."""
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc_out.bias.data.zero_()
self.fc_out.weight.data.uniform_(-initrange, initrange)
def generate_square_subsequent_mask(self, sz):
"""Generate causal mask for decoder."""
return torch.triu(torch.ones(sz, sz), diagonal=1).bool()
def forward(self, captions, img_features, tgt_mask=None, tgt_padding_mask=None):
"""
captions: (B, L)
memory: (B, N, D)
"""
B, L = captions.shape
device = captions.device
src = img_features
# 2. Prepare Caption Embedding (Target)
tgt = self.dropout(self.pos_encoder(self.embedding(captions) * math.sqrt(self.d_model)))
# Generate target mask if not provided (Mask future tokens)
if tgt_mask is None:
tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
tgt_key_padding_mask = (captions == self.pad_id)
for layer in self.layers:
tgt = layer(tgt, src, tgt_mask, tgt_key_padding_mask)
logits = self.fc_out(tgt)
return logits