import torch import torch.nn as nn class TransformerWithToken(nn.Module): def __init__(self, d_model, nhead, dim_feedforward, num_layers): super().__init__() self.token = nn.Parameter(torch.randn(1, 1, d_model)) token_mask = torch.zeros(1, 1, dtype=torch.bool) self.register_buffer('token_mask', token_mask) self.core = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, ), num_layers=num_layers) def forward(self, x, src_key_padding_mask): # x: [N, B, E] # padding_mask: [B, N] # `False` for valid values # `True` for padded values B = x.size(1) token = self.token.expand(-1, B, -1) x = torch.cat([token, x], dim=0) token_mask = self.token_mask.expand(B, -1) padding_mask = torch.cat([token_mask, src_key_padding_mask], dim=1) x = self.core(x, src_key_padding_mask=padding_mask) return x