MagicText-2.0-BF / archctr.py
BIGAI-models's picture
Update archctr.py
f4d5ec7 verified
Raw
History Blame Contribute Delete
1.76 kB
import torch
import torch.nn as nn
class GPTDecoder(nn.Module):
def __init__(self, nclass=16135, d_model=1024, max_pos=512, nhead=16, dim_feedforward=2048, vocab_size=16135, nhead_emb=8):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Embedding(max_pos, d_model)
self.Heads = nn.ParameterList([])
for _ in range(nhead_emb):
self.Heads.append(nn.Parameter(torch.randn(d_model, d_model//nhead_emb)))
self.LayerNorm_heads = nn.LayerNorm(d_model)
self.GPT_MHA = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.LayerNorm_mha = nn.LayerNorm(d_model)
self.FFN_Layer1 = nn.Linear(d_model, dim_feedforward)
self.LeakyReLU = nn.LeakyReLU(0.1)
self.FFN_Layer2 = nn.Linear(dim_feedforward, d_model)
self.LayerNorm_ffn = nn.LayerNorm(d_model)
self.Linear = nn.Linear(d_model, nclass)
def forward(self, x):
x = self.embedding(x)
x_n = torch.arange(x.size(1)).unsqueeze(0)
x_n = self.pos_embedding(x_n)
x = x + x_n
x_origin = x
x_heads = []
for head in self.Heads:
nx = x@head
x_heads.append(nx)
x = torch.cat(x_heads, dim=2)
x = x + x_origin
x_origin = x
x = self.LayerNorm_heads(x)
mask = nn.Transformer.generate_square_subsequent_mask(x.size(1))
x, _ = self.GPT_MHA(x,x,x, attn_mask=mask)
x = x+x_origin
x_origin = x
x = self.LayerNorm_mha(x)
x = self.FFN_Layer1(x)
x = self.LeakyReLU(x)
x = self.FFN_Layer2(x)
x = x+x_origin
x = self.LayerNorm_ffn(x)
x = self.Linear(x)
return x