koesan's picture
Initial commit: Manga Layout Generator with model
66a1d29
import torch
import torch.nn as nn
class TransformerWithToken(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward, num_layers):
super().__init__()
self.token = nn.Parameter(torch.randn(1, 1, d_model))
token_mask = torch.zeros(1, 1, dtype=torch.bool)
self.register_buffer('token_mask', token_mask)
self.core = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead,
dim_feedforward=dim_feedforward,
), num_layers=num_layers)
def forward(self, x, src_key_padding_mask):
# x: [N, B, E]
# padding_mask: [B, N]
# `False` for valid values
# `True` for padded values
B = x.size(1)
token = self.token.expand(-1, B, -1)
x = torch.cat([token, x], dim=0)
token_mask = self.token_mask.expand(B, -1)
padding_mask = torch.cat([token_mask, src_key_padding_mask], dim=1)
x = self.core(x, src_key_padding_mask=padding_mask)
return x