import torch import torch.nn as nn class GPTDecoder(nn.Module): def __init__(self, nclass=16135, d_model=1024, max_pos=512, nhead=16, dim_feedforward=2048, vocab_size=16135, nhead_emb=8): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_embedding = nn.Embedding(max_pos, d_model) self.Heads = nn.ParameterList([]) for _ in range(nhead_emb): self.Heads.append(nn.Parameter(torch.randn(d_model, d_model//nhead_emb))) self.LayerNorm_heads = nn.LayerNorm(d_model) self.GPT_MHA = nn.MultiheadAttention(d_model, nhead, batch_first=True) self.LayerNorm_mha = nn.LayerNorm(d_model) self.FFN_Layer1 = nn.Linear(d_model, dim_feedforward) self.LeakyReLU = nn.LeakyReLU(0.1) self.FFN_Layer2 = nn.Linear(dim_feedforward, d_model) self.LayerNorm_ffn = nn.LayerNorm(d_model) self.Linear = nn.Linear(d_model, nclass) def forward(self, x): x = self.embedding(x) x_n = torch.arange(x.size(1)).unsqueeze(0) x_n = self.pos_embedding(x_n) x = x + x_n x_origin = x x_heads = [] for head in self.Heads: nx = x@head x_heads.append(nx) x = torch.cat(x_heads, dim=2) x = x + x_origin x_origin = x x = self.LayerNorm_heads(x) mask = nn.Transformer.generate_square_subsequent_mask(x.size(1)) x, _ = self.GPT_MHA(x,x,x, attn_mask=mask) x = x+x_origin x_origin = x x = self.LayerNorm_mha(x) x = self.FFN_Layer1(x) x = self.LeakyReLU(x) x = self.FFN_Layer2(x) x = x+x_origin x = self.LayerNorm_ffn(x) x = self.Linear(x) return x