| import math
|
| import torch
|
| import torch.nn as nn
|
|
|
| from transformer_decoder import PositionalEncoding
|
|
|
|
|
| class EncoderBlock(nn.Module):
|
| def __init__(self, d_model, num_heads, dim_ff, dropout=0.1):
|
| super().__init__()
|
|
|
| self.self_attn = nn.MultiheadAttention(
|
| d_model, num_heads, dropout=dropout, batch_first=True
|
| )
|
|
|
| self.ffn = nn.Sequential(
|
| nn.Linear(d_model, dim_ff),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(dim_ff, d_model),
|
| )
|
|
|
| self.norm1 = nn.LayerNorm(d_model)
|
| self.norm2 = nn.LayerNorm(d_model)
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| def forward(self, x, src_key_padding_mask=None):
|
| attn_out, _ = self.self_attn(
|
| x, x, x, key_padding_mask=src_key_padding_mask
|
| )
|
| x = self.norm1(x + self.dropout(attn_out))
|
|
|
| ffn_out = self.ffn(x)
|
| x = self.norm2(x + self.dropout(ffn_out))
|
|
|
| return x
|
|
|
|
|
| class TransformerEncoder(nn.Module):
|
| def __init__(
|
| self,
|
| d_model=512,
|
| num_layers=6,
|
| num_heads=8,
|
| dim_ff=2048,
|
| max_len=200,
|
| dropout=0.1,
|
| use_vit=False
|
| ):
|
| super().__init__()
|
| self.d_model = d_model
|
| self.use_vit = use_vit
|
|
|
| self.pos_encoder = PositionalEncoding(d_model, max_len=49)
|
|
|
| self.layers = nn.ModuleList(
|
| [
|
| EncoderBlock(d_model, num_heads, dim_ff, dropout)
|
| for _ in range(num_layers)
|
| ]
|
| )
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| def forward(self, img_features, src_key_padding_mask=None):
|
|
|
| if self.use_vit:
|
| src = self.dropout(img_features * math.sqrt(self.d_model))
|
| else:
|
| src = self.dropout(self.pos_encoder(img_features * math.sqrt(self.d_model)))
|
|
|
| for layer in self.layers:
|
| src = layer(src, src_key_padding_mask)
|
| return src
|
|
|