Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class TransformerWithToken(nn.Module): | |
| def __init__(self, d_model, nhead, dim_feedforward, num_layers): | |
| super().__init__() | |
| self.token = nn.Parameter(torch.randn(1, 1, d_model)) | |
| token_mask = torch.zeros(1, 1, dtype=torch.bool) | |
| self.register_buffer('token_mask', token_mask) | |
| self.core = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer( | |
| d_model=d_model, nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| ), num_layers=num_layers) | |
| def forward(self, x, src_key_padding_mask): | |
| # x: [N, B, E] | |
| # padding_mask: [B, N] | |
| # `False` for valid values | |
| # `True` for padded values | |
| B = x.size(1) | |
| token = self.token.expand(-1, B, -1) | |
| x = torch.cat([token, x], dim=0) | |
| token_mask = self.token_mask.expand(B, -1) | |
| padding_mask = torch.cat([token_mask, src_key_padding_mask], dim=1) | |
| x = self.core(x, src_key_padding_mask=padding_mask) | |
| return x | |