vit-image-captioning / transformer_encoder.py
mostafahagali's picture
Upload 9 files
601cad6 verified
import math
import torch
import torch.nn as nn
from transformer_decoder import PositionalEncoding
class EncoderBlock(nn.Module):
def __init__(self, d_model, num_heads, dim_ff, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, src_key_padding_mask=None):
attn_out, _ = self.self_attn(
x, x, x, key_padding_mask=src_key_padding_mask
)
x = self.norm1(x + self.dropout(attn_out))
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
d_model=512,
num_layers=6,
num_heads=8,
dim_ff=2048,
max_len=200,
dropout=0.1,
use_vit=False
):
super().__init__()
self.d_model = d_model
self.use_vit = use_vit
self.pos_encoder = PositionalEncoding(d_model, max_len=49) # ResNet outputs 7x7=49 patches
self.layers = nn.ModuleList(
[
EncoderBlock(d_model, num_heads, dim_ff, dropout)
for _ in range(num_layers)
]
)
self.dropout = nn.Dropout(dropout)
def forward(self, img_features, src_key_padding_mask=None):
if self.use_vit:
src = self.dropout(img_features * math.sqrt(self.d_model)) # ViT already has positional embeddings, so we skip adding them again
else:
src = self.dropout(self.pos_encoder(img_features * math.sqrt(self.d_model)))
for layer in self.layers:
src = layer(src, src_key_padding_mask)
return src