import math import torch import torch.nn as nn from transformer_decoder import PositionalEncoding class EncoderBlock(nn.Module): def __init__(self, d_model, num_heads, dim_ff, dropout=0.1): super().__init__() self.self_attn = nn.MultiheadAttention( d_model, num_heads, dropout=dropout, batch_first=True ) self.ffn = nn.Sequential( nn.Linear(d_model, dim_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_ff, d_model), ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, src_key_padding_mask=None): attn_out, _ = self.self_attn( x, x, x, key_padding_mask=src_key_padding_mask ) x = self.norm1(x + self.dropout(attn_out)) ffn_out = self.ffn(x) x = self.norm2(x + self.dropout(ffn_out)) return x class TransformerEncoder(nn.Module): def __init__( self, d_model=512, num_layers=6, num_heads=8, dim_ff=2048, max_len=200, dropout=0.1, use_vit=False ): super().__init__() self.d_model = d_model self.use_vit = use_vit self.pos_encoder = PositionalEncoding(d_model, max_len=49) # ResNet outputs 7x7=49 patches self.layers = nn.ModuleList( [ EncoderBlock(d_model, num_heads, dim_ff, dropout) for _ in range(num_layers) ] ) self.dropout = nn.Dropout(dropout) def forward(self, img_features, src_key_padding_mask=None): if self.use_vit: src = self.dropout(img_features * math.sqrt(self.d_model)) # ViT already has positional embeddings, so we skip adding them again else: src = self.dropout(self.pos_encoder(img_features * math.sqrt(self.d_model))) for layer in self.layers: src = layer(src, src_key_padding_mask) return src