Spaces:
Runtime error
Runtime error
| # BERT architecture for the Masked Bidirectional Encoder Transformer | |
| import torch | |
| from torch import nn | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(self.norm(x), **kwargs) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, hidden_dim, bias=True), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, dim, bias=True), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Attention(nn.Module): | |
| def __init__(self, embed_dim, num_heads, dropout=0.): | |
| super(Attention, self).__init__() | |
| self.dim = embed_dim | |
| self.mha = nn.MultiheadAttention(embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True, bias=True) | |
| def forward(self, x): | |
| attention_value, attention_weight = self.mha(x, x, x) | |
| return attention_value, attention_weight | |
| class TransformerEncoder(nn.Module): | |
| def __init__(self, dim, depth, heads, mlp_dim, dropout=0.): | |
| super().__init__() | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append(nn.ModuleList([ | |
| PreNorm(dim, Attention(dim, heads, dropout=dropout)), | |
| PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) | |
| ])) | |
| def forward(self, x): | |
| l_attn = [] | |
| for attn, ff in self.layers: | |
| attention_value, attention_weight = attn(x) | |
| x = attention_value + x | |
| x = ff(x) + x | |
| l_attn.append(attention_weight) | |
| return x, l_attn | |
| class MaskTransformer(nn.Module): | |
| def __init__(self, img_size=256, hidden_dim=768, codebook_size=1024, depth=24, heads=8, mlp_dim=3072, dropout=0.1, nclass=1000): | |
| super().__init__() | |
| self.nclass = nclass | |
| self.patch_size = img_size // 16 | |
| self.codebook_size = codebook_size | |
| self.tok_emb = nn.Embedding(codebook_size+1+nclass+1, hidden_dim) # +1 for the mask of the viz token, +1 for mask of the class | |
| # self.msk_emb = nn.Embedding(2, hidden_dim) | |
| self.pos_emb = nn.init.trunc_normal_(nn.Parameter(torch.zeros(1, (self.patch_size*self.patch_size)+1, hidden_dim)), 0., 0.02) | |
| self.first_layer = nn.Sequential( | |
| nn.LayerNorm(hidden_dim, eps=1e-12), | |
| nn.Dropout(p=dropout), | |
| nn.Linear(in_features=hidden_dim, out_features=hidden_dim), | |
| nn.GELU(), | |
| nn.LayerNorm(hidden_dim, eps=1e-12), | |
| nn.Dropout(p=dropout), | |
| nn.Linear(in_features=hidden_dim, out_features=hidden_dim), | |
| ) | |
| self.transformer = TransformerEncoder(dim=hidden_dim, depth=depth, heads=heads, mlp_dim=mlp_dim, dropout=dropout) | |
| self.last_layer = nn.Sequential( | |
| nn.LayerNorm(hidden_dim, eps=1e-12), | |
| nn.Dropout(p=dropout), | |
| nn.Linear(in_features=hidden_dim, out_features=hidden_dim), | |
| nn.GELU(), | |
| nn.LayerNorm(hidden_dim, eps=1e-12), | |
| ) | |
| self.bias = nn.Parameter(torch.zeros((self.patch_size*self.patch_size)+1, codebook_size+1+nclass+1)) | |
| def forward(self, img_token, y=None, drop_label=None, return_attn=False): # , masking_flag=None): | |
| b, w, h = img_token.size() | |
| cls_token = y.view(b, -1) + self.codebook_size + 1 | |
| cls_token[drop_label] = self.codebook_size + 1 + self.nclass | |
| input = torch.cat([img_token.view(b, -1), cls_token.view(b, -1)], -1) | |
| tok_embeddings = self.tok_emb(input) | |
| pos_embeddings = self.pos_emb | |
| x = tok_embeddings + pos_embeddings | |
| # if masking_flag is not None: | |
| # flag = torch.cat([masking_flag.view(b, -1), torch.zeros_like(cls_token.view(b, -1))], -1) | |
| # x += self.msk_emb(flag) | |
| x = self.first_layer(x) | |
| x, attn = self.transformer(x) | |
| x = self.last_layer(x) | |
| logit = torch.matmul(x, self.tok_emb.weight.T) + self.bias | |
| if return_attn: | |
| return logit[:, :self.patch_size * self.patch_size, :self.codebook_size + 1], attn | |
| return logit[:, :self.patch_size*self.patch_size, :self.codebook_size+1] | |